diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..46a9051b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for 4D ensemble tensors in `WeatherDataset.create_dataarray_from_tensor()` with `ensemble_member` coordinate binding for probabilistic Zarr exports [\#520](https://github.com/mllam/neural-lam/issues/520) +- Updated `ARModel._create_dataarray_from_tensor()` to support ensemble/probabilistic (4D) tensor conversion for external verification [\#520](https://github.com/mllam/neural-lam/issues/520) +- Updated `ARModel.unroll_prediction()` to properly handle ensemble dimension stacking for ensemble/probabilistic forecasts [\#520](https://github.com/mllam/neural-lam/issues/520) +- Added test `test_dataset_item_create_dataarray_from_tensor_4d_ensemble()` for 4D ensemble DataArray creation [\#520](https://github.com/mllam/neural-lam/issues/520) + - Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a411a3afc..b6be5c6fa 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -170,23 +170,35 @@ def _create_dataarray_from_tensor( """ Create an `xr.DataArray` from a tensor, with the correct dimensions and coordinates to match the datastore used by the model. This function in - in effect is the inverse of what is returned by - `WeatherDataset.__getitem__`. + effect is the inverse of what is returned by `WeatherDataset.__getitem__`. + + Supports both deterministic (3D) and ensemble/probabilistic (4D) forecasts + for Zarr export and external verification. Parameters ---------- tensor : torch.Tensor - The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature]. The tensor will be copied to the CPU if it is - not already there. + The tensor to convert to a `xr.DataArray`. Supported shapes: + - 3D: [time, grid_index, feature] for deterministic predictions + - 4D: [ensemble_member, time, grid_index, feature] for + ensemble/probabilistic predictions + The tensor will be copied to the CPU if it is not already there. time : torch.Tensor The time index or indices for the data, given as tensor representing - epoch time in nanoseconds. The tensor will be - copied to the CPU memory if they are not already there. + epoch time in nanoseconds. For 3D tensors, can be a single value or + list. For 4D tensors, must be a list matching the time dimension. + The tensor will be copied to the CPU memory if not already there. split : str The split of the data, either 'train', 'val', or 'test' category : str The category of the data, either 'state' or 'forcing' + + Returns + ------- + xr.DataArray + DataArray with proper dimensions and coordinates including + 'time', 'grid_index', '{category}_feature', and optionally + 'ensemble_member' for ensemble predictions. Ready for Zarr export. """ # TODO: creating an instance of WeatherDataset here on every call is # not how this should be done but whether WeatherDataset should be @@ -229,11 +241,44 @@ def predict_step(self, prev_state, prev_prev_state, forcing): def unroll_prediction(self, init_states, forcing_features, true_states): """ - Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) - """ + Roll out prediction taking multiple autoregressive steps with model. + + Supports both deterministic (4D) and ensemble/probabilistic (5D) tensor + inputs for Zarr export of probabilistic forecasts. + + Parameters + ---------- + init_states : torch.Tensor + Initial states. Shape: + - Deterministic: (B, 2, num_grid_nodes, d_f) + - Ensemble: (B, S, 2, num_grid_nodes, d_f) where S is ensemble members + forcing_features : torch.Tensor + Forcing features. Shape: + - Deterministic: (B, pred_steps, num_grid_nodes, d_forcing_f) + - Ensemble: (B, S, pred_steps, num_grid_nodes, d_forcing_f) + true_states : torch.Tensor + True boundary states. Shape: + - Deterministic: (B, pred_steps, num_grid_nodes, d_f) + - Ensemble: (B, S, pred_steps, num_grid_nodes, d_f) + + Returns + ------- + prediction : torch.Tensor + Predictions with shape: + - Deterministic: (B, pred_steps, num_grid_nodes, d_f) + - Ensemble: (B, S, pred_steps, num_grid_nodes, d_f) + pred_std : torch.Tensor + Prediction std or per-variable std depending on output_std setting. + """ + # Handle ensemble dimension if present (5D vs 4D) + is_ensemble = init_states.ndim == 5 + if is_ensemble: + B, S, _, N, F = init_states.shape + # Reshape to (B*S, 2, N, F) for processing + init_states = init_states.reshape(B * S, *init_states.shape[2:]) + forcing_features = forcing_features.reshape(B * S, *forcing_features.shape[2:]) + true_states = true_states.reshape(B * S, *true_states.shape[2:]) + prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] prediction_list = [] @@ -247,7 +292,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing ) - # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, + # state: (B*S, num_grid_nodes, d_f) pred_std: (B*S, num_grid_nodes, # d_f) or None # Overwrite border with true state @@ -266,13 +311,25 @@ def unroll_prediction(self, init_states, forcing_features, true_states): prediction = torch.stack( prediction_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) - if self.output_std: - pred_std = torch.stack( - pred_std_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) + ) # (B*S, pred_steps, num_grid_nodes, d_f) + + # Reshape back to ensemble form if needed + if is_ensemble: + prediction = prediction.reshape(B, S, *prediction.shape[1:]) + if self.output_std: + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B*S, pred_steps, num_grid_nodes, d_f) + pred_std = pred_std.reshape(B, S, *pred_std.shape[1:]) + else: + pred_std = self.per_var_std # (d_f,) else: - pred_std = self.per_var_std # (d_f,) + if self.output_std: + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + else: + pred_std = self.per_var_std # (d_f,) return prediction, pred_std diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5547fdd4e..ba47d7894 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -553,28 +553,39 @@ def create_dataarray_from_tensor( and number of times provided and add the x/y coordinates from the datastore. - The number if times provided is expected to match the shape of the + Supports both deterministic (2D/3D) and ensemble/probabilistic (4D) + tensors. + + The number of times provided is expected to match the shape of the tensor. For a 2D tensor, the dimensions are assumed to be (grid_index, {category}_feature) and only a single time should be provided. For a 3D tensor, the dimensions are assumed to be (time, grid_index, - {category}_feature) and a list of times should be provided. + {category}_feature) and a list of times should be provided. For a 4D + tensor, the dimensions are assumed to be (ensemble_member, time, + grid_index, {category}_feature) and a list of times should be provided. Parameters ---------- tensor : torch.Tensor - The tensor to construct the DataArray from, this assumed to have - the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). The tensor will be - copied to the CPU before constructing the DataArray. + The tensor to construct the DataArray from. Supported shapes: + - 2D: (grid_index, {category}_feature) + - 3D: (time, grid_index, {category}_feature) for deterministic + - 4D: (ensemble_member, time, grid_index, {category}_feature) for + ensemble/probabilistic predictions + The tensor will be copied to the CPU before constructing the + DataArray. time : datetime.datetime or list[datetime.datetime] - The time or times of the tensor. + The time or times of the tensor. For 2D tensors, a single value. + For 3D/4D tensors, a list of times matching the time dimension. category : str The category of the tensor, either "state", "forcing" or "static". Returns ------- da : xr.DataArray - The constructed DataArray. + The constructed DataArray with proper dimensions and coordinates + including 'time', 'grid_index', '{category}_feature', and + optionally 'ensemble_member' for ensemble predictions. """ def _is_listlike(obj): @@ -582,6 +593,7 @@ def _is_listlike(obj): return hasattr(obj, "__iter__") and not isinstance(obj, str) add_time_as_dim = False + add_ensemble_as_dim = False if len(tensor.shape) == 2: dims = ["grid_index", f"{category}_feature"] if _is_listlike(time): @@ -596,12 +608,22 @@ def _is_listlike(obj): if not _is_listlike(time): raise ValueError( "Expected a list of times for a 3D tensor with assumed " - "dimensions (time, grid_index, {category}_feature), but " - "got a single time" + "dimensions (time, grid_index, {category}_feature), but got " + "a single time" + ) + elif len(tensor.shape) == 4: + add_time_as_dim = True + add_ensemble_as_dim = True + dims = ["ensemble_member", "time", "grid_index", f"{category}_feature"] + if not _is_listlike(time): + raise ValueError( + "Expected a list of times for a 4D tensor with assumed " + "dimensions (ensemble_member, time, grid_index, " + f"{category}_feature), but got a single time" ) else: raise ValueError( - "Expected tensor to have 2 or 3 dimensions, but got " + "Expected tensor to have 2, 3 or 4 dimensions, but got " f"{len(tensor.shape)}" ) @@ -615,6 +637,14 @@ def _is_listlike(obj): } if add_time_as_dim: coords["time"] = time + if add_ensemble_as_dim: + if ( + "ensemble_member" in da_datastore_state.coords + and da_datastore_state.ensemble_member.size == tensor.shape[0] + ): + coords["ensemble_member"] = da_datastore_state.ensemble_member + else: + coords["ensemble_member"] = np.arange(tensor.shape[0]) da = xr.DataArray( tensor.cpu().numpy(), diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dd863b657..aa67116d9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -156,6 +156,72 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): ) +def test_dataset_item_create_dataarray_from_tensor_4d_ensemble(): + """Test creating DataArray from 4D ensemble tensor (S, T, N, F). + + Verify that 4D tensors representing ensemble/probabilistic predictions + are properly converted to xarray DataArrays with ensemble_member coordinate. + """ + n_timesteps = 15 + n_ensemble_members = 3 + datastore = DummyDatastore(n_timesteps=n_timesteps) + + N_pred_steps = 4 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + dataset = WeatherDataset( + datastore=datastore, + split="train", + ar_steps=N_pred_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + ) + + N_grid = datastore.num_grid_points + N_features = datastore.get_num_data_vars(category="state") + + # Create synthetic 4D ensemble tensor: (S, T, N, F) + tensor_4d = torch.randn(n_ensemble_members, N_pred_steps, N_grid, N_features) + times = [np.datetime64(f"2020-01-{i+1:02d}") for i in range(N_pred_steps)] + + da_ensemble = dataset.create_dataarray_from_tensor( + tensor=tensor_4d, + time=times, + category="state" + ) + + # Verify shape and dimensions + assert da_ensemble.shape == (n_ensemble_members, N_pred_steps, N_grid, N_features) + assert da_ensemble.dims == ("ensemble_member", "time", "grid_index", "state_feature") + + # Verify ensemble_member coordinate exists and has correct size + assert "ensemble_member" in da_ensemble.coords + assert da_ensemble.ensemble_member.size == n_ensemble_members + np.testing.assert_array_equal( + da_ensemble.ensemble_member.values, + np.arange(n_ensemble_members) + ) + + # Verify time coordinate matches input + assert "time" in da_ensemble.coords + assert len(da_ensemble.time) == N_pred_steps + np.testing.assert_array_equal( + da_ensemble.time.values, + np.array(times, dtype="datetime64[ns]") + ) + + # Verify grid_index and state_feature are properly mapped + assert "grid_index" in da_ensemble.coords + assert "state_feature" in da_ensemble.coords + + # Verify values match input tensor + np.testing.assert_allclose( + da_ensemble.values, + tensor_4d.cpu().numpy(), + rtol=1e-6 + ) + + @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_single_batch(datastore_name, split):