Skip to content
114 changes: 64 additions & 50 deletions centml/cli/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,26 @@
from centml.sdk.api import get_centml_client


# convert deployment type enum to a user friendly name
depl_type_to_name_map = {
DeploymentType.INFERENCE: 'inference',
DeploymentType.COMPUTE: 'compute',
DeploymentType.COMPILATION: 'compilation',
DeploymentType.INFERENCE_V2: 'inference',
DeploymentType.COMPUTE_V2: 'compute',
DeploymentType.CSERVE: 'cserve',
DeploymentType.CSERVE_V2: 'cserve',
DeploymentType.RAG: 'rag',
DeploymentType.INFERENCE: "inference",
DeploymentType.COMPUTE: "compute",
DeploymentType.COMPILATION: "compilation",
DeploymentType.INFERENCE_V2: "inference",
DeploymentType.INFERENCE_V3: "inference",
DeploymentType.COMPUTE_V2: "compute",
# For user, they are all cserve.
DeploymentType.CSERVE: "cserve",
DeploymentType.CSERVE_V2: "cserve",
DeploymentType.CSERVE_V3: "cserve",
DeploymentType.RAG: "rag",
}
# use latest type to for user requests
depl_name_to_type_map = {
'inference': DeploymentType.INFERENCE_V2,
'cserve': DeploymentType.CSERVE_V2,
'compute': DeploymentType.COMPUTE_V2,
'rag': DeploymentType.RAG,
"inference": DeploymentType.INFERENCE_V3,
"cserve": DeploymentType.CSERVE_V3,
"compute": DeploymentType.COMPUTE_V2,
"rag": DeploymentType.RAG,
}


Expand Down Expand Up @@ -56,6 +61,21 @@ def _format_ssh_key(ssh_key):
return ssh_key[:32] + "..."


def _get_replica_info(deployment):
"""Extract replica information handling V2/V3 field differences"""
# Check actual deployment object fields rather than depl_type
# since unified get_cserve() can return either V2 or V3 objects
if hasattr(deployment, 'min_replicas'):
# V3 deployment response object
return {"min": deployment.min_replicas, "max": deployment.max_replicas}
elif hasattr(deployment, 'min_scale'):
# V2 deployment response object
return {"min": deployment.min_scale, "max": deployment.max_scale}
else:
# Fallback - shouldn't happen
return {"min": "N/A", "max": "N/A"}


def _get_ready_status(cclient, deployment):
api_status = deployment.status
service_status = (
Expand Down Expand Up @@ -121,12 +141,12 @@ def get(type, id):
with get_centml_client() as cclient:
depl_type = depl_name_to_type_map[type]

if depl_type == DeploymentType.INFERENCE_V2:
deployment = cclient.get_inference(id)
if depl_type in [DeploymentType.INFERENCE_V2, DeploymentType.INFERENCE_V3]:
deployment = cclient.get_inference(id) # handles both V2 and V3
elif depl_type == DeploymentType.COMPUTE_V2:
deployment = cclient.get_compute(id)
elif depl_type == DeploymentType.CSERVE_V2:
deployment = cclient.get_cserve(id)
elif depl_type in [DeploymentType.CSERVE_V2, DeploymentType.CSERVE_V3]:
deployment = cclient.get_cserve(id) # handles both V2 and V3
else:
sys.exit("Please enter correct deployment type")

Expand All @@ -150,21 +170,18 @@ def get(type, id):
)

click.echo("Additional deployment configurations:")
if depl_type == DeploymentType.INFERENCE_V2:
click.echo(
tabulate(
[
("Image", deployment.image_url),
("Container port", deployment.container_port),
("Healthcheck", deployment.healthcheck or "/"),
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
("Environment variables", deployment.env_vars or "None"),
("Max concurrency", deployment.concurrency or "None"),
],
tablefmt="rounded_outline",
disable_numparse=True,
)
)
if depl_type in [DeploymentType.INFERENCE_V2, DeploymentType.INFERENCE_V3]:
replica_info = _get_replica_info(deployment)
display_rows = [
("Image", deployment.image_url),
("Container port", deployment.container_port),
("Healthcheck", deployment.healthcheck or "/"),
("Replicas", replica_info),
("Environment variables", deployment.env_vars or "None"),
("Max concurrency", deployment.concurrency or "None"),
]

