diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..f00c0114d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov +- Support global domains with no boundary mask: `boundary_mask` can now return `None`, and the model and visualisation code handles this gracefully [\#444](https://github.com/mllam/neural-lam/pull/444) @RajdeepKushwaha5 + - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar ### Changed diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index f6dddb007..854c05b15 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -274,17 +274,20 @@ def get_dataarray( @cached_property @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: + def boundary_mask(self) -> Optional[xr.DataArray]: """ Return the boundary mask for the dataset, with spatial dimensions stacked. Where the value is 1, the grid point is a boundary point, and where the value is 0, the grid point is not a boundary point. + For global datastores that have no lateral boundaries, this should + return None. + Returns ------- - xr.DataArray + xr.DataArray or None The boundary mask for the dataset, with dimensions - `('grid_index',)`. + `('grid_index',)`, or None for global domains. """ pass diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a411a3afc..bd3131321 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -124,11 +124,17 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim + if da_boundary_mask is not None: + boundary_mask = torch.tensor( + da_boundary_mask.values, dtype=torch.float32 + ).unsqueeze( + 1 + ) # add feature dim + else: + # Global domain: no boundary points, all grid points are interior + boundary_mask = torch.zeros( + self.num_grid_nodes, 1, dtype=torch.float32 + ) self.register_buffer("boundary_mask", boundary_mask, persistent=False) # Pre-compute interior mask for use in loss function diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 06f2f6d35..01c544c42 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -108,7 +108,7 @@ def plot_on_axis( shading="auto", ) - if boundary_alpha is not None: + if boundary_alpha is not None and datastore.boundary_mask is not None: # Overlay boundary mask mask_da = datastore.boundary_mask mask_values = mask_da.values @@ -128,7 +128,7 @@ def plot_on_axis( shading="auto", ) - if crop_to_interior: + if crop_to_interior and datastore.boundary_mask is not None: # Calculate extent of interior mask_da = datastore.boundary_mask mask_values = mask_da.values diff --git a/tests/conftest.py b/tests/conftest.py index 47237ed55..c3f6a1445 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ ) # Local -from .dummy_datastore import DummyDatastore +from .dummy_datastore import DummyDatastore, GlobalDummyDatastore # Disable weights and biases to avoid unnecessary logging # and to avoid having to deal with authentication @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset(): ), npyfilesmeps=None, dummydata=None, + dummydata_global=None, ) DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore +DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore def init_datastore_example(datastore_kind): diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 3a844d6d9..16446afa5 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -758,3 +758,26 @@ def num_grid_points(self) -> int: @property def state_feature_weights_values(self) -> list[float]: return [1.0] + + +class GlobalDummyDatastore(DummyDatastore): + """ + Variant of DummyDatastore that simulates a global domain with no + lateral boundaries. A global setup should not have a boundary mask + at all, so ``boundary_mask`` returns None. + """ + + SHORT_NAME = "dummydata_global" + + @cached_property + def boundary_mask(self) -> None: + """Return None for global domains (no boundary mask). + + A global domain has no lateral boundaries, so no boundary mask + is needed. The model will use predictions everywhere. + + Returns + ------- + None + """ + return None diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4ce3875ea..7a9c063b1 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -225,16 +225,22 @@ def test_get_dataarray(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_boundary_mask(datastore_name): """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" + that the returned object is an xarray DataArray with the correct shape, + or None for global domains.""" datastore = init_datastore_example(datastore_name) da_mask = datastore.boundary_mask + if da_mask is None: + # Global-style datastore: no boundary mask at all + return + assert isinstance(da_mask, xr.DataArray) assert set(da_mask.dims) == {"grid_index"} assert da_mask.dtype == "int" assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size + # If a mask is present it must have both boundary and interior points + mask_sum = int(da_mask.sum()) + assert 0 < mask_sum < da_mask.size if isinstance(datastore, BaseRegularGridDatastore): grid_shape = datastore.grid_shape_state diff --git a/tests/test_training.py b/tests/test_training.py index 972740695..cc54b82cb 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -189,3 +189,16 @@ def all_gather(self, tensor, sync_grads=False): "all_gather_cat produced incorrectly ordered/combined values " "on multi-device simulation" ) + + +def test_training_global(): + """Test global-style datastore boundary behavior (no boundary mask). + Training with this datastore is already exercised via the parametrized + test_training, so this test only verifies the global boundary mask + property to avoid duplicate expensive training runs.""" + datastore = init_datastore_example("dummydata_global") + + # Verify the global property: boundary mask should be None + assert ( + datastore.boundary_mask is None + ), "GlobalDummyDatastore boundary_mask should be None for global domains"