Skip to content

Commit 4f1fea2

Browse files
committed
added prebuilt img list for all, but only inf now working
1 parent 1886028 commit 4f1fea2

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

centml/cli/cluster.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,38 @@ def create():
226226
concurrency = click.prompt("Max concurrency (or leave blank)", default="", show_default=False)
227227
concurrency = int(concurrency) if concurrency else None
228228

229-
# Depending on type:
230229
if depl_type == DeploymentType.INFERENCE_V2:
231-
image = click.prompt("Enter the image URL")
232-
port = click.prompt("Enter the container port", default=8080, type=int)
233-
healthcheck = click.prompt("Enter healthcheck endpoint (default '/')", default="/", show_default=True)
230+
# Retrieve prebuilt images for inference deployments
231+
prebuilt_images = cclient.get_prebuilt_images(depl_type=depl_type)
232+
image_choices = [img.image_name for img in prebuilt_images.results] if prebuilt_images.results else []
233+
234+
chosen_image = click.prompt(
235+
"Select a prebuilt image or provide a custom image URL",
236+
type=click.Choice(image_choices),
237+
show_choices=True
238+
)
239+
240+
if chosen_image == "Other":
241+
image = click.prompt("Enter the image URL")
242+
port = click.prompt("Enter the container port", default=8080, type=int)
243+
healthcheck = click.prompt("Enter healthcheck endpoint (default '/')", default="/", show_default=True)
244+
else:
245+
# Find the selected prebuilt image details
246+
selected_prebuilt = next(img for img in prebuilt_images.results if img.image_name == chosen_image)
247+
image = selected_prebuilt.image_name
248+
# Use the prebuilt image port and healthcheck as defaults
249+
port = click.prompt(
250+
"Enter the container port",
251+
default=selected_prebuilt.port,
252+
type=int
253+
)
254+
default_healthcheck = selected_prebuilt.healthcheck if selected_prebuilt.healthcheck else "/"
255+
healthcheck = click.prompt(
256+
"Enter healthcheck endpoint (default '/')",
257+
default=default_healthcheck,
258+
show_default=True
259+
)
260+
234261
env_vars_str = click.prompt("Enter environment variables in KEY=VALUE format (comma separated) or leave blank", default="", show_default=False)
235262
env_vars = {}
236263
if env_vars_str.strip():
@@ -260,10 +287,15 @@ def create():
260287
ssh_key = click.prompt("Enter your public SSH key", default="", show_default=False)
261288

262289
from platform_api_python_client import CreateComputeDeploymentRequest
290+
# If compute deployments also use prebuilt images and require image_url,
291+
# we could similarly fetch them and prompt just like inference above.
292+
# For now, if the schema doesn't require image_url for compute:
263293
req = CreateComputeDeploymentRequest(
264294
name=name,
265295
cluster_id=cluster_id,
266296
hardware_instance_id=hw_id,
297+
# If needed, you can do similar logic for prebuilt images here:
298+
# image_url = ...
267299
ssh_public_key=ssh_key if ssh_key.strip() else None
268300
)
269301
created = cclient.create_compute(req)
@@ -274,9 +306,9 @@ def create():
274306
model = click.prompt("Enter the Hugging Face model", default="facebook/opt-1.3b")
275307
tensor_parallel_size = click.prompt("Tensor parallel size", default=1, type=int)
276308
pipeline_parallel_size = click.prompt("Pipeline parallel size", default=1, type=int)
277-
# concurrency asked above
278309

279310
from platform_api_python_client import CreateCServeDeploymentRequest
311+
# If cserve deployments also require images, we could do similar logic here.
280312
req = CreateCServeDeploymentRequest(
281313
name=name,
282314
cluster_id=cluster_id,

centml/sdk/api.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import platform_api_python_client
44
from platform_api_python_client import (
5+
DeploymentType,
56
DeploymentStatus,
67
CreateInferenceDeploymentRequest,
78
CreateComputeDeploymentRequest,
@@ -62,11 +63,6 @@ def get_hardware_instances(self, cluster_id):
6263
return self._api.get_hardware_instances_hardware_instances_get(cluster_id).results
6364

6465
def get_prebuilt_images(self, depl_type: DeploymentType = None):
65-
"""Get Prebuilt Images
66-
67-
:param depl_type: DeploymentType, optional
68-
:return: ListPrebuiltImageResponse
69-
"""
7066
return self._api.get_prebuilt_images_prebuilt_images_get(type=depl_type)
7167

7268

0 commit comments

Comments
 (0)