Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,20 @@ def get_dataarray(

@cached_property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
def boundary_mask(self) -> Optional[xr.DataArray]:
"""
Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.

For global datastores that have no lateral boundaries, this should
return None.

Returns
-------
xr.DataArray
xr.DataArray or None
The boundary mask for the dataset, with dimensions
`('grid_index',)`.
`('grid_index',)`, or None for global domains.

"""
pass
Expand Down
16 changes: 11 additions & 5 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@ def __init__(
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)

boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
).unsqueeze(
1
) # add feature dim
if da_boundary_mask is not None:
boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
).unsqueeze(
1
) # add feature dim
else:
# Global domain: no boundary points, all grid points are interior
boundary_mask = torch.zeros(
self.num_grid_nodes, 1, dtype=torch.float32
)

self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def plot_on_axis(
shading="auto",
)

if boundary_alpha is not None:
if boundary_alpha is not None and datastore.boundary_mask is not None:
# Overlay boundary mask
mask_da = datastore.boundary_mask
mask_values = mask_da.values
Expand All @@ -128,7 +128,7 @@ def plot_on_axis(
shading="auto",
)

if crop_to_interior:
if crop_to_interior and datastore.boundary_mask is not None:
# Calculate extent of interior
mask_da = datastore.boundary_mask
mask_values = mask_da.values
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

# Local
from .dummy_datastore import DummyDatastore
from .dummy_datastore import DummyDatastore, GlobalDummyDatastore

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset():
),
npyfilesmeps=None,
dummydata=None,
dummydata_global=None,
)

DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore
DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore


def init_datastore_example(datastore_kind):
Expand Down
23 changes: 23 additions & 0 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,26 @@ def grid_shape_state(self) -> CartesianGridShape:

n_points_1d = int(np.sqrt(self.num_grid_points))
return CartesianGridShape(x=n_points_1d, y=n_points_1d)


class GlobalDummyDatastore(DummyDatastore):
"""
Variant of DummyDatastore that simulates a global domain with no
lateral boundaries. A global setup should not have a boundary mask
at all, so ``boundary_mask`` returns None.
"""

SHORT_NAME = "dummydata_global"

@cached_property
def boundary_mask(self) -> None:
"""Return None for global domains (no boundary mask).

A global domain has no lateral boundaries, so no boundary mask
is needed. The model will use predictions everywhere.

Returns
-------
None
"""
return None
12 changes: 9 additions & 3 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,22 @@ def test_get_dataarray(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
"""Check that the `datastore.boundary_mask` property is implemented and
that the returned object is an xarray DataArray with the correct shape."""
that the returned object is an xarray DataArray with the correct shape,
or None for global domains."""
datastore = init_datastore_example(datastore_name)
da_mask = datastore.boundary_mask

if da_mask is None:
# Global-style datastore: no boundary mask at all
return

assert isinstance(da_mask, xr.DataArray)
assert set(da_mask.dims) == {"grid_index"}
assert da_mask.dtype == "int"
assert set(da_mask.values) == {0, 1}
assert da_mask.sum() > 0
assert da_mask.sum() < da_mask.size
# If a mask is present it must have both boundary and interior points
mask_sum = int(da_mask.sum())
assert 0 < mask_sum < da_mask.size

if isinstance(datastore, BaseRegularGridDatastore):
grid_shape = datastore.grid_shape_state
Expand Down
13 changes: 13 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,16 @@ def test_training(datastore_name):
def test_training_output_std():
datastore = init_datastore_example("mdp") # Test only with mdp datastore
run_simple_training(datastore, set_output_std=True)


def test_training_global():
"""Test global-style datastore boundary behavior (no boundary mask).
Training with this datastore is already exercised via the parametrized
test_training, so this test only verifies the global boundary mask
property to avoid duplicate expensive training runs."""
datastore = init_datastore_example("dummydata_global")

# Verify the global property: boundary mask should be None
assert datastore.boundary_mask is None, (
"GlobalDummyDatastore boundary_mask should be None for global domains"
)