Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for 4D ensemble tensors in `WeatherDataset.create_dataarray_from_tensor()` with `ensemble_member` coordinate binding for probabilistic Zarr exports [\#520](https://github.com/mllam/neural-lam/issues/520)
- Updated `ARModel._create_dataarray_from_tensor()` to support ensemble/probabilistic (4D) tensor conversion for external verification [\#520](https://github.com/mllam/neural-lam/issues/520)
- Updated `ARModel.unroll_prediction()` to properly handle ensemble dimension stacking for ensemble/probabilistic forecasts [\#520](https://github.com/mllam/neural-lam/issues/520)
- Added test `test_dataset_item_create_dataarray_from_tensor_4d_ensemble()` for 4D ensemble DataArray creation [\#520](https://github.com/mllam/neural-lam/issues/520)

- Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov

- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar
Expand Down
95 changes: 76 additions & 19 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,35 @@ def _create_dataarray_from_tensor(
"""
Create an `xr.DataArray` from a tensor, with the correct dimensions and
coordinates to match the datastore used by the model. This function in
in effect is the inverse of what is returned by
`WeatherDataset.__getitem__`.
effect is the inverse of what is returned by `WeatherDataset.__getitem__`.

Supports both deterministic (3D) and ensemble/probabilistic (4D) forecasts
for Zarr export and external verification.

Parameters
----------
tensor : torch.Tensor
The tensor to convert to a `xr.DataArray` with dimensions [time,
grid_index, feature]. The tensor will be copied to the CPU if it is
not already there.
The tensor to convert to a `xr.DataArray`. Supported shapes:
- 3D: [time, grid_index, feature] for deterministic predictions
- 4D: [ensemble_member, time, grid_index, feature] for
ensemble/probabilistic predictions
The tensor will be copied to the CPU if it is not already there.
time : torch.Tensor
The time index or indices for the data, given as tensor representing
epoch time in nanoseconds. The tensor will be
copied to the CPU memory if they are not already there.
epoch time in nanoseconds. For 3D tensors, can be a single value or
list. For 4D tensors, must be a list matching the time dimension.
The tensor will be copied to the CPU memory if not already there.
split : str
The split of the data, either 'train', 'val', or 'test'
category : str
The category of the data, either 'state' or 'forcing'

Returns
-------
xr.DataArray
DataArray with proper dimensions and coordinates including
'time', 'grid_index', '{category}_feature', and optionally
'ensemble_member' for ensemble predictions. Ready for Zarr export.
"""
# TODO: creating an instance of WeatherDataset here on every call is
# not how this should be done but whether WeatherDataset should be
Expand Down Expand Up @@ -229,11 +241,44 @@ def predict_step(self, prev_state, prev_prev_state, forcing):

def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B,
pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps,
num_grid_nodes, d_f)
"""
Roll out prediction taking multiple autoregressive steps with model.

Supports both deterministic (4D) and ensemble/probabilistic (5D) tensor
inputs for Zarr export of probabilistic forecasts.

Parameters
----------
init_states : torch.Tensor
Initial states. Shape:
- Deterministic: (B, 2, num_grid_nodes, d_f)
- Ensemble: (B, S, 2, num_grid_nodes, d_f) where S is ensemble members
forcing_features : torch.Tensor
Forcing features. Shape:
- Deterministic: (B, pred_steps, num_grid_nodes, d_forcing_f)
- Ensemble: (B, S, pred_steps, num_grid_nodes, d_forcing_f)
true_states : torch.Tensor
True boundary states. Shape:
- Deterministic: (B, pred_steps, num_grid_nodes, d_f)
- Ensemble: (B, S, pred_steps, num_grid_nodes, d_f)

Returns
-------
prediction : torch.Tensor
Predictions with shape:
- Deterministic: (B, pred_steps, num_grid_nodes, d_f)
- Ensemble: (B, S, pred_steps, num_grid_nodes, d_f)
pred_std : torch.Tensor
Prediction std or per-variable std depending on output_std setting.
"""
# Handle ensemble dimension if present (5D vs 4D)
is_ensemble = init_states.ndim == 5
if is_ensemble:
B, S, _, N, F = init_states.shape
# Reshape to (B*S, 2, N, F) for processing
init_states = init_states.reshape(B * S, *init_states.shape[2:])
forcing_features = forcing_features.reshape(B * S, *forcing_features.shape[2:])
true_states = true_states.reshape(B * S, *true_states.shape[2:])

prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
prediction_list = []
Expand All @@ -247,7 +292,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
)
# state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes,
# state: (B*S, num_grid_nodes, d_f) pred_std: (B*S, num_grid_nodes,
# d_f) or None

# Overwrite border with true state
Expand All @@ -266,13 +311,25 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

prediction = torch.stack(
prediction_list, dim=1
) # (B, pred_steps, num_grid_nodes, d_f)
if self.output_std:
pred_std = torch.stack(
pred_std_list, dim=1
) # (B, pred_steps, num_grid_nodes, d_f)
) # (B*S, pred_steps, num_grid_nodes, d_f)

# Reshape back to ensemble form if needed
if is_ensemble:
prediction = prediction.reshape(B, S, *prediction.shape[1:])
if self.output_std:
pred_std = torch.stack(
pred_std_list, dim=1
) # (B*S, pred_steps, num_grid_nodes, d_f)
pred_std = pred_std.reshape(B, S, *pred_std.shape[1:])
else:
pred_std = self.per_var_std # (d_f,)
else:
pred_std = self.per_var_std # (d_f,)
if self.output_std:
pred_std = torch.stack(
pred_std_list, dim=1
) # (B, pred_steps, num_grid_nodes, d_f)
else:
pred_std = self.per_var_std # (d_f,)

return prediction, pred_std

Expand Down
52 changes: 41 additions & 11 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,35 +553,47 @@ def create_dataarray_from_tensor(
and number of times provided and add the x/y coordinates from the
datastore.

The number if times provided is expected to match the shape of the
Supports both deterministic (2D/3D) and ensemble/probabilistic (4D)
tensors.

The number of times provided is expected to match the shape of the
tensor. For a 2D tensor, the dimensions are assumed to be (grid_index,
{category}_feature) and only a single time should be provided. For a 3D
tensor, the dimensions are assumed to be (time, grid_index,
{category}_feature) and a list of times should be provided.
{category}_feature) and a list of times should be provided. For a 4D
tensor, the dimensions are assumed to be (ensemble_member, time,
grid_index, {category}_feature) and a list of times should be provided.

Parameters
----------
tensor : torch.Tensor
The tensor to construct the DataArray from, this assumed to have
the same dimension ordering as returned by the __getitem__ method
(i.e. time, grid_index, {category}_feature). The tensor will be
copied to the CPU before constructing the DataArray.
The tensor to construct the DataArray from. Supported shapes:
- 2D: (grid_index, {category}_feature)
- 3D: (time, grid_index, {category}_feature) for deterministic
- 4D: (ensemble_member, time, grid_index, {category}_feature) for
ensemble/probabilistic predictions
The tensor will be copied to the CPU before constructing the
DataArray.
time : datetime.datetime or list[datetime.datetime]
The time or times of the tensor.
The time or times of the tensor. For 2D tensors, a single value.
For 3D/4D tensors, a list of times matching the time dimension.
category : str
The category of the tensor, either "state", "forcing" or "static".

Returns
-------
da : xr.DataArray
The constructed DataArray.
The constructed DataArray with proper dimensions and coordinates
including 'time', 'grid_index', '{category}_feature', and
optionally 'ensemble_member' for ensemble predictions.
"""

def _is_listlike(obj):
# match list, tuple, numpy array
return hasattr(obj, "__iter__") and not isinstance(obj, str)

add_time_as_dim = False
add_ensemble_as_dim = False
if len(tensor.shape) == 2:
dims = ["grid_index", f"{category}_feature"]
if _is_listlike(time):
Expand All @@ -596,12 +608,22 @@ def _is_listlike(obj):
if not _is_listlike(time):
raise ValueError(
"Expected a list of times for a 3D tensor with assumed "
"dimensions (time, grid_index, {category}_feature), but "
"got a single time"
"dimensions (time, grid_index, {category}_feature), but got "
"a single time"
)
elif len(tensor.shape) == 4:
add_time_as_dim = True
add_ensemble_as_dim = True
dims = ["ensemble_member", "time", "grid_index", f"{category}_feature"]
if not _is_listlike(time):
raise ValueError(
"Expected a list of times for a 4D tensor with assumed "
"dimensions (ensemble_member, time, grid_index, "
f"{category}_feature), but got a single time"
)
else:
raise ValueError(
"Expected tensor to have 2 or 3 dimensions, but got "
"Expected tensor to have 2, 3 or 4 dimensions, but got "
f"{len(tensor.shape)}"
)

Expand All @@ -615,6 +637,14 @@ def _is_listlike(obj):
}
if add_time_as_dim:
coords["time"] = time
if add_ensemble_as_dim:
if (
"ensemble_member" in da_datastore_state.coords
and da_datastore_state.ensemble_member.size == tensor.shape[0]
):
coords["ensemble_member"] = da_datastore_state.ensemble_member
else:
coords["ensemble_member"] = np.arange(tensor.shape[0])

da = xr.DataArray(
tensor.cpu().numpy(),
Expand Down
66 changes: 66 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,72 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name):
)


def test_dataset_item_create_dataarray_from_tensor_4d_ensemble():
"""Test creating DataArray from 4D ensemble tensor (S, T, N, F).

Verify that 4D tensors representing ensemble/probabilistic predictions
are properly converted to xarray DataArrays with ensemble_member coordinate.
"""
n_timesteps = 15
n_ensemble_members = 3
datastore = DummyDatastore(n_timesteps=n_timesteps)

N_pred_steps = 4
num_past_forcing_steps = 1
num_future_forcing_steps = 1
dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=N_pred_steps,
num_past_forcing_steps=num_past_forcing_steps,
num_future_forcing_steps=num_future_forcing_steps,
)

N_grid = datastore.num_grid_points
N_features = datastore.get_num_data_vars(category="state")

# Create synthetic 4D ensemble tensor: (S, T, N, F)
tensor_4d = torch.randn(n_ensemble_members, N_pred_steps, N_grid, N_features)
times = [np.datetime64(f"2020-01-{i+1:02d}") for i in range(N_pred_steps)]

da_ensemble = dataset.create_dataarray_from_tensor(
tensor=tensor_4d,
time=times,
category="state"
)

# Verify shape and dimensions
assert da_ensemble.shape == (n_ensemble_members, N_pred_steps, N_grid, N_features)
assert da_ensemble.dims == ("ensemble_member", "time", "grid_index", "state_feature")

# Verify ensemble_member coordinate exists and has correct size
assert "ensemble_member" in da_ensemble.coords
assert da_ensemble.ensemble_member.size == n_ensemble_members
np.testing.assert_array_equal(
da_ensemble.ensemble_member.values,
np.arange(n_ensemble_members)
)

# Verify time coordinate matches input
assert "time" in da_ensemble.coords
assert len(da_ensemble.time) == N_pred_steps
np.testing.assert_array_equal(
da_ensemble.time.values,
np.array(times, dtype="datetime64[ns]")
)

# Verify grid_index and state_feature are properly mapped
assert "grid_index" in da_ensemble.coords
assert "state_feature" in da_ensemble.coords

# Verify values match input tensor
np.testing.assert_allclose(
da_ensemble.values,
tensor_4d.cpu().numpy(),
rtol=1e-6
)


@pytest.mark.parametrize("split", ["train", "val", "test"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_single_batch(datastore_name, split):
Expand Down