diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 376081aa5..fb7c27817 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,8 @@ repos: - id: end-of-file-fixer - id: mixed-line-ending - id: trailing-whitespace + - id: debug-statements + exclude: tests/runners/test_model_signatures.py - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.11.4 diff --git a/clarifai/cli/model.py b/clarifai/cli/model.py index 9a2303b10..019234bf9 100644 --- a/clarifai/cli/model.py +++ b/clarifai/cli/model.py @@ -670,6 +670,8 @@ def predict( raise ValueError( "Either --model_id & --user_id & --app_id or --model_url must be provided." ) + + model_kwargs = {} if compute_cluster_id or nodepool_id or deployment_id: if ( sum( @@ -684,15 +686,31 @@ def predict( raise ValueError( "Either --compute_cluster_id & --nodepool_id or --deployment_id must be provided." ) + if deployment_id: + model_kwargs = { + 'deployment_id': deployment_id, + } + else: + model_kwargs = { + 'compute_cluster_id': compute_cluster_id, + 'nodepool_id': nodepool_id, + } + if model_url: - model = Model(url=model_url, pat=ctx.obj['pat'], base_url=ctx.obj['base_url']) + model = Model( + url=model_url, + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + **model_kwargs, + ) else: model = Model( model_id=model_id, user_id=user_id, app_id=app_id, - pat=ctx.obj['pat'], - base_url=ctx.obj['base_url'], + pat=ctx.obj.current.pat, + base_url=ctx.obj.current.api_base, + **model_kwargs, ) if inference_params: @@ -704,9 +722,6 @@ def predict( model_prediction = model.predict_by_filepath( filepath=file_path, input_type=input_type, - compute_cluster_id=compute_cluster_id, - nodepool_id=nodepool_id, - deployment_id=deployment_id, inference_params=inference_params, output_config=output_config, ) @@ -714,9 +729,6 @@ def predict( model_prediction = model.predict_by_url( url=url, input_type=input_type, - compute_cluster_id=compute_cluster_id, - nodepool_id=nodepool_id, - deployment_id=deployment_id, inference_params=inference_params, output_config=output_config, ) @@ -725,9 +737,6 @@ def predict( model_prediction = model.predict_by_bytes( input_bytes=bytes, input_type=input_type, - compute_cluster_id=compute_cluster_id, - nodepool_id=nodepool_id, - deployment_id=deployment_id, inference_params=inference_params, output_config=output_config, ) ## TO DO: Add support for input_id