Skip to content
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `GlobalDummyDatastore` to support testing and development of global domain models without requiring large datasets. [\#453](https://github.com/mllam/neural-lam/pull/453) @sohampatil01-svg

- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar

### Changed
Expand All @@ -18,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Handle single-device runs in all_gather_cat by skipping all_gather and returning the concatenated input tensor directly. [\#452](https://github.com/mllam/neural-lam/pull/452) @sohampatil01-svg

- Initialize `da_forcing_mean` and `da_forcing_std` to `None` when forcing data is absent, fixing `AttributeError` in `WeatherDataset` with `standardize=True` [\#369](https://github.com/mllam/neural-lam/issues/369) @Sir-Sloth-The-Lazy

- Ensure proper sorting of `analysis_time` in `NpyFilesDatastoreMEPS._get_analysis_times` independent of the order in which files are processed with glob [\#386](https://github.com/mllam/neural-lam/pull/386) @Gopisokk
Expand Down Expand Up @@ -57,6 +61,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Maintenance

- Fix reST formatting errors in `ARModel` methods to ensure documentation builds correctly. [\#458](https://github.com/mllam/neural-lam/pull/458) @sohampatil01-svg

- Update PR template to clarify milestone/roadmap requirement and maintenance changes [\#186](https://github.com/mllam/neural-lam/pull/186) @joeloskarsson

- Update CI/CD to use python 3.13 for testing and full range of current python versions for linting (3.10 - 3.14) [\#173](https://github.com/mllam/neural-lam/pull/173) @observingClouds
Expand Down
18 changes: 12 additions & 6 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,14 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

def common_step(self, batch):
"""
Predict on single batch batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
Predict on single batch.

Batch consists of:
init_states: (B, 2, num_grid_nodes, d_features)
target_states: (B, pred_steps, num_grid_nodes, d_features)
forcing_features: (B, pred_steps, num_grid_nodes, d_forcing)

Note: index 0 corresponds to index 1 of init_states.
"""
(init_states, target_states, forcing_features, batch_times) = batch

Expand Down Expand Up @@ -327,7 +330,10 @@ def all_gather_cat(self, tensor_to_gather):

returns: (K*d1, d2, ...)
"""
return self.all_gather(tensor_to_gather).flatten(0, 1)
gathered = self.all_gather(tensor_to_gather)
if gathered.dim() > tensor_to_gather.dim():
return gathered.flatten(0, 1)
return gathered

# newer lightning versions requires batch_idx argument, even if unused
# pylint: disable-next=unused-argument
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

# Local
from .dummy_datastore import DummyDatastore
from .dummy_datastore import DummyDatastore, GlobalDummyDatastore

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset():
),
npyfilesmeps=None,
dummydata=None,
dummydata_global=None,
)

DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore
DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore


def init_datastore_example(datastore_kind):
Expand Down
15 changes: 15 additions & 0 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,18 @@ def grid_shape_state(self) -> CartesianGridShape:

n_points_1d = int(np.sqrt(self.num_grid_points))
return CartesianGridShape(x=n_points_1d, y=n_points_1d)


class GlobalDummyDatastore(DummyDatastore):
"""
Dummy datastore that represents a global domain (no lateral boundaries).
"""

SHORT_NAME = "dummydata_global"

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Return an all-zero boundary mask.
"""
return xr.zeros_like(self.ds["boundary_mask"])
12 changes: 9 additions & 3 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,15 @@ def test_boundary_mask(datastore_name):
assert isinstance(da_mask, xr.DataArray)
assert set(da_mask.dims) == {"grid_index"}
assert da_mask.dtype == "int"
assert set(da_mask.values) == {0, 1}
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size
assert set(da_mask.values).issubset({0, 1})

if "global" in datastore_name:
# global datastores have no lateral boundaries
assert da_mask.sum() == 0
else:
# LAM datastores must have a boundary
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size

if isinstance(datastore, BaseRegularGridDatastore):
grid_shape = datastore.grid_shape_state
Expand Down
13 changes: 9 additions & 4 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def run_simple_training(datastore, set_output_std):
max_epochs=1,
deterministic=True,
accelerator=device_name,
# XXX: `devices` has to be set to 2 otherwise
# neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails
# because it expects to aggregate over multiple devices
devices=2,
devices=1,
log_every_n_steps=1,
# use `detect_anomaly` to ensure that we don't have NaNs popping up
# during training
Expand Down Expand Up @@ -124,3 +121,11 @@ def test_training(datastore_name):
def test_training_output_std():
datastore = init_datastore_example("mdp") # Test only with mdp datastore
run_simple_training(datastore, set_output_std=True)


def test_training_global():
"""
Test training with a global datastore (no boundaries).
"""
datastore = init_datastore_example("dummydata_global")
run_simple_training(datastore, set_output_std=False)