diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..65ba93161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar +- Add `GlobalDummyDatastore` test variant for global forecasting support with all-zero boundary mask [\#445](https://github.com/mllam/neural-lam/issues/445) + ### Changed +- Refactor model class hierarchy: Extract `NormalizationManager`, `MetricTracker`, and `ModelVisualizer` from `ARModel` to improve separation of concerns and code maintainability ([#49](https://github.com/mllam/neural-lam/issues/49)) - Change the default ensemble-loading behavior in `WeatherDataset` / `WeatherDataModule` to use all ensemble members as independent samples for ensemble datastores (with matching ensemble-member selection for forcing when available); single-member behavior now requires explicitly opting in via `--load_single_member` [\#332](https://github.com/mllam/neural-lam/pull/332) @kshirajahere - Refactor graph loading: move zero-indexing out of the model and update plotting to prepare using the research-branch graph I/O [\#184](https://github.com/mllam/neural-lam/pull/184) @zweihuehner - Replace `print()`-based `rank_zero_print` with `loguru` `logger.info()` for structured log-level control ([#33](https://github.com/mllam/neural-lam/issues/33)) +- Relax `test_boundary_mask` to accept all-zero boundary masks for global domains in support of global forecasting capabilities [\#445](https://github.com/mllam/neural-lam/issues/45)] ### Fixed diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py index f65387ab6..f8148af57 100644 --- a/neural_lam/models/__init__.py +++ b/neural_lam/models/__init__.py @@ -4,3 +4,6 @@ from .graph_lam import GraphLAM from .hi_lam import HiLAM from .hi_lam_parallel import HiLAMParallel +from .metric_tracker import MetricTracker +from .model_visualizer import ModelVisualizer +from .normalization_manager import NormalizationManager diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a411a3afc..3c02ba9d8 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,7 +1,7 @@ # Standard library import os import warnings -from typing import Any, Dict, List +from typing import Any # Third-party import matplotlib.pyplot as plt @@ -19,6 +19,9 @@ from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting from ..weather_dataset import WeatherDataset +from .metric_tracker import MetricTracker +from .model_visualizer import ModelVisualizer +from .normalization_manager import NormalizationManager class ARModel(pl.LightningModule): @@ -42,47 +45,41 @@ def __init__( self._datastore = datastore num_state_vars = datastore.get_num_data_vars(category="state") num_forcing_vars = datastore.get_num_data_vars(category="forcing") - # Load static features standardized - da_static_features = datastore.get_dataarray( - category="static", split=None, standardize=True + + # Initialize helper modules + self.normalization_manager = NormalizationManager(datastore) + self.metric_tracker = MetricTracker(output_std=bool(args.output_std)) + + # Get time step info for visualizer + self.time_step_int, self.time_step_unit = get_integer_time( + self._datastore.step_length ) - if da_static_features is None: - raise ValueError("Static features are required for ARModel") - da_state_stats = datastore.get_standardization_dataarray( - category="state" + self.visualizer = ModelVisualizer( + datastore=datastore, + args=args, + time_step_int=self.time_step_int, + time_step_unit=self.time_step_unit, ) - da_boundary_mask = datastore.boundary_mask - num_past_forcing_steps = args.num_past_forcing_steps - num_future_forcing_steps = args.num_future_forcing_steps - # Load static features for grid/data, + # Get references to normalization buffers for compatibility + # These are now managed by normalization_manager but exposed as properties + # for backward compatibility with existing code + for key in ["state_mean", "state_std", "diff_mean", "diff_std"]: + self.register_buffer( + key, + getattr(self.normalization_manager, key), + persistent=False, + ) + self.register_buffer( "grid_static_features", - torch.tensor(da_static_features.values, dtype=torch.float32), + self.normalization_manager.grid_static_features, persistent=False, ) - state_stats = { - "state_mean": torch.tensor( - da_state_stats.state_mean.values, dtype=torch.float32 - ), - "state_std": torch.tensor( - da_state_stats.state_std.values, dtype=torch.float32 - ), - # Note that the one-step-diff stats (diff_mean and diff_std) are - # for differences computed on standardized data - "diff_mean": torch.tensor( - da_state_stats.state_diff_mean_standardized.values, - dtype=torch.float32, - ), - "diff_std": torch.tensor( - da_state_stats.state_diff_std_standardized.values, - dtype=torch.float32, - ), - } - - for key, val in state_stats.items(): - self.register_buffer(key, val, persistent=False) + da_boundary_mask = datastore.boundary_mask + num_past_forcing_steps = args.num_past_forcing_steps + num_future_forcing_steps = args.num_future_forcing_steps state_feature_weights = get_state_feature_weighting( config=config, datastore=datastore @@ -136,16 +133,6 @@ def __init__( "interior_mask", 1.0 - self.boundary_mask, persistent=False ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics: Dict[str, List] = { - "mse": [], - } - self.test_metrics: Dict[str, List] = { - "mse": [], - "mae": [], - } - if self.output_std: - self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional self.restore_opt = args.restore_opt @@ -153,13 +140,6 @@ def __init__( self.n_example_pred = args.n_example_pred self.plotted_examples = 0 - # For storing spatial loss maps during evaluation - self.spatial_loss_maps: List[Any] = [] - - self.time_step_int, self.time_step_unit = get_integer_time( - self._datastore.step_length - ) - def _create_dataarray_from_tensor( self, tensor: torch.Tensor, @@ -367,7 +347,7 @@ def validation_step(self, batch, batch_idx): batch_size=batch[0].shape[0], ) - # Store MSEs + # Store MSEs using metric tracker entry_mses = metrics.mse( prediction, target, @@ -375,18 +355,29 @@ def validation_step(self, batch, batch_idx): mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) - self.val_metrics["mse"].append(entry_mses) + self.metric_tracker.add_val_metric("mse", entry_mses) def on_validation_epoch_end(self): """ Compute val metrics at the end of val epoch """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") + # Get metrics from tracker + val_metrics = self.metric_tracker.get_val_metrics() + + # Gather and plot metrics + for metric_name, metric_val_list in val_metrics.items(): + metric_tensor = self.all_gather_cat(torch.cat(metric_val_list, dim=0)) + # Pass the gathered tensor back for visualization + self.visualizer.aggregate_and_plot_metrics( + {metric_name: [metric_tensor]}, + prefix="val", + state_std=self.state_std, + logger=self.logger, + trainer=self.trainer, + ) - # Clear lists with validation metrics values - for metric_list in self.val_metrics.values(): - metric_list.clear() + # Clear metrics + self.metric_tracker.clear_val_metrics() # pylint: disable-next=unused-argument def test_step(self, batch, batch_idx): @@ -421,9 +412,7 @@ def test_step(self, batch, batch_idx): batch_size=batch[0].shape[0], ) - # Compute all evaluation metrics for error maps Note: explicitly list - # metrics here, as test_metrics can contain additional ones, computed - # differently, but that should be aggregated on_test_epoch_end + # Compute all evaluation metrics for error maps for metric_name in ("mse", "mae"): metric_func = metrics.get_metric(metric_name) batch_metric_vals = metric_func( @@ -433,14 +422,14 @@ def test_step(self, batch, batch_idx): mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) - self.test_metrics[metric_name].append(batch_metric_vals) + self.metric_tracker.add_test_metric(metric_name, batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged mean_pred_std = torch.mean( pred_std[..., self.interior_mask_bool, :], dim=-2 ) # (B, pred_steps, d_f) - self.test_metrics["output_std"].append(mean_pred_std) + self.metric_tracker.add_test_metric("output_std", mean_pred_std) # Save per-sample spatial loss for specific times spatial_loss = self.loss( @@ -449,7 +438,7 @@ def test_step(self, batch, batch_idx): log_spatial_losses = spatial_loss[ :, [step - 1 for step in self.args.val_steps_to_log] ] - self.spatial_loss_maps.append(log_spatial_losses) + self.metric_tracker.add_spatial_loss_map(log_spatial_losses) # (B, N_log, num_grid_nodes) # Plot example predictions (on rank 0 only) @@ -598,160 +587,41 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ), ) - def create_metric_log_dict(self, metric_tensor, prefix, metric_name): - """ - Put together a dict with everything to log for one metric. Also saves - plots as pdf and csv if using test prefix. - - metric_tensor: (pred_steps, d_f), metric values per time and variable - prefix: string, prefix to use for logging metric_name: string, name of - the metric - - Return: log_dict: dict with everything to log for given metric - """ - log_dict = {} - metric_fig = vis.plot_error_map( - errors=metric_tensor, - datastore=self._datastore, - ) - full_log_name = f"{prefix}_{metric_name}" - log_dict[full_log_name] = metric_fig - - if prefix == "test": - # Save pdf - metric_fig.savefig( - os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") - ) - # Save errors also as csv - np.savetxt( - os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), - metric_tensor.cpu().numpy(), - delimiter=",", - ) - - # Check if metrics are watched, log exact values for specific vars - var_names = self._datastore.get_vars_names(category="state") - if full_log_name in self.args.metrics_watch: - for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var_name = var_names[var_i] - for step in timesteps: - key = f"{full_log_name}_{var_name}_step_{step}" - log_dict[key] = metric_tensor[step - 1, var_i] - - return log_dict - - def aggregate_and_plot_metrics(self, metrics_dict, prefix): - """ - Aggregate and create error map plots for all metrics in metrics_dict - - metrics_dict: dictionary with metric_names and list of tensors - with step-evals. - prefix: string, prefix to use for logging - """ - log_dict = {} - for metric_name, metric_val_list in metrics_dict.items(): - metric_tensor = self.all_gather_cat( - torch.cat(metric_val_list, dim=0) - ) # (N_eval, pred_steps, d_f) - - if self.trainer.is_global_zero: - metric_tensor_averaged = torch.mean(metric_tensor, dim=0) - # (pred_steps, d_f) - - # Take square root after all averaging to change MSE to RMSE - if "mse" in metric_name: - metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) - metric_name = metric_name.replace("mse", "rmse") - - # NOTE: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.state_std - # (pred_steps, d_f) - log_dict.update( - self.create_metric_log_dict( - metric_rescaled, prefix, metric_name - ) - ) - - # Ensure that log_dict has structure for - # logging as dict(str, plt.Figure) - assert all( - isinstance(key, str) and isinstance(value, plt.Figure) - for key, value in log_dict.items() - ) - - if self.trainer.is_global_zero and not self.trainer.sanity_checking: - - current_epoch = self.trainer.current_epoch - - for key, figure in log_dict.items(): - # For other loggers than wandb, add epoch to key. - # Wandb can log multiple images to the same key, while other - # loggers, such as MLFlow need unique keys for each image. - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}-{current_epoch}" - - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[figure]) - - plt.close("all") # Close all figs - def on_test_epoch_end(self): """ Compute test metrics and make plots at the end of test epoch. Will gather stored tensors and perform plotting and logging on rank 0. """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") + # Get metrics from tracker + test_metrics = self.metric_tracker.get_test_metrics() + + # Gather and plot metrics + for metric_name, metric_val_list in test_metrics.items(): + metric_tensor = self.all_gather_cat(torch.cat(metric_val_list, dim=0)) + # Pass the gathered tensor back for visualization + self.visualizer.aggregate_and_plot_metrics( + {metric_name: [metric_tensor]}, + prefix="test", + state_std=self.state_std, + logger=self.logger, + trainer=self.trainer, + ) # Plot spatial loss maps + spatial_loss_maps = self.metric_tracker.get_spatial_loss_maps() spatial_loss_tensor = self.all_gather_cat( - torch.cat(self.spatial_loss_maps, dim=0) + torch.cat(spatial_loss_maps, dim=0) ) # (N_test, N_log, num_grid_nodes) - if self.trainer.is_global_zero: - mean_spatial_loss = torch.mean( - spatial_loss_tensor, dim=0 - ) # (N_log, num_grid_nodes) - - loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, - datastore=self._datastore, - title=f"Test loss, t={t_i} " - f"({(self.time_step_int * t_i)} {self.time_step_unit})", - ) - for t_i, loss_map in zip( - self.args.val_steps_to_log, mean_spatial_loss - ) - ] - - # log all to same key, sequentially - for i, fig in enumerate(loss_map_figs): - key = "test_loss" - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}_{i}" - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[fig]) - - # also make without title and save as pdf - pdf_loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, datastore=self._datastore - ) - for loss_map in mean_spatial_loss - ] - pdf_loss_maps_dir = os.path.join( - self.logger.save_dir, "spatial_loss_maps" - ) - os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): - fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) - # save mean spatial loss as .pt file also - torch.save( - mean_spatial_loss.cpu(), - os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), - ) - self.spatial_loss_maps.clear() + self.visualizer.plot_spatial_loss( + spatial_loss_tensor=spatial_loss_tensor, + logger=self.logger, + trainer=self.trainer, + ) + + # Clear metrics + self.metric_tracker.clear_test_metrics() + self.metric_tracker.clear_spatial_loss_maps() def on_load_checkpoint(self, checkpoint): """ diff --git a/neural_lam/models/ar_model_backup.py b/neural_lam/models/ar_model_backup.py new file mode 100644 index 000000000..a411a3afc --- /dev/null +++ b/neural_lam/models/ar_model_backup.py @@ -0,0 +1,779 @@ +# Standard library +import os +import warnings +from typing import Any, Dict, List + +# Third-party +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import xarray as xr + +# First-party +from neural_lam.utils import get_integer_time + +# Local +from .. import metrics, vis +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset + + +class ARModel(pl.LightningModule): + """ + Generic auto-regressive weather model. + Abstract class that can be extended. + """ + + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + ): + super().__init__() + self.save_hyperparameters(ignore=["datastore"]) + self.args = args + self._datastore = datastore + num_state_vars = datastore.get_num_data_vars(category="state") + num_forcing_vars = datastore.get_num_data_vars(category="forcing") + # Load static features standardized + da_static_features = datastore.get_dataarray( + category="static", split=None, standardize=True + ) + if da_static_features is None: + raise ValueError("Static features are required for ARModel") + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + da_boundary_mask = datastore.boundary_mask + num_past_forcing_steps = args.num_past_forcing_steps + num_future_forcing_steps = args.num_future_forcing_steps + + # Load static features for grid/data, + self.register_buffer( + "grid_static_features", + torch.tensor(da_static_features.values, dtype=torch.float32), + persistent=False, + ) + + state_stats = { + "state_mean": torch.tensor( + da_state_stats.state_mean.values, dtype=torch.float32 + ), + "state_std": torch.tensor( + da_state_stats.state_std.values, dtype=torch.float32 + ), + # Note that the one-step-diff stats (diff_mean and diff_std) are + # for differences computed on standardized data + "diff_mean": torch.tensor( + da_state_stats.state_diff_mean_standardized.values, + dtype=torch.float32, + ), + "diff_std": torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ), + } + + for key, val in state_stats.items(): + self.register_buffer(key, val, persistent=False) + + state_feature_weights = get_state_feature_weighting( + config=config, datastore=datastore + ) + self.feature_weights = torch.tensor( + state_feature_weights, dtype=torch.float32 + ) + + # Double grid output dim. to also output std.-dev. + self.output_std = bool(args.output_std) + if self.output_std: + # Pred. dim. in grid cell + self.grid_output_dim = 2 * num_state_vars + else: + # Pred. dim. in grid cell + self.grid_output_dim = num_state_vars + # Store constant per-variable std.-dev. weighting + # NOTE that this is the inverse of the multiplicative weighting + # in wMSE/wMAE + self.register_buffer( + "per_var_std", + self.diff_std / torch.sqrt(self.feature_weights), + persistent=False, + ) + + # grid_dim from data + static + ( + self.num_grid_nodes, + grid_static_dim, + ) = self.grid_static_features.shape + + self.grid_dim = ( + 2 * num_state_vars + + grid_static_dim + + num_forcing_vars + * (num_past_forcing_steps + num_future_forcing_steps + 1) + ) + + # Instantiate loss function + self.loss = metrics.get_metric(args.loss) + + boundary_mask = torch.tensor( + da_boundary_mask.values, dtype=torch.float32 + ).unsqueeze( + 1 + ) # add feature dim + + self.register_buffer("boundary_mask", boundary_mask, persistent=False) + # Pre-compute interior mask for use in loss function + self.register_buffer( + "interior_mask", 1.0 - self.boundary_mask, persistent=False + ) # (num_grid_nodes, 1), 1 for non-border + + self.val_metrics: Dict[str, List] = { + "mse": [], + } + self.test_metrics: Dict[str, List] = { + "mse": [], + "mae": [], + } + if self.output_std: + self.test_metrics["output_std"] = [] # Treat as metric + + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt + + # For example plotting + self.n_example_pred = args.n_example_pred + self.plotted_examples = 0 + + # For storing spatial loss maps during evaluation + self.spatial_loss_maps: List[Any] = [] + + self.time_step_int, self.time_step_unit = get_integer_time( + self._datastore.step_length + ) + + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: torch.Tensor, + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. + time : torch.Tensor + The time index or indices for the data, given as tensor representing + epoch time in nanoseconds. The tensor will be + copied to the CPU memory if they are not already there. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + + def configure_optimizers(self): + opt = torch.optim.AdamW( + self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) + ) + return opt + + @property + def interior_mask_bool(self): + """ + Get the interior mask as a boolean (N,) mask. + """ + return self.interior_mask[:, 0].to(torch.bool) + + @staticmethod + def expand_to_batch(x, batch_size): + """ + Expand tensor with initial batch dimension + """ + return x.unsqueeze(0).expand(batch_size, -1, -1) + + def predict_step(self, prev_state, prev_prev_state, forcing): + """ + Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 + prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, + num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, + forcing_dim) + """ + raise NotImplementedError("No prediction step implemented") + + def unroll_prediction(self, init_states, forcing_features, true_states): + """ + Roll out prediction taking multiple autoregressive steps with model + init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, + pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, + num_grid_nodes, d_f) + """ + prev_prev_state = init_states[:, 0] + prev_state = init_states[:, 1] + prediction_list = [] + pred_std_list = [] + pred_steps = forcing_features.shape[1] + + for i in range(pred_steps): + forcing = forcing_features[:, i] + border_state = true_states[:, i] + + pred_state, pred_std = self.predict_step( + prev_state, prev_prev_state, forcing + ) + # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, + # d_f) or None + + # Overwrite border with true state + new_state = ( + self.boundary_mask * border_state + + self.interior_mask * pred_state + ) + + prediction_list.append(new_state) + if self.output_std: + pred_std_list.append(pred_std) + + # Update conditioning states + prev_prev_state = prev_state + prev_state = new_state + + prediction = torch.stack( + prediction_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + if self.output_std: + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + else: + pred_std = self.per_var_std # (d_f,) + + return prediction, pred_std + + def common_step(self, batch): + """ + Predict on single batch batch consists of: init_states: (B, 2, + num_grid_nodes, d_features) target_states: (B, pred_steps, + num_grid_nodes, d_features) forcing_features: (B, pred_steps, + num_grid_nodes, d_forcing), + where index 0 corresponds to index 1 of init_states + """ + (init_states, target_states, forcing_features, batch_times) = batch + + prediction, pred_std = self.unroll_prediction( + init_states, forcing_features, target_states + ) # (B, pred_steps, num_grid_nodes, d_f) + # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, + # pred_steps, num_grid_nodes, d_f) or (d_f,) + + return prediction, target_states, pred_std, batch_times + + def training_step(self, batch): + """ + Train on single batch + """ + prediction, target, pred_std, _ = self.common_step(batch) + + # Compute loss + batch_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ) + ) # mean over unrolled times and batch + + log_dict = {"train_loss": batch_loss} + self.log_dict( + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + return batch_loss + + def all_gather_cat(self, tensor_to_gather): + """ + Gather tensors across all ranks, and concatenate across dim. 0 (instead + of stacking in new dim. 0) + + tensor_to_gather: (d1, d2, ...), distributed over K ranks + + returns: + - single-device strategies: (d1, d2, ...) + - multi-device strategies: (K*d1, d2, ...) + """ + gathered = self.all_gather(tensor_to_gather) + # all_gather adds a leading dim (K,) only on multi-device runs; + # on single-device it returns the tensor unchanged. + if gathered.dim() > tensor_to_gather.dim(): + return gathered.flatten(0, 1) + return gathered + + # newer lightning versions requires batch_idx argument, even if unused + # pylint: disable-next=unused-argument + def validation_step(self, batch, batch_idx): + """ + Run validation on single batch + """ + prediction, target, pred_std, _ = self.common_step(batch) + + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1) + mean_loss = torch.mean(time_step_loss) + + # Log loss per time step forward and mean + val_log_dict = { + f"val_loss_unroll{step}": time_step_loss[step - 1] + for step in self.args.val_steps_to_log + if step <= len(time_step_loss) + } + val_log_dict["val_mean_loss"] = mean_loss + self.log_dict( + val_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + # Store MSEs + entry_mses = metrics.mse( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.val_metrics["mse"].append(entry_mses) + + def on_validation_epoch_end(self): + """ + Compute val metrics at the end of val epoch + """ + # Create error maps for all test metrics + self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") + + # Clear lists with validation metrics values + for metric_list in self.val_metrics.values(): + metric_list.clear() + + # pylint: disable-next=unused-argument + def test_step(self, batch, batch_idx): + """ + Run test on single batch + """ + # TODO Here batch_times can be used for plotting routines + prediction, target, pred_std, batch_times = self.common_step(batch) + # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, + # pred_steps, num_grid_nodes, d_f) or (d_f,) + + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1,) + mean_loss = torch.mean(time_step_loss) + + # Log loss per time step forward and mean + test_log_dict = { + f"test_loss_unroll{step}": time_step_loss[step - 1] + for step in self.args.val_steps_to_log + } + test_log_dict["test_mean_loss"] = mean_loss + + self.log_dict( + test_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + # Compute all evaluation metrics for error maps Note: explicitly list + # metrics here, as test_metrics can contain additional ones, computed + # differently, but that should be aggregated on_test_epoch_end + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics[metric_name].append(batch_metric_vals) + + if self.output_std: + # Store output std. per variable, spatially averaged + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) # (B, pred_steps, d_f) + self.test_metrics["output_std"].append(mean_pred_std) + + # Save per-sample spatial loss for specific times + spatial_loss = self.loss( + prediction, target, pred_std, average_grid=False + ) # (B, pred_steps, num_grid_nodes) + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] + self.spatial_loss_maps.append(log_spatial_losses) + # (B, N_log, num_grid_nodes) + + # Plot example predictions (on rank 0 only) + if ( + self.trainer.is_global_zero + and self.plotted_examples < self.n_example_pred + ): + # Need to plot more example predictions + n_additional_examples = min( + prediction.shape[0], + self.n_example_pred - self.plotted_examples, + ) + + self.plot_examples( + batch, + n_additional_examples, + prediction=prediction, + split="test", + ) + + def plot_examples(self, batch, n_examples, split, prediction=None): + """ + Plot the first n_examples forecasts from batch + + batch: batch with data to plot corresponding forecasts for n_examples: + number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes, + d_f), existing prediction. + Generate if None. + """ + if prediction is None: + prediction, target, _, _ = self.common_step(batch) + + target = batch[1] + time = batch[3] + + # Rescale to original data scale + prediction_rescaled = prediction * self.state_std + self.state_mean + target_rescaled = target * self.state_std + self.state_mean + + # Iterate over the examples + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], + ): + # Each slice is (pred_steps, num_grid_nodes, d_f) + self.plotted_examples += 1 # Increment already here + + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + + var_vmin = ( + torch.minimum( + pred_slice.flatten(0, 1).min(dim=0)[0], + target_slice.flatten(0, 1).min(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) + var_vmax = ( + torch.maximum( + pred_slice.flatten(0, 1).max(dim=0)[0], + target_slice.flatten(0, 1).max(dim=0)[0], + ) + .cpu() + .numpy() + ) # (d_f,) + var_vranges = list(zip(var_vmin, var_vmax)) + + # Iterate over prediction horizon time steps + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): + # Create one figure per variable at this time step + var_figs = [ + vis.plot_prediction( + datastore=self._datastore, + title=f"{var_name}, t={t_i}" + f" ({self.time_step_int * t_i}" + f"{self.time_step_unit})", + colorbar_label=var_unit, + vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip( + self._datastore.get_vars_names("state"), + self._datastore.get_vars_units("state"), + var_vranges, + ) + ) + ] + + example_i = self.plotted_examples + + for var_name, fig in zip( + self._datastore.get_vars_names("state"), var_figs + ): + + # We need treat logging images differently for different + # loggers. WANDB can log multiple images to the same key, + # while other loggers, as MLFlow, need unique keys for + # each image. + if isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{var_name}_example_{example_i}" + else: + key = f"{var_name}_example" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig], step=t_i) + else: + warnings.warn( + f"{self.logger} does not support image logging." + ) + + plt.close( + "all" + ) # Close all figs for this time step, saves memory + + # Save pred and target as .pt files + torch.save( + pred_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_pred_{self.plotted_examples}.pt", + ), + ) + torch.save( + target_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_target_{self.plotted_examples}.pt", + ), + ) + + def create_metric_log_dict(self, metric_tensor, prefix, metric_name): + """ + Put together a dict with everything to log for one metric. Also saves + plots as pdf and csv if using test prefix. + + metric_tensor: (pred_steps, d_f), metric values per time and variable + prefix: string, prefix to use for logging metric_name: string, name of + the metric + + Return: log_dict: dict with everything to log for given metric + """ + log_dict = {} + metric_fig = vis.plot_error_map( + errors=metric_tensor, + datastore=self._datastore, + ) + full_log_name = f"{prefix}_{metric_name}" + log_dict[full_log_name] = metric_fig + + if prefix == "test": + # Save pdf + metric_fig.savefig( + os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") + ) + # Save errors also as csv + np.savetxt( + os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) + + # Check if metrics are watched, log exact values for specific vars + var_names = self._datastore.get_vars_names(category="state") + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var_name = var_names[var_i] + for step in timesteps: + key = f"{full_log_name}_{var_name}_step_{step}" + log_dict[key] = metric_tensor[step - 1, var_i] + + return log_dict + + def aggregate_and_plot_metrics(self, metrics_dict, prefix): + """ + Aggregate and create error map plots for all metrics in metrics_dict + + metrics_dict: dictionary with metric_names and list of tensors + with step-evals. + prefix: string, prefix to use for logging + """ + log_dict = {} + for metric_name, metric_val_list in metrics_dict.items(): + metric_tensor = self.all_gather_cat( + torch.cat(metric_val_list, dim=0) + ) # (N_eval, pred_steps, d_f) + + if self.trainer.is_global_zero: + metric_tensor_averaged = torch.mean(metric_tensor, dim=0) + # (pred_steps, d_f) + + # Take square root after all averaging to change MSE to RMSE + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + + # NOTE: we here assume rescaling for all metrics is linear + metric_rescaled = metric_tensor_averaged * self.state_std + # (pred_steps, d_f) + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name + ) + ) + + # Ensure that log_dict has structure for + # logging as dict(str, plt.Figure) + assert all( + isinstance(key, str) and isinstance(value, plt.Figure) + for key, value in log_dict.items() + ) + + if self.trainer.is_global_zero and not self.trainer.sanity_checking: + + current_epoch = self.trainer.current_epoch + + for key, figure in log_dict.items(): + # For other loggers than wandb, add epoch to key. + # Wandb can log multiple images to the same key, while other + # loggers, such as MLFlow need unique keys for each image. + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}-{current_epoch}" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[figure]) + + plt.close("all") # Close all figs + + def on_test_epoch_end(self): + """ + Compute test metrics and make plots at the end of test epoch. Will + gather stored tensors and perform plotting and logging on rank 0. + """ + # Create error maps for all test metrics + self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") + + # Plot spatial loss maps + spatial_loss_tensor = self.all_gather_cat( + torch.cat(self.spatial_loss_maps, dim=0) + ) # (N_test, N_log, num_grid_nodes) + if self.trainer.is_global_zero: + mean_spatial_loss = torch.mean( + spatial_loss_tensor, dim=0 + ) # (N_log, num_grid_nodes) + + loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, + datastore=self._datastore, + title=f"Test loss, t={t_i} " + f"({(self.time_step_int * t_i)} {self.time_step_unit})", + ) + for t_i, loss_map in zip( + self.args.val_steps_to_log, mean_spatial_loss + ) + ] + + # log all to same key, sequentially + for i, fig in enumerate(loss_map_figs): + key = "test_loss" + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}_{i}" + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig]) + + # also make without title and save as pdf + pdf_loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, datastore=self._datastore + ) + for loss_map in mean_spatial_loss + ] + pdf_loss_maps_dir = os.path.join( + self.logger.save_dir, "spatial_loss_maps" + ) + os.makedirs(pdf_loss_maps_dir, exist_ok=True) + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): + fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + # save mean spatial loss as .pt file also + torch.save( + mean_spatial_loss.cpu(), + os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), + ) + + self.spatial_loss_maps.clear() + + def on_load_checkpoint(self, checkpoint): + """ + Perform any changes to state dict before loading checkpoint + """ + loaded_state_dict = checkpoint["state_dict"] + + # Fix for loading older models after IneractionNet refactoring, where + # the grid MLP was moved outside the encoder InteractionNet class + if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: + replace_keys = list( + filter( + lambda key: key.startswith("g2m_gnn.grid_mlp"), + loaded_state_dict.keys(), + ) + ) + for old_key in replace_keys: + new_key = old_key.replace( + "g2m_gnn.grid_mlp", "encoding_grid_mlp" + ) + loaded_state_dict[new_key] = loaded_state_dict[old_key] + del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/metric_tracker.py b/neural_lam/models/metric_tracker.py new file mode 100644 index 000000000..fe50fc27d --- /dev/null +++ b/neural_lam/models/metric_tracker.py @@ -0,0 +1,125 @@ +# Standard library +from typing import Any, Dict, List + +# Third-party +import torch + + +class MetricTracker: + """ + Tracks validation and test metrics during model evaluation. + + This class manages the collection and storage of metrics across + multiple batches during validation and testing phases. + """ + + def __init__(self, output_std: bool = False): + """ + Initialize the MetricTracker. + + Parameters + ---------- + output_std : bool, optional + Whether the model outputs standard deviation, by default False + """ + self.output_std = output_std + self.val_metrics: Dict[str, List] = { + "mse": [], + } + self.test_metrics: Dict[str, List] = { + "mse": [], + "mae": [], + } + if self.output_std: + self.test_metrics["output_std"] = [] # Treat as metric + + # For storing spatial loss maps during evaluation + self.spatial_loss_maps: List[Any] = [] + + def add_val_metric(self, metric_name: str, metric_value: torch.Tensor): + """ + Add a validation metric value. + + Parameters + ---------- + metric_name : str + Name of the metric (e.g., 'mse', 'mae') + metric_value : torch.Tensor + The metric value tensor to store + """ + if metric_name not in self.val_metrics: + self.val_metrics[metric_name] = [] + self.val_metrics[metric_name].append(metric_value) + + def add_test_metric(self, metric_name: str, metric_value: torch.Tensor): + """ + Add a test metric value. + + Parameters + ---------- + metric_name : str + Name of the metric (e.g., 'mse', 'mae', 'output_std') + metric_value : torch.Tensor + The metric value tensor to store + """ + if metric_name not in self.test_metrics: + self.test_metrics[metric_name] = [] + self.test_metrics[metric_name].append(metric_value) + + def add_spatial_loss_map(self, spatial_loss: torch.Tensor): + """ + Add a spatial loss map. + + Parameters + ---------- + spatial_loss : torch.Tensor + Spatial loss tensor, typically (B, N_log, num_grid_nodes) + """ + self.spatial_loss_maps.append(spatial_loss) + + def clear_val_metrics(self): + """Clear all validation metrics.""" + for metric_list in self.val_metrics.values(): + metric_list.clear() + + def clear_test_metrics(self): + """Clear all test metrics.""" + for metric_list in self.test_metrics.values(): + metric_list.clear() + + def clear_spatial_loss_maps(self): + """Clear all spatial loss maps.""" + self.spatial_loss_maps.clear() + + def get_val_metrics(self) -> Dict[str, List]: + """ + Get all validation metrics. + + Returns + ------- + Dict[str, List] + Dictionary mapping metric names to lists of tensors + """ + return self.val_metrics + + def get_test_metrics(self) -> Dict[str, List]: + """ + Get all test metrics. + + Returns + ------- + Dict[str, List] + Dictionary mapping metric names to lists of tensors + """ + return self.test_metrics + + def get_spatial_loss_maps(self) -> List[Any]: + """ + Get all spatial loss maps. + + Returns + ------- + List[Any] + List of spatial loss map tensors + """ + return self.spatial_loss_maps diff --git a/neural_lam/models/model_visualizer.py b/neural_lam/models/model_visualizer.py new file mode 100644 index 000000000..4664f812a --- /dev/null +++ b/neural_lam/models/model_visualizer.py @@ -0,0 +1,237 @@ +# Standard library +import os +import warnings +from typing import Any, Dict + +# Third-party +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch + +# First-party +from .. import vis +from ..datastore import BaseDatastore + + +class ModelVisualizer: + """ + Handles visualization of model predictions and metrics. + + This class centralizes all plotting and visualization logic for weather models, + including prediction plots, error maps, and spatial loss visualizations. + """ + + def __init__( + self, + datastore: BaseDatastore, + args: Any, + time_step_int: int, + time_step_unit: str, + ): + """ + Initialize the ModelVisualizer. + + Parameters + ---------- + datastore : BaseDatastore + The datastore for accessing data metadata + args : Any + Arguments object containing configuration + time_step_int : int + Integer part of time step + time_step_unit : str + Unit of time step (e.g., 'h' for hours) + """ + self.datastore = datastore + self.args = args + self.time_step_int = time_step_int + self.time_step_unit = time_step_unit + + def create_metric_log_dict( + self, + metric_tensor: torch.Tensor, + prefix: str, + metric_name: str, + logger: Any, + ) -> Dict[str, Any]: + """ + Put together a dict with everything to log for one metric. + + Also saves plots as pdf and csv if using test prefix. + + Parameters + ---------- + metric_tensor : torch.Tensor + Metric values per time and variable, shape (pred_steps, d_f) + prefix : str + Prefix to use for logging (e.g., 'val', 'test') + metric_name : str + Name of the metric + logger : Any + Logger instance for saving files + + Returns + ------- + Dict[str, Any] + Dictionary with everything to log for given metric + """ + log_dict = {} + metric_fig = vis.plot_error_map( + errors=metric_tensor, + datastore=self.datastore, + ) + full_log_name = f"{prefix}_{metric_name}" + log_dict[full_log_name] = metric_fig + + if prefix == "test": + # Save pdf + metric_fig.savefig( + os.path.join(logger.save_dir, f"{full_log_name}.pdf") + ) + # Save errors also as csv + np.savetxt( + os.path.join(logger.save_dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) + + # Check if metrics are watched, log exact values for specific vars + var_names = self.datastore.get_vars_names(category="state") + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var_name = var_names[var_i] + for step in timesteps: + key = f"{full_log_name}_{var_name}_step_{step}" + log_dict[key] = metric_tensor[step - 1, var_i] + + return log_dict + + def aggregate_and_plot_metrics( + self, + metrics_dict: Dict[str, Any], + prefix: str, + state_std: torch.Tensor, + logger: Any, + trainer: Any, + ): + """ + Aggregate and create error map plots for all metrics in metrics_dict. + + Parameters + ---------- + metrics_dict : dict + Dictionary with metric_names and list of tensors with step-evals + prefix : str + Prefix to use for logging + state_std : torch.Tensor + State standard deviation for rescaling + logger : Any + Logger instance + trainer : Any + Trainer instance for rank and sanity check info + """ + log_dict = {} + for metric_name, metric_val_list in metrics_dict.items(): + # Gather metrics across devices + metric_list_cat = torch.cat(metric_val_list, dim=0) + # Note: gathering is handled by the caller (ARModel.all_gather_cat) + + if trainer.is_global_zero: + metric_tensor_averaged = torch.mean( + metric_list_cat, dim=0 + ) # (pred_steps, d_f) + + # Take square root after all averaging to change MSE to RMSE + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + + # NOTE: we here assume rescaling for all metrics is linear + metric_rescaled = ( + metric_tensor_averaged * state_std + ) # (pred_steps, d_f) + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name, logger + ) + ) + + # Ensure that log_dict has structure for logging as dict(str, plt.Figure) + assert all( + isinstance(key, str) and isinstance(value, plt.Figure) + for key, value in log_dict.items() + ) + + if trainer.is_global_zero and not trainer.sanity_checking: + current_epoch = trainer.current_epoch + + for key, figure in log_dict.items(): + # For other loggers than wandb, add epoch to key. + # Wandb can log multiple images to the same key, while other + # loggers, such as MLFlow need unique keys for each image. + if not isinstance(logger, pl.loggers.WandbLogger): + key = f"{key}-{current_epoch}" + + if hasattr(logger, "log_image"): + logger.log_image(key=key, images=[figure]) + + plt.close("all") # Close all figs + + def plot_spatial_loss( + self, + spatial_loss_tensor: torch.Tensor, + logger: Any, + trainer: Any, + ): + """ + Plot spatial loss maps. + + Parameters + ---------- + spatial_loss_tensor : torch.Tensor + Spatial loss tensor, shape (N_test, N_log, num_grid_nodes) + logger : Any + Logger instance + trainer : Any + Trainer instance + """ + if trainer.is_global_zero: + mean_spatial_loss = torch.mean( + spatial_loss_tensor, dim=0 + ) # (N_log, num_grid_nodes) + + loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, + datastore=self.datastore, + title=f"Test loss, t={t_i} " + f"({(self.time_step_int * t_i)} {self.time_step_unit})", + ) + for t_i, loss_map in zip( + self.args.val_steps_to_log, mean_spatial_loss + ) + ] + + # log all to same key, sequentially + for i, fig in enumerate(loss_map_figs): + key = "test_loss" + if not isinstance(logger, pl.loggers.WandbLogger): + key = f"{key}_{i}" + if hasattr(logger, "log_image"): + logger.log_image(key=key, images=[fig]) + + # also make without title and save as pdf + pdf_loss_map_figs = [ + vis.plot_spatial_error(error=loss_map, datastore=self.datastore) + for loss_map in mean_spatial_loss + ] + pdf_loss_maps_dir = os.path.join(logger.save_dir, "spatial_loss_maps") + os.makedirs(pdf_loss_maps_dir, exist_ok=True) + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): + fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + # save mean spatial loss as .pt file also + torch.save( + mean_spatial_loss.cpu(), + os.path.join(logger.save_dir, "mean_spatial_loss.pt"), + ) diff --git a/neural_lam/models/normalization_manager.py b/neural_lam/models/normalization_manager.py new file mode 100644 index 000000000..219302c9e --- /dev/null +++ b/neural_lam/models/normalization_manager.py @@ -0,0 +1,129 @@ +# Standard library +from typing import Dict + +# Third-party +import torch +import torch.nn as nn + +# First-party +from ..datastore import BaseDatastore + + +class NormalizationManager(nn.Module): + """ + Manages state normalization and statistical buffers for weather models. + + This class centralizes the handling of normalization statistics including + state mean/std and difference mean/std, as well as static grid features. + All statistics are registered as buffers to ensure proper device handling + and state dict serialization. + """ + + def __init__(self, datastore: BaseDatastore): + """ + Initialize the NormalizationManager. + + Parameters + ---------- + datastore : BaseDatastore + The datastore containing the data statistics and static features. + """ + super().__init__() + + # Load static features standardized + da_static_features = datastore.get_dataarray( + category="static", split=None, standardize=True + ) + if da_static_features is None: + raise ValueError("Static features are required for NormalizationManager") + + da_state_stats = datastore.get_standardization_dataarray(category="state") + + # Load static features for grid/data + self.register_buffer( + "grid_static_features", + torch.tensor(da_static_features.values, dtype=torch.float32), + persistent=False, + ) + + # Register state statistics as buffers + state_stats = { + "state_mean": torch.tensor( + da_state_stats.state_mean.values, dtype=torch.float32 + ), + "state_std": torch.tensor( + da_state_stats.state_std.values, dtype=torch.float32 + ), + # Note that the one-step-diff stats (diff_mean and diff_std) are + # for differences computed on standardized data + "diff_mean": torch.tensor( + da_state_stats.state_diff_mean_standardized.values, + dtype=torch.float32, + ), + "diff_std": torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ), + } + + for key, val in state_stats.items(): + self.register_buffer(key, val, persistent=False) + + def normalize_state(self, state: torch.Tensor) -> torch.Tensor: + """ + Normalize state using mean and standard deviation. + + Parameters + ---------- + state : torch.Tensor + The state tensor to normalize, shape (..., d_f) + + Returns + ------- + torch.Tensor + Normalized state with same shape as input + """ + return (state - self.state_mean) / self.state_std + + def denormalize_state(self, normalized_state: torch.Tensor) -> torch.Tensor: + """ + Convert normalized state back to original scale. + + Parameters + ---------- + normalized_state : torch.Tensor + The normalized state tensor, shape (..., d_f) + + Returns + ------- + torch.Tensor + State in original scale with same shape as input + """ + return normalized_state * self.state_std + self.state_mean + + def get_grid_static_features(self) -> torch.Tensor: + """ + Get the grid static features tensor. + + Returns + ------- + torch.Tensor + Grid static features, shape (num_grid_nodes, static_feature_dim) + """ + return self.grid_static_features + + def get_state_stats(self) -> Dict[str, torch.Tensor]: + """ + Get all state statistics as a dictionary. + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing 'state_mean', 'state_std', 'diff_mean', 'diff_std' + """ + return { + "state_mean": self.state_mean, + "state_std": self.state_std, + "diff_mean": self.diff_mean, + "diff_std": self.diff_std, + } diff --git a/tests/conftest.py b/tests/conftest.py index 47237ed55..fb56b4acd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ ) # Local -from .dummy_datastore import DummyDatastore +from .dummy_datastore import DummyDatastore, GlobalDummyDatastore # Disable weights and biases to avoid unnecessary logging # and to avoid having to deal with authentication @@ -103,9 +103,11 @@ def download_meps_example_reduced_dataset(): ), npyfilesmeps=None, dummydata=None, + globaldummydata=None, ) DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore +DATASTORES[GlobalDummyDatastore.SHORT_NAME] = GlobalDummyDatastore def init_datastore_example(datastore_kind): diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 3a844d6d9..9b80b1c39 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -758,3 +758,56 @@ def num_grid_points(self) -> int: @property def state_feature_weights_values(self) -> list[float]: return [1.0] + + +class GlobalDummyDatastore(DummyDatastore): + """ + Datastore variant for testing global forecasting support. Returns an + all-zero boundary mask, representing a global domain with no lateral + boundaries. All model predictions should be used everywhere without + boundary forcing. + """ + + SHORT_NAME = "globaldummydata" + + def __init__( + self, + config_path=None, + n_grid_points=10000, + n_timesteps=10, + step_length=None, + ) -> None: + """ + Create a global dummy datastore with random data and no boundaries. + + Parameters + ---------- + config_path : None + No config file is needed for the dummy datastore. This argument is + only present to match the signature of the other datastores. + n_grid_points : int + The number of grid points in the dataset. Must be a perfect square. + n_timesteps : int + The number of timesteps in the dataset. + step_length : timedelta, optional + The step length between timesteps. Defaults to timedelta(hours=1). + """ + # Initialize parent class + super().__init__( + config_path=config_path, + n_grid_points=n_grid_points, + n_timesteps=n_timesteps, + step_length=step_length, + ) + + # Override boundary mask to be all zeros (global domain) + n_points_1d = int(np.sqrt(n_grid_points)) + self.ds["boundary_mask"] = xr.DataArray( + np.zeros((n_points_1d, n_points_1d), dtype=int), + dims=["x", "y"], + ) + + # Re-stack the spatial dimensions after overriding boundary_mask + self.ds = self.ds.drop_vars("grid_index").stack( + grid_index=self.spatial_coordinates + ) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4ce3875ea..9c2717c56 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -232,9 +232,15 @@ def test_boundary_mask(datastore_name): assert isinstance(da_mask, xr.DataArray) assert set(da_mask.dims) == {"grid_index"} assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size + + # Allow for global domains (all zeros) or regional domains (mix of 0 and 1) + assert set(da_mask.values).issubset({0, 1}) + + # For regional domains (non-global), ensure we have both boundary and interior + if da_mask.sum() > 0: + # Has boundary points - this is a regional domain + assert da_mask.sum() < da_mask.size # Must also have interior points + # For global domains, sum can be 0 (all interior, no boundaries) if isinstance(datastore, BaseRegularGridDatastore): grid_shape = datastore.grid_shape_state diff --git a/tests/test_training.py b/tests/test_training.py index 972740695..106794f68 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -189,3 +189,20 @@ def all_gather(self, tensor, sync_grads=False): "all_gather_cat produced incorrectly ordered/combined values " "on multi-device simulation" ) + + +def test_training_global(): + """ + Test that training works end-to-end with a global domain (no boundary + forcing). This verifies that the model can handle all-zero boundary masks + correctly, which is essential for global forecasting support. + """ + datastore = init_datastore_example("globaldummydata") + + # Verify this is actually a global domain + assert ( + datastore.boundary_mask.sum() == 0 + ), "GlobalDummyDatastore should have all-zero boundary mask" + + # Run training - should complete without errors + run_simple_training(datastore, set_output_std=False)