diff --git a/CHANGELOG.md b/CHANGELOG.md index d976ad4b9..41266d5f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add `GlobalDummyDatastore` to support testing and development of global domain models without requiring large datasets. [\#453](https://github.com/mllam/neural-lam/pull/453) @sohampatil01-svg + - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar ### Changed diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f1bcb461d..e8f7ca8eb 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -327,7 +327,10 @@ def all_gather_cat(self, tensor_to_gather): returns: (K*d1, d2, ...) """ - return self.all_gather(tensor_to_gather).flatten(0, 1) + gathered = self.all_gather(tensor_to_gather) + if gathered.dim() > tensor_to_gather.dim(): + return gathered.flatten(0, 1) + return gathered # newer lightning versions requires batch_idx argument, even if unused # pylint: disable-next=unused-argument diff --git a/tests/conftest.py b/tests/conftest.py index 47237ed55..c3f6a1445 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ ) # Local -from .dummy_datastore import DummyDatastore +from .dummy_datastore import DummyDatastore, GlobalDummyDatastore # Disable weights and biases to avoid unnecessary logging # and to avoid having to deal with authentication @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset(): ), npyfilesmeps=None, dummydata=None, + dummydata_global=None, ) DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore +DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore def init_datastore_example(datastore_kind): diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index e3de83ca4..3108bb0cd 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -466,3 +466,18 @@ def grid_shape_state(self) -> CartesianGridShape: n_points_1d = int(np.sqrt(self.num_grid_points)) return CartesianGridShape(x=n_points_1d, y=n_points_1d) + + +class GlobalDummyDatastore(DummyDatastore): + """ + Dummy datastore that represents a global domain (no lateral boundaries). + """ + + SHORT_NAME = "dummydata_global" + + @cached_property + def boundary_mask(self) -> xr.DataArray: + """ + Return an all-zero boundary mask. + """ + return xr.zeros_like(self.ds["boundary_mask"]) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index c719c69df..885d8f2a8 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -229,9 +229,15 @@ def test_boundary_mask(datastore_name): assert isinstance(da_mask, xr.DataArray) assert set(da_mask.dims) == {"grid_index"} assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size + assert set(da_mask.values).issubset({0, 1}) + + if "global" in datastore_name: + # global datastores have no lateral boundaries + assert da_mask.sum() == 0 + else: + # LAM datastores must have a boundary + assert da_mask.sum() > 0 + assert da_mask.sum() < da_mask.size if isinstance(datastore, BaseRegularGridDatastore): grid_shape = datastore.grid_shape_state diff --git a/tests/test_training.py b/tests/test_training.py index e8c131572..930742c41 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -41,10 +41,7 @@ def run_simple_training(datastore, set_output_std): max_epochs=1, deterministic=True, accelerator=device_name, - # XXX: `devices` has to be set to 2 otherwise - # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails - # because it expects to aggregate over multiple devices - devices=2, + devices=1, log_every_n_steps=1, # use `detect_anomaly` to ensure that we don't have NaNs popping up # during training @@ -124,3 +121,11 @@ def test_training(datastore_name): def test_training_output_std(): datastore = init_datastore_example("mdp") # Test only with mdp datastore run_simple_training(datastore, set_output_std=True) + + +def test_training_global(): + """ + Test training with a global datastore (no boundaries). + """ + datastore = init_datastore_example("dummydata_global") + run_simple_training(datastore, set_output_std=False)