diff --git a/neural_lam/config.py b/neural_lam/config.py index f4195ec3..eba51ea2 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -153,6 +153,69 @@ class InvalidConfigError(Exception): pass +def validate_config(config: "NeuralLAMConfig", config_path: str) -> None: + """Validate a loaded NeuralLAMConfig and raise descriptive errors. + + This runs after YAML parsing so dataclass defaults are already applied. + It catches issues that would otherwise produce cryptic runtime tracebacks, + such as a missing or non-existent datastore config file. + + Parameters + ---------- + config : NeuralLAMConfig + The fully loaded config object. + config_path : str + Path to the neural-lam YAML config file. Used to resolve the + datastore config path, which is relative to this file. + + Raises + ------ + InvalidConfigError + If any required field is missing, invalid, or points to a + non-existent file. + """ + errors = [] + + # datastore.config_path must resolve to an existing file + resolved = Path(config_path).parent / config.datastore.config_path + if not resolved.exists(): + errors.append( + f"Missing required config field: 'datastore.config_path'.\n" + f" Resolved path does not exist: {resolved}\n" + f" This path is resolved relative to your neural-lam config " + f"file at: {config_path}\n" + f" Check that 'config_path' in the 'datastore' section is " + f"correct." + ) + + # if ManualStateFeatureWeighting, weights must not be empty + weighting = config.training.state_feature_weighting + if isinstance(weighting, ManualStateFeatureWeighting): + if not weighting.weights: + errors.append( + "Invalid config field: 'training.state_feature_weighting.weights'.\n" + " ManualStateFeatureWeighting requires at least one weight " + "entry.\n" + " Example:\n" + " training:\n" + " state_feature_weighting:\n" + " __config_class__: ManualStateFeatureWeighting\n" + " weights:\n" + " u100m: 1.0\n" + " v100m: 1.0" + ) + + if errors: + error_list = "\n\n".join( + f" [{i + 1}] {e}" for i, e in enumerate(errors) + ) + raise InvalidConfigError( + f"neural-lam config validation failed " + f"({len(errors)} error(s)):\n\n{error_list}\n\n" + f"Refer to the config documentation for correct usage." + ) + + def load_config_and_datastore( config_path: str, ) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]: @@ -177,6 +240,9 @@ def load_config_and_datastore( "There was an error loading the configuration file at " f"{config_path}. " ) from ex + + validate_config(config, config_path) + # datastore config is assumed to be relative to the config file datastore_config_path = ( Path(config_path).parent / config.datastore.config_path diff --git a/tests/test_config.py b/tests/test_config.py index 1ff40bc6..2ca6b910 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -70,3 +70,99 @@ def test_config_serialization(state_weighting_config): def test_config_load_from_yaml(yaml_str, config_expected): c = nlconfig.NeuralLAMConfig.from_yaml(yaml_str) assert c == config_expected + + +# Tests for validate_config +def test_validate_config_passes_with_existing_datastore_path(tmp_path): + """validate_config should not raise when datastore config_path exists.""" + datastore_file = tmp_path / "datastore.yaml" + datastore_file.write_text("dummy: true\n") + + nlam_config_path = str(tmp_path / "nlam_config.yaml") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind="mdp", config_path="datastore.yaml" + ), + training=nlconfig.TrainingConfig(), + ) + + nlconfig.validate_config(config, nlam_config_path) + + +def test_validate_config_raises_on_missing_datastore_file(tmp_path): + """validate_config raises InvalidConfigError when the resolved + datastore config path does not exist on disk.""" + nlam_config_path = str(tmp_path / "nlam_config.yaml") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind="mdp", config_path="does_not_exist.yaml" + ), + training=nlconfig.TrainingConfig(), + ) + + with pytest.raises(nlconfig.InvalidConfigError, match="datastore.config_path"): + nlconfig.validate_config(config, nlam_config_path) + + +def test_validate_config_error_message_contains_resolved_path(tmp_path): + """The error message must contain the resolved path so users + know exactly what file is missing.""" + nlam_config_path = str(tmp_path / "nlam_config.yaml") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind="mdp", config_path="missing.yaml" + ), + training=nlconfig.TrainingConfig(), + ) + + with pytest.raises(nlconfig.InvalidConfigError) as exc_info: + nlconfig.validate_config(config, nlam_config_path) + + assert "missing.yaml" in str(exc_info.value) + + +def test_validate_config_raises_on_empty_manual_weights(tmp_path): + """ManualStateFeatureWeighting with an empty weights dict is invalid + and should raise InvalidConfigError at startup.""" + datastore_file = tmp_path / "datastore.yaml" + datastore_file.write_text("dummy: true\n") + nlam_config_path = str(tmp_path / "nlam_config.yaml") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind="mdp", config_path="datastore.yaml" + ), + training=nlconfig.TrainingConfig( + state_feature_weighting=nlconfig.ManualStateFeatureWeighting( + weights={} + ) + ), + ) + + with pytest.raises( + nlconfig.InvalidConfigError, match="state_feature_weighting" + ): + nlconfig.validate_config(config, nlam_config_path) + + +def test_validate_config_passes_with_manual_weights(tmp_path): + """ManualStateFeatureWeighting with actual weights should pass.""" + datastore_file = tmp_path / "datastore.yaml" + datastore_file.write_text("dummy: true\n") + nlam_config_path = str(tmp_path / "nlam_config.yaml") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind="mdp", config_path="datastore.yaml" + ), + training=nlconfig.TrainingConfig( + state_feature_weighting=nlconfig.ManualStateFeatureWeighting( + weights={"u100m": 1.0, "v100m": 0.5} + ) + ), + ) + + nlconfig.validate_config(config, nlam_config_path)