Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `GlobalDummyDatastore` to support testing and development of global domain models without requiring large datasets. [\#453](https://github.com/mllam/neural-lam/pull/453) @sohampatil01-svg

- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar

### Changed
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ def all_gather_cat(self, tensor_to_gather):

returns: (K*d1, d2, ...)
"""
return self.all_gather(tensor_to_gather).flatten(0, 1)
gathered = self.all_gather(tensor_to_gather)
if gathered.dim() > tensor_to_gather.dim():
return gathered.flatten(0, 1)
return gathered

# newer lightning versions requires batch_idx argument, even if unused
# pylint: disable-next=unused-argument
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

# Local
from .dummy_datastore import DummyDatastore
from .dummy_datastore import DummyDatastore, GlobalDummyDatastore

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset():
),
npyfilesmeps=None,
dummydata=None,
dummydata_global=None,
)

DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore
DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore


def init_datastore_example(datastore_kind):
Expand Down
15 changes: 15 additions & 0 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,18 @@ def grid_shape_state(self) -> CartesianGridShape:

n_points_1d = int(np.sqrt(self.num_grid_points))
return CartesianGridShape(x=n_points_1d, y=n_points_1d)


class GlobalDummyDatastore(DummyDatastore):
"""
Dummy datastore that represents a global domain (no lateral boundaries).
"""

SHORT_NAME = "dummydata_global"

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Return an all-zero boundary mask.
"""
return xr.zeros_like(self.ds["boundary_mask"])
12 changes: 9 additions & 3 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,15 @@ def test_boundary_mask(datastore_name):
assert isinstance(da_mask, xr.DataArray)
assert set(da_mask.dims) == {"grid_index"}
assert da_mask.dtype == "int"
assert set(da_mask.values) == {0, 1}
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size
assert set(da_mask.values).issubset({0, 1})

if "global" in datastore_name:
# global datastores have no lateral boundaries
assert da_mask.sum() == 0
else:
# LAM datastores must have a boundary
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size

if isinstance(datastore, BaseRegularGridDatastore):
grid_shape = datastore.grid_shape_state
Expand Down
13 changes: 9 additions & 4 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ def run_simple_training(datastore, set_output_std):
max_epochs=1,
deterministic=True,
accelerator=device_name,
# XXX: `devices` has to be set to 2 otherwise
# neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails
# because it expects to aggregate over multiple devices
devices=2,
devices=1,
log_every_n_steps=1,
# use `detect_anomaly` to ensure that we don't have NaNs popping up
# during training
Expand Down Expand Up @@ -124,3 +121,11 @@ def test_training(datastore_name):
def test_training_output_std():
datastore = init_datastore_example("mdp") # Test only with mdp datastore
run_simple_training(datastore, set_output_std=True)


def test_training_global():
"""
Test training with a global datastore (no boundaries).
"""
datastore = init_datastore_example("dummydata_global")
run_simple_training(datastore, set_output_std=False)