diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 5e4afa188..999905e73 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -891,10 +891,14 @@ class runopt: Represents the metadata about the specific run option """ + class alias(str): + pass + default: CfgVal opt_type: Type[CfgVal] is_required: bool help: str + aliases: list[alias] | None = None @property def is_type_list_of_str(self) -> bool: @@ -986,6 +990,7 @@ class runopts: def __init__(self) -> None: self._opts: Dict[str, runopt] = {} + self._alias_to_key: dict[runopt.alias, str] = {} def __iter__(self) -> Iterator[Tuple[str, runopt]]: return self._opts.items().__iter__() @@ -1013,9 +1018,16 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool: def get(self, name: str) -> Optional[runopt]: """ - Returns option if any was registered, or None otherwise + Returns option if any was registered, or None otherwise. + First searches for the option by ``name``, then falls-back to matching ``name`` with any + registered aliases. + """ - return self._opts.get(name, None) + if name in self._opts: + return self._opts[name] + if name in self._alias_to_key: + return self._opts[self._alias_to_key[name]] + return None def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: """ @@ -1030,6 +1042,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: for cfg_key, runopt in self._opts.items(): val = resolved_cfg.get(cfg_key) + resolved_name = None + aliases = runopt.aliases or [] + if val is None: + for alias in aliases: + val = resolved_cfg.get(alias) + if alias in cfg or val is not None: + resolved_name = alias + break + else: + resolved_name = cfg_key + for alias in aliases: + duplicate_val = resolved_cfg.get(alias) + if alias in cfg or duplicate_val is not None: + raise InvalidRunConfigException( + f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`", + resolved_name, + cfg, + ) # check required opt if runopt.is_required and val is None: @@ -1049,7 +1079,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: ) # not required and not set, set to default - if val is None: + if val is None and resolved_name is None: resolved_cfg[cfg_key] = runopt.default return resolved_cfg @@ -1142,9 +1172,38 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: cfg[key] = val return cfg + def _get_primary_key_and_aliases( + self, + cfg_key: list[str] | str, + ) -> tuple[str, list[runopt.alias]]: + """ + Returns the primary key and aliases for the given cfg_key. + """ + if isinstance(cfg_key, str): + return cfg_key, [] + + if len(cfg_key) == 0: + raise ValueError("cfg_key must be a non-empty list") + primary_key = None + aliases = list[runopt.alias]() + for name in cfg_key: + if isinstance(name, runopt.alias): + aliases.append(name) + else: + if primary_key is not None: + raise ValueError( + f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. " + ) + primary_key = name + if primary_key is None or primary_key == "": + raise ValueError( + "Missing cfg_key. Please provide one other than the aliases." + ) + return primary_key, aliases + def add( self, - cfg_key: str, + cfg_key: str | list[str], type_: Type[CfgVal], help: str, default: CfgVal = None, @@ -1155,6 +1214,7 @@ def add( value (if any). If the ``default`` is not specified then this option is a required option. """ + primary_key, aliases = self._get_primary_key_and_aliases(cfg_key) if required and default is not None: raise ValueError( f"Required option: {cfg_key} must not specify default value. Given: {default}" @@ -1165,8 +1225,10 @@ def add( f"Option: {cfg_key}, must be of type: {type_}." f" Given: {default} ({type(default).__name__})" ) - - self._opts[cfg_key] = runopt(default, type_, required, help) + opt = runopt(default, type_, required, help, aliases) + for alias in aliases: + self._alias_to_key[alias] = primary_key + self._opts[primary_key] = opt def update(self, other: "runopts") -> None: self._opts.update(other._opts) diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 6bbacd5ee..9f7f8aa1b 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -578,6 +578,49 @@ def test_runopts_add(self) -> None: # this print is intentional (demonstrates the intended usecase) print(opts) + def test_runopts_add_with_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + self.assertEqual(1, len(opts._opts)) + self.assertIsNotNone(opts.get("job_priority")) + self.assertIsNotNone(opts.get("jobPriority")) + + def test_runopts_resolve_with_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + opts.resolve({"job_priority": "high"}) + opts.resolve({"jobPriority": "low"}) + with self.assertRaises(InvalidRunConfigException): + opts.resolve({"job_priority": "high", "jobPriority": "low"}) + + def test_runopts_resolve_with_none_valued_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + opts.add( + ["modelTypeName", runopt.alias("model_type_name")], + type_=Union[str, None], + help="ML Hub Model Type to attribute resource utilization for job", + ) + resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"}) + self.assertEqual(resolved_opts.get("model_type_name"), None) + self.assertEqual(resolved_opts.get("jobPriority"), "low") + self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"}) + + with self.assertRaises(InvalidRunConfigException): + opts.resolve({"model_type_name": None, "modelTypeName": "low"}) + def get_runopts(self) -> runopts: opts = runopts() opts.add("run_as", type_=str, help="run as user", required=True)