diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 2f45b03fa..c1e169922 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -131,6 +131,83 @@ def aggregate(self, inputs, index, ptr, dim_size): return aggr, inputs +class PropagationNet(InteractionNet): + """ + Alternative version of InteractionNet that incentivizes the propagation + of information from sender nodes to receivers. + """ + + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + + def __init__( + self, + edge_index, + input_dim, + update_edges=True, + hidden_layers=1, + hidden_dim=None, + edge_chunk_sizes=None, + aggr_chunk_sizes=None, + aggr="sum", + ): + # Use mean aggregation in propagation version to avoid instability + super().__init__( + edge_index, + input_dim, + update_edges=update_edges, + hidden_layers=hidden_layers, + hidden_dim=hidden_dim, + edge_chunk_sizes=edge_chunk_sizes, + aggr_chunk_sizes=aggr_chunk_sizes, + aggr="mean", + ) + + def forward(self, send_rep, rec_rep, edge_rep): + """ + Apply propagation network to update the representations of receiver + nodes, and optionally the edge representations. + + send_rep: (N_send, d_h), vector representations of sender nodes + rec_rep: (N_rec, d_h), vector representations of receiver nodes + edge_rep: (M, d_h), vector representations of edges used + + Returns: + rec_rep: (N_rec, d_h), updated vector representations of receiver + nodes + (optionally) edge_rep: (M, d_h), updated vector representations + of edges + """ + # Always concatenate to [rec_nodes, send_nodes] for propagation, + # but only aggregate to rec_nodes + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + edge_rep_aggr, edge_diff = self.propagate( + self.edge_index, x=node_reps, edge_attr=edge_rep + ) + rec_diff = self.aggr_mlp( + torch.cat((rec_rep, edge_rep_aggr), dim=-1) + ) + + # Residual connections + rec_rep = edge_rep_aggr + rec_diff # residual is to aggregation + + if self.update_edges: + edge_rep = edge_rep + edge_diff + return rec_rep, edge_rep + + return rec_rep + + def message(self, x_j, x_i, edge_attr): + """ + Compute messages from node j to node i. + """ + # Residual connection is to sender node, propagating information + # to edge + return x_j + self.edge_mlp( + torch.cat((edge_attr, x_j, x_i), dim=-1) + ) + + class SplitMLPs(nn.Module): """ Module that feeds chunks of input through different MLPs. diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py index f65387ab6..ad5897f8a 100644 --- a/neural_lam/models/__init__.py +++ b/neural_lam/models/__init__.py @@ -1,6 +1,11 @@ # Local -from .base_graph_model import BaseGraphModel -from .base_hi_graph_model import BaseHiGraphModel +from .forecaster_module import ForecasterModule from .graph_lam import GraphLAM from .hi_lam import HiLAM from .hi_lam_parallel import HiLAMParallel + +MODELS = { + "graph_lam": GraphLAM, + "hi_lam": HiLAM, + "hi_lam_parallel": HiLAMParallel, +} diff --git a/neural_lam/models/ar_forecaster.py b/neural_lam/models/ar_forecaster.py new file mode 100644 index 000000000..9a18d054a --- /dev/null +++ b/neural_lam/models/ar_forecaster.py @@ -0,0 +1,84 @@ +# Third-party +import torch + +# Local +from ..datastore import BaseDatastore +from .forecaster import Forecaster +from .step_predictor import StepPredictor + + +class ARForecaster(Forecaster): + """ + Subclass of Forecaster that uses an auto-regressive strategy to + unroll a forecast. Makes use of a StepPredictor at each AR step. + """ + + def __init__(self, predictor: StepPredictor, datastore: BaseDatastore): + super().__init__() + self.predictor = predictor + + # Register boundary/interior masks on the forecaster, not the predictor + boundary_mask = ( + torch.tensor(datastore.boundary_mask.values, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(-1) + ) + self.register_buffer("boundary_mask", boundary_mask, persistent=False) + self.register_buffer( + "interior_mask", 1.0 - self.boundary_mask, persistent=False + ) + + @property + def predicts_std(self) -> bool: + return self.predictor.predicts_std + + def forward( + self, + init_states: torch.Tensor, + forcing_features: torch.Tensor, + boundary_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Unroll the autoregressive model. + boundary_states is used ONLY to overwrite boundary nodes at each step. + The interior prediction at step i must not depend on + boundary_states[:, i] in any other way. + """ + + prev_prev_state = init_states[:, 0] + prev_state = init_states[:, 1] + prediction_list = [] + pred_std_list = [] + pred_steps = forcing_features.shape[1] + + for i in range(pred_steps): + forcing = forcing_features[:, i] + boundary_state = boundary_states[:, i] + + pred_state, pred_std = self.predictor( + prev_state, prev_prev_state, forcing + ) + + # Overwrite boundary with true state using ARForecaster's mask + new_state = ( + self.boundary_mask * boundary_state + + self.interior_mask * pred_state + ) + + prediction_list.append(new_state) + if pred_std is not None: + pred_std_list.append(pred_std) + + # Update conditioning states + prev_prev_state = prev_state + prev_state = new_state + + prediction = torch.stack(prediction_list, dim=1) + # If predictor outputs std, stack it; otherwise return None so + # ForecasterModule can substitute the constant per_var_std + if pred_std_list: + pred_std = torch.stack(pred_std_list, dim=1) + else: + pred_std = None + + return prediction, pred_std diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py deleted file mode 100644 index 10fe64190..000000000 --- a/neural_lam/models/ar_model.py +++ /dev/null @@ -1,771 +0,0 @@ -# Standard library -import os -import warnings -from typing import Any, Dict, List - -# Third-party -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -import xarray as xr - -# First-party -from neural_lam.utils import get_integer_time - -# Local -from .. import metrics, vis -from ..config import NeuralLAMConfig -from ..datastore import BaseDatastore -from ..loss_weighting import get_state_feature_weighting -from ..weather_dataset import WeatherDataset - - -class ARModel(pl.LightningModule): - """ - Generic auto-regressive weather model. - Abstract class that can be extended. - """ - - # pylint: disable=arguments-differ - # Disable to override args/kwargs from superclass - - def __init__( - self, - args, - config: NeuralLAMConfig, - datastore: BaseDatastore, - ): - super().__init__() - self.save_hyperparameters(ignore=["datastore"]) - self.args = args - self._datastore = datastore - num_state_vars = datastore.get_num_data_vars(category="state") - num_forcing_vars = datastore.get_num_data_vars(category="forcing") - # Load static features standardized - da_static_features = datastore.get_dataarray( - category="static", split=None, standardize=True - ) - if da_static_features is None: - raise ValueError("Static features are required for ARModel") - da_state_stats = datastore.get_standardization_dataarray( - category="state" - ) - da_boundary_mask = datastore.boundary_mask - num_past_forcing_steps = args.num_past_forcing_steps - num_future_forcing_steps = args.num_future_forcing_steps - - # Load static features for grid/data, - self.register_buffer( - "grid_static_features", - torch.tensor(da_static_features.values, dtype=torch.float32), - persistent=False, - ) - - state_stats = { - "state_mean": torch.tensor( - da_state_stats.state_mean.values, dtype=torch.float32 - ), - "state_std": torch.tensor( - da_state_stats.state_std.values, dtype=torch.float32 - ), - # Note that the one-step-diff stats (diff_mean and diff_std) are - # for differences computed on standardized data - "diff_mean": torch.tensor( - da_state_stats.state_diff_mean_standardized.values, - dtype=torch.float32, - ), - "diff_std": torch.tensor( - da_state_stats.state_diff_std_standardized.values, - dtype=torch.float32, - ), - } - - for key, val in state_stats.items(): - self.register_buffer(key, val, persistent=False) - - state_feature_weights = get_state_feature_weighting( - config=config, datastore=datastore - ) - self.feature_weights = torch.tensor( - state_feature_weights, dtype=torch.float32 - ) - - # Double grid output dim. to also output std.-dev. - self.output_std = bool(args.output_std) - if self.output_std: - # Pred. dim. in grid cell - self.grid_output_dim = 2 * num_state_vars - else: - # Pred. dim. in grid cell - self.grid_output_dim = num_state_vars - # Store constant per-variable std.-dev. weighting - # NOTE that this is the inverse of the multiplicative weighting - # in wMSE/wMAE - self.register_buffer( - "per_var_std", - self.diff_std / torch.sqrt(self.feature_weights), - persistent=False, - ) - - # grid_dim from data + static - ( - self.num_grid_nodes, - grid_static_dim, - ) = self.grid_static_features.shape - - self.grid_dim = ( - 2 * num_state_vars - + grid_static_dim - + num_forcing_vars - * (num_past_forcing_steps + num_future_forcing_steps + 1) - ) - - # Instantiate loss function - self.loss = metrics.get_metric(args.loss) - - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - - self.val_metrics: Dict[str, List] = { - "mse": [], - } - self.test_metrics: Dict[str, List] = { - "mse": [], - "mae": [], - } - if self.output_std: - self.test_metrics["output_std"] = [] # Treat as metric - - # For making restoring of optimizer state optional - self.restore_opt = args.restore_opt - - # For example plotting - self.n_example_pred = args.n_example_pred - self.plotted_examples = 0 - - # For storing spatial loss maps during evaluation - self.spatial_loss_maps: List[Any] = [] - - self.time_step_int, self.time_step_unit = get_integer_time( - self._datastore.step_length - ) - - def _create_dataarray_from_tensor( - self, - tensor: torch.Tensor, - time: torch.Tensor, - split: str, - category: str, - ) -> xr.DataArray: - """ - Create an `xr.DataArray` from a tensor, with the correct dimensions and - coordinates to match the datastore used by the model. This function in - in effect is the inverse of what is returned by - `WeatherDataset.__getitem__`. - - Parameters - ---------- - tensor : torch.Tensor - The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature]. The tensor will be copied to the CPU if it is - not already there. - time : torch.Tensor - The time index or indices for the data, given as tensor representing - epoch time in nanoseconds. The tensor will be - copied to the CPU memory if they are not already there. - split : str - The split of the data, either 'train', 'val', or 'test' - category : str - The category of the data, either 'state' or 'forcing' - """ - # TODO: creating an instance of WeatherDataset here on every call is - # not how this should be done but whether WeatherDataset should be - # provided to ARModel or where to put plotting still needs discussion - weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time.cpu(), dtype="datetime64[ns]") - da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category - ) - return da - - def configure_optimizers(self): - opt = torch.optim.AdamW( - self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) - ) - return opt - - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - - @staticmethod - def expand_to_batch(x, batch_size): - """ - Expand tensor with initial batch dimension - """ - return x.unsqueeze(0).expand(batch_size, -1, -1) - - def predict_step(self, prev_state, prev_prev_state, forcing): - """ - Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, - num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, - forcing_dim) - """ - raise NotImplementedError("No prediction step implemented") - - def unroll_prediction(self, init_states, forcing_features, true_states): - """ - Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) - """ - prev_prev_state = init_states[:, 0] - prev_state = init_states[:, 1] - prediction_list = [] - pred_std_list = [] - pred_steps = forcing_features.shape[1] - - for i in range(pred_steps): - forcing = forcing_features[:, i] - border_state = true_states[:, i] - - pred_state, pred_std = self.predict_step( - prev_state, prev_prev_state, forcing - ) - # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, - # d_f) or None - - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) - if self.output_std: - pred_std_list.append(pred_std) - - # Update conditioning states - prev_prev_state = prev_state - prev_state = new_state - - prediction = torch.stack( - prediction_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) - if self.output_std: - pred_std = torch.stack( - pred_std_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) - else: - pred_std = self.per_var_std # (d_f,) - - return prediction, pred_std - - def common_step(self, batch): - """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states - """ - (init_states, target_states, forcing_features, batch_times) = batch - - prediction, pred_std = self.unroll_prediction( - init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) - - return prediction, target_states, pred_std, batch_times - - def training_step(self, batch): - """ - Train on single batch - """ - prediction, target, pred_std, _ = self.common_step(batch) - - # Compute loss - batch_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ) - ) # mean over unrolled times and batch - - log_dict = {"train_loss": batch_loss} - self.log_dict( - log_dict, - prog_bar=True, - on_step=True, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - return batch_loss - - def all_gather_cat(self, tensor_to_gather): - """ - Gather tensors across all ranks, and concatenate across dim. 0 (instead - of stacking in new dim. 0) - - tensor_to_gather: (d1, d2, ...), distributed over K ranks - - returns: (K*d1, d2, ...) - """ - return self.all_gather(tensor_to_gather).flatten(0, 1) - - # newer lightning versions requires batch_idx argument, even if unused - # pylint: disable-next=unused-argument - def validation_step(self, batch, batch_idx): - """ - Run validation on single batch - """ - prediction, target, pred_std, _ = self.common_step(batch) - - time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), - dim=0, - ) # (time_steps-1) - mean_loss = torch.mean(time_step_loss) - - # Log loss per time step forward and mean - val_log_dict = { - f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_to_log - if step <= len(time_step_loss) - } - val_log_dict["val_mean_loss"] = mean_loss - self.log_dict( - val_log_dict, - on_step=False, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - - # Store MSEs - entry_mses = metrics.mse( - prediction, - target, - pred_std, - mask=self.interior_mask_bool, - sum_vars=False, - ) # (B, pred_steps, d_f) - self.val_metrics["mse"].append(entry_mses) - - def on_validation_epoch_end(self): - """ - Compute val metrics at the end of val epoch - """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") - - # Clear lists with validation metrics values - for metric_list in self.val_metrics.values(): - metric_list.clear() - - # pylint: disable-next=unused-argument - def test_step(self, batch, batch_idx): - """ - Run test on single batch - """ - # TODO Here batch_times can be used for plotting routines - prediction, target, pred_std, batch_times = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) - - time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), - dim=0, - ) # (time_steps-1,) - mean_loss = torch.mean(time_step_loss) - - # Log loss per time step forward and mean - test_log_dict = { - f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_to_log - } - test_log_dict["test_mean_loss"] = mean_loss - - self.log_dict( - test_log_dict, - on_step=False, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - - # Compute all evaluation metrics for error maps Note: explicitly list - # metrics here, as test_metrics can contain additional ones, computed - # differently, but that should be aggregated on_test_epoch_end - for metric_name in ("mse", "mae"): - metric_func = metrics.get_metric(metric_name) - batch_metric_vals = metric_func( - prediction, - target, - pred_std, - mask=self.interior_mask_bool, - sum_vars=False, - ) # (B, pred_steps, d_f) - self.test_metrics[metric_name].append(batch_metric_vals) - - if self.output_std: - # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) - self.test_metrics["output_std"].append(mean_pred_std) - - # Save per-sample spatial loss for specific times - spatial_loss = self.loss( - prediction, target, pred_std, average_grid=False - ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[ - :, [step - 1 for step in self.args.val_steps_to_log] - ] - self.spatial_loss_maps.append(log_spatial_losses) - # (B, N_log, num_grid_nodes) - - # Plot example predictions (on rank 0 only) - if ( - self.trainer.is_global_zero - and self.plotted_examples < self.n_example_pred - ): - # Need to plot more example predictions - n_additional_examples = min( - prediction.shape[0], - self.n_example_pred - self.plotted_examples, - ) - - self.plot_examples( - batch, - n_additional_examples, - prediction=prediction, - split="test", - ) - - def plot_examples(self, batch, n_examples, split, prediction=None): - """ - Plot the first n_examples forecasts from batch - - batch: batch with data to plot corresponding forecasts for n_examples: - number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes, - d_f), existing prediction. - Generate if None. - """ - if prediction is None: - prediction, target, _, _ = self.common_step(batch) - - target = batch[1] - time = batch[3] - - # Rescale to original data scale - prediction_rescaled = prediction * self.state_std + self.state_mean - target_rescaled = target * self.state_std + self.state_mean - - # Iterate over the examples - for pred_slice, target_slice, time_slice in zip( - prediction_rescaled[:n_examples], - target_rescaled[:n_examples], - time[:n_examples], - ): - # Each slice is (pred_steps, num_grid_nodes, d_f) - self.plotted_examples += 1 # Increment already here - - da_prediction = self._create_dataarray_from_tensor( - tensor=pred_slice, - time=time_slice, - split=split, - category="state", - ).unstack("grid_index") - da_target = self._create_dataarray_from_tensor( - tensor=target_slice, - time=time_slice, - split=split, - category="state", - ).unstack("grid_index") - - var_vmin = ( - torch.minimum( - pred_slice.flatten(0, 1).min(dim=0)[0], - target_slice.flatten(0, 1).min(dim=0)[0], - ) - .cpu() - .numpy() - ) # (d_f,) - var_vmax = ( - torch.maximum( - pred_slice.flatten(0, 1).max(dim=0)[0], - target_slice.flatten(0, 1).max(dim=0)[0], - ) - .cpu() - .numpy() - ) # (d_f,) - var_vranges = list(zip(var_vmin, var_vmax)) - - # Iterate over prediction horizon time steps - for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): - # Create one figure per variable at this time step - var_figs = [ - vis.plot_prediction( - datastore=self._datastore, - title=f"{var_name} ({var_unit}), " - f"t={t_i} ({(self.time_step_int * t_i)}" - f"{self.time_step_unit})", - vrange=var_vrange, - da_prediction=da_prediction.isel( - state_feature=var_i, time=t_i - 1 - ).squeeze(), - da_target=da_target.isel( - state_feature=var_i, time=t_i - 1 - ).squeeze(), - ) - for var_i, (var_name, var_unit, var_vrange) in enumerate( - zip( - self._datastore.get_vars_names("state"), - self._datastore.get_vars_units("state"), - var_vranges, - ) - ) - ] - - example_i = self.plotted_examples - - for var_name, fig in zip( - self._datastore.get_vars_names("state"), var_figs - ): - - # We need treat logging images differently for different - # loggers. WANDB can log multiple images to the same key, - # while other loggers, as MLFlow, need unique keys for - # each image. - if isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{var_name}_example_{example_i}" - else: - key = f"{var_name}_example" - - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[fig], step=t_i) - else: - warnings.warn( - f"{self.logger} does not support image logging." - ) - - plt.close( - "all" - ) # Close all figs for this time step, saves memory - - # Save pred and target as .pt files - torch.save( - pred_slice.cpu(), - os.path.join( - self.logger.save_dir, - f"example_pred_{self.plotted_examples}.pt", - ), - ) - torch.save( - target_slice.cpu(), - os.path.join( - self.logger.save_dir, - f"example_target_{self.plotted_examples}.pt", - ), - ) - - def create_metric_log_dict(self, metric_tensor, prefix, metric_name): - """ - Put together a dict with everything to log for one metric. Also saves - plots as pdf and csv if using test prefix. - - metric_tensor: (pred_steps, d_f), metric values per time and variable - prefix: string, prefix to use for logging metric_name: string, name of - the metric - - Return: log_dict: dict with everything to log for given metric - """ - log_dict = {} - metric_fig = vis.plot_error_map( - errors=metric_tensor, - datastore=self._datastore, - ) - full_log_name = f"{prefix}_{metric_name}" - log_dict[full_log_name] = metric_fig - - if prefix == "test": - # Save pdf - metric_fig.savefig( - os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") - ) - # Save errors also as csv - np.savetxt( - os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), - metric_tensor.cpu().numpy(), - delimiter=",", - ) - - # Check if metrics are watched, log exact values for specific vars - var_names = self._datastore.get_vars_names(category="state") - if full_log_name in self.args.metrics_watch: - for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var_name = var_names[var_i] - for step in timesteps: - key = f"{full_log_name}_{var_name}_step_{step}" - log_dict[key] = metric_tensor[step - 1, var_i] - - return log_dict - - def aggregate_and_plot_metrics(self, metrics_dict, prefix): - """ - Aggregate and create error map plots for all metrics in metrics_dict - - metrics_dict: dictionary with metric_names and list of tensors - with step-evals. - prefix: string, prefix to use for logging - """ - log_dict = {} - for metric_name, metric_val_list in metrics_dict.items(): - metric_tensor = self.all_gather_cat( - torch.cat(metric_val_list, dim=0) - ) # (N_eval, pred_steps, d_f) - - if self.trainer.is_global_zero: - metric_tensor_averaged = torch.mean(metric_tensor, dim=0) - # (pred_steps, d_f) - - # Take square root after all averaging to change MSE to RMSE - if "mse" in metric_name: - metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) - metric_name = metric_name.replace("mse", "rmse") - - # NOTE: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.state_std - # (pred_steps, d_f) - log_dict.update( - self.create_metric_log_dict( - metric_rescaled, prefix, metric_name - ) - ) - - # Ensure that log_dict has structure for - # logging as dict(str, plt.Figure) - assert all( - isinstance(key, str) and isinstance(value, plt.Figure) - for key, value in log_dict.items() - ) - - if self.trainer.is_global_zero and not self.trainer.sanity_checking: - - current_epoch = self.trainer.current_epoch - - for key, figure in log_dict.items(): - # For other loggers than wandb, add epoch to key. - # Wandb can log multiple images to the same key, while other - # loggers, such as MLFlow need unique keys for each image. - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}-{current_epoch}" - - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[figure]) - - plt.close("all") # Close all figs - - def on_test_epoch_end(self): - """ - Compute test metrics and make plots at the end of test epoch. Will - gather stored tensors and perform plotting and logging on rank 0. - """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") - - # Plot spatial loss maps - spatial_loss_tensor = self.all_gather_cat( - torch.cat(self.spatial_loss_maps, dim=0) - ) # (N_test, N_log, num_grid_nodes) - if self.trainer.is_global_zero: - mean_spatial_loss = torch.mean( - spatial_loss_tensor, dim=0 - ) # (N_log, num_grid_nodes) - - loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, - datastore=self._datastore, - title=f"Test loss, t={t_i} " - f"({(self.time_step_int * t_i)} {self.time_step_int_unit})", - ) - for t_i, loss_map in zip( - self.args.val_steps_to_log, mean_spatial_loss - ) - ] - - # log all to same key, sequentially - for i, fig in enumerate(loss_map_figs): - key = "test_loss" - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}_{i}" - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[fig]) - - # also make without title and save as pdf - pdf_loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, datastore=self._datastore - ) - for loss_map in mean_spatial_loss - ] - pdf_loss_maps_dir = os.path.join( - self.logger.save_dir, "spatial_loss_maps" - ) - os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): - fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) - # save mean spatial loss as .pt file also - torch.save( - mean_spatial_loss.cpu(), - os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), - ) - - self.spatial_loss_maps.clear() - - def on_load_checkpoint(self, checkpoint): - """ - Perform any changes to state dict before loading checkpoint - """ - loaded_state_dict = checkpoint["state_dict"] - - # Fix for loading older models after IneractionNet refactoring, where - # the grid MLP was moved outside the encoder InteractionNet class - if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: - replace_keys = list( - filter( - lambda key: key.startswith("g2m_gnn.grid_mlp"), - loaded_state_dict.keys(), - ) - ) - for old_key in replace_keys: - new_key = old_key.replace( - "g2m_gnn.grid_mlp", "encoding_grid_mlp" - ) - loaded_state_dict[new_key] = loaded_state_dict[old_key] - del loaded_state_dict[old_key] - if not self.restore_opt: - opt = self.configure_optimizers() - checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 35b1ab126..7f4099b3a 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -5,23 +5,66 @@ from .. import utils from ..config import NeuralLAMConfig from ..datastore import BaseDatastore -from ..interaction_net import InteractionNet -from .ar_model import ARModel +from ..interaction_net import InteractionNet, PropagationNet +from .step_predictor import StepPredictor -class BaseGraphModel(ARModel): +class BaseGraphModel(StepPredictor): """ Base (abstract) class for graph-based models building on the encode-process-decode idea. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + vertical_propnets: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + output_std=output_std, + ) + self.vertical_propnets = vertical_propnets + + # Retrieve difference statistics for rescaling in forward pass + da_state_stats = datastore.get_standardization_dataarray("state") + self.register_buffer( + "diff_mean", + torch.tensor( + da_state_stats.state_diff_mean_standardized.values, + dtype=torch.float32, + ), + persistent=False, + ) + self.register_buffer( + "diff_std", + torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ), + persistent=False, + ) + + # Store architecture hyperparameters for subclass use + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + self.processor_layers = processor_layers + self.mesh_aggr = mesh_aggr # Load graph with static features # NOTE: (IMPORTANT!) mesh nodes MUST have the first # num_mesh_nodes indices, - graph_dir_path = datastore.root_path / "graph" / args.graph + graph_dir_path = datastore.root_path / "graph" / graph_name self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) @@ -39,232 +82,61 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" ) - # grid_dim from data + static + # Compute grid_input_dim: total input dimensionality on the grid + num_state_vars = datastore.get_num_data_vars(category="state") + num_forcing_vars = datastore.get_num_data_vars(category="forcing") + grid_static_dim = self.grid_static_features.shape[1] + self.grid_input_dim = ( + 2 * num_state_vars + + grid_static_dim + + num_forcing_vars + * (num_past_forcing_steps + num_future_forcing_steps + 1) + ) + self.g2m_edges, g2m_dim = self.g2m_features.shape self.m2g_edges, m2g_dim = self.m2g_features.shape # Define sub-models # Feature embedders for grid - self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) + self.mlp_blueprint_end = [hidden_dim] * (hidden_layers + 1) self.grid_embedder = utils.make_mlp( - [self.grid_dim] + self.mlp_blueprint_end + [self.grid_input_dim] + self.mlp_blueprint_end ) self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) # GNNs + gnn_class = ( + PropagationNet if vertical_propnets else InteractionNet + ) # encoder - self.g2m_gnn = InteractionNet( + self.g2m_gnn = gnn_class( self.g2m_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) self.encoding_grid_mlp = utils.make_mlp( - [args.hidden_dim] + self.mlp_blueprint_end + [hidden_dim] + self.mlp_blueprint_end ) # decoder - self.m2g_gnn = InteractionNet( + self.m2g_gnn = gnn_class( self.m2g_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) # Output mapping (hidden_dim -> output_dim) self.output_map = utils.make_mlp( - [args.hidden_dim] * (args.hidden_layers + 1) - + [self.grid_output_dim], + [hidden_dim] * (hidden_layers + 1) + [self.grid_output_dim], layer_norm=False, ) # No layer norm on this one # Compute indices and define clamping functions self.prepare_clamping_params(config, datastore) - def prepare_clamping_params( - self, config: NeuralLAMConfig, datastore: BaseDatastore - ): - """ - Prepare parameters for clamping predicted values to valid range - """ - - # Read configs - state_feature_names = datastore.get_vars_names(category="state") - lower_lims = config.training.output_clamping.lower - upper_lims = config.training.output_clamping.upper - - # Check that limits in config are for valid features - unknown_features_lower = set(lower_lims.keys()) - set( - state_feature_names - ) - unknown_features_upper = set(upper_lims.keys()) - set( - state_feature_names - ) - if unknown_features_lower or unknown_features_upper: - raise ValueError( - "State feature limits were provided for unknown features: " - f"{unknown_features_lower.union(unknown_features_upper)}" - ) - - # Constant parameters for clamping - sigmoid_sharpness = 1 - softplus_sharpness = 1 - sigmoid_center = 0 - softplus_center = 0 - - normalize_clamping_lim = ( - lambda x, feature_idx: (x - self.state_mean[feature_idx]) - / self.state_std[feature_idx] - ) - - # Check which clamping functions to use for each feature - sigmoid_lower_upper_idx = [] - sigmoid_lower_lims = [] - sigmoid_upper_lims = [] - - softplus_lower_idx = [] - softplus_lower_lims = [] - - softplus_upper_idx = [] - softplus_upper_lims = [] - - for feature_idx, feature in enumerate(state_feature_names): - if feature in lower_lims and feature in upper_lims: - assert ( - lower_lims[feature] < upper_lims[feature] - ), f'Invalid clamping limits for feature "{feature}",\ - lower: {lower_lims[feature]}, larger than\ - upper: {upper_lims[feature]}' - sigmoid_lower_upper_idx.append(feature_idx) - sigmoid_lower_lims.append( - normalize_clamping_lim(lower_lims[feature], feature_idx) - ) - sigmoid_upper_lims.append( - normalize_clamping_lim(upper_lims[feature], feature_idx) - ) - elif feature in lower_lims and feature not in upper_lims: - softplus_lower_idx.append(feature_idx) - softplus_lower_lims.append( - normalize_clamping_lim(lower_lims[feature], feature_idx) - ) - elif feature not in lower_lims and feature in upper_lims: - softplus_upper_idx.append(feature_idx) - softplus_upper_lims.append( - normalize_clamping_lim(upper_lims[feature], feature_idx) - ) - - self.register_buffer( - "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) - ) - self.register_buffer( - "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) - ) - self.register_buffer( - "softplus_lower_lims", torch.tensor(softplus_lower_lims) - ) - self.register_buffer( - "softplus_upper_lims", torch.tensor(softplus_upper_lims) - ) - - self.register_buffer( - "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) - ) - self.register_buffer( - "clamp_lower_idx", torch.tensor(softplus_lower_idx) - ) - self.register_buffer( - "clamp_upper_idx", torch.tensor(softplus_upper_idx) - ) - - # Define clamping functions - self.clamp_lower_upper = lambda x: ( - self.sigmoid_lower_lims - + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) - * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) - ) - self.clamp_lower = lambda x: ( - self.softplus_lower_lims - + torch.nn.functional.softplus( - x - softplus_center, beta=softplus_sharpness - ) - ) - self.clamp_upper = lambda x: ( - self.softplus_upper_lims - - torch.nn.functional.softplus( - softplus_center - x, beta=softplus_sharpness - ) - ) - - self.inverse_clamp_lower_upper = lambda x: ( - sigmoid_center - + utils.inverse_sigmoid( - (x - self.sigmoid_lower_lims) - / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) - ) - / sigmoid_sharpness - ) - self.inverse_clamp_lower = lambda x: ( - utils.inverse_softplus( - x - self.softplus_lower_lims, beta=softplus_sharpness - ) - + softplus_center - ) - self.inverse_clamp_upper = lambda x: ( - -utils.inverse_softplus( - self.softplus_upper_lims - x, beta=softplus_sharpness - ) - + softplus_center - ) - - def get_clamped_new_state(self, state_delta, prev_state): - """ - Clamp prediction to valid range supplied in config - Returns the clamped new state after adding delta to original state - - Instead of the new state being computed as - $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ - The clamped values will be - $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ - Which means the model will learn to output values in the range of the - inverse clamping function - - state_delta: (B, num_grid_nodes, feature_dim) - prev_state: (B, num_grid_nodes, feature_dim) - """ - - # Assign new state, but overwrite clamped values of each type later - new_state = prev_state + state_delta - - # Sigmoid/logistic clamps between ]a,b[ - if self.clamp_lower_upper_idx.numel() > 0: - idx = self.clamp_lower_upper_idx - - new_state[:, :, idx] = self.clamp_lower_upper( - self.inverse_clamp_lower_upper(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - # Softplus clamps between ]a,infty[ - if self.clamp_lower_idx.numel() > 0: - idx = self.clamp_lower_idx - - new_state[:, :, idx] = self.clamp_lower( - self.inverse_clamp_lower(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - # Softplus clamps between ]-infty,b[ - if self.clamp_upper_idx.numel() > 0: - idx = self.clamp_upper_idx - - new_state[:, :, idx] = self.clamp_upper( - self.inverse_clamp_upper(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - return new_state - def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, @@ -289,7 +161,7 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def forward(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 882dbf4da..bf639e1d6 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -5,7 +5,7 @@ from .. import utils from ..config import NeuralLAMConfig from ..datastore import BaseDatastore -from ..interaction_net import InteractionNet +from ..interaction_net import InteractionNet, PropagationNet from .base_graph_model import BaseGraphModel @@ -14,8 +14,33 @@ class BaseHiGraphModel(BaseGraphModel): Base class for hierarchical graph models. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + vertical_propnets: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + vertical_propnets=vertical_propnets, + ) # Track number of nodes, edges on each level # Flatten lists for efficient embedding @@ -77,12 +102,15 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # Instantiate GNNs # Init GNNs + init_gnn_class = ( + PropagationNet if self.vertical_propnets else InteractionNet + ) self.mesh_init_gnns = nn.ModuleList( [ - InteractionNet( + init_gnn_class( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, ) for edge_index in self.mesh_up_edge_index ] @@ -93,8 +121,8 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) for edge_index in self.mesh_down_edge_index diff --git a/neural_lam/models/forecaster.py b/neural_lam/models/forecaster.py new file mode 100644 index 000000000..a2f268c3a --- /dev/null +++ b/neural_lam/models/forecaster.py @@ -0,0 +1,36 @@ +# Standard library +from abc import ABC, abstractmethod + +# Third-party +import torch +from torch import nn + + +class Forecaster(nn.Module, ABC): + """ + Generic forecaster capable of mapping from a set of initial states, + forcing and forces and previous states into a full forecast of the + requested length. + """ + + @property + @abstractmethod + def predicts_std(self) -> bool: + """Whether this forecaster outputs a predicted standard deviation.""" + + @abstractmethod + def forward( + self, + init_states: torch.Tensor, + forcing_features: torch.Tensor, + boundary_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + boundary_states: (B, pred_steps, num_grid_nodes, d_f) + Returns: + prediction: (B, pred_steps, num_grid_nodes, d_f) + pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + """ + pass diff --git a/neural_lam/models/forecaster_module.py b/neural_lam/models/forecaster_module.py new file mode 100644 index 000000000..1d92b8082 --- /dev/null +++ b/neural_lam/models/forecaster_module.py @@ -0,0 +1,604 @@ +# Standard library +import os +import warnings +from typing import Any, Dict, List, Optional + +# Third-party +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import xarray as xr + +# First-party +from neural_lam.utils import get_integer_time + +# Local +from .. import metrics, vis +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset +from .forecaster import Forecaster + + +class ForecasterModule(pl.LightningModule): + """ + Lightning module handling training, validation and testing loops. + Wraps a Forecaster instance which performs the actual prediction. + """ + + # pylint: disable=arguments-differ + + def __init__( + self, + forecaster: Forecaster, + config: NeuralLAMConfig, + datastore: BaseDatastore, + loss: str = "wmse", + lr: float = 1e-3, + restore_opt: bool = False, + n_example_pred: int = 1, + val_steps_to_log: Optional[List[int]] = None, + metrics_watch: Optional[List[str]] = None, + var_leads_metrics_watch: Optional[Dict[int, List[int]]] = None, + ): + super().__init__() + # Resolve mutable defaults + if val_steps_to_log is None: + val_steps_to_log = [ + 1, + ] + if metrics_watch is None: + metrics_watch = [] + if var_leads_metrics_watch is None: + var_leads_metrics_watch = {} + + # Note: datastore is excluded from saved hparams and must be provided + # explicitly when calling load_from_checkpoint(path, + # datastore=datastore) + self.save_hyperparameters(ignore=["datastore", "forecaster"]) + self.datastore = datastore + self.forecaster = forecaster + + # Compute interior_mask_bool directly from datastore + boundary_mask = ( + torch.tensor(datastore.boundary_mask.values, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(-1) + ) # (1, num_grid_nodes, 1) + interior_mask = 1.0 - boundary_mask + self.register_buffer( + "interior_mask_bool", + interior_mask[0, :, 0].to(torch.bool), + persistent=False, + ) + + # Store per_var_std here if predictor does not output std + if not self.forecaster.predicts_std: + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + state_feature_weights = get_state_feature_weighting( + config=config, datastore=datastore + ) + diff_std = torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ) + feature_weights_t = torch.tensor( + state_feature_weights, dtype=torch.float32 + ) + self.register_buffer( + "per_var_std", + diff_std / torch.sqrt(feature_weights_t), + persistent=False, + ) + else: + self.per_var_std = None + + # Instantiate loss function + self.loss = metrics.get_metric(loss) + + self.val_metrics: Dict[str, List] = { + "mse": [], + } + self.test_metrics: Dict[str, List] = { + "mse": [], + "mae": [], + } + if self.forecaster.predicts_std: + self.test_metrics["output_std"] = [] # Treat as metric + + # For making restoring of optimizer state optional + self.restore_opt = restore_opt + + # For example plotting + self.n_example_pred = n_example_pred + self.plotted_examples = 0 + + # For storing spatial loss maps during evaluation + self.spatial_loss_maps: List[Any] = [] + + self.time_step_int, self.time_step_unit = get_integer_time( + self.datastore.step_length + ) + + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: torch.Tensor, + split: str, + category: str, + ) -> xr.DataArray: + weather_dataset = WeatherDataset(datastore=self.datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + + def configure_optimizers(self): + opt = torch.optim.AdamW( + self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.95) + ) + return opt + + def training_step(self, batch): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + if pred_std is None: + pred_std = self.per_var_std + + batch_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ) + ) + + log_dict = {"train_loss": batch_loss} + self.log_dict( + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + return batch_loss + + def all_gather_cat(self, tensor_to_gather): + gathered = self.all_gather(tensor_to_gather) + # all_gather adds dim 0 only on multi-device; on single + # device it returns the same tensor unchanged. + if gathered.dim() > tensor_to_gather.dim(): + gathered = gathered.flatten(0, 1) + return gathered + + # pylint: disable-next=unused-argument + def validation_step(self, batch, batch_idx): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + if pred_std is None: + pred_std = self.per_var_std + + time_step_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ), + dim=0, + ) + mean_loss = torch.mean(time_step_loss) + + val_log_dict = { + f"val_loss_unroll{step}": time_step_loss[step - 1] + for step in self.hparams.val_steps_to_log + if step <= len(time_step_loss) + } + val_log_dict["val_mean_loss"] = mean_loss + self.log_dict( + val_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + entry_mses = metrics.mse( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) + self.val_metrics["mse"].append(entry_mses) + + def on_validation_epoch_end(self): + self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") + for metric_list in self.val_metrics.values(): + metric_list.clear() + + # pylint: disable-next=unused-argument + def test_step(self, batch, batch_idx): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + + if pred_std is not None: + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) + self.test_metrics["output_std"].append(mean_pred_std) + + if pred_std is None: + pred_std = self.per_var_std + + time_step_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ), + dim=0, + ) + mean_loss = torch.mean(time_step_loss) + + test_log_dict = { + f"test_loss_unroll{step}": time_step_loss[step - 1] + for step in self.hparams.val_steps_to_log + if step <= len(time_step_loss) + } + test_log_dict["test_mean_loss"] = mean_loss + + self.log_dict( + test_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) + self.test_metrics[metric_name].append(batch_metric_vals) + + spatial_loss = self.loss( + prediction, target_states, pred_std, average_grid=False + ) + log_spatial_losses = spatial_loss[ + :, + [ + step - 1 + for step in self.hparams.val_steps_to_log + if step <= spatial_loss.shape[1] + ], + ] + self.spatial_loss_maps.append(log_spatial_losses) + + if ( + self.trainer.is_global_zero + and self.plotted_examples < self.n_example_pred + ): + n_additional_examples = min( + prediction.shape[0], + self.n_example_pred - self.plotted_examples, + ) + + self.plot_examples( + batch, + n_additional_examples, + prediction=prediction, + split="test", + ) + + def plot_examples(self, batch, n_examples, split, prediction): + + target = batch[1] + time = batch[3] + + da_state_stats = self.datastore.get_standardization_dataarray("state") + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=prediction.device, + ) + state_mean = torch.tensor( + da_state_stats.state_mean.values, + dtype=torch.float32, + device=prediction.device, + ) + + prediction_rescaled = prediction * state_std + state_mean + target_rescaled = target * state_std + state_mean + + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], + ): + self.plotted_examples += 1 + + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + + var_vmin = ( + torch.minimum( + pred_slice.flatten(0, 1).min(dim=0)[0], + target_slice.flatten(0, 1).min(dim=0)[0], + ) + .cpu() + .numpy() + ) + var_vmax = ( + torch.maximum( + pred_slice.flatten(0, 1).max(dim=0)[0], + target_slice.flatten(0, 1).max(dim=0)[0], + ) + .cpu() + .numpy() + ) + var_vranges = list(zip(var_vmin, var_vmax)) + + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): + var_figs = [ + vis.plot_prediction( + datastore=self.datastore, + title=f"{var_name} ({var_unit}), " + f"t={t_i} ({(self.time_step_int * t_i)}" + f"{self.time_step_unit})", + vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip( + self.datastore.get_vars_names("state"), + self.datastore.get_vars_units("state"), + var_vranges, + ) + ) + ] + + example_i = self.plotted_examples + + for var_name, fig in zip( + self.datastore.get_vars_names("state"), var_figs + ): + if isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{var_name}_example_{example_i}" + else: + key = f"{var_name}_example" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig], step=t_i) + else: + warnings.warn( + f"{self.logger} does not support image logging." + ) + + plt.close("all") + + torch.save( + pred_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_pred_{self.plotted_examples}.pt", + ), + ) + torch.save( + target_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_target_{self.plotted_examples}.pt", + ), + ) + + def create_metric_log_dict(self, metric_tensor, prefix, metric_name): + log_dict = {} + metric_fig = vis.plot_error_map( + errors=metric_tensor, + datastore=self.datastore, + ) + full_log_name = f"{prefix}_{metric_name}" + log_dict[full_log_name] = metric_fig + + if prefix == "test": + metric_fig.savefig( + os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") + ) + np.savetxt( + os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) + + var_names = self.datastore.get_vars_names(category="state") + if full_log_name in self.hparams.metrics_watch: + for ( + var_i, + timesteps, + ) in self.hparams.var_leads_metrics_watch.items(): + var_name = var_names[var_i] + for step in timesteps: + key = f"{full_log_name}_{var_name}_step_{step}" + log_dict[key] = metric_tensor[step - 1, var_i] + + return log_dict + + def aggregate_and_plot_metrics(self, metrics_dict, prefix): + log_dict = {} + for metric_name, metric_val_list in metrics_dict.items(): + metric_tensor = self.all_gather_cat( + torch.cat(metric_val_list, dim=0) + ) + + if self.trainer.is_global_zero: + metric_tensor_averaged = torch.mean(metric_tensor, dim=0) + + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + + da_state_stats = self.datastore.get_standardization_dataarray( + "state" + ) + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=metric_tensor_averaged.device, + ) + metric_rescaled = metric_tensor_averaged * state_std + + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name + ) + ) + + figure_dict = { + k: v for k, v in log_dict.items() if isinstance(v, plt.Figure) + } + scalar_dict = { + k: v for k, v in log_dict.items() if not isinstance(v, plt.Figure) + } + + if self.trainer.is_global_zero and not self.trainer.sanity_checking: + # Log scalars via Lightning's built-in mechanism + if scalar_dict: + self.log_dict(scalar_dict, sync_dist=True) + + current_epoch = self.trainer.current_epoch + + for key, figure in figure_dict.items(): + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}-{current_epoch}" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[figure]) + + plt.close("all") + + def on_test_epoch_end(self): + self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") + + spatial_loss_tensor = self.all_gather_cat( + torch.cat(self.spatial_loss_maps, dim=0) + ) + if self.trainer.is_global_zero: + mean_spatial_loss = torch.mean(spatial_loss_tensor, dim=0) + + loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, + datastore=self.datastore, + title=f"Test loss, t={t_i} " + f"({(self.time_step_int * t_i)} {self.time_step_unit})", + ) + for t_i, loss_map in zip( + self.hparams.val_steps_to_log, mean_spatial_loss + ) + ] + + for i, fig in enumerate(loss_map_figs): + key = "test_loss" + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}_{i}" + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig]) + + pdf_loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, datastore=self.datastore + ) + for loss_map in mean_spatial_loss + ] + pdf_loss_maps_dir = os.path.join( + self.logger.save_dir, "spatial_loss_maps" + ) + os.makedirs(pdf_loss_maps_dir, exist_ok=True) + for t_i, fig in zip( + self.hparams.val_steps_to_log, pdf_loss_map_figs + ): + fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + + torch.save( + mean_spatial_loss.cpu(), + os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), + ) + + self.spatial_loss_maps.clear() + + def on_load_checkpoint(self, checkpoint): + loaded_state_dict = checkpoint["state_dict"] + + # 1. Broad namespace remap: for pre-refactor checkpoints + # The old ARModel was a flat LightningModule. Everything that belonged + # to the predictor needs to be moved to 'forecaster.predictor.' + old_keys = list(loaded_state_dict.keys()) + for key in old_keys: + if not key.startswith("forecaster.") and key not in ( + "interior_mask_bool", + "per_var_std", + ): + new_key = f"forecaster.predictor.{key}" + loaded_state_dict[new_key] = loaded_state_dict.pop(key) + + # 2. Specific rename from g2m_gnn.grid_mlp -> encoding_grid_mlp + # Will be under forecaster.predictor due to the remap above, or + # already there if from a recent checkpoint before this rename. + if ( + "forecaster.predictor.g2m_gnn.grid_mlp.0.weight" + in loaded_state_dict + ): + replace_keys = list( + filter( + lambda key: key.startswith( + "forecaster.predictor.g2m_gnn.grid_mlp" + ), + loaded_state_dict.keys(), + ) + ) + for old_key in replace_keys: + new_key = old_key.replace( + "forecaster.predictor.g2m_gnn.grid_mlp", + "forecaster.predictor.encoding_grid_mlp", + ) + loaded_state_dict[new_key] = loaded_state_dict[old_key] + del loaded_state_dict[old_key] + + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index 0a5b6b574..06fcc7a2a 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -17,8 +17,33 @@ class GraphLAM(BaseGraphModel): Oskarsson et al. (2023). """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + vertical_propnets: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + vertical_propnets=vertical_propnets, + ) assert ( not self.hierarchical @@ -42,11 +67,11 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): processor_nets = [ InteractionNet( self.m2m_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, - aggr=args.mesh_aggr, + hidden_dim, + hidden_layers=hidden_layers, + aggr=mesh_aggr, ) - for _ in range(args.processor_layers) + for _ in range(processor_layers) ] self.processor = pyg.nn.Sequential( "mesh_rep, edge_rep", diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index c340c95da..66a7b5d1b 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -4,7 +4,7 @@ # Local from ..config import NeuralLAMConfig from ..datastore import BaseDatastore -from ..interaction_net import InteractionNet +from ..interaction_net import InteractionNet, PropagationNet from .base_hi_graph_model import BaseHiGraphModel @@ -15,26 +15,51 @@ class HiLAM(BaseHiGraphModel): The Hi-LAM model from Oskarsson et al. (2023) """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + vertical_propnets: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + vertical_propnets=vertical_propnets, + ) # Make down GNNs, both for down edges and same level self.mesh_down_gnns = nn.ModuleList( - [self.make_down_gnns(args) for _ in range(args.processor_layers)] + [self.make_down_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels-1) self.mesh_down_same_gnns = nn.ModuleList( - [self.make_same_gnns(args) for _ in range(args.processor_layers)] + [self.make_same_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels) # Make up GNNs, both for up edges and same level self.mesh_up_gnns = nn.ModuleList( - [self.make_up_gnns(args) for _ in range(args.processor_layers)] + [self.make_up_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels-1) self.mesh_up_same_gnns = nn.ModuleList( - [self.make_same_gnns(args) for _ in range(args.processor_layers)] + [self.make_same_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels) - def make_same_gnns(self, args): + def make_same_gnns(self): """ Make intra-level GNNs. """ @@ -42,29 +67,34 @@ def make_same_gnns(self, args): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.m2m_edge_index ] ) - def make_up_gnns(self, args): + def make_up_gnns(self): """ Make GNNs for processing steps up through the hierarchy. """ + gnn_class = ( + PropagationNet + if self.vertical_propnets + else InteractionNet + ) return nn.ModuleList( [ - InteractionNet( + gnn_class( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.mesh_up_edge_index ] ) - def make_down_gnns(self, args): + def make_down_gnns(self): """ Make GNNs for processing steps down through the hierarchy. """ @@ -72,8 +102,8 @@ def make_down_gnns(self, args): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.mesh_down_edge_index ] diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index a0a84d293..d4cbd1ecf 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -18,8 +18,33 @@ class HiLAMParallel(BaseHiGraphModel): of Hi-LAM. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + vertical_propnets: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + vertical_propnets=vertical_propnets, + ) # Processor GNNs # Create the complete edge_index combining all edges for processing @@ -31,18 +56,18 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): total_edge_index = torch.cat(total_edge_index_list, dim=1) self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] - if args.processor_layers == 0: + if processor_layers == 0: self.processor = lambda x, edge_attr: (x, edge_attr) else: processor_nets = [ InteractionNet( total_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, edge_chunk_sizes=self.edge_split_sections, aggr_chunk_sizes=self.level_mesh_sizes, ) - for _ in range(args.processor_layers) + for _ in range(processor_layers) ] self.processor = pyg.nn.Sequential( "mesh_rep, edge_rep", diff --git a/neural_lam/models/step_predictor.py b/neural_lam/models/step_predictor.py new file mode 100644 index 000000000..4ba7a8393 --- /dev/null +++ b/neural_lam/models/step_predictor.py @@ -0,0 +1,286 @@ +# Standard library +from abc import ABC, abstractmethod +from typing import Optional + +# Third-party +import torch +from torch import nn + +# Local +from .. import utils +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore + + +class StepPredictor(nn.Module, ABC): + """ + Abstract base class for step predictors mapping from the two previous + time steps plus forcing into a prediction of the next state. + """ + + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + output_std: bool = False, + ): + super().__init__() + + num_state_vars = datastore.get_num_data_vars(category="state") + + # Load static features standardized + da_static_features = datastore.get_dataarray( + category="static", split=None, standardize=True + ) + if da_static_features is None: + # Create empty static features of the correct shape + num_grid_nodes = datastore.num_grid_points + grid_static_features = torch.empty( + (num_grid_nodes, 0), dtype=torch.float32 + ) + else: + grid_static_features = torch.tensor( + da_static_features.values, dtype=torch.float32 + ) + + self.register_buffer( + "grid_static_features", + grid_static_features, + persistent=False, + ) + + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + + self.register_buffer( + "state_mean", + torch.tensor(da_state_stats.state_mean.values, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "state_std", + torch.tensor(da_state_stats.state_std.values, dtype=torch.float32), + persistent=False, + ) + + self.output_std = bool(output_std) + if self.output_std: + self.grid_output_dim = 2 * num_state_vars + else: + self.grid_output_dim = num_state_vars + + (self.num_grid_nodes, _) = self.grid_static_features.shape + + @property + def predicts_std(self) -> bool: + """Whether this predictor outputs a predicted standard deviation.""" + return self.output_std + + def expand_to_batch(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: + """ + Expand tensor with shape (N, d) to (B, N, d) + """ + return x.unsqueeze(0).expand(batch_size, -1, -1) + + @abstractmethod + def forward( + self, + prev_state: torch.Tensor, + prev_prev_state: torch.Tensor, + forcing: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) + + Returns: + pred_state: (B, num_grid_nodes, d_f) + pred_std: (B, num_grid_nodes, d_f) or None + """ + pass + + def prepare_clamping_params( + self, config: NeuralLAMConfig, datastore: BaseDatastore + ): + """ + Prepare parameters for clamping predicted values to valid range + """ + + # Read configs + state_feature_names = datastore.get_vars_names(category="state") + lower_lims = config.training.output_clamping.lower + upper_lims = config.training.output_clamping.upper + + # Check that limits in config are for valid features + unknown_features_lower = set(lower_lims.keys()) - set( + state_feature_names + ) + unknown_features_upper = set(upper_lims.keys()) - set( + state_feature_names + ) + if unknown_features_lower or unknown_features_upper: + raise ValueError( + "State feature limits were provided for unknown features: " + f"{unknown_features_lower.union(unknown_features_upper)}" + ) + + # Constant parameters for clamping + sigmoid_sharpness = 1 + softplus_sharpness = 1 + sigmoid_center = 0 + softplus_center = 0 + + normalize_clamping_lim = ( + lambda x, feature_idx: (x - self.state_mean[feature_idx]) + / self.state_std[feature_idx] + ) + + # Check which clamping functions to use for each feature + sigmoid_lower_upper_idx = [] + sigmoid_lower_lims = [] + sigmoid_upper_lims = [] + + softplus_lower_idx = [] + softplus_lower_lims = [] + + softplus_upper_idx = [] + softplus_upper_lims = [] + + for feature_idx, feature in enumerate(state_feature_names): + if feature in lower_lims and feature in upper_lims: + assert ( + lower_lims[feature] < upper_lims[feature] + ), f'Invalid clamping limits for feature "{feature}",\ + lower: {lower_lims[feature]}, larger than\ + upper: {upper_lims[feature]}' + sigmoid_lower_upper_idx.append(feature_idx) + sigmoid_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + sigmoid_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + elif feature in lower_lims and feature not in upper_lims: + softplus_lower_idx.append(feature_idx) + softplus_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + elif feature not in lower_lims and feature in upper_lims: + softplus_upper_idx.append(feature_idx) + softplus_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + + self.register_buffer( + "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) + ) + self.register_buffer( + "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) + ) + self.register_buffer( + "softplus_lower_lims", torch.tensor(softplus_lower_lims) + ) + self.register_buffer( + "softplus_upper_lims", torch.tensor(softplus_upper_lims) + ) + + self.register_buffer( + "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) + ) + self.register_buffer( + "clamp_lower_idx", torch.tensor(softplus_lower_idx) + ) + self.register_buffer( + "clamp_upper_idx", torch.tensor(softplus_upper_idx) + ) + + # Define clamping functions + self.clamp_lower_upper = lambda x: ( + self.sigmoid_lower_lims + + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) + ) + self.clamp_lower = lambda x: ( + self.softplus_lower_lims + + torch.nn.functional.softplus( + x - softplus_center, beta=softplus_sharpness + ) + ) + self.clamp_upper = lambda x: ( + self.softplus_upper_lims + - torch.nn.functional.softplus( + softplus_center - x, beta=softplus_sharpness + ) + ) + + self.inverse_clamp_lower_upper = lambda x: ( + sigmoid_center + + utils.inverse_sigmoid( + (x - self.sigmoid_lower_lims) + / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + ) + / sigmoid_sharpness + ) + self.inverse_clamp_lower = lambda x: ( + utils.inverse_softplus( + x - self.softplus_lower_lims, beta=softplus_sharpness + ) + + softplus_center + ) + self.inverse_clamp_upper = lambda x: ( + -utils.inverse_softplus( + self.softplus_upper_lims - x, beta=softplus_sharpness + ) + + softplus_center + ) + + def get_clamped_new_state(self, state_delta, prev_state): + """ + Clamp prediction to valid range supplied in config + Returns the clamped new state after adding delta to original state + + Instead of the new state being computed as + $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ + The clamped values will be + $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ + Which means the model will learn to output values in the range of the + inverse clamping function + + state_delta: (B, num_grid_nodes, feature_dim) + prev_state: (B, num_grid_nodes, feature_dim) + """ + + # Assign new state, but overwrite clamped values of each type later + new_state = prev_state + state_delta + + # Sigmoid/logistic clamps between ]a,b[ + if self.clamp_lower_upper_idx.numel() > 0: + idx = self.clamp_lower_upper_idx + + new_state[:, :, idx] = self.clamp_lower_upper( + self.inverse_clamp_lower_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]a,infty[ + if self.clamp_lower_idx.numel() > 0: + idx = self.clamp_lower_idx + + new_state[:, :, idx] = self.clamp_lower( + self.inverse_clamp_lower(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]-infty,b[ + if self.clamp_upper_idx.numel() > 0: + idx = self.clamp_upper_idx + + new_state[:, :, idx] = self.clamp_upper( + self.inverse_clamp_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + return new_state diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 7c7a4eefe..0cb0197f1 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -14,15 +14,10 @@ # Local from . import utils from .config import load_config_and_datastore -from .models import GraphLAM, HiLAM, HiLAMParallel +from .models import MODELS, ForecasterModule +from .models.ar_forecaster import ARForecaster from .weather_dataset import WeatherDataModule -MODELS = { - "graph_lam": GraphLAM, - "hi_lam": HiLAM, - "hi_lam_parallel": HiLAMParallel, -} - @logger.catch def main(input_args=None): @@ -128,6 +123,12 @@ def main(input_args=None): help="If models should additionally output std.-dev. per " "output dimensions", ) + parser.add_argument( + "--vertical_propnets", + action="store_true", + help="If PropagationNets should be used for all vertical " + "(grid-mesh and up/down) message passing", + ) # Training options parser.add_argument( @@ -286,9 +287,36 @@ def main(input_args=None): except ValueError: raise ValueError("devices should be 'auto' or a list of integers") - # Load model parameters Use new args for model - ModelClass = MODELS[args.model] - model = ModelClass(args, config=config, datastore=datastore) + # Build predictor and forecaster externally, then inject into + # ForecasterModule + predictor_class = MODELS[args.model] + predictor = predictor_class( + config=config, + datastore=datastore, + graph=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + vertical_propnets=args.vertical_propnets, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss=args.loss, + lr=args.lr, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + ) if args.eval: prefix = f"eval-{args.eval}-" diff --git a/tests/test_clamping.py b/tests/test_clamping.py index f3f9365d0..09a46ffdf 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -57,9 +57,16 @@ class ModelArgs: ) model = GraphLAM( - args=model_args, - datastore=datastore, config=config, + datastore=datastore, + graph_name=model_args.graph, + hidden_dim=model_args.hidden_dim, + hidden_layers=model_args.hidden_layers, + processor_layers=model_args.processor_layers, + mesh_aggr=model_args.mesh_aggr, + num_past_forcing_steps=model_args.num_past_forcing_steps, + num_future_forcing_steps=model_args.num_future_forcing_steps, + output_std=model_args.output_std, ) features = datastore.get_vars_names(category="state") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece09..e4f3ad113 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,7 +12,7 @@ from neural_lam.create_graph import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataset from tests.conftest import init_datastore_example from tests.dummy_datastore import DummyDatastore @@ -182,6 +182,10 @@ class ModelArgs: hidden_layers = 1 processor_layers = 2 mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + var_leads_metrics_watch = {} num_past_forcing_steps = 1 num_future_forcing_steps = 1 @@ -212,13 +216,42 @@ def _create_graph(): dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) - model = GraphLAM(args=args, datastore=datastore, config=config) # noqa + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + ) + forecaster = ARForecaster(predictor, datastore=datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss=args.loss, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + lr=args.lr, + ) model_device = model.to(device_name) data_loader = DataLoader(dataset, batch_size=2) batch = next(iter(data_loader)) batch_device = [part.to(device_name) for part in batch] - model_device.common_step(batch_device) model_device.training_step(batch_device) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 93a7a55f4..db0597aec 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -84,7 +84,7 @@ def test_graph_creation(datastore_name, graph_name): # try to load each and ensure they have the right shape for file_name in required_graph_files: file_id = Path(file_name).stem # remove the extension - result = torch.load(graph_dir_path / file_name) + result = torch.load(graph_dir_path / file_name, weights_only=True) if file_id.startswith("g2m") or file_id.startswith("m2g"): assert isinstance(result, torch.Tensor) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2ae332172..c4549566e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -12,7 +12,7 @@ from neural_lam import config as nlconfig from neural_lam import vis from neural_lam.create_graph import create_graph_from_datastore -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataset from tests.dummy_datastore import DummyDatastore @@ -66,10 +66,37 @@ class ModelArgs: ) # Create model - model = GraphLAM( - args=ModelArgs(), + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + args = ModelArgs() + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + ) + forecaster = ARForecaster(predictor, datastore=datastore) + + model = ForecasterModule( + forecaster=forecaster, config=config, datastore=datastore, + loss=args.loss, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + lr=args.lr, ) # Create dataset to get a sample batch @@ -116,11 +143,26 @@ def test_plot_examples_integration_saves_figure( ), f"Expected time_step_unit={time_unit}, got {model.time_step_unit}" # Generate prediction - prediction, target, _, _ = model.common_step(batch) + (init_states, target, forcing_features, _batch_times) = batch + prediction, _ = model.forecaster( + init_states, forcing_features, target + ) # Rescale to original data scale - prediction_rescaled = prediction * model.state_std + model.state_mean - target_rescaled = target * model.state_std + model.state_mean + da_state_stats = datastore.get_standardization_dataarray("state") + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=prediction.device, + ) + state_mean = torch.tensor( + da_state_stats.state_mean.values, + dtype=torch.float32, + device=prediction.device, + ) + + prediction_rescaled = prediction * state_std + state_mean + target_rescaled = target * state_std + state_mean # Get first example pred_slice = prediction_rescaled[0].detach() # Detach from graph diff --git a/tests/test_prediction_model_classes.py b/tests/test_prediction_model_classes.py new file mode 100644 index 000000000..4528f6f7c --- /dev/null +++ b/tests/test_prediction_model_classes.py @@ -0,0 +1,377 @@ +# Third-party +import pytorch_lightning as pl +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.models.ar_forecaster import ARForecaster +from neural_lam.models.forecaster_module import ForecasterModule +from neural_lam.models.step_predictor import StepPredictor +from tests.conftest import init_datastore_example +from tests.dummy_datastore import DummyDatastore + + +class NoStaticDummyDatastore(DummyDatastore): + """DummyDatastore variant that returns None for static features.""" + + def get_dataarray(self, category, split, standardize=False): + if category == "static": + return None + return super().get_dataarray(category, split, standardize=standardize) + + +class MockStepPredictor(StepPredictor): + def __init__(self, config, datastore, **kwargs): + super().__init__(config, datastore, **kwargs) + + def forward(self, prev_state, prev_prev_state, forcing): + # Return zeros for state + # The true state will be mixed in at boundaries + pred_state = torch.zeros_like(prev_state) + pred_std = torch.zeros_like(prev_state) if self.output_std else None + return pred_state, pred_std + + +def test_ar_forecaster_unroll(): + datastore = init_datastore_example("mdp") + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + predictor = MockStepPredictor( + config=config, + datastore=datastore, + output_std=False, + ) + + forecaster = ARForecaster(predictor, datastore) + + # Override masks to test boundary masking behaviour + forecaster.interior_mask = torch.zeros_like(forecaster.interior_mask) + forecaster.interior_mask[0, 0] = 1 # One node is interior + forecaster.boundary_mask = 1 - forecaster.interior_mask + + B, num_grid_nodes = 2, predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + pred_steps = 3 + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, pred_steps, num_grid_nodes, d_forcing) + true_states = torch.ones(B, pred_steps, num_grid_nodes, d_state) * 5.0 + + prediction, pred_std = forecaster( + init_states, forcing_features, true_states + ) + + assert prediction.shape == (B, pred_steps, num_grid_nodes, d_state) + + # Boundary (where interior_mask == 0) should equal true_state (5.0) + # Interior (where interior_mask == 1) should equal predictor output (0.0) + assert torch.all(prediction[:, :, 0, :] == 0.0) + assert torch.all(prediction[:, :, 1:, :] == 5.0) + + +def test_forecaster_module_checkpoint(tmp_path): + datastore = init_datastore_example("mdp") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + # Build predictor and forecaster externally, then inject into + # ForecasterModule + # First-party + from neural_lam.models import MODELS + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss="mse", + lr=1e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1], + metrics_watch=[], + ) + + ckpt_path = tmp_path / "test.ckpt" + trainer = pl.Trainer( + max_epochs=1, + accelerator="cpu", + logger=False, + enable_checkpointing=False, + ) + trainer.strategy.connect(model) + trainer.save_checkpoint(ckpt_path) + + # Build a fresh forecaster structure for loading weights into + load_predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + load_forecaster = ARForecaster(load_predictor, datastore) + + # Load from checkpoint + loaded_model = ForecasterModule.load_from_checkpoint( + ckpt_path, + datastore=datastore, + forecaster=load_forecaster, + weights_only=False, + ) + + # Validate the correct internal hierarchy has been constructed + assert loaded_model.forecaster.predictor.__class__.__name__ == "GraphLAM" + + # Verify that outputs match (checkpoint successfully restored weights) + B, num_grid_nodes = 2, model.forecaster.predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.ones(B, 1, num_grid_nodes, d_state) * 5.0 + + with torch.no_grad(): + out_before = model.forecaster( + init_states, forcing_features, boundary_states + ) + out_after = loaded_model.forecaster( + init_states, forcing_features, boundary_states + ) + + assert torch.allclose(out_before[0], out_after[0]) + + +def test_forecaster_module_old_checkpoint(tmp_path): + datastore = init_datastore_example("mdp") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + # First-party + from neural_lam.models import MODELS + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss="mse", + lr=1e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1], + metrics_watch=[], + ) + + ckpt_path = tmp_path / "test_old.ckpt" + trainer = pl.Trainer( + max_epochs=1, + accelerator="cpu", + logger=False, + enable_checkpointing=False, + ) + trainer.strategy.connect(model) + trainer.save_checkpoint(ckpt_path) + + # Manually hack the checkpoint to emulate a pre-refactor state dict + ckpt = torch.load(ckpt_path, weights_only=False) + old_state_dict = {} + for k, v in ckpt["state_dict"].items(): + if k.startswith("forecaster.predictor."): + # Revert structural rename to emulate old flat keys + new_k = k.replace("forecaster.predictor.", "") + if "encoding_grid_mlp" in new_k: + new_k = new_k.replace("encoding_grid_mlp", "g2m_gnn.grid_mlp") + old_state_dict[new_k] = v + else: + old_state_dict[k] = v + + ckpt["state_dict"] = old_state_dict + torch.save(ckpt, ckpt_path) + + # Build a fresh forecaster structure for loading weights into + load_predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + load_forecaster = ARForecaster(load_predictor, datastore) + + # Load from hacked old checkpoint + loaded_model = ForecasterModule.load_from_checkpoint( + ckpt_path, + datastore=datastore, + forecaster=load_forecaster, + weights_only=False, + ) + + # Validate the correct internal hierarchy has been constructed + assert loaded_model.forecaster.predictor.__class__.__name__ == "GraphLAM" + + # Verify that outputs match (checkpoint successfully restored weights) + B, num_grid_nodes = 2, model.forecaster.predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.ones(B, 1, num_grid_nodes, d_state) * 5.0 + + with torch.no_grad(): + out_before = model.forecaster( + init_states, forcing_features, boundary_states + ) + out_after = loaded_model.forecaster( + init_states, forcing_features, boundary_states + ) + + assert torch.allclose(out_before[0], out_after[0]) + + +def test_step_predictor_no_static_features(): + """Model should run correctly when the datastore has no static features, + using an empty (N, 0) tensor in place of static features.""" + datastore = NoStaticDummyDatastore() + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + predictor = MockStepPredictor( + config=config, + datastore=datastore, + output_std=False, + ) + + # Static features buffer should exist but be empty (zero width) + assert predictor.grid_static_features.shape == ( + datastore.num_grid_points, + 0, + ) + + # Verify a forward pass works end-to-end via ARForecaster + forecaster = ARForecaster(predictor, datastore) + B, num_grid_nodes = 2, predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + d_forcing = datastore.get_num_data_vars(category="forcing") + init_states = torch.zeros(B, 2, num_grid_nodes, d_state) + forcing_features = torch.zeros(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.zeros(B, 1, num_grid_nodes, d_state) + + prediction, pred_std = forecaster( + init_states, forcing_features, boundary_states + ) + assert prediction.shape == (B, 1, num_grid_nodes, d_state) + assert pred_std is None + + +def test_fold_unfold_equivalence(): + """Folding (S, B, ...) into (S*B, ...) before ARForecaster and keeping + prediction folded must match running each sample independently. + + This confirms ARForecaster's internal indexing is rank-transparent: + init_states[:, 0/1] selects along the conditioning-step axis (dim 1), + not the batch axis, so a folded leading dim is invisible to the rollout. + The test also confirms that no unfold is needed inside forecast_for_batch + for the existing logging/aggregation paths to remain correct. + """ + datastore = init_datastore_example("mdp") + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + predictor = MockStepPredictor(config=config, datastore=datastore) + forecaster = ARForecaster(predictor, datastore) + + S, B = 3, 2 + N = predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + d_forcing = datastore.get_num_data_vars(category="forcing") + pred_steps = 4 + + torch.manual_seed(42) + init_states = torch.randn(S, B, 2, N, d_state) + forcing = torch.randn(S, B, pred_steps, N, d_forcing) + boundary = torch.randn(S, B, pred_steps, N, d_state) + + # Run with (S, B) folded into one effective batch dim + with torch.no_grad(): + pred_folded, _ = forecaster( + init_states.flatten(0, 1), + forcing.flatten(0, 1), + boundary.flatten(0, 1), + ) + pred_folded = pred_folded.unflatten(0, (S, B)) + + # Run each sample independently and stack + with torch.no_grad(): + pred_explicit = torch.stack( + [forecaster(init_states[s], forcing[s], boundary[s])[0] + for s in range(S)] + ) + + assert torch.allclose(pred_folded, pred_explicit) diff --git a/tests/test_propagation_net.py b/tests/test_propagation_net.py new file mode 100644 index 000000000..da58287e9 --- /dev/null +++ b/tests/test_propagation_net.py @@ -0,0 +1,981 @@ +# Standard library +from pathlib import Path + +# Third-party +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.interaction_net import InteractionNet, PropagationNet +from neural_lam.models import MODELS +from neural_lam.models.ar_forecaster import ARForecaster +from tests.conftest import init_datastore_example + + +def _make_edge_index(n_send, n_rec, n_edges): + """Create a random edge_index for testing.""" + torch.manual_seed(0) + senders = torch.randint(0, n_send, (n_edges,)) + receivers = torch.randint(0, n_rec, (n_edges,)) + return torch.stack([senders, receivers]) + + +def _make_fully_connected_edge_index(n_send, n_rec): + """Create a fully-connected edge_index (every sender to every receiver).""" + senders = ( + torch.arange(n_send).unsqueeze(1).expand(n_send, n_rec).reshape(-1) + ) + receivers = ( + torch.arange(n_rec).unsqueeze(0).expand(n_send, n_rec).reshape(-1) + ) + return torch.stack([senders, receivers]) + + +def _build_model_and_data(datastore, config, model_name, graph_name, + vertical_propnets=False): + """Helper to build a model and matching random input tensors.""" + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + predictor = MODELS[model_name]( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=False, + vertical_propnets=vertical_propnets, + ) + forecaster = ARForecaster(predictor, datastore) + + B = 2 + num_grid_nodes = predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + + torch.manual_seed(123) + init_states = torch.randn(B, 2, num_grid_nodes, d_state) + forcing = torch.randn(B, 1, num_grid_nodes, d_forcing) + boundary = torch.randn(B, 1, num_grid_nodes, d_state) + + return forecaster, predictor, init_states, forcing, boundary + + +def _get_datastore_and_config(graph_name): + """Create a datastore with graph already built.""" + datastore = init_datastore_example("mdp") + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, + config_path=datastore.root_path, + ) + ) + + # Ensure graph exists + if graph_name == "hierarchical": + hierarchical = True + n_max_levels = 3 + elif graph_name == "multiscale": + hierarchical = False + n_max_levels = 3 + else: + hierarchical = False + n_max_levels = 1 + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + hierarchical=hierarchical, + n_max_levels=n_max_levels, + ) + + return datastore, config + + +# +# Section A: Structural Tests +# + + +class TestPropagationNetStructure: + """Tests for PropagationNet class structure and constructor behavior.""" + + def test_propagation_net_is_subclass(self): + """PropagationNet should be a subclass of InteractionNet.""" + assert issubclass(PropagationNet, InteractionNet) + + def test_forced_mean_aggregation(self): + """PropagationNet should always use mean aggregation, + regardless of what is passed.""" + edge_index = _make_edge_index(5, 4, 10) + net = PropagationNet(edge_index, input_dim=8, aggr="sum") + assert net.aggr == "mean" + + def test_interaction_net_respects_aggr(self): + """InteractionNet should use whatever aggregation is passed.""" + edge_index = _make_edge_index(5, 4, 10) + net_sum = InteractionNet(edge_index.clone(), input_dim=8, aggr="sum") + net_mean = InteractionNet( + edge_index.clone(), input_dim=8, aggr="mean" + ) + assert net_sum.aggr == "sum" + assert net_mean.aggr == "mean" + + def test_mlp_input_dimensions(self): + """Edge MLP should accept 3*input_dim, aggr MLP should accept + 2*input_dim.""" + d_h = 16 + edge_index = _make_edge_index(5, 4, 10) + pnet = PropagationNet(edge_index, input_dim=d_h) + + # Edge MLP: first layer input should be 3 * d_h + edge_mlp_first = pnet.edge_mlp[0] + assert edge_mlp_first.in_features == 3 * d_h + + # Aggr MLP: first layer input should be 2 * d_h + aggr_mlp_first = pnet.aggr_mlp[0] + assert aggr_mlp_first.in_features == 2 * d_h + + def test_node_index_offset_convention(self): + """After construction, sender indices in edge_index should be + offset by num_rec so that receivers are [0, num_rec) and senders + are [num_rec, num_rec + num_send).""" + n_send, n_rec, n_edges = 5, 4, 10 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + pnet = PropagationNet(edge_index, input_dim=8) + + stored_ei = pnet.edge_index + # Receiver indices should be in [0, num_rec) + assert stored_ei[1].min() >= 0 + assert stored_ei[1].max() < pnet.num_rec + # Sender indices should be in [num_rec, num_rec + num_send) + assert stored_ei[0].min() >= pnet.num_rec + + +# +# Section B: Forward Pass Correctness +# + + +class TestPropagationNetForwardPass: + """Tests for PropagationNet forward pass mechanics.""" + + def test_output_shapes_match_interaction_net(self): + """PropagationNet output shapes should match InteractionNet + for both update_edges=True and update_edges=False.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + for update_edges in [True, False]: + inet = InteractionNet( + edge_index.clone(), + input_dim=d_h, + update_edges=update_edges, + ) + pnet = PropagationNet( + edge_index.clone(), + input_dim=d_h, + update_edges=update_edges, + ) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + i_out = inet(send_rep, rec_rep, edge_rep) + p_out = pnet(send_rep, rec_rep, edge_rep) + + if update_edges: + assert isinstance(i_out, tuple) and len(i_out) == 2 + assert isinstance(p_out, tuple) and len(p_out) == 2 + assert i_out[0].shape == p_out[0].shape == (n_rec, d_h) + assert i_out[1].shape == p_out[1].shape == (n_edges, d_h) + else: + assert i_out.shape == p_out.shape == (n_rec, d_h) + + def test_sender_residual_in_message(self): + """PropagationNet message should be x_j + edge_mlp(...), verified + by zeroing edge_mlp weights so message reduces to x_j.""" + n_send, n_rec, d_h = 3, 2, 4 + # Fully connected: every sender connects to every receiver + edge_index = _make_fully_connected_edge_index(n_send, n_rec) + n_edges = edge_index.shape[1] + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=True + ) + + # Zero out edge MLP so edge_mlp(...) = 0 + # Then message = x_j + 0 = x_j + with torch.no_grad(): + for param in pnet.edge_mlp.parameters(): + param.zero_() + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + # Call propagate directly to get raw messages + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + edge_rep_aggr, edge_diff = pnet.propagate( + pnet.edge_index, x=node_reps, edge_attr=edge_rep + ) + + # With zeroed edge_mlp, messages = x_j (sender reps) + # edge_diff should equal the sender reps for each edge + # Each edge_diff[i] should be send_rep[sender_of_edge_i] + sender_indices = pnet.edge_index[0] - pnet.num_rec + expected_messages = send_rep[sender_indices] + assert torch.allclose(edge_diff, expected_messages, atol=1e-6) + + def test_receiver_residual_targets_aggregated_messages(self): + """PropagationNet receiver update should be: + rec_new = agg_msgs + aggr_mlp(cat(rec_old, agg_msgs)) + NOT rec_new = rec_old + aggr_mlp(...)""" + n_send, n_rec, n_edges, d_h = 3, 2, 6, 4 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + # Zero out aggr_mlp so aggr_mlp(...) = 0 + # Then rec_new = agg_msgs + 0 = agg_msgs (for PropagationNet) + # But for InteractionNet it would be rec_old + 0 = rec_old + with torch.no_grad(): + for param in pnet.aggr_mlp.parameters(): + param.zero_() + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + rec_out = pnet(send_rep, rec_rep, edge_rep) + + # With zeroed aggr_mlp, output should NOT equal rec_rep + # (it would equal rec_rep if residual targeted rec_rep like INet) + assert not torch.allclose(rec_out, rec_rep, atol=1e-6) + + # Verify by also computing what agg_msgs would be: + # Run propagate to get edge_rep_aggr + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + edge_rep_aggr, _ = pnet.propagate( + pnet.edge_index, x=node_reps, edge_attr=edge_rep + ) + + # rec_out should equal edge_rep_aggr (since aggr_mlp output = 0) + assert torch.allclose(rec_out, edge_rep_aggr, atol=1e-6) + + def test_numerical_divergence_from_interaction_net(self): + """PropagationNet should produce different outputs than InteractionNet + given the same weights.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + inet = InteractionNet( + edge_index.clone(), input_dim=d_h, update_edges=True + ) + pnet = PropagationNet( + edge_index.clone(), input_dim=d_h, update_edges=True + ) + pnet.load_state_dict(inet.state_dict()) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + i_rec, i_edge = inet(send_rep, rec_rep, edge_rep) + p_rec, p_edge = pnet(send_rep, rec_rep, edge_rep) + + assert not torch.allclose(i_rec, p_rec) + + +# +# Section C: Edge Update Behavior +# + + +class TestEdgeUpdateBehavior: + """Tests for edge update mechanics.""" + + def test_update_edges_true_returns_tuple(self): + """update_edges=True should return (rec_rep, edge_rep) tuple.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=True + ) + + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + result = pnet(send_rep, rec_rep, edge_rep) + assert isinstance(result, tuple) and len(result) == 2 + assert result[0].shape == (n_rec, d_h) + assert result[1].shape == (n_edges, d_h) + + def test_update_edges_false_returns_tensor(self): + """update_edges=False should return only rec_rep tensor.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + result = pnet(send_rep, rec_rep, edge_rep) + assert isinstance(result, torch.Tensor) + assert result.shape == (n_rec, d_h) + + def test_edge_residual_connection(self): + """Edge update should use edge_rep + edge_diff residual, + same as InteractionNet.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=True + ) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + _, edge_out = pnet(send_rep, rec_rep, edge_rep) + + # Edge output should differ from original (the MLP adds something) + assert not torch.allclose(edge_out, edge_rep) + # But the difference should be the MLP output (residual structure) + # Verify: edge_diff = edge_out - edge_rep should be the raw message + # from propagate. Recompute to verify. + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + _, edge_diff = pnet.propagate( + pnet.edge_index, x=node_reps, edge_attr=edge_rep + ) + expected_edge_out = edge_rep + edge_diff + assert torch.allclose(edge_out, expected_edge_out, atol=1e-5) + + +# +# Section D: Batched Processing +# + + +class TestBatchedProcessing: + """Tests for batch dimension handling.""" + + def test_output_shapes_batched(self): + """PropagationNet should work with batched inputs (B, N, d_h).""" + n_send, n_rec, n_edges, d_h, B = 5, 4, 10, 8, 3 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=True + ) + + torch.manual_seed(42) + send_rep = torch.randn(B, n_send, d_h) + rec_rep = torch.randn(B, n_rec, d_h) + edge_rep = torch.randn(B, n_edges, d_h) + + rec_out, edge_out = pnet(send_rep, rec_rep, edge_rep) + assert rec_out.shape == (B, n_rec, d_h) + assert edge_out.shape == (B, n_edges, d_h) + + def test_batch_independence(self): + """Different samples in a batch should not influence each other.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + torch.manual_seed(42) + send_rep_0 = torch.randn(1, n_send, d_h) + rec_rep_0 = torch.randn(1, n_rec, d_h) + edge_rep_0 = torch.randn(1, n_edges, d_h) + + torch.manual_seed(99) + send_rep_1 = torch.randn(1, n_send, d_h) + rec_rep_1 = torch.randn(1, n_rec, d_h) + edge_rep_1 = torch.randn(1, n_edges, d_h) + + # Run individually + out_0 = pnet(send_rep_0, rec_rep_0, edge_rep_0) + out_1 = pnet(send_rep_1, rec_rep_1, edge_rep_1) + + # Run as batch + send_batch = torch.cat([send_rep_0, send_rep_1], dim=0) + rec_batch = torch.cat([rec_rep_0, rec_rep_1], dim=0) + edge_batch = torch.cat([edge_rep_0, edge_rep_1], dim=0) + out_batch = pnet(send_batch, rec_batch, edge_batch) + + assert torch.allclose(out_batch[0], out_0[0], atol=1e-6) + assert torch.allclose(out_batch[1], out_1[0], atol=1e-6) + + +# +# Section E: Chunk/Split MLP Support +# + + +class TestChunkSupport: + """Tests for edge_chunk_sizes and aggr_chunk_sizes.""" + + def test_edge_and_aggr_chunk_sizes(self): + """PropagationNet should work with edge_chunk_sizes and + aggr_chunk_sizes, using separate SplitMLPs.""" + n_send, n_rec, d_h = 6, 4, 8 + n_edges_a, n_edges_b = 5, 7 + n_edges = n_edges_a + n_edges_b + + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, + input_dim=d_h, + update_edges=True, + edge_chunk_sizes=[n_edges_a, n_edges_b], + aggr_chunk_sizes=[n_rec // 2, n_rec - n_rec // 2], + ) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + rec_out, edge_out = pnet(send_rep, rec_rep, edge_rep) + assert rec_out.shape == (n_rec, d_h) + assert edge_out.shape == (n_edges, d_h) + + def test_chunked_differs_from_unchunked(self): + """Chunked MLPs should produce different outputs than a single MLP + (they use independent weights for each chunk).""" + n_send, n_rec, d_h = 6, 4, 8 + n_edges = 12 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet_plain = PropagationNet( + edge_index.clone(), input_dim=d_h, update_edges=False + ) + pnet_chunked = PropagationNet( + edge_index.clone(), + input_dim=d_h, + update_edges=False, + edge_chunk_sizes=[6, 6], + ) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + out_plain = pnet_plain(send_rep, rec_rep, edge_rep) + out_chunked = pnet_chunked(send_rep, rec_rep, edge_rep) + + # Different MLP architectures -> different results + assert not torch.allclose(out_plain, out_chunked) + + +# +# Section F: Gradient Flow +# + + +class TestGradientFlow: + """Tests for backpropagation through PropagationNet.""" + + def test_gradient_flow_all_inputs(self): + """Gradients should flow to send_rep, rec_rep, and edge_rep.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_send, d_h, requires_grad=True) + rec_rep = torch.randn(n_rec, d_h, requires_grad=True) + edge_rep = torch.randn(n_edges, d_h, requires_grad=True) + + out = pnet(send_rep, rec_rep, edge_rep) + loss = out.sum() + loss.backward() + + assert send_rep.grad is not None + assert rec_rep.grad is not None + assert edge_rep.grad is not None + + def test_gradient_through_sender_residual(self): + """Gradient should flow through both MLP path AND direct x_j + residual in the message function.""" + n_send, n_rec, n_edges, d_h = 3, 2, 6, 4 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_send, d_h, requires_grad=True) + rec_rep = torch.randn(n_rec, d_h, requires_grad=True) + edge_rep = torch.randn(n_edges, d_h, requires_grad=True) + + out = pnet(send_rep, rec_rep, edge_rep) + loss = out.sum() + loss.backward() + + # send_rep gradient should be non-zero (flows through x_j residual) + assert send_rep.grad is not None + assert send_rep.grad.abs().sum() > 0 + + # Zero out edge MLP and recheck: gradient should STILL flow via + # direct x_j path + with torch.no_grad(): + for param in pnet.edge_mlp.parameters(): + param.zero_() + + send_rep2 = send_rep.detach().clone().requires_grad_(True) + rec_rep2 = rec_rep.detach().clone() + edge_rep2 = edge_rep.detach().clone() + + out2 = pnet(send_rep2, rec_rep2, edge_rep2) + out2.sum().backward() + + assert send_rep2.grad is not None + assert send_rep2.grad.abs().sum() > 0 + + def test_gradient_through_edge_update(self): + """When update_edges=True, gradients should flow to edge outputs.""" + n_send, n_rec, n_edges, d_h = 5, 4, 10, 8 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=True + ) + + send_rep = torch.randn(n_send, d_h, requires_grad=True) + rec_rep = torch.randn(n_rec, d_h, requires_grad=True) + edge_rep = torch.randn(n_edges, d_h, requires_grad=True) + + rec_out, edge_out = pnet(send_rep, rec_rep, edge_rep) + + # Backprop through edge output only + edge_out.sum().backward() + + assert edge_rep.grad is not None + assert edge_rep.grad.abs().sum() > 0 + + +# +# Section G: Graph Structure Compatibility +# + + +class TestGraphStructureCompatibility: + """Tests for various graph topologies.""" + + def test_asymmetric_graph(self): + """Should handle graphs where n_send != n_rec (e.g. grid->mesh).""" + # Large ratio like grid(100) -> mesh(10) + n_send, n_rec, d_h = 100, 10, 8 + n_edges = 200 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + out = pnet(send_rep, rec_rep, edge_rep) + assert out.shape == (n_rec, d_h) + assert torch.isfinite(out).all() + + def test_single_sender_single_receiver(self): + """Degenerate graph with 1 node on each side.""" + d_h = 8 + edge_index = torch.tensor([[0], [0]]) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(1, d_h) + rec_rep = torch.randn(1, d_h) + edge_rep = torch.randn(1, d_h) + + out = pnet(send_rep, rec_rep, edge_rep) + assert out.shape == (1, d_h) + assert torch.isfinite(out).all() + + def test_disconnected_receiver(self): + """A receiver with no incoming edges should not produce NaN. + PyG fills aggregation with 0 for disconnected nodes. + For PropagationNet: rec_new = 0 + aggr_mlp(cat(rec_rep, 0)), + meaning the receiver loses its direct residual (unlike INet + which would give rec_rep + aggr_mlp(cat(rec_rep, 0))).""" + d_h = 4 + # Receivers 0 and 2 get edges, receiver 1 is disconnected + edge_index = torch.tensor([[0, 1], [0, 2]]) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + torch.manual_seed(42) + send_rep = torch.randn(2, d_h) + rec_rep = torch.randn(3, d_h) + edge_rep = torch.randn(2, d_h) + + out = pnet(send_rep, rec_rep, edge_rep) + + # No NaN for disconnected receiver + assert out.shape == (3, d_h) + assert torch.isfinite(out).all() + + # Disconnected receiver: agg_msgs = 0, so + # rec_new = 0 + aggr_mlp(cat(rec_rep[1], 0)) + zeros = torch.zeros(d_h) + expected = zeros + pnet.aggr_mlp( + torch.cat((rec_rep[1], zeros), dim=-1) + ) + assert torch.allclose(out[1], expected, atol=1e-6) + + def test_self_loops(self): + """Self-loops (where sender == receiver index in original graph) + should compute correctly.""" + d_h = 8 + n_nodes = 4 + # Self-loop edges: each node connects to itself + indices = torch.arange(n_nodes) + edge_index = torch.stack([indices, indices]) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_nodes, d_h) + rec_rep = torch.randn(n_nodes, d_h) + edge_rep = torch.randn(n_nodes, d_h) + + out = pnet(send_rep, rec_rep, edge_rep) + assert out.shape == (n_nodes, d_h) + assert torch.isfinite(out).all() + + +# +# Section H: Numerical Stability +# + + +class TestNumericalStability: + """Tests for numerical stability under stress conditions.""" + + def test_deep_stacking(self): + """Multiple PropagationNet layers in sequence should not cause + numerical blow-up.""" + n_send, n_rec, n_edges, d_h = 10, 10, 30, 16 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + # Stack 8 layers + layers = [] + for _ in range(8): + layers.append( + PropagationNet( + edge_index.clone(), + input_dim=d_h, + update_edges=True, + ) + ) + + torch.manual_seed(42) + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + current_rec = rec_rep + current_edge = edge_rep + for layer in layers: + current_rec, current_edge = layer( + send_rep, current_rec, current_edge + ) + + assert torch.isfinite(current_rec).all(), ( + "Receiver reps contain non-finite values after deep stacking" + ) + assert torch.isfinite(current_edge).all(), ( + "Edge reps contain non-finite values after deep stacking" + ) + + def test_high_degree_stability(self): + """With many incoming edges per receiver, mean aggregation should + keep outputs stable.""" + n_send, n_rec, d_h = 50, 3, 8 + # Many edges to few receivers + n_edges = 500 + edge_index = _make_edge_index(n_send, n_rec, n_edges) + + pnet = PropagationNet( + edge_index, input_dim=d_h, update_edges=False + ) + + send_rep = torch.randn(n_send, d_h) + rec_rep = torch.randn(n_rec, d_h) + edge_rep = torch.randn(n_edges, d_h) + + out = pnet(send_rep, rec_rep, edge_rep) + assert torch.isfinite(out).all() + # Mean aggregation should keep magnitude reasonable + assert out.abs().max() < 1000 + + +# +# Section I: Model-Level Integration (Deterministic Models) +# + + +class TestDefaultBehaviorUnchanged: + """Tests that without vertical_propnets, models use InteractionNet + (backward compatibility). All tests use deterministic models only.""" + + def test_base_graph_model_default_uses_interaction_net(self): + """BaseGraphModel should use InteractionNet for g2m/m2g by default.""" + datastore, config = _get_datastore_and_config("1level") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "graph_lam", "1level" + ) + + assert isinstance(predictor.g2m_gnn, InteractionNet) + assert not isinstance(predictor.g2m_gnn, PropagationNet) + assert isinstance(predictor.m2g_gnn, InteractionNet) + assert not isinstance(predictor.m2g_gnn, PropagationNet) + + def test_base_graph_model_propnet_flag(self): + """With vertical_propnets=True, g2m/m2g should be PropagationNet.""" + datastore, config = _get_datastore_and_config("1level") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "graph_lam", "1level", + vertical_propnets=True, + ) + + assert isinstance(predictor.g2m_gnn, PropagationNet) + assert isinstance(predictor.m2g_gnn, PropagationNet) + + def test_graph_lam_processor_always_interaction_net(self): + """GraphLAM processor GNNs should always be InteractionNet, + even with vertical_propnets=True.""" + datastore, config = _get_datastore_and_config("1level") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "graph_lam", "1level", + vertical_propnets=True, + ) + + # Check processor GNNs are InteractionNet (not PropagationNet) + for module in predictor.processor.modules(): + if isinstance(module, InteractionNet): + assert not isinstance(module, PropagationNet) + + def test_default_forward_pass_unchanged(self): + """A forward pass with default settings (no vertical_propnets) + should produce the same output as before the PropagationNet + addition, verified by deterministic seeding.""" + datastore, config = _get_datastore_and_config("1level") + + torch.manual_seed(42) + forecaster_a, _, init_states, forcing, boundary = ( + _build_model_and_data( + datastore, config, "graph_lam", "1level" + ) + ) + + torch.manual_seed(42) + forecaster_b, _, _, _, _ = _build_model_and_data( + datastore, config, "graph_lam", "1level" + ) + + with torch.no_grad(): + out_a, _ = forecaster_a(init_states, forcing, boundary) + out_b, _ = forecaster_b(init_states, forcing, boundary) + + assert torch.allclose(out_a, out_b) + + def test_propnet_forward_pass_differs(self): + """A forward pass with vertical_propnets=True should produce + different outputs than the default (InteractionNet).""" + datastore, config = _get_datastore_and_config("1level") + + torch.manual_seed(42) + forecaster_default, _, init_states, forcing, boundary = ( + _build_model_and_data( + datastore, config, "graph_lam", "1level" + ) + ) + + torch.manual_seed(42) + forecaster_prop, _, _, _, _ = _build_model_and_data( + datastore, config, "graph_lam", "1level", + vertical_propnets=True, + ) + + with torch.no_grad(): + out_default, _ = forecaster_default( + init_states, forcing, boundary + ) + out_prop, _ = forecaster_prop(init_states, forcing, boundary) + + assert not torch.allclose(out_default, out_prop) + + +# +# Section J: Hierarchical Model Integration +# + + +class TestHierarchicalIntegration: + """Tests for PropagationNet in hierarchical deterministic models.""" + + def test_hilam_default_uses_interaction_net(self): + """HiLAM should use InteractionNet for all GNNs by default.""" + datastore, config = _get_datastore_and_config("hierarchical") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "hi_lam", "hierarchical" + ) + + # g2m and m2g should be InteractionNet + assert isinstance(predictor.g2m_gnn, InteractionNet) + assert not isinstance(predictor.g2m_gnn, PropagationNet) + assert isinstance(predictor.m2g_gnn, InteractionNet) + assert not isinstance(predictor.m2g_gnn, PropagationNet) + + # mesh_init_gnns should all be InteractionNet + for gnn in predictor.mesh_init_gnns: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) + + # mesh_up_gnns (nested) should all be InteractionNet + for up_gnn_list in predictor.mesh_up_gnns: + for gnn in up_gnn_list: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) + + def test_hilam_propnet_flag_affects_vertical_gnns(self): + """With vertical_propnets=True, HiLAM should use PropagationNet + for mesh_init_gnns and mesh_up_gnns, but InteractionNet for + mesh_down_gnns and same-level GNNs.""" + datastore, config = _get_datastore_and_config("hierarchical") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "hi_lam", "hierarchical", + vertical_propnets=True, + ) + + # g2m and m2g should be PropagationNet + assert isinstance(predictor.g2m_gnn, PropagationNet) + assert isinstance(predictor.m2g_gnn, PropagationNet) + + # mesh_init_gnns should be PropagationNet + for gnn in predictor.mesh_init_gnns: + assert isinstance(gnn, PropagationNet) + + # mesh_up_gnns should be PropagationNet + for up_gnn_list in predictor.mesh_up_gnns: + for gnn in up_gnn_list: + assert isinstance(gnn, PropagationNet) + + # mesh_down_gnns should ALWAYS be InteractionNet (not PropagationNet) + for down_gnn_list in predictor.mesh_down_gnns: + for gnn in down_gnn_list: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) + + # same-level GNNs should ALWAYS be InteractionNet + for same_gnn_list in predictor.mesh_down_same_gnns: + for gnn in same_gnn_list: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) + for same_gnn_list in predictor.mesh_up_same_gnns: + for gnn in same_gnn_list: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) + + def test_hilam_propnet_forward_pass_differs(self): + """HiLAM with vertical_propnets=True should produce different + outputs than the default.""" + datastore, config = _get_datastore_and_config("hierarchical") + + torch.manual_seed(42) + forecaster_default, _, init_states, forcing, boundary = ( + _build_model_and_data( + datastore, config, "hi_lam", "hierarchical" + ) + ) + + torch.manual_seed(42) + forecaster_prop, _, _, _, _ = _build_model_and_data( + datastore, config, "hi_lam", "hierarchical", + vertical_propnets=True, + ) + + with torch.no_grad(): + out_default, _ = forecaster_default( + init_states, forcing, boundary + ) + out_prop, _ = forecaster_prop(init_states, forcing, boundary) + + assert not torch.allclose(out_default, out_prop) + + def test_hilam_parallel_propnet_flag(self): + """HiLAMParallel should also support vertical_propnets flag.""" + datastore, config = _get_datastore_and_config("hierarchical") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "hi_lam_parallel", "hierarchical", + vertical_propnets=True, + ) + + # g2m and m2g should be PropagationNet + assert isinstance(predictor.g2m_gnn, PropagationNet) + assert isinstance(predictor.m2g_gnn, PropagationNet) + + # mesh_init_gnns should be PropagationNet + for gnn in predictor.mesh_init_gnns: + assert isinstance(gnn, PropagationNet) + + def test_hilam_read_gnns_always_interaction_net(self): + """mesh_read_gnns should always be InteractionNet, even with + vertical_propnets=True.""" + datastore, config = _get_datastore_and_config("hierarchical") + + _, predictor, _, _, _ = _build_model_and_data( + datastore, config, "hi_lam", "hierarchical", + vertical_propnets=True, + ) + + for gnn in predictor.mesh_read_gnns: + assert isinstance(gnn, InteractionNet) + assert not isinstance(gnn, PropagationNet) diff --git a/tests/test_training.py b/tests/test_training.py index e8c131572..b6c3fed92 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -12,10 +12,19 @@ from neural_lam.create_graph import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataModule from tests.conftest import init_datastore_example +# Model architecture defaults for tests +GRAPH = "1level" +HIDDEN_DIM = 4 +HIDDEN_LAYERS = 1 +PROCESSOR_LAYERS = 2 +MESH_AGGR = "sum" +NUM_PAST_FORCING_STEPS = 1 +NUM_FUTURE_FORCING_STEPS = 1 + def run_simple_training(datastore, set_output_std): """ @@ -41,10 +50,7 @@ def run_simple_training(datastore, set_output_std): max_epochs=1, deterministic=True, accelerator=device_name, - # XXX: `devices` has to be set to 2 otherwise - # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails - # because it expects to aggregate over multiple devices - devices=2, + devices=1, log_every_n_steps=1, # use `detect_anomaly` to ensure that we don't have NaNs popping up # during training @@ -73,39 +79,50 @@ def run_simple_training(datastore, set_output_std): num_future_forcing_steps=1, ) - class ModelArgs: - output_std = set_output_std - loss = "mse" - restore_opt = False - n_example_pred = 1 - # XXX: this should be superfluous when we have already defined the - # model object no? - graph = graph_name - hidden_dim = 4 - hidden_layers = 1 - processor_layers = 2 - mesh_aggr = "sum" - lr = 1.0e-3 - val_steps_to_log = [1, 3] - metrics_watch = [] - num_past_forcing_steps = 1 - num_future_forcing_steps = 1 - - model_args = ModelArgs() - config = nlconfig.NeuralLAMConfig( datastore=nlconfig.DatastoreSelection( kind=datastore.SHORT_NAME, config_path=datastore.root_path ) ) - model = GraphLAM( # noqa - args=model_args, + # Build predictor and forecaster externally, then inject into + # ForecasterModule + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, datastore=datastore, + graph_name=graph_name, + hidden_dim=HIDDEN_DIM, + hidden_layers=HIDDEN_LAYERS, + processor_layers=PROCESSOR_LAYERS, + mesh_aggr=MESH_AGGR, + num_past_forcing_steps=NUM_PAST_FORCING_STEPS, + num_future_forcing_steps=NUM_FUTURE_FORCING_STEPS, + output_std=set_output_std, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, config=config, + datastore=datastore, + loss="mse", + lr=1.0e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1, 3], + metrics_watch=[], + var_leads_metrics_watch={}, ) wandb.init() - trainer.fit(model=model, datamodule=data_module) + try: + trainer.fit(model=model, datamodule=data_module) + finally: + wandb.finish() @pytest.mark.parametrize("datastore_name", DATASTORES.keys())