|
| 1 | +import json |
1 | 2 | import os |
2 | 3 | import platform |
3 | 4 | import random |
@@ -145,6 +146,7 @@ def __init__(self, model: str, model_store_path: str): |
145 | 146 | self._model_type: str |
146 | 147 | self._model_name, self._model_tag, self._model_organization = self.extract_model_identifiers() |
147 | 148 | self._model_type = type(self).__name__.lower() |
| 149 | + self.artifact = False |
148 | 150 |
|
149 | 151 | self._model_store_path: str = model_store_path |
150 | 152 | self._model_store: Optional[ModelStore] = None |
@@ -200,6 +202,8 @@ def _get_entry_model_path(self, use_container: bool, should_generate: bool, dry_ |
200 | 202 |
|
201 | 203 | if self.model_type == 'oci': |
202 | 204 | if use_container or should_generate: |
| 205 | + if getattr(self, "artifact", False): |
| 206 | + return os.path.join(MNT_DIR, self.artifact_name()) |
203 | 207 | return os.path.join(MNT_DIR, 'model.file') |
204 | 208 | else: |
205 | 209 | return f"oci://{self.model}" |
@@ -345,9 +349,10 @@ def exec_model_in_container(self, cmd_args, args): |
345 | 349 | def setup_mounts(self, args): |
346 | 350 | if args.dryrun: |
347 | 351 | return |
| 352 | + |
348 | 353 | if self.model_type == 'oci': |
349 | 354 | if self.engine.use_podman: |
350 | | - mount_cmd = f"--mount=type=image,src={self.model},destination={MNT_DIR},subpath=/models,rw=false" |
| 355 | + mount_cmd = self.mount_cmd() |
351 | 356 | elif self.engine.use_docker: |
352 | 357 | output_filename = self._get_entry_model_path(args.container, True, args.dryrun) |
353 | 358 | volume = populate_volume_from_image(self, os.path.basename(output_filename)) |
@@ -651,40 +656,48 @@ def inspect( |
651 | 656 | as_json: bool = False, |
652 | 657 | dryrun: bool = False, |
653 | 658 | ) -> None: |
| 659 | + print(self.get_inspect(show_all, show_all_metadata, get_field, dryrun, as_json=as_json)) |
| 660 | + |
| 661 | + def get_inspect( |
| 662 | + self, |
| 663 | + show_all: bool = False, |
| 664 | + show_all_metadata: bool = False, |
| 665 | + get_field: str = "", |
| 666 | + dryrun: bool = False, |
| 667 | + as_json: bool = False, |
| 668 | + ) -> Any: |
654 | 669 | model_name = self.filename |
655 | 670 | model_registry = self.type.lower() |
656 | | - model_path = self._get_inspect_model_path(dryrun) |
657 | | - |
| 671 | + model_path = self._get_entry_model_path(False, False, dryrun) |
658 | 672 | if GGUFInfoParser.is_model_gguf(model_path): |
659 | 673 | if not show_all_metadata and get_field == "": |
660 | 674 | gguf_info: GGUFModelInfo = GGUFInfoParser.parse(model_name, model_registry, model_path) |
661 | | - print(gguf_info.serialize(json=as_json, all=show_all)) |
662 | | - return |
| 675 | + return gguf_info.serialize(json=as_json, all=show_all) |
663 | 676 |
|
664 | 677 | metadata = GGUFInfoParser.parse_metadata(model_path) |
665 | 678 | if show_all_metadata: |
666 | | - print(metadata.serialize(json=as_json)) |
667 | | - return |
| 679 | + return metadata.serialize(json=as_json) |
668 | 680 | elif get_field != "": # If a specific field is requested, print only that field |
669 | 681 | field_value = metadata.get(get_field) |
670 | 682 | if field_value is None: |
671 | 683 | raise KeyError(f"Field '{get_field}' not found in GGUF model metadata") |
672 | | - print(field_value) |
673 | | - return |
| 684 | + return field_value |
674 | 685 |
|
675 | 686 | if SafetensorInfoParser.is_model_safetensor(model_name): |
676 | 687 | safetensor_info: SafetensorModelInfo = SafetensorInfoParser.parse(model_name, model_registry, model_path) |
677 | | - print(safetensor_info.serialize(json=as_json, all=show_all)) |
678 | | - return |
| 688 | + return safetensor_info.serialize(json=as_json, all=show_all) |
679 | 689 |
|
680 | | - print(ModelInfoBase(model_name, model_registry, model_path).serialize(json=as_json)) |
| 690 | + return ModelInfoBase(model_name, model_registry, model_path).serialize(json=as_json) |
681 | 691 |
|
682 | | - def print_pull_message(self, model_name): |
| 692 | + def print_pull_message(self, model_name) -> None: |
683 | 693 | model_name = trim_model_name(model_name) |
684 | 694 | # Write messages to stderr |
685 | 695 | perror(f"Downloading {model_name} ...") |
686 | 696 | perror(f"Trying to pull {model_name} ...") |
687 | 697 |
|
| 698 | + def is_artifact(self) -> bool: |
| 699 | + return False |
| 700 | + |
688 | 701 |
|
689 | 702 | def compute_ports(exclude: list[str] | None = None) -> list[int]: |
690 | 703 | exclude = exclude and set(map(int, exclude)) or set() |
|
0 commit comments