Skip to content

Commit

Permalink
chore: rework save_to_oci_registry to incorporate user-provided provi…
Browse files Browse the repository at this point in the history
…ders

Signed-off-by: Eric Dobroveanu <[email protected]>
  • Loading branch information
Crazyglue committed Feb 17, 2025
1 parent 2678e81 commit 4554ce2
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import os
import pathlib

from typing_extensions import overload
from typing import Optional, TypedDict, Callable, Dict
from pathlib import Path

from ._utils import required_args
from .exceptions import MissingMetadata, StoreError
Expand Down Expand Up @@ -92,13 +93,61 @@ def s3_uri_from(
# FIXME: is this safe?
return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}"


class BackendDefinition(TypedDict):
"""
Holds the 3 core callables for a backend:
- is_available() -> bool
- pull(base_image: str, dest_dir: Path) -> None
- push(local_image_path: Path, oci_ref: str) -> None
"""
available: Callable[[], bool]
pull: Callable[[str, Path], None]
push: Callable[[Path, str], None]

# A dict mapping backend names to their definitions
BackendDict = Dict[str, Callable[[], BackendDefinition]]


def get_skopeo_backend() -> BackendDefinition:
try:
from olot.backend.skopeo import is_skopeo, skopeo_pull, skopeo_push
except ImportError as e:
msg = "Could not import 'olot.backend.skopeo'. Ensure that 'olot' is installed if you want to use the 'skopeo' backend."
raise ImportError(msg) from e

return {
"is_available": is_skopeo,
"pull": skopeo_pull,
"push": skopeo_push
}

def get_oras_backend() -> BackendDefinition:
try:
from olot.backend.oras_cp import is_oras, oras_pull, oras_push
except ImportError as e:
msg = "Could not import 'olot.backend.oras_cp'. Ensure that 'olot' is installed if you want to use the 'oras_cp' backend."
raise ImportError(msg) from e

return {
"is_available": is_oras,
"pull": oras_pull,
"push": oras_push,
}

DEFAULT_BACKENDS = {
"skopeo": get_skopeo_backend,
"oras": get_oras_backend,
}

def save_to_oci_registry(
base_image: str,
dest_dir: str | os.PathLike,
oci_ref: str,
model_files: list[os.PathLike],
backend: str = "skopeo",
modelcard: os.PathLike | None = None,
backend_registry: Optional[BackendDict] = DEFAULT_BACKENDS,
):
"""Appends a list of files to an OCI-based image.
Expand Down Expand Up @@ -133,29 +182,19 @@ def save_to_oci_registry(
"""
raise StoreError(msg) from e

local_image_path = pathlib.Path(dest_dir)

if backend == "skopeo":
from olot.backend.skopeo import is_skopeo, skopeo_pull, skopeo_push

if not is_skopeo():
msg = "skopeo is selected, but it is not present on the machine. Please validate the skopeo cli is installed and available in the PATH"
raise ValueError(msg)

skopeo_pull(base_image, local_image_path)
oci_layers_on_top(local_image_path, model_files, modelcard)
skopeo_push(dest_dir, oci_ref)

elif backend == "oras":
from olot.backend.oras_cp import is_oras, oras_pull, oras_push
if not is_oras():
msg = "oras is selected, but it is not present on the machine. Please validate the oras cli is installed and available in the PATH"
raise ValueError(msg)
if backend not in backend_registry:
msg = f"'{backend}' is not an available backend to use. Available backends: {backend_registry.keys()}"
raise ValueError(msg)

# Fetching the backend definition can throw an error, but it should bubble up as it has the appropriate messaging
backend_def = backend_registry[backend]()

oras_pull(base_image, local_image_path)
oci_layers_on_top(local_image_path, model_files, modelcard)
oras_push(local_image_path, oci_ref)
if not backend_def["available"]():
msg = f"Backend '{backend}' is selected, but not available on the system. Ensure the dependencies for '{backend}' are installed in your environment."
raise ValueError(msg)

else:
msg = f"Invalid backend chosen: '{backend}'"
raise StoreError(msg)
local_image_path = Path(dest_dir)
backend_def["pull"](base_image, local_image_path)
oci_layers_on_top(local_image_path, model_files, modelcard)
backend_def["push"](local_image_path, oci_ref)

0 comments on commit 4554ce2

Please sign in to comment.