Skip to content

Commit 1d43068

Browse files
committed
change selections to be indexed.
1 parent 4963ce5 commit 1d43068

File tree

1 file changed

+77
-81
lines changed

1 file changed

+77
-81
lines changed

centml/cli/cluster.py

Lines changed: 77 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def ls(type):
9494
)
9595
)
9696

97+
9798
# TODO: Status for Cserve seems to be broken
9899
@click.command(help="Get deployment details")
99100
@click.argument("name", type=str)
@@ -188,12 +189,16 @@ def create():
188189
# Prompt for general fields
189190
name = click.prompt("Enter a name for the deployment")
190191

191-
dtype_str = click.prompt(
192-
"Select a deployment type",
193-
type=click.Choice(list(depl_name_to_type_map.keys())),
194-
show_choices=True,
195-
default=list(depl_name_to_type_map.keys())[0],
196-
)
192+
# --- Deployment Type Selection (Indexed) ---
193+
deploy_types = list(depl_name_to_type_map.keys())
194+
click.echo("Select a deployment type:")
195+
for idx, dtype in enumerate(deploy_types, start=1):
196+
click.echo(f"{idx}. {dtype}")
197+
dtype_index = click.prompt("Enter the deployment type number", type=int, default=1)
198+
if dtype_index < 1 or dtype_index > len(deploy_types):
199+
click.echo("Invalid selection.")
200+
return
201+
dtype_str = deploy_types[dtype_index - 1]
197202
depl_type = depl_name_to_type_map[dtype_str]
198203

199204
if depl_type == DeploymentType.INFERENCE_V2:
@@ -226,19 +231,17 @@ def create():
226231

227232
# Retrieve prebuilt images for inference deployments
228233
prebuilt_images = cclient.get_prebuilt_images(depl_type=depl_type)
229-
230-
# Build list of image labels
231234
image_choices = [img.label for img in prebuilt_images.results] if prebuilt_images.results else []
232-
233-
# Right now we disable this other option to get a MVP out quickly.
234-
#image_choices.append("Other")
235-
236-
chosen_label = click.prompt(
237-
"Select a prebuilt image label or choose 'Other' to provide a custom image URL",
238-
type=click.Choice(image_choices),
239-
show_choices=True,
240-
default=image_choices[0],
241-
)
235+
# Enable custom image selection by adding "Other" to the list.
236+
image_choices.append("Other")
237+
click.echo("Available prebuilt image labels:")
238+
for idx, label in enumerate(image_choices, start=1):
239+
click.echo(f"{idx}. {label}")
240+
choice_index = click.prompt("Select a prebuilt image label by number", type=int, default=1)
241+
if choice_index < 1 or choice_index > len(image_choices):
242+
click.echo("Invalid selection.")
243+
return
244+
chosen_label = image_choices[choice_index - 1]
242245

243246
if chosen_label == "Other":
244247
image = click.prompt("Enter the custom image URL")
@@ -249,19 +252,20 @@ def create():
249252
else:
250253
# Find the prebuilt image with the matching label
251254
selected_prebuilt = next(img for img in prebuilt_images.results if img.label == chosen_label)
252-
# Prompt the user to select a tag from the available tags
253-
tag = click.prompt(
254-
"Select a tag for the image",
255-
type=click.Choice(selected_prebuilt.tags),
256-
show_choices=True,
257-
default=selected_prebuilt.tags[0],
258-
)
255+
# Prompt the user to select a tag from the available tags (indexed)
256+
click.echo("Available tags for the selected image:")
257+
for idx, tag in enumerate(selected_prebuilt.tags, start=1):
258+
click.echo(f"{idx}. {tag}")
259+
tag_index = click.prompt("Select a tag for the image by number", type=int, default=1)
260+
if tag_index < 1 or tag_index > len(selected_prebuilt.tags):
261+
click.echo("Invalid tag selection.")
262+
return
263+
tag = selected_prebuilt.tags[tag_index - 1]
259264
# Combine the image URL with the chosen tag
260265
image = f"{selected_prebuilt.image_name}:{tag}"
261266
port = selected_prebuilt.port
262267
healthcheck = selected_prebuilt.healthcheck if selected_prebuilt.healthcheck else "/"
263268

264-
265269
env_vars_str = click.prompt(
266270
"Enter environment variables in KEY=VALUE format (comma separated) or leave blank",
267271
default="",
@@ -336,46 +340,40 @@ def create():
336340

337341
# Retrieve prebuilt images for compute deployments
338342
prebuilt_images = cclient.get_prebuilt_images(depl_type=depl_type)
339-
# Build list of image labels
340343
image_choices = [img.label for img in prebuilt_images.results] if prebuilt_images.results else []
344+
click.echo("Available prebuilt image labels:")
345+
for idx, label in enumerate(image_choices, start=1):
346+
click.echo(f"{idx}. {label}")
347+
choice_index = click.prompt("Select a prebuilt image label by number", type=int, default=1)
348+
if choice_index < 1 or choice_index > len(image_choices):
349+
click.echo("Invalid selection.")
350+
return
351+
chosen_label = image_choices[choice_index - 1]
341352

342-
chosen_label = click.prompt(
343-
"Select a prebuilt image label",
344-
type=click.Choice(image_choices),
345-
show_choices=True,
346-
default=image_choices[0],
347-
)
348-
349-
selected_prebuilt = next(img for img in prebuilt_images.results if img.label == chosen_label)
350-
351-
# Find the prebuilt image with the matching label
352353
selected_prebuilt = next(img for img in prebuilt_images.results if img.label == chosen_label)
353-
# Prompt the user to select a tag from the available tags
354-
tag = click.prompt(
355-
"Select a tag for the image",
356-
type=click.Choice(selected_prebuilt.tags),
357-
show_choices=True,
358-
default=selected_prebuilt.tags[0],
359-
)
360-
# Combine the image URL with the chosen tag
354+
# Prompt the user to select a tag from the available tags (indexed)
355+
click.echo("Available tags for the selected image:")
356+
for idx, tag in enumerate(selected_prebuilt.tags, start=1):
357+
click.echo(f"{idx}. {tag}")
358+
tag_index = click.prompt("Select a tag for the image by number", type=int, default=1)
359+
if tag_index < 1 or tag_index > len(selected_prebuilt.tags):
360+
click.echo("Invalid tag selection.")
361+
return
362+
tag = selected_prebuilt.tags[tag_index - 1]
361363
image_url = f"{selected_prebuilt.image_name}:{tag}"
362364

363365
# For compute deployments, we might ask for a public SSH key
364366
ssh_key = click.prompt("Enter your public SSH key")
365367

366-
# Right now we not support this on prod platform, just unify the feature
367-
#jupyter = click.prompt("Enable Jupyter Notebook on this compute deployment?", type=bool,default=False, show_default=False)
368-
369368
from platform_api_python_client import CreateComputeDeploymentRequest
370369

371370
req = CreateComputeDeploymentRequest(
372371
name=name,
373372
cluster_id=cluster_id,
374373
hardware_instance_id=hw_id,
375374
image_url=image_url,
376-
ssh_public_key=ssh_key, # we require this
377-
#enable_jupyter=jupyter,
378-
)
375+
ssh_public_key=ssh_key,
376+
)
379377

380378
created = cclient.create_compute(req)
381379
click.echo(f"Compute deployment {name} created with ID: {created.id}")
@@ -431,7 +429,6 @@ def create():
431429
sys.exit(1)
432430

433431
# Display the hardware instance information to the user.
434-
435432
credits = selected_hw.cost_per_hr / 100.0 # e.g., 360 -> 3.60 credits per hour
436433
vram_gib = selected_hw.accelerator_memory / 1024 # e.g., 81920 MB -> 80 GiB VRAM
437434
memory_gib = selected_hw.memory / 1024 # e.g., 239616 MB -> 234 GiB memory
@@ -453,34 +450,34 @@ def create():
453450
recipe_dict.pop("additional_properties", None)
454451

455452
recipe_payload = {
456-
"model": recipe_dict.get("model"),
457-
"is_embedding_model": recipe_dict.get("is_embedding_model"),
458-
"dtype": recipe_dict.get("dtype"),
459-
"tokenizer": recipe_dict.get("tokenizer"),
460-
"block_size": recipe_dict.get("block_size"),
461-
"swap_space": recipe_dict.get("swap_space"),
462-
"cache_dtype": recipe_dict.get("cache_dtype"),
463-
"spec_tokens": recipe_dict.get("spec_tokens"),
464-
"gpu_mem_util": recipe_dict.get("gpu_mem_util"),
465-
"max_num_seqs": recipe_dict.get("max_num_seqs"),
466-
"quantization": recipe_dict.get("quantization"),
467-
"max_model_len": recipe_dict.get("max_model_len"),
468-
"offloading_num": int(recipe_dict.get("offloading_num")),
469-
"use_flashinfer": recipe_dict.get("use_flashinfer"),
470-
"eager_execution": recipe_dict.get("eager_execution"),
471-
"spec_draft_model": recipe_dict.get("spec_draft_model"),
472-
"spec_max_seq_len": recipe_dict.get("spec_max_seq_len"),
473-
"use_prefix_caching": recipe_dict.get("use_prefix_caching"),
474-
"num_scheduler_steps": recipe_dict.get("num_scheduler_steps"),
475-
"spec_max_batch_size": recipe_dict.get("spec_max_batch_size"),
476-
"use_chunked_prefill": recipe_dict.get("use_chunked_prefill"),
477-
"chunked_prefill_size": recipe_dict.get("chunked_prefill_size"),
478-
"tensor_parallel_size": recipe_dict.get("tensor_parallel_size"),
479-
"max_seq_len_to_capture": recipe_dict.get("max_seq_len_to_capture"),
480-
"pipeline_parallel_size": recipe_dict.get("pipeline_parallel_size"),
481-
"spec_prompt_lookup_max": recipe_dict.get("spec_prompt_lookup_max"),
482-
"spec_prompt_lookup_min": recipe_dict.get("spec_prompt_lookup_min"),
483-
"distributed_executor_backend": recipe_dict.get("distributed_executor_backend"),
453+
"model": recipe_dict.get("model"),
454+
"is_embedding_model": recipe_dict.get("is_embedding_model"),
455+
"dtype": recipe_dict.get("dtype"),
456+
"tokenizer": recipe_dict.get("tokenizer"),
457+
"block_size": recipe_dict.get("block_size"),
458+
"swap_space": recipe_dict.get("swap_space"),
459+
"cache_dtype": recipe_dict.get("cache_dtype"),
460+
"spec_tokens": recipe_dict.get("spec_tokens"),
461+
"gpu_mem_util": recipe_dict.get("gpu_mem_util"),
462+
"max_num_seqs": recipe_dict.get("max_num_seqs"),
463+
"quantization": recipe_dict.get("quantization"),
464+
"max_model_len": recipe_dict.get("max_model_len"),
465+
"offloading_num": int(recipe_dict.get("offloading_num")),
466+
"use_flashinfer": recipe_dict.get("use_flashinfer"),
467+
"eager_execution": recipe_dict.get("eager_execution"),
468+
"spec_draft_model": recipe_dict.get("spec_draft_model"),
469+
"spec_max_seq_len": recipe_dict.get("spec_max_seq_len"),
470+
"use_prefix_caching": recipe_dict.get("use_prefix_caching"),
471+
"num_scheduler_steps": recipe_dict.get("num_scheduler_steps"),
472+
"spec_max_batch_size": recipe_dict.get("spec_max_batch_size"),
473+
"use_chunked_prefill": recipe_dict.get("use_chunked_prefill"),
474+
"chunked_prefill_size": recipe_dict.get("chunked_prefill_size"),
475+
"tensor_parallel_size": recipe_dict.get("tensor_parallel_size"),
476+
"max_seq_len_to_capture": recipe_dict.get("max_seq_len_to_capture"),
477+
"pipeline_parallel_size": recipe_dict.get("pipeline_parallel_size"),
478+
"spec_prompt_lookup_max": recipe_dict.get("spec_prompt_lookup_max"),
479+
"spec_prompt_lookup_min": recipe_dict.get("spec_prompt_lookup_min"),
480+
"distributed_executor_backend": recipe_dict.get("distributed_executor_backend"),
484481
}
485482

486483
# --- Additional Prompts ---
@@ -530,7 +527,6 @@ def create():
530527
created = cclient.create_cserve(req)
531528
click.echo(f"CServe deployment {name} created with ID: {created.id}")
532529

533-
534530
else:
535531
click.echo("Unknown deployment type.")
536532

0 commit comments

Comments
 (0)