diff --git a/inference/core/entities/types.py b/inference/core/entities/types.py index 16611627d7..b3d9c95f65 100644 --- a/inference/core/entities/types.py +++ b/inference/core/entities/types.py @@ -3,3 +3,4 @@ TaskType = str ModelType = str WorkspaceID = str +ModelID = str \ No newline at end of file diff --git a/inference/core/models/roboflow.py b/inference/core/models/roboflow.py index a6e9dfc8de..74cc4bd920 100644 --- a/inference/core/models/roboflow.py +++ b/inference/core/models/roboflow.py @@ -55,6 +55,7 @@ ModelEndpointType, get_from_url, get_roboflow_model_data, + get_roboflow_workspace, ) from inference.core.utils.image_utils import load_image from inference.core.utils.onnx import get_onnxruntime_execution_providers @@ -116,7 +117,13 @@ def __init__( self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} self.api_key = api_key if api_key else API_KEY model_id = resolve_roboflow_model_alias(model_id=model_id) - self.dataset_id, self.version_id = model_id.split("/") + # TODO: + # Is this really all we had to do here?, think we don't even need it? + if "/" in model_id: + self.dataset_id, self.version_id = model_id.split("/") + else: + self.model_id = model_id + # Model ID is only unique for a workspace self.endpoint = model_id self.device_id = GLOBAL_DEVICE_ID self.cache_dir = os.path.join(cache_dir_root, self.endpoint) @@ -274,10 +281,19 @@ def download_model_artifacts_from_roboflow_api(self) -> None: "Could not find `model` key in roboflow API model description response." ) if "environment" not in api_data: - raise ModelArtefactError( - "Could not find `environment` key in roboflow API model description response." - ) - environment = get_from_url(api_data["environment"]) + # Create default environment if not provided + environment = { + "PREPROCESSING": api_data.get("preprocessing", {}), + "MULTICLASS": api_data.get("multilabel", False), + #don't think we actually need this + "MODEL_NAME": api_data.get("modelName", ""), + + # ClASS_MAP might be the only other thing that we would need + # "CLASS_MAP": api_data.get("classes", {}), + } + else: + # TODO: do we need to load the environment from the url or can we safely remove? + environment = get_from_url(api_data["environment"]) model_weights_response = get_from_url(api_data["model"], json_response=False) save_bytes_in_cache( content=model_weights_response.content, @@ -308,6 +324,7 @@ def load_model_artifacts_from_cache(self) -> None: model_id=self.endpoint, object_pairs_hook=OrderedDict, ) + if "class_names.txt" in infer_bucket_files: self.class_names = load_text_file_from_cache( file="class_names.txt", diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py index b1cb05c350..43c071ac17 100644 --- a/inference/core/registries/roboflow.py +++ b/inference/core/registries/roboflow.py @@ -3,7 +3,7 @@ from inference.core.cache import cache from inference.core.devices.utils import GLOBAL_DEVICE_ID -from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID +from inference.core.entities.types import ModelType, TaskType from inference.core.env import LAMBDA, MODEL_CACHE_DIR from inference.core.exceptions import ( MissingApiKeyError, @@ -90,41 +90,51 @@ def get_model_type( """ model_id = resolve_roboflow_model_alias(model_id=model_id) dataset_id, version_id = get_model_id_chunks(model_id=model_id) + lock_key, cache_path = determine_cache_paths(dataset_or_model_id=dataset_id, version_id=version_id) + if dataset_id in GENERIC_MODELS: logger.debug(f"Loading generic model: {dataset_id}.") return GENERIC_MODELS[dataset_id] + cached_metadata = get_model_metadata_from_cache( - dataset_id=dataset_id, version_id=version_id + cache_path=cache_path, lock_key=lock_key ) if cached_metadata is not None: return cached_metadata[0], cached_metadata[1] + + + # THis path will never be executed for a model ID if version_id == STUB_VERSION_ID: if api_key is None: raise MissingApiKeyError( "Stub model version provided but no API key was provided. API key is required to load stub models." ) workspace_id = get_roboflow_workspace(api_key=api_key) + project_task_type = get_roboflow_dataset_type( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id ) model_type = "stub" save_model_metadata_in_cache( - dataset_id=dataset_id, - version_id=version_id, + cache_path=cache_path, + lock_key=lock_key, project_task_type=project_task_type, model_type=model_type, + # TODO: do we need to save the workspace_id here/for the cache path to be unique? ) return project_task_type, model_type + api_data = get_roboflow_model_data( api_key=api_key, model_id=model_id, endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, ).get("ort") + if api_data is None: raise ModelArtefactError("Error loading model artifacts from Roboflow API.") # some older projects do not have type field - hence defaulting - project_task_type = api_data.get("type", "object-detection") + project_task_type = api_data.get("taskType", "object-detection") model_type = api_data.get("modelType") if model_type is None or model_type == "ort": # some very old model versions do not have modelType reported - and API respond in a generic way - @@ -133,46 +143,46 @@ def get_model_type( if model_type is None or project_task_type is None: raise ModelArtefactError("Error loading model artifacts from Roboflow API.") save_model_metadata_in_cache( - dataset_id=dataset_id, - version_id=version_id, + cache_path=cache_path, + lock_key=lock_key, project_task_type=project_task_type, model_type=model_type, ) - return project_task_type, model_type +def determine_cache_paths(dataset_or_model_id: str, version_id: Optional[str]) -> Tuple[str, str]: + if dataset_or_model_id and version_id: + # It's a dataset/version ID + lock_key = f"lock:metadata:dataset:{dataset_or_model_id}:{version_id}" + cache_path = construct_dataset_version_cache_path(dataset_or_model_id, version_id) + else: + # It's a model ID + lock_key = f"lock:metadata:model:{dataset_or_model_id}" + cache_path = construct_model_id_cache_path(dataset_or_model_id) + + return lock_key, cache_path def get_model_metadata_from_cache( - dataset_id: str, version_id: str + cache_path: str, + lock_key: str ) -> Optional[Tuple[TaskType, ModelType]]: if LAMBDA: - return _get_model_metadata_from_cache( - dataset_id=dataset_id, version_id=version_id - ) - with cache.lock( - f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT - ): - return _get_model_metadata_from_cache( - dataset_id=dataset_id, version_id=version_id - ) + return _get_model_metadata_from_cache(cache_path=cache_path) + + with cache.lock(lock_key, expire=CACHE_METADATA_LOCK_TIMEOUT): + return _get_model_metadata_from_cache(cache_path=cache_path) - -def _get_model_metadata_from_cache( - dataset_id: str, version_id: str -) -> Optional[Tuple[TaskType, ModelType]]: - model_type_cache_path = construct_model_type_cache_path( - dataset_id=dataset_id, version_id=version_id - ) - if not os.path.isfile(model_type_cache_path): +def _get_model_metadata_from_cache(cache_path: str) -> Optional[Tuple[TaskType, ModelType]]: + if not os.path.isfile(cache_path): return None try: - model_metadata = read_json(path=model_type_cache_path) + model_metadata = read_json(path=cache_path) if model_metadata_content_is_invalid(content=model_metadata): return None return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] except ValueError as e: logger.warning( - f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." + f"Could not load model description from cache under path: {cache_path} - decoding issue: {e}." ) return None @@ -193,49 +203,44 @@ def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> b def save_model_metadata_in_cache( - dataset_id: DatasetID, - version_id: VersionID, + cache_path: str, + lock_key: str, project_task_type: TaskType, model_type: ModelType, ) -> None: if LAMBDA: _save_model_metadata_in_cache( - dataset_id=dataset_id, - version_id=version_id, + cache_path=cache_path, project_task_type=project_task_type, model_type=model_type, ) return None - with cache.lock( - f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT - ): + + with cache.lock(lock_key, expire=CACHE_METADATA_LOCK_TIMEOUT): _save_model_metadata_in_cache( - dataset_id=dataset_id, - version_id=version_id, + cache_path=cache_path, project_task_type=project_task_type, model_type=model_type, ) return None - def _save_model_metadata_in_cache( - dataset_id: DatasetID, - version_id: VersionID, + cache_path: str, project_task_type: TaskType, model_type: ModelType, ) -> None: - model_type_cache_path = construct_model_type_cache_path( - dataset_id=dataset_id, version_id=version_id - ) metadata = { PROJECT_TASK_TYPE_KEY: project_task_type, MODEL_TYPE_KEY: model_type, } dump_json( - path=model_type_cache_path, content=metadata, allow_override=True, indent=4 + path=cache_path, content=metadata, allow_override=True, indent=4 ) +def construct_model_id_cache_path(model_id: str) -> str: + """Constructs the cache path for a given model ID.""" + return os.path.join(MODEL_CACHE_DIR, "models", model_id, "model_type.json") -def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str: - cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id) - return os.path.join(cache_dir, "model_type.json") +def construct_dataset_version_cache_path(dataset_id: str, version_id: str) -> str: + """Constructs the cache path for a given dataset ID and version ID.""" + return os.path.join(MODEL_CACHE_DIR, dataset_id, version_id, "model_type.json") \ No newline at end of file diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py index 41afdbdd72..669578aac0 100644 --- a/inference/core/roboflow_api.py +++ b/inference/core/roboflow_api.py @@ -225,13 +225,17 @@ def get_roboflow_model_data( ("nocache", "true"), ("device", device_id), ("dynamic", "true"), + ("type", endpoint_type.value), + ("model", model_id), ] if api_key is not None: params.append(("api_key", api_key)) api_url = _add_params_to_url( - url=f"{API_BASE_URL}/{endpoint_type.value}/{model_id}", + url=f"{API_BASE_URL}/getWeights", params=params, ) + print("api_url", api_url) + api_data = _get_from_url(url=api_url) cache.set( api_data_cache_key, @@ -598,7 +602,9 @@ def get_from_url( def _get_from_url(url: str, json_response: bool = True) -> Union[Response, dict]: response = requests.get(wrap_url(url)) + api_key_safe_raise_for_status(response=response) + if json_response: return response.json() return response diff --git a/inference/core/utils/roboflow.py b/inference/core/utils/roboflow.py index 0fe05c2f35..b643b64762 100644 --- a/inference/core/utils/roboflow.py +++ b/inference/core/utils/roboflow.py @@ -1,11 +1,35 @@ -from typing import Tuple +from typing import Optional, Tuple, Union from inference.core.entities.types import DatasetID, VersionID from inference.core.exceptions import InvalidModelIDError -def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]: +def get_model_id_chunks( + model_id: str, +) -> Union[Tuple[DatasetID, VersionID], Tuple[str, None]]: + """Parse a model ID into its components. + + Args: + model_id (str): The model identifier, either in format "dataset/version" + or a plain string for the new model IDs + + Returns: + Union[Tuple[DatasetID, VersionID], Tuple[str, None]]: + For traditional IDs: (dataset_id, version_id) + For new string IDs: (model_id, None) + + Raises: + InvalidModelIDError: If traditional model ID format is invalid + """ + if "/" not in model_id: + # Handle new style model IDs that are just strings + return model_id, None + + # Handle traditional dataset/version model IDs model_id_chunks = model_id.split("/") if len(model_id_chunks) != 2: - raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.") + raise InvalidModelIDError( + f"Model ID: `{model_id}` is invalid. Expected format: 'dataset/version' or 'model_name'" + ) + return model_id_chunks[0], model_id_chunks[1] diff --git a/inference/models/transformers/transformers.py b/inference/models/transformers/transformers.py index 5aff8fc231..5b1119d106 100644 --- a/inference/models/transformers/transformers.py +++ b/inference/models/transformers/transformers.py @@ -40,6 +40,7 @@ get_from_url, get_roboflow_base_lora, get_roboflow_model_data, + get_roboflow_workspace, ) from inference.core.utils.image_utils import load_image_rgb diff --git a/tests/inference/unit_tests/core/registries/test_roboflow.py b/tests/inference/unit_tests/core/registries/test_roboflow.py index c7c2684e2a..33ea30fd46 100644 --- a/tests/inference/unit_tests/core/registries/test_roboflow.py +++ b/tests/inference/unit_tests/core/registries/test_roboflow.py @@ -21,98 +21,78 @@ @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_get_model_metadata_from_cache_when_metadata_file_does_not_exist( - construct_model_type_cache_path_mock: MagicMock, - empty_local_dir: str, is_lambda: bool, ) -> None: - # given - construct_model_type_cache_path_mock.return_value = os.path.join( - empty_local_dir, "model_type.json" - ) - # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): - result = get_model_metadata_from_cache(dataset_id="some", version_id="1") + result = get_model_metadata_from_cache(cache_path="model-id", lock_key="model:id") # then assert result is None @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_get_model_metadata_from_cache_when_metadata_file_is_not_json( - construct_model_type_cache_path_mock: MagicMock, empty_local_dir: str, is_lambda: bool, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path with open(metadata_path, "w") as f: f.write("FOR SURE NOT JSON :)") # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): - result = get_model_metadata_from_cache(dataset_id="some", version_id="1") + result = get_model_metadata_from_cache(cache_path=metadata_path, lock_key="model:id") # then assert result is None @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_get_model_metadata_from_cache_when_metadata_file_is_empty( - construct_model_type_cache_path_mock: MagicMock, empty_local_dir: str, is_lambda: bool, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path with open(metadata_path, "w") as f: f.write("") # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): - result = get_model_metadata_from_cache(dataset_id="some", version_id="1") + result = get_model_metadata_from_cache(cache_path=metadata_path, lock_key="model:id") # then assert result is None @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_get_model_metadata_from_cache_when_metadata_is_invalid( - construct_model_type_cache_path_mock: MagicMock, empty_local_dir: str, is_lambda: bool, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path with open(metadata_path, "w") as f: f.write(json.dumps({"some": "key"})) # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): - result = get_model_metadata_from_cache(dataset_id="some", version_id="1") + result = get_model_metadata_from_cache(cache_path=metadata_path, lock_key="model:id") # then assert result is None @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_get_model_metadata_from_cache_when_metadata_is_valid( - construct_model_type_cache_path_mock: MagicMock, empty_local_dir: str, is_lambda: bool, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path with open(metadata_path, "w") as f: f.write( json.dumps( @@ -125,7 +105,7 @@ def test_get_model_metadata_from_cache_when_metadata_is_valid( # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): - result = get_model_metadata_from_cache(dataset_id="some", version_id="1") + result = get_model_metadata_from_cache(cache_path=metadata_path, lock_key="model:id") # then assert result == ("object-detection", "yolov8n") @@ -172,21 +152,19 @@ def test_model_metadata_content_is_invalid_when_task_type_is_missing() -> None: @pytest.mark.parametrize("is_lambda", [False, True]) -@mock.patch.object(roboflow, "construct_model_type_cache_path") def test_save_model_metadata_in_cache( - construct_model_type_cache_path_mock: MagicMock, empty_local_dir: str, is_lambda: bool, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + lock_key = "lock:metadata:test" # when with mock.patch.object(roboflow, "LAMBDA", is_lambda): save_model_metadata_in_cache( - dataset_id="some", - version_id="1", + cache_path=metadata_path, + lock_key=lock_key, project_task_type="instance-segmentation", model_type="yolov8l", ) @@ -196,19 +174,11 @@ def test_save_model_metadata_in_cache( # then assert result["model_type"] == "yolov8l" assert result["project_task_type"] == "instance-segmentation" - construct_model_type_cache_path_mock.assert_called_once_with( - dataset_id="some", version_id="1" - ) -@mock.patch.object(roboflow, "construct_model_type_cache_path") -def test_get_model_type_when_cache_is_utilised( - construct_model_type_cache_path_mock: MagicMock, - empty_local_dir: str, -) -> None: +def test_get_model_type_when_cache_is_utilised(empty_local_dir: str) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path with open(metadata_path, "w") as f: f.write( json.dumps( @@ -220,12 +190,10 @@ def test_get_model_type_when_cache_is_utilised( ) # when - result = get_model_type(model_id="some/1", api_key="my_api_key") + with mock.patch.object(roboflow, "determine_cache_paths", return_value=("test_lock", metadata_path)): + result = get_model_type(model_id="some/1", api_key="my_api_key") # then - construct_model_type_cache_path_mock.assert_called_once_with( - dataset_id="some", version_id="1" - ) assert result == ("object-detection", "yolov8n") @@ -249,15 +217,23 @@ def test_get_model_type_when_generic_model_is_utilised( @mock.patch.object(roboflow, "get_roboflow_model_data") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") +@pytest.mark.parametrize( + "model_id", + [ + "some/1", + "model-id1", + ], +) def test_get_model_type_when_roboflow_api_is_called_for_specific_model( - construct_model_type_cache_path_mock: MagicMock, + get_roboflow_workspace_mock: MagicMock, get_roboflow_model_data_mock: MagicMock, empty_local_dir: str, + model_id: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + get_roboflow_workspace_mock.return_value = "my_workspace" get_roboflow_model_data_mock.return_value = { "ort": { "type": "object-detection", @@ -266,10 +242,11 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model( } # when - result = get_model_type( - model_id="some/1", - api_key="my_api_key", - ) + with mock.patch.object(roboflow, "determine_cache_paths", return_value=("test_lock", metadata_path)): + result = get_model_type( + model_id=model_id, + api_key="my_api_key", + ) # then assert result == ("object-detection", "yolov8n") @@ -279,34 +256,36 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model( assert persisted_metadata["project_task_type"] == "object-detection" get_roboflow_model_data_mock.assert_called_once_with( api_key="my_api_key", - model_id="some/1", + model_id=model_id, endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, + workspace_id="my_workspace", ) @mock.patch.object(roboflow, "get_roboflow_model_data") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_model_type_specified_as_ort( - construct_model_type_cache_path_mock: MagicMock, + get_roboflow_workspace_mock: MagicMock, get_roboflow_model_data_mock: MagicMock, empty_local_dir: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + get_roboflow_workspace_mock.return_value = "my_workspace" get_roboflow_model_data_mock.return_value = { "ort": { - "type": "object-detection", + "taskType": "object-detection", "modelType": "ort", } } # when - result = get_model_type( - model_id="some/1", - api_key="my_api_key", - ) + with mock.patch.object(roboflow, "determine_cache_paths", return_value=("test_lock", metadata_path)): + result = get_model_type( + model_id="some/1", + api_key="my_api_key", + ) # then assert result == ("object-detection", "yolov5v2s") @@ -319,24 +298,28 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_model model_id="some/1", endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, + workspace_id="my_workspace", ) +@mock.patch.object(roboflow, "determine_cache_paths") @mock.patch.object(roboflow, "get_roboflow_model_data") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_model_type_not_specified( - construct_model_type_cache_path_mock: MagicMock, + get_roboflow_workspace_mock: MagicMock, get_roboflow_model_data_mock: MagicMock, + determine_cache_paths_mock: MagicMock, empty_local_dir: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + get_roboflow_workspace_mock.return_value = "my_workspace" get_roboflow_model_data_mock.return_value = { "ort": { - "type": "object-detection", + "taskType": "object-detection", } } + determine_cache_paths_mock.return_value = ("test_lock", metadata_path) # when result = get_model_type( @@ -355,20 +338,24 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_model model_id="some/1", endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, + workspace_id="my_workspace", ) +@mock.patch.object(roboflow, "determine_cache_paths") @mock.patch.object(roboflow, "get_roboflow_model_data") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_project_type_not_specified( - construct_model_type_cache_path_mock: MagicMock, + get_roboflow_workspace_mock: MagicMock, get_roboflow_model_data_mock: MagicMock, + determine_cache_paths_mock: MagicMock, empty_local_dir: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + get_roboflow_workspace_mock.return_value = "my_workspace" get_roboflow_model_data_mock.return_value = {"ort": {}} + determine_cache_paths_mock.return_value = ("test_lock", metadata_path) # when result = get_model_type( @@ -387,25 +374,29 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model_and_proje model_id="some/1", endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, + workspace_id="my_workspace", ) +@mock.patch.object(roboflow, "determine_cache_paths") @mock.patch.object(roboflow, "get_roboflow_model_data") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") def test_get_model_type_when_roboflow_api_is_called_for_specific_model_without_api_key_for_public_model( - construct_model_type_cache_path_mock: MagicMock, + get_roboflow_workspace_mock: MagicMock, get_roboflow_model_data_mock: MagicMock, + determine_cache_paths_mock: MagicMock, empty_local_dir: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + get_roboflow_workspace_mock.return_value = "my_workspace" get_roboflow_model_data_mock.return_value = { "ort": { - "type": "object-detection", + "taskType": "object-detection", # Updated to match new API response structure "modelType": "yolov8n", } } + determine_cache_paths_mock.return_value = ("test_lock", metadata_path) # when result = get_model_type( @@ -424,21 +415,21 @@ def test_get_model_type_when_roboflow_api_is_called_for_specific_model_without_a model_id="some/1", endpoint_type=ModelEndpointType.ORT, device_id=GLOBAL_DEVICE_ID, + workspace_id="my_workspace", ) - -@mock.patch.object(roboflow, "get_roboflow_workspace") +@mock.patch.object(roboflow, "determine_cache_paths") @mock.patch.object(roboflow, "get_roboflow_dataset_type") -@mock.patch.object(roboflow, "construct_model_type_cache_path") +@mock.patch.object(roboflow, "get_roboflow_workspace") def test_get_model_type_when_roboflow_api_is_called_for_mock( - construct_model_type_cache_path_mock: MagicMock, - get_roboflow_dataset_type_mock: MagicMock, get_roboflow_workspace_mock: MagicMock, + get_roboflow_dataset_type_mock: MagicMock, + determine_cache_paths_mock: MagicMock, empty_local_dir: str, ) -> None: # given metadata_path = os.path.join(empty_local_dir, "model_type.json") - construct_model_type_cache_path_mock.return_value = metadata_path + determine_cache_paths_mock.return_value = ("test_lock", metadata_path) get_roboflow_dataset_type_mock.return_value = "object-detection" get_roboflow_workspace_mock.return_value = "my_workspace" @@ -461,8 +452,17 @@ def test_get_model_type_when_roboflow_api_is_called_for_mock( ) get_roboflow_workspace_mock.assert_called_once_with(api_key="my_api_key") +@mock.patch.object(roboflow, "get_model_id_chunks") +@mock.patch.object(roboflow, "get_roboflow_workspace") +def test_get_model_type_when_roboflow_api_is_called_for_mock_without_api_key( + get_roboflow_workspace_mock: MagicMock, + get_model_id_chunks_mock: MagicMock, +) -> None: + # given + get_model_id_chunks_mock.return_value = ("some", "0") + get_roboflow_workspace_mock.return_value = "workspace-id" -def test_get_model_type_when_roboflow_api_is_called_for_mock_without_api_key() -> None: + # when with pytest.raises(MissingApiKeyError): _ = get_model_type( model_id="some/0", diff --git a/tests/inference/unit_tests/core/utils/test_roboflow.py b/tests/inference/unit_tests/core/utils/test_roboflow.py index 791c01c986..f4376d7643 100644 --- a/tests/inference/unit_tests/core/utils/test_roboflow.py +++ b/tests/inference/unit_tests/core/utils/test_roboflow.py @@ -6,16 +6,21 @@ from inference.core.utils.roboflow import get_model_id_chunks -@pytest.mark.parametrize("value", ["some", "some/2/invalid", "another-2"]) +@pytest.mark.parametrize("value", ["contains/2/slashes", "some/model/id/with/many/slashes"]) def test_get_model_id_chunks_when_invalid_input_given(value: Any) -> None: # when with pytest.raises(InvalidModelIDError): _ = get_model_id_chunks(model_id=value) -def test_get_model_id_chunks_when_valid_input_given() -> None: +@pytest.mark.parametrize("model_id, expected", [ + ("some/1", ("some", "1")), + ("modelid123", ("modelid123", None)), + ("model-id-dashes", ("model-id-dashes", None)), +]) +def test_get_model_id_chunks_with_various_valid_inputs(model_id: str, expected: tuple) -> None: # when - result = get_model_id_chunks("some/1") + result = get_model_id_chunks(model_id) # then - assert result == ("some", "1") + assert result == expected