click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))
elif depl_type == DeploymentType.COMPUTE_V2:
click.echo(
tabulate(
Expand All @@ -173,25 +190,22 @@ def get(type, id):
disable_numparse=True,
)
)
elif depl_type == DeploymentType.CSERVE_V2:
click.echo(
tabulate(
[
("Hugging face model", deployment.recipe.model),
(
"Parallelism",
{
"tensor": deployment.recipe.additional_properties['tensor_parallel_size'],
"pipeline": deployment.recipe.additional_properties['pipeline_parallel_size'],
},
),
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
("Max concurrency", deployment.concurrency or "None"),
],
tablefmt="rounded_outline",
disable_numparse=True,
)
)
elif depl_type in [DeploymentType.CSERVE_V2, DeploymentType.CSERVE_V3]:
replica_info = _get_replica_info(deployment)
display_rows = [
("Hugging face model", deployment.recipe.model),
(
"Parallelism",
{
"tensor": deployment.recipe.additional_properties.get("tensor_parallel_size", "N/A"),
"pipeline": deployment.recipe.additional_properties.get("pipeline_parallel_size", "N/A"),
},
),
("Replicas", replica_info),
("Max concurrency", deployment.concurrency or "None"),
]

click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))


@click.command(help="Delete a deployment")
Expand Down
168 changes: 159 additions & 9 deletions centml/sdk/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from contextlib import contextmanager
from typing import Union

import platform_api_python_client
from platform_api_python_client import (
DeploymentType,
DeploymentStatus,
CreateInferenceDeploymentRequest,
CreateInferenceV3DeploymentRequest,
CreateComputeDeploymentRequest,
CreateCServeV2DeploymentRequest,
CreateCServeV3DeploymentRequest,
ApiException,
Metric,
)

Expand All @@ -27,31 +31,163 @@ def get_status(self, id):
return self._api.get_deployment_status_deployments_status_deployment_id_get(id)

def get_inference(self, id):
return self._api.get_inference_deployment_deployments_inference_deployment_id_get(id)
"""Get Inference deployment details - automatically handles both V2 and V3 deployments"""
# Try V3 first (recommended), fallback to V2 if deployment is V2
try:
return self._api.get_inference_v3_deployment_deployments_inference_v3_deployment_id_get(id)
except ApiException as e:
# If V3 fails with 404 or similar, try V2
if e.status in [404, 400]: # Deployment might be V2 or endpoint not found
try:
return self._api.get_inference_deployment_deployments_inference_deployment_id_get(id)
except ApiException as v2_error:
# If both fail, raise the original V3 error as it's more likely to be the real issue
raise e from v2_error
else:
# For other errors (auth, network, etc.), raise immediately
raise

def get_compute(self, id):
return self._api.get_compute_deployment_deployments_compute_deployment_id_get(id)

def get_cserve(self, id):
return self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(id)

def create_inference(self, request: CreateInferenceDeploymentRequest):
"""Get CServe deployment details - automatically handles both V2 and V3 deployments"""
# Try V3 first (recommended), fallback to V2 if deployment is V2
try:
return self._api.get_cserve_v3_deployment_deployments_cserve_v3_deployment_id_get(id)
except ApiException as e:
# If V3 fails with 404 or similar, try V2
if e.status in [404, 400]: # Deployment might be V2 or endpoint not found
try:
return self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(id)
except ApiException as v2_error:
# If both fail, raise the original V3 error as it's more likely to be the real issue
raise e from v2_error
else:
# For other errors (auth, network, etc.), raise immediately
raise

def create_inference(self, request: CreateInferenceV3DeploymentRequest):
return self._api.create_inference_v3_deployment_deployments_inference_v3_post(request)

def create_inference_v2(self, request: CreateInferenceDeploymentRequest):
return self._api.create_inference_deployment_deployments_inference_post(request)

def create_inference_v3(self, request: CreateInferenceV3DeploymentRequest):
return self._api.create_inference_v3_deployment_deployments_inference_v3_post(request)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need these two. we should only allow creating v3 deployments

def create_compute(self, request: CreateComputeDeploymentRequest):
return self._api.create_compute_deployment_deployments_compute_post(request)

def create_cserve(self, request: CreateCServeV2DeploymentRequest):
def create_cserve(self, request: CreateCServeV3DeploymentRequest):
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)

def create_cserve_v2(self, request: CreateCServeV2DeploymentRequest):
return self._api.create_cserve_v2_deployment_deployments_cserve_v2_post(request)

def update_inference(self, deployment_id: int, request: CreateInferenceDeploymentRequest):
return self._api.update_inference_deployment_deployments_inference_put(deployment_id, request)
def create_cserve_v3(self, request: CreateCServeV3DeploymentRequest):
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)

def detect_inference_deployment_version(self, deployment_id: int) -> str:
"""Detect if an inference deployment is V2 or V3 by testing the specific API endpoints"""
try:
# Try V3 endpoint first
self._api.get_inference_v3_deployment_deployments_inference_v3_deployment_id_get(deployment_id)
return 'v3'
except ApiException as e:
if e.status in [404, 400]: # V3 endpoint doesn't exist for this deployment
try:
# Try V2 endpoint
self._api.get_inference_deployment_deployments_inference_deployment_id_get(deployment_id)
return 'v2'
except ApiException as exc:
# If both fail, it might not be an inference deployment or doesn't exist
raise ValueError(
f"Deployment {deployment_id} is not a valid inference deployment or does not exist"
) from exc
else:
# Other error (auth, network, etc.)
raise

def update_inference(
self, deployment_id: int, request: Union[CreateInferenceDeploymentRequest, CreateInferenceV3DeploymentRequest]
):
"""Update Inference deployment - validates request type matches deployment version"""
# Detect the deployment version
deployment_version = self.detect_inference_deployment_version(deployment_id)

# Validate request type matches deployment version
if isinstance(request, CreateInferenceV3DeploymentRequest):
if deployment_version != 'v3':
raise ValueError(
f"Deployment {deployment_id} is Inference {deployment_version.upper()}, "
f"but you provided a V3 request. Please use CreateInferenceDeploymentRequest instead."
)
return self._api.update_inference_v3_deployment_deployments_inference_v3_put(deployment_id, request)
elif isinstance(request, CreateInferenceDeploymentRequest):
if deployment_version != 'v2':
raise ValueError(
f"Deployment {deployment_id} is Inference {deployment_version.upper()}, "
f"but you provided a V2 request. Please use CreateInferenceV3DeploymentRequest instead."
)
return self._api.update_inference_deployment_deployments_inference_put(deployment_id, request)
else:
raise ValueError(
f"Unsupported request type: {type(request)}. "
f"Expected CreateInferenceDeploymentRequest or CreateInferenceV3DeploymentRequest."
)

def update_compute(self, deployment_id: int, request: CreateComputeDeploymentRequest):
return self._api.update_compute_deployment_deployments_compute_put(deployment_id, request)

def update_cserve(self, deployment_id: int, request: CreateCServeV2DeploymentRequest):
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, request)
def detect_deployment_version(self, deployment_id: int) -> str:
"""Detect if a deployment is V2 or V3 by testing the specific API endpoints"""
try:
# Try V3 endpoint first
self._api.get_cserve_v3_deployment_deployments_cserve_v3_deployment_id_get(deployment_id)
return 'v3'
except ApiException as e:
if e.status in [404, 400]: # V3 endpoint doesn't exist for this deployment
try:
# Try V2 endpoint
self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(deployment_id)
return 'v2'
except ApiException as exc:
# If both fail, it might not be a CServe deployment or doesn't exist
raise ValueError(
f"Deployment {deployment_id} is not a valid CServe deployment or does not exist"
) from exc
else:
# Other error (auth, network, etc.)
raise

def update_cserve(
self, deployment_id: int, request: Union[CreateCServeV2DeploymentRequest, CreateCServeV3DeploymentRequest]
):
"""Update CServe deployment - validates request type matches deployment version"""
# Detect the deployment version
deployment_version = self.detect_deployment_version(deployment_id)

# Validate request type matches deployment version
if isinstance(request, CreateCServeV3DeploymentRequest):
if deployment_version != 'v3':
raise ValueError(
f"Deployment {deployment_id} is CServe {deployment_version.upper()}, "
f"but you provided a V3 request. Please use CreateCServeV2DeploymentRequest instead."
)
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, request)
elif isinstance(request, CreateCServeV2DeploymentRequest):
if deployment_version != 'v2':
raise ValueError(
f"Deployment {deployment_id} is CServe {deployment_version.upper()}, "
f"but you provided a V2 request. Please use CreateCServeV3DeploymentRequest instead."
)
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, request)
else:
raise ValueError(
f"Unsupported request type: {type(request)}. "
f"Expected CreateCServeV2DeploymentRequest or CreateCServeV3DeploymentRequest."
)

def _update_status(self, id, new_status):
status_req = platform_api_python_client.DeploymentStatusRequest(status=new_status)
Expand Down Expand Up @@ -93,6 +229,20 @@ def get_user_vault(self, type):

return {i.key: i.value for i in items}

def detect_cserve_deployment_version(self, deployment_response):
"""Detect if a CServe deployment is V2 or V3 based on response fields"""
# Check for V3-specific fields
if hasattr(deployment_response, 'max_surge') or hasattr(deployment_response, 'max_unavailable'):
return 'v3'
# Check for V3 field names (min_replicas vs min_scale)
if hasattr(deployment_response, 'min_replicas'):
return 'v3'
# Check for V2 field names
if hasattr(deployment_response, 'min_scale'):
return 'v2'
# Default to V2 for backward compatibility
return 'v2'

# pylint: disable=R0917
def get_deployment_usage(
self, id: int, metric: Metric, start_time_in_seconds: int, end_time_in_seconds: int, step: int
Expand Down
Loading
Loading