Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Allow training with state-only datastores (no `static` features); `get_vars_names` now emits a `UserWarning` instead of raising an error [\#231](https://github.com/mllam/neural-lam/pull/231) @varunsiravuri, fixes [\#125](https://github.com/mllam/neural-lam/issues/125)

- Replace `shell=True` subprocess call in `compute_standardization_stats.py` with a safe argument list and Python-side hostname parsing to prevent command injection via `SLURM_JOB_NODELIST` [\#264](https://github.com/mllam/neural-lam/pull/264) @ashum9
- Avoid NaN when standardizing fields with zero std [#189](https://github.com/mllam/neural-lam/pull/189) @varunsiravuri

Expand Down
12 changes: 10 additions & 2 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,16 @@ def get_vars_names(self, category: str) -> List[str]:
The names of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if f"{category}_feature" not in self._ds:
if category == "forcing":
warnings.warn("no forcing data found in datastore")
elif category == "static":
warnings.warn(
"No static features found in the datastore. "
"Training without static features.",
UserWarning,
stacklevel=2,
)
return []
return self._ds[f"{category}_feature"].values.tolist()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
schema_version: v0.5.0
dataset_version: v0.1.0

output:
variables:
state: [time, grid_index, state_feature]
coord_ranges:
time:
start: 2022-04-01T00:00
end: 2022-04-10T00:00
step: PT3H
x:
start: null
end: -894248.0
y:
start: null
end: -142041.0
chunking:
time: 1
splitting:
dim: time
splits:
train:
start: 2022-04-01T00:00
end: 2022-04-04T00:00
compute_statistics:
ops: [mean, std, diff_mean, diff_std]
dims: [grid_index, time]
val:
start: 2022-04-04T00:00
end: 2022-04-07T00:00
test:
start: 2022-04-07T00:00
end: 2022-04-10T00:00

inputs:
danra_height_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/height_levels.zarr
dims: [time, x, y, altitude]
variables:
u:
altitude:
values: [100,]
units: m
v:
altitude:
values: [100, ]
units: m
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
dims: [altitude]
name_format: "{var_name}{altitude}m"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state

danra_surface:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [time, x, y]
variables:
- r2m
- t2m
dim_mapping:
time:
method: rename
dim: time
grid_index:
method: stack
dims: [x, y]
state_feature:
method: stack_variables_by_var_name
name_format: "{var_name}"
target_output_variable: state

extra:
projection:
class_name: LambertConformal
kwargs:
central_longitude: 25.0
central_latitude: 56.7
standard_parallels: [56.7, 56.7]
globe:
semimajor_axis: 6367470.0
semiminor_axis: 6367470.0
48 changes: 48 additions & 0 deletions tests/test_datastore_static_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Third-party
import pytest

# First-party
from neural_lam.datastore.mdp import MDPDatastore
from tests.conftest import init_datastore_example
from tests.test_training import run_simple_training

STATE_ONLY_CONFIG = (
"tests/datastore_examples/mdp/danra_100m_winds/state_only.datastore.yaml"
)


def test_state_only_datastore_emits_warning():
"""Test that a state-only datastore (no static features) emits a
UserWarning instead of raising an error."""
with pytest.warns(UserWarning, match="No static features"):
datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG)

assert datastore is not None


def test_state_only_datastore_static_returns_empty():
"""Test that get_vars_names returns [] for static in a state-only
datastore, consistent with how missing forcing is handled."""
with pytest.warns(UserWarning):
datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG)

# Static should be empty - this is the core behaviour being fixed
assert datastore.get_vars_names("static") == []

# State variables should still be present and non-empty
assert len(datastore.get_vars_names("state")) > 0


def test_state_only_datastore_forcing_returns_empty():
"""Test that get_vars_names also returns [] for forcing in a
state-only datastore (no forcing in config)."""
with pytest.warns(UserWarning):
datastore = MDPDatastore(config_path=STATE_ONLY_CONFIG)

assert datastore.get_vars_names("forcing") == []


def test_state_only_datastore_training_setup_runs():
"""Run the shared small training setup against the MDP datastore."""
datastore = init_datastore_example(MDPDatastore.SHORT_NAME)
run_simple_training(datastore, set_output_std=False)