Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,69 @@ class InvalidConfigError(Exception):
pass


def validate_config(config: "NeuralLAMConfig", config_path: str) -> None:
"""Validate a loaded NeuralLAMConfig and raise descriptive errors.

This runs after YAML parsing so dataclass defaults are already applied.
It catches issues that would otherwise produce cryptic runtime tracebacks,
such as a missing or non-existent datastore config file.

Parameters
----------
config : NeuralLAMConfig
The fully loaded config object.
config_path : str
Path to the neural-lam YAML config file. Used to resolve the
datastore config path, which is relative to this file.

Raises
------
InvalidConfigError
If any required field is missing, invalid, or points to a
non-existent file.
"""
errors = []

# datastore.config_path must resolve to an existing file
resolved = Path(config_path).parent / config.datastore.config_path
if not resolved.exists():
errors.append(
f"Missing required config field: 'datastore.config_path'.\n"
f" Resolved path does not exist: {resolved}\n"
f" This path is resolved relative to your neural-lam config "
f"file at: {config_path}\n"
f" Check that 'config_path' in the 'datastore' section is "
f"correct."
)

# if ManualStateFeatureWeighting, weights must not be empty
weighting = config.training.state_feature_weighting
if isinstance(weighting, ManualStateFeatureWeighting):
if not weighting.weights:
errors.append(
"Invalid config field: 'training.state_feature_weighting.weights'.\n"
" ManualStateFeatureWeighting requires at least one weight "
"entry.\n"
" Example:\n"
" training:\n"
" state_feature_weighting:\n"
" __config_class__: ManualStateFeatureWeighting\n"
" weights:\n"
" u100m: 1.0\n"
" v100m: 1.0"
)

if errors:
error_list = "\n\n".join(
f" [{i + 1}] {e}" for i, e in enumerate(errors)
)
raise InvalidConfigError(
f"neural-lam config validation failed "
f"({len(errors)} error(s)):\n\n{error_list}\n\n"
f"Refer to the config documentation for correct usage."
)


def load_config_and_datastore(
config_path: str,
) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]:
Expand All @@ -177,6 +240,9 @@ def load_config_and_datastore(
"There was an error loading the configuration file at "
f"{config_path}. "
) from ex

validate_config(config, config_path)

# datastore config is assumed to be relative to the config file
datastore_config_path = (
Path(config_path).parent / config.datastore.config_path
Expand Down
96 changes: 96 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,99 @@ def test_config_serialization(state_weighting_config):
def test_config_load_from_yaml(yaml_str, config_expected):
c = nlconfig.NeuralLAMConfig.from_yaml(yaml_str)
assert c == config_expected


# Tests for validate_config
def test_validate_config_passes_with_existing_datastore_path(tmp_path):
"""validate_config should not raise when datastore config_path exists."""
datastore_file = tmp_path / "datastore.yaml"
datastore_file.write_text("dummy: true\n")

nlam_config_path = str(tmp_path / "nlam_config.yaml")

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind="mdp", config_path="datastore.yaml"
),
training=nlconfig.TrainingConfig(),
)

nlconfig.validate_config(config, nlam_config_path)


def test_validate_config_raises_on_missing_datastore_file(tmp_path):
"""validate_config raises InvalidConfigError when the resolved
datastore config path does not exist on disk."""
nlam_config_path = str(tmp_path / "nlam_config.yaml")

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind="mdp", config_path="does_not_exist.yaml"
),
training=nlconfig.TrainingConfig(),
)

with pytest.raises(nlconfig.InvalidConfigError, match="datastore.config_path"):
nlconfig.validate_config(config, nlam_config_path)


def test_validate_config_error_message_contains_resolved_path(tmp_path):
"""The error message must contain the resolved path so users
know exactly what file is missing."""
nlam_config_path = str(tmp_path / "nlam_config.yaml")

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind="mdp", config_path="missing.yaml"
),
training=nlconfig.TrainingConfig(),
)

with pytest.raises(nlconfig.InvalidConfigError) as exc_info:
nlconfig.validate_config(config, nlam_config_path)

assert "missing.yaml" in str(exc_info.value)


def test_validate_config_raises_on_empty_manual_weights(tmp_path):
"""ManualStateFeatureWeighting with an empty weights dict is invalid
and should raise InvalidConfigError at startup."""
datastore_file = tmp_path / "datastore.yaml"
datastore_file.write_text("dummy: true\n")
nlam_config_path = str(tmp_path / "nlam_config.yaml")

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind="mdp", config_path="datastore.yaml"
),
training=nlconfig.TrainingConfig(
state_feature_weighting=nlconfig.ManualStateFeatureWeighting(
weights={}
)
),
)

with pytest.raises(
nlconfig.InvalidConfigError, match="state_feature_weighting"
):
nlconfig.validate_config(config, nlam_config_path)


def test_validate_config_passes_with_manual_weights(tmp_path):
"""ManualStateFeatureWeighting with actual weights should pass."""
datastore_file = tmp_path / "datastore.yaml"
datastore_file.write_text("dummy: true\n")
nlam_config_path = str(tmp_path / "nlam_config.yaml")

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind="mdp", config_path="datastore.yaml"
),
training=nlconfig.TrainingConfig(
state_feature_weighting=nlconfig.ManualStateFeatureWeighting(
weights={"u100m": 1.0, "v100m": 0.5}
)
),
)

nlconfig.validate_config(config, nlam_config_path)