Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ repos:
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace
- id: debug-statements
exclude: tests/runners/test_model_signatures.py
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.4
Expand Down
33 changes: 21 additions & 12 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ def predict(
raise ValueError(
"Either --model_id & --user_id & --app_id or --model_url must be provided."
)

model_kwargs = {}
if compute_cluster_id or nodepool_id or deployment_id:
if (
sum(
Expand All @@ -684,15 +686,31 @@ def predict(
raise ValueError(
"Either --compute_cluster_id & --nodepool_id or --deployment_id must be provided."
)
if deployment_id:
model_kwargs = {
'deployment_id': deployment_id,
}
else:
model_kwargs = {
'compute_cluster_id': compute_cluster_id,
'nodepool_id': nodepool_id,
}

if model_url:
model = Model(url=model_url, pat=ctx.obj['pat'], base_url=ctx.obj['base_url'])
model = Model(
url=model_url,
pat=ctx.obj.current.pat,
base_url=ctx.obj.current.api_base,
**model_kwargs,
)
else:
model = Model(
model_id=model_id,
user_id=user_id,
app_id=app_id,
pat=ctx.obj['pat'],
base_url=ctx.obj['base_url'],
pat=ctx.obj.current.pat,
base_url=ctx.obj.current.api_base,
**model_kwargs,
Comment on lines 699 to +713
Copy link

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The instantiation of Model in both branches (model_url and non-model_url) uses identical parameters (pat, base_url, and model_kwargs); consider refactoring to reduce duplication.

Copilot uses AI. Check for mistakes.
)

if inference_params:
Expand All @@ -704,19 +722,13 @@ def predict(
model_prediction = model.predict_by_filepath(
filepath=file_path,
input_type=input_type,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
deployment_id=deployment_id,
inference_params=inference_params,
output_config=output_config,
)
elif url:
model_prediction = model.predict_by_url(
url=url,
input_type=input_type,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
deployment_id=deployment_id,
inference_params=inference_params,
output_config=output_config,
)
Expand All @@ -725,9 +737,6 @@ def predict(
model_prediction = model.predict_by_bytes(
input_bytes=bytes,
input_type=input_type,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
deployment_id=deployment_id,
inference_params=inference_params,
output_config=output_config,
) ## TO DO: Add support for input_id
Expand Down