diff --git a/CHANGELOG.md b/CHANGELOG.md index f62f3b676..40941487b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Allow training with state-only datastores (no `static` features); `get_vars_names` now emits a `UserWarning` instead of raising an error [\#231](https://github.com/mllam/neural-lam/pull/231) @varunsiravuri, fixes [\#125](https://github.com/mllam/neural-lam/issues/125) + - Replace `shell=True` subprocess call in `compute_standardization_stats.py` with a safe argument list and Python-side hostname parsing to prevent command injection via `SLURM_JOB_NODELIST` [\#264](https://github.com/mllam/neural-lam/pull/264) @ashum9 - Avoid NaN when standardizing fields with zero std [#189](https://github.com/mllam/neural-lam/pull/189) @varunsiravuri diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index ed37a4003..7ef11d075 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -184,8 +184,16 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if f"{category}_feature" not in self._ds: + if category == "forcing": + warnings.warn("no forcing data found in datastore") + elif category == "static": + warnings.warn( + "No static features found in the datastore. " + "Training without static features.", + UserWarning, + stacklevel=2, + ) return [] return self._ds[f"{category}_feature"].values.tolist() diff --git a/tests/datastore_examples/mdp/danra_100m_winds/state_only.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/state_only.datastore.yaml new file mode 100644 index 000000000..87810ad1d --- /dev/null +++ b/tests/datastore_examples/mdp/danra_100m_winds/state_only.datastore.yaml @@ -0,0 +1,89 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + state: [time, grid_index, state_feature] + coord_ranges: + time: + start: 2022-04-01T00:00 + end: 2022-04-10T00:00 + step: PT3H + x: + start: null + end: -894248.0 + y: + start: null + end: -142041.0 + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 2022-04-01T00:00 + end: 2022-04-04T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 2022-04-04T00:00 + end: 2022-04-07T00:00 + test: + start: 2022-04-07T00:00 + end: 2022-04-10T00:00 + +inputs: + danra_height_levels: + path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr + dims: [time, x, y] + variables: + - r2m + - t2m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + state_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: state + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/test_datastore_static_only.py b/tests/test_datastore_static_only.py new file mode 100644 index 000000000..3d5836fae --- /dev/null +++ b/tests/test_datastore_static_only.py @@ -0,0 +1,48 @@ +# Third-party +import pytest + +# First-party +from neural_lam.datastore.mdp import MDPDatastore +from tests.conftest import init_datastore_example +from tests.test_training import run_simple_training + +STATE_ONLY_CONFIG = ( + "tests/datastore_examples/mdp/danra_100m_winds/state_only.datastore.yaml" +) + + +def test_state_only_datastore_emits_warning(): + """Test that a state-only datastore (no static features) emits a + UserWarning instead of raising an error.""" + with pytest.warns(UserWarning, match="No static features"): + datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG) + + assert datastore is not None + + +def test_state_only_datastore_static_returns_empty(): + """Test that get_vars_names returns [] for static in a state-only + datastore, consistent with how missing forcing is handled.""" + with pytest.warns(UserWarning): + datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG) + + # Static should be empty - this is the core behaviour being fixed + assert datastore.get_vars_names("static") == [] + + # State variables should still be present and non-empty + assert len(datastore.get_vars_names("state")) > 0 + + +def test_state_only_datastore_forcing_returns_empty(): + """Test that get_vars_names also returns [] for forcing in a + state-only datastore (no forcing in config).""" + with pytest.warns(UserWarning): + datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG) + + assert datastore.get_vars_names("forcing") == [] + + +def test_state_only_datastore_training_setup_runs(): + """Run the shared small training setup against the MDP datastore.""" + datastore = init_datastore_example(MDPDatastore.SHORT_NAME) + run_simple_training(datastore, set_output_std=False)