From f9b48bac516f71d952a7be24b547280ddee04d6d Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Fri, 15 Nov 2024 06:36:28 -0800 Subject: [PATCH] Simplify sweep flags Sweeps are now passed directly as `--cfg.model="{'__qualname__': 'kd.nn:Transformer', 'num_layers': 16}"`, no need for an additional special flag PiperOrigin-RevId: 696869108 --- ml_collections/config_flags/config_flags.py | 27 ++++++++++++++++++- .../tests/config_overriding_test.py | 15 +++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/ml_collections/config_flags/config_flags.py b/ml_collections/config_flags/config_flags.py index 8569668..e8514a7 100644 --- a/ml_collections/config_flags/config_flags.py +++ b/ml_collections/config_flags/config_flags.py @@ -83,12 +83,32 @@ def flag_type(self): return 'config_literal' +class _ConfigDictParser(flags.ArgumentParser): + """Parser for ConfigDict values.""" + + def parse(self, argument: str) -> config_dict.ConfigDict: + try: + value = ast.literal_eval(argument) + except (SyntaxError, ValueError) as e: + raise ValueError( + f'Failed to parse {argument!r} as a ConfigDict: {e!r}' + ) from None + + if not isinstance(value, dict): + raise ValueError( + f'Failed to parse {argument!r} as a ConfigDict: `{value!r}` is not a' + ' dict.' + ) + return config_dict.ConfigDict(value) + + _FIELD_TYPE_TO_PARSER = { float: flags.FloatParser(), bool: flags.BooleanParser(), tuple: tuple_parser.TupleParser(), int: flags.IntegerParser(), str: flags.ArgumentParser(), + config_dict.ConfigDict: _ConfigDictParser(), object: _LiteralParser(), } @@ -816,6 +836,11 @@ def _parse(self, argument): parser = None if field_type in _FIELD_TYPE_TO_PARSER: parser = _FIELD_TYPE_TO_PARSER[field_type] + elif isinstance(field_type, type) and issubclass( + field_type, config_dict.ConfigDict + ): + # Supports ConfigDict sub-classes. + parser = _FIELD_TYPE_TO_PARSER[config_dict.ConfigDict] elif field_type_origin and field_type_origin in _FIELD_TYPE_TO_PARSER: parser = _FIELD_TYPE_TO_PARSER[field_type_origin] elif issubclass(field_type, enum.Enum): @@ -830,7 +855,7 @@ def _parse(self, argument): if parser: if not isinstance(parser, tuple_parser.TupleParser): - if isinstance(parser, _LiteralParser): + if isinstance(parser, (_LiteralParser, _ConfigDictParser)): # We do not pass the default to `_ConfigFieldFlag`, otherwise # `_LiteralParser.parse(default)` is called with `default`, # which would try to parse string. diff --git a/ml_collections/config_flags/tests/config_overriding_test.py b/ml_collections/config_flags/tests/config_overriding_test.py index 2cc0277..68a1478 100644 --- a/ml_collections/config_flags/tests/config_overriding_test.py +++ b/ml_collections/config_flags/tests/config_overriding_test.py @@ -493,6 +493,21 @@ def testLoadingLockedConfigDict(self): self.assertFalse(values.test_config.is_locked) self.assertFalse(values.test_config.nested_configdict.is_locked) + def testOverridingNestedConfigDict(self): + """Tests overriding of ConfigDict fields.""" + + config_flag = f'--test_config={_CONFIGDICT_CONFIG_FILE}' + values = _parse_flags( + f'./program {config_flag}' + ' --test_config.nested_configdict="{\\"a\\": True, \\"b\\": 123}"' + ) + self.assertEqual(values.test_config.nested_configdict.a, True) + self.assertEqual(values.test_config.nested_configdict.b, 123) + self.assertEqual( + dict(values.test_config.nested_configdict.items()), + {'a': True, 'b': 123}, + ) + @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_DIRECTORY)), ('WithTwoDashes', '--test_config {}'.format(_TEST_DIRECTORY)),