Skip to content
89 changes: 56 additions & 33 deletions centml/cli/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@


depl_type_to_name_map = {
DeploymentType.INFERENCE: 'inference',
DeploymentType.COMPUTE: 'compute',
DeploymentType.COMPILATION: 'compilation',
DeploymentType.INFERENCE_V2: 'inference',
DeploymentType.COMPUTE_V2: 'compute',
DeploymentType.CSERVE: 'cserve',
DeploymentType.CSERVE_V2: 'cserve',
DeploymentType.RAG: 'rag',
DeploymentType.INFERENCE: "inference",
DeploymentType.COMPUTE: "compute",
DeploymentType.COMPILATION: "compilation",
DeploymentType.INFERENCE_V2: "inference",
DeploymentType.COMPUTE_V2: "compute",
DeploymentType.CSERVE: "cserve",
DeploymentType.CSERVE_V2: "cserve-v2",
DeploymentType.CSERVE_V3: "cserve",
DeploymentType.RAG: "rag",
}
depl_name_to_type_map = {
'inference': DeploymentType.INFERENCE_V2,
'cserve': DeploymentType.CSERVE_V2,
'compute': DeploymentType.COMPUTE_V2,
'rag': DeploymentType.RAG,
"inference": DeploymentType.INFERENCE_V2,
"cserve": DeploymentType.CSERVE_V3,
"cserve-v2": DeploymentType.CSERVE_V2,
"cserve-v3": DeploymentType.CSERVE_V3,
"compute": DeploymentType.COMPUTE_V2,
"rag": DeploymentType.RAG,
}


Expand Down Expand Up @@ -56,6 +59,17 @@ def _format_ssh_key(ssh_key):
return ssh_key[:32] + "..."


def _get_replica_info(deployment, depl_type):
"""Extract replica information handling V2/V3 field differences"""
if depl_type == DeploymentType.CSERVE_V3:
return {
"min": getattr(deployment, "min_replicas", getattr(deployment, "min_scale", None)),
"max": getattr(deployment, "max_replicas", getattr(deployment, "max_scale", None)),
}
else: # V2
return {"min": deployment.min_scale, "max": deployment.max_scale}


def _get_ready_status(cclient, deployment):
api_status = deployment.status
service_status = (
Expand Down Expand Up @@ -126,7 +140,9 @@ def get(type, id):
elif depl_type == DeploymentType.COMPUTE_V2:
deployment = cclient.get_compute(id)
elif depl_type == DeploymentType.CSERVE_V2:
deployment = cclient.get_cserve(id)
deployment = cclient.get_cserve_v2(id)
elif depl_type == DeploymentType.CSERVE_V3:
deployment = cclient.get_cserve_v3(id)
else:
sys.exit("Please enter correct deployment type")

Expand Down Expand Up @@ -157,7 +173,7 @@ def get(type, id):
("Image", deployment.image_url),
("Container port", deployment.container_port),
("Healthcheck", deployment.healthcheck or "/"),
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
("Replicas", _get_replica_info(deployment, depl_type)),
("Environment variables", deployment.env_vars or "None"),
("Max concurrency", deployment.concurrency or "None"),
],
Expand All @@ -173,25 +189,32 @@ def get(type, id):
disable_numparse=True,
)
)
elif depl_type == DeploymentType.CSERVE_V2:
click.echo(
tabulate(
[
("Hugging face model", deployment.recipe.model),
(
"Parallelism",
{
"tensor": deployment.recipe.additional_properties['tensor_parallel_size'],
"pipeline": deployment.recipe.additional_properties['pipeline_parallel_size'],
},
),
("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}),
("Max concurrency", deployment.concurrency or "None"),
],
tablefmt="rounded_outline",
disable_numparse=True,
)
)
elif depl_type in [DeploymentType.CSERVE_V2, DeploymentType.CSERVE_V3]:
replica_info = _get_replica_info(deployment, depl_type)
display_rows = [
("Hugging face model", deployment.recipe.model),
(
"Parallelism",
{
"tensor": deployment.recipe.additional_properties.get("tensor_parallel_size", "N/A"),
"pipeline": deployment.recipe.additional_properties.get("pipeline_parallel_size", "N/A"),
},
),
("Replicas", replica_info),
("Max concurrency", deployment.concurrency or "None"),
]

# Add V3-specific rollout information
if depl_type == DeploymentType.CSERVE_V3:
rollout_info = {}
if hasattr(deployment, "max_surge") and deployment.max_surge is not None:
rollout_info["max_surge"] = deployment.max_surge
if hasattr(deployment, "max_unavailable") and deployment.max_unavailable is not None:
rollout_info["max_unavailable"] = deployment.max_unavailable
if rollout_info:
display_rows.append(("Rollout strategy", rollout_info))

click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))


@click.command(help="Delete a deployment")
Expand Down
23 changes: 21 additions & 2 deletions centml/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CreateInferenceDeploymentRequest,
CreateComputeDeploymentRequest,
CreateCServeV2DeploymentRequest,
CreateCServeV3DeploymentRequest,
Metric,
)

Expand All @@ -33,26 +34,44 @@ def get_compute(self, id):
return self._api.get_compute_deployment_deployments_compute_deployment_id_get(id)

def get_cserve(self, id):
return self._api.get_cserve_v3_deployment_deployments_cserve_v3_deployment_id_get(id)

def get_cserve_v2(self, id):
return self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(id)

def get_cserve_v3(self, id):
return self._api.get_cserve_v3_deployment_deployments_cserve_v3_deployment_id_get(id)

def create_inference(self, request: CreateInferenceDeploymentRequest):
return self._api.create_inference_deployment_deployments_inference_post(request)

def create_compute(self, request: CreateComputeDeploymentRequest):
return self._api.create_compute_deployment_deployments_compute_post(request)

def create_cserve(self, request: CreateCServeV2DeploymentRequest):
def create_cserve(self, request: CreateCServeV3DeploymentRequest):
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)

def create_cserve_v2(self, request: CreateCServeV2DeploymentRequest):
return self._api.create_cserve_v2_deployment_deployments_cserve_v2_post(request)

def create_cserve_v3(self, request: CreateCServeV3DeploymentRequest):
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)

def update_inference(self, deployment_id: int, request: CreateInferenceDeploymentRequest):
return self._api.update_inference_deployment_deployments_inference_put(deployment_id, request)

def update_compute(self, deployment_id: int, request: CreateComputeDeploymentRequest):
return self._api.update_compute_deployment_deployments_compute_put(deployment_id, request)

def update_cserve(self, deployment_id: int, request: CreateCServeV2DeploymentRequest):
def update_cserve(self, deployment_id: int, request: CreateCServeV3DeploymentRequest):
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, request)

def update_cserve_v2(self, deployment_id: int, request: CreateCServeV2DeploymentRequest):
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, request)

def update_cserve_v3(self, deployment_id: int, request: CreateCServeV3DeploymentRequest):
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, request)

def _update_status(self, id, new_status):
status_req = platform_api_python_client.DeploymentStatusRequest(status=new_status)
self._api.update_deployment_status_deployments_status_deployment_id_put(id, status_req)
Expand Down
2 changes: 1 addition & 1 deletion centml/sdk/utils/client_certs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def save_pem_file(service_name, client_private_key, client_certificate):

try:
# Save the combined PEM file
with open(ca_file_path, 'w') as combined_pem_file:
with open(ca_file_path, "w") as combined_pem_file:
combined_pem_file.write(client_private_key + client_certificate)
click.echo(f"Combined PEM file for accessing the private endpoint has been saved to {ca_file_path}")

Expand Down
29 changes: 16 additions & 13 deletions examples/sdk/create_cserve.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import centml
from centml.sdk.api import get_centml_client
from centml.sdk import DeploymentType, CreateCServeV2DeploymentRequest, CServeV2Recipe
from centml.sdk import DeploymentType, CreateCServeV3DeploymentRequest, CServeV2Recipe


def get_fastest_cserve_config(cclient, name, model):
fastest = cclient.get_cserve_recipe(model=model)[0].fastest

return CreateCServeV2DeploymentRequest(
return CreateCServeV3DeploymentRequest(
name=name,
cluster_id=cclient.get_cluster_id(fastest.hardware_instance_id),
hardware_instance_id=fastest.hardware_instance_id,
recipe=fastest.recipe,
min_scale=1,
max_scale=1,
min_replicas=1,
max_replicas=1,
env_vars={},
)

Expand All @@ -22,41 +22,44 @@ def get_default_cserve_config(cclient, name, model):

hardware_instance = cclient.get_hardware_instances(cluster_id=1001)[0]

return CreateCServeV2DeploymentRequest(
return CreateCServeV3DeploymentRequest(
name=name,
cluster_id=hardware_instance.cluster_id,
hardware_instance_id=hardware_instance.id,
recipe=default_recipe,
min_scale=1,
max_scale=1,
min_replicas=1,
max_replicas=1,
env_vars={},
)


def main():
with get_centml_client() as cclient:
### Get the configurations for the Qwen model
qwen_config = get_fastest_cserve_config(cclient, name="qwen-fastest", model="Qwen/Qwen2-VL-7B-Instruct")
#qwen_config = get_default_cserve_config(cclient, name="qwen-default", model="Qwen/Qwen2-VL-7B-Instruct")
qwen_config = get_fastest_cserve_config(
cclient, name="qwen-fastest", model="Qwen/Qwen2-VL-7B-Instruct"
)
# qwen_config = get_default_cserve_config(cclient, name="qwen-default", model="Qwen/Qwen2-VL-7B-Instruct")

### Modify the recipe if necessary
qwen_config.recipe.additional_properties["max_num_seqs"] = 512

# Create CServeV2 deployment
# Create CServeV3 deployment
response = cclient.create_cserve(qwen_config)
print("Create deployment response: ", response)

### Get deployment details
deployment = cclient.get_cserve(response.id)
deployment = cclient.get_cserve_v3(response.id)
print("Deployment details: ", deployment)

'''
"""
### Pause the deployment
cclient.pause(deployment.id)

### Delete the deployment
cclient.delete(deployment.id)
'''
"""


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ cryptography==44.0.1
prometheus-client>=0.20.0
scipy>=1.6.0
scikit-learn>=1.5.1
platform-api-python-client==4.0.12
platform-api-python-client==4.1.9
Loading