diff --git a/latch_cli/services/get_params.py b/latch_cli/services/get_params.py index 7c0584f3..729ed77c 100644 --- a/latch_cli/services/get_params.py +++ b/latch_cli/services/get_params.py @@ -8,6 +8,7 @@ import keyword import typing from typing import Optional +from slugify import slugify import google.protobuf.json_format as gpjson from flyteidl.core.literals_pb2 import Literal as _Literal @@ -122,6 +123,7 @@ def get_params(wf_name: str, wf_version: Optional[str] = None): param_map_str += f'\n "_name": "{wf_name}", # Don\'t edit this value.' for param_name, value in params.items(): python_type, python_val, default = value + enum_dict = {} # Check for imports. @@ -142,8 +144,9 @@ def _handle_enum(python_type: typing.T): if variant in keyword.kwlist: variant_name = f"_{variant}" else: - variant_name = variant + variant_name = _handle_format_enum_name(variant) _enum_literal += f"\n {variant_name} = '{variant}'" + enum_dict[variant_name] = variant enum_literals.append(_enum_literal) # Parse collection, union types for potential imports and dependent @@ -162,6 +165,15 @@ def _handle_enum(python_type: typing.T): python_val, python_type = _get_code_literal(python_val, python_type) + # modify python_val with correct enum name + if isinstance(python_type, str) and ('enum' in python_type): + python_val = python_val.replace(param_name + ".", "") + if default is True: + enum_name = next((key for key, value in enum_dict.items() if value == python_val), None) + else: + enum_name = next(iter(enum_dict.keys())) + python_val = param_name + "." + enum_name + if default is True: default = "DEFAULT. " else: @@ -373,3 +385,10 @@ def _best_effort_default_val(t: typing.T): raise NotImplementedError( f"Unable to produce a best-effort value for the python type {t}" ) + + +def _handle_format_enum_name(enum_name_input): + slugified_string = slugify(enum_name_input, separator='_') + if slugified_string[0].isdigit(): + slugified_string = '_' + slugified_string + return slugified_string.upper() \ No newline at end of file