diff --git a/neural_lam/config.py b/neural_lam/config.py index f4195ec36..84a41df84 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -103,6 +103,9 @@ class TrainingConfig: default_factory=OutputClamping ) + output_mode: str = "deterministic" + ensemble_size: int = 1 + @dataclasses.dataclass class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a411a3afc..fed218c8f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -38,6 +38,12 @@ def __init__( ): super().__init__() self.save_hyperparameters(ignore=["datastore"]) + self.output_mode = getattr(config.training, "output_mode", "deterministic") + self.ensemble_size = getattr(config.training, "ensemble_size", 1) + if self.output_mode not in ["deterministic", "ensemble"]: + raise ValueError( + f"Unsupported output_mode: {self.output_mode}" + ) self.args = args self._datastore = datastore num_state_vars = datastore.get_num_data_vars(category="state") @@ -93,6 +99,13 @@ def __init__( # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) + if self.output_std and args.loss in ("mse", "mae", "wmse", "wmae"): + raise ValueError( + f"Prediction of standard deviation (--output_std) is " + f"incompatible with loss function '{args.loss}' as it " + f"leads to degenerate learned-variance training. Use " + f"'nll' or 'crps_gauss' instead." + ) if self.output_std: # Pred. dim. in grid cell self.grid_output_dim = 2 * num_state_vars @@ -160,6 +173,46 @@ def __init__( self._datastore.step_length ) + @property + def is_ensemble(self): + return self.output_mode == "ensemble" and self.ensemble_size > 1 + + def forward(self, preds, pred_std=None): + """ + Convert deterministic predictions to ensemble samples when requested. + + Parameters + ---------- + preds : torch.Tensor + Deterministic prediction tensor. + pred_std : torch.Tensor, optional + Predicted standard deviation tensor used to scale ensemble noise. + """ + if self.is_ensemble: + preds = self._sample_ensemble(preds, pred_std=pred_std) + return preds + + def _sample_ensemble(self, preds, pred_std=None): + if pred_std is not None: + noise = torch.randn( + self.ensemble_size, + *preds.shape, + device=preds.device, + dtype=preds.dtype, + ) * pred_std.unsqueeze(0) + else: + noise = torch.randn( + self.ensemble_size, + *preds.shape, + device=preds.device, + dtype=preds.dtype, + ) * 0.01 + preds = preds.unsqueeze(0).repeat( + self.ensemble_size, *[1] * preds.ndim + ) + preds = preds + noise + return preds + def _create_dataarray_from_tensor( self, tensor: torch.Tensor, @@ -325,11 +378,19 @@ def all_gather_cat(self, tensor_to_gather): tensor_to_gather: (d1, d2, ...), distributed over K ranks - returns: - - single-device strategies: (d1, d2, ...) - - multi-device strategies: (K*d1, d2, ...) + returns: (K*d1, d2, ...) """ - gathered = self.all_gather(tensor_to_gather) + trainer = getattr(self, "_trainer", None) + if trainer is not None and getattr(trainer, "world_size", 1) == 1: + return tensor_to_gather + + try: + gathered = self.all_gather(tensor_to_gather) + except RuntimeError: + # Lightning modules without an attached trainer cannot call + # `all_gather`; in that case there is nothing to aggregate. + return tensor_to_gather + # all_gather adds a leading dim (K,) only on multi-device runs; # on single-device it returns the tensor unchanged. if gathered.dim() > tensor_to_gather.dim(): @@ -610,6 +671,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): Return: log_dict: dict with everything to log for given metric """ log_dict = {} + scalar_log_dict = {} metric_fig = vis.plot_error_map( errors=metric_tensor, datastore=self._datastore, @@ -636,9 +698,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): var_name = var_names[var_i] for step in timesteps: key = f"{full_log_name}_{var_name}_step_{step}" - log_dict[key] = metric_tensor[step - 1, var_i] + scalar_log_dict[key] = metric_tensor[step - 1, var_i] - return log_dict + return log_dict, scalar_log_dict def aggregate_and_plot_metrics(self, metrics_dict, prefix): """ @@ -649,6 +711,7 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): prefix: string, prefix to use for logging """ log_dict = {} + scalar_log_dict = {} for metric_name, metric_val_list in metrics_dict.items(): metric_tensor = self.all_gather_cat( torch.cat(metric_val_list, dim=0) @@ -666,11 +729,11 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): # NOTE: we here assume rescaling for all metrics is linear metric_rescaled = metric_tensor_averaged * self.state_std # (pred_steps, d_f) - log_dict.update( - self.create_metric_log_dict( - metric_rescaled, prefix, metric_name - ) + figure_logs, metric_scalars = self.create_metric_log_dict( + metric_rescaled, prefix, metric_name ) + log_dict.update(figure_logs) + scalar_log_dict.update(metric_scalars) # Ensure that log_dict has structure for # logging as dict(str, plt.Figure) @@ -683,6 +746,14 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): current_epoch = self.trainer.current_epoch + if scalar_log_dict: + self.log_dict( + scalar_log_dict, + on_step=False, + on_epoch=True, + sync_dist=False, + ) + for key, figure in log_dict.items(): # For other loggers than wandb, add epoch to key. # Wandb can log multiple images to the same key, while other diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fd38a2e67..9e591bc91 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -351,7 +351,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing): # NOTE: The predicted std. is not scaled in any way here # linter for some reason does not think softplus is callable # pylint: disable-next=not-callable - pred_std = torch.nn.functional.softplus(pred_std_raw) + pred_std = torch.clamp(torch.nn.functional.softplus(pred_std_raw), min=1e-6) else: pred_delta_mean = net_output pred_std = None diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5547fdd4e..4b278ac90 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -607,10 +607,12 @@ def _is_listlike(obj): da_datastore_state = getattr(self, f"da_{category}") da_grid_index = da_datastore_state.grid_index - da_state_feature = da_datastore_state.state_feature + da_feature = da_datastore_state[f"{category}_feature"] + if tensor.shape[-1] != da_feature.size: + da_feature = np.arange(tensor.shape[-1]) coords = { - f"{category}_feature": da_state_feature, + f"{category}_feature": da_feature, "grid_index": da_grid_index, } if add_time_as_dim: diff --git a/tests/test_ar_model.py b/tests/test_ar_model.py new file mode 100644 index 000000000..e5a7d3d1c --- /dev/null +++ b/tests/test_ar_model.py @@ -0,0 +1,100 @@ +# Third-party +import torch +import pytest + +# First-party +from neural_lam import config as nlconfig +from neural_lam.models.ar_model import ARModel +from tests.dummy_datastore import DummyDatastore + + +class DummyArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 0 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + var_leads_metrics_watch = {} + + +class DummyARModel(ARModel): + def predict_step(self, prev_state, prev_prev_state, forcing): + del prev_prev_state, forcing + return prev_state, None + + +def test_ar_model_initializes_core_training_state(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path="" + ) + ) + + model = DummyARModel(args=DummyArgs(), config=config, datastore=datastore) + + assert model.grid_static_features.shape == (16, 1) + assert model.num_grid_nodes == 16 + assert model.grid_dim == 17 + assert model.grid_output_dim == datastore.get_num_data_vars("state") + assert model.boundary_mask.shape == (16, 1) + assert model.interior_mask.shape == (16, 1) + assert model.feature_weights.shape == (datastore.get_num_data_vars("state"),) + + +def test_ar_model_forward_samples_ensemble_when_enabled(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path="" + ), + training=nlconfig.TrainingConfig( + output_mode="ensemble", + ensemble_size=3, + ), + ) + model = DummyARModel(args=DummyArgs(), config=config, datastore=datastore) + + preds = model(torch.ones(2, 4, 5)) + + assert preds.shape == (3, 2, 4, 5) + + +def test_all_gather_cat_returns_input_without_trainer(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path="" + ) + ) + model = DummyARModel(args=DummyArgs(), config=config, datastore=datastore) + tensor = torch.randn(2, 3, 4) + + gathered = model.all_gather_cat(tensor) + + assert torch.equal(gathered, tensor) + + +def test_ar_model_raises_error_on_incompatible_loss_and_output_std(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path="" + ) + ) + + class IncompatibleArgs(DummyArgs): + output_std = True + loss = "wmse" + + with pytest.raises( + ValueError, match="is incompatible with loss function 'wmse'" + ): + DummyARModel( + args=IncompatibleArgs(), config=config, datastore=datastore + ) + diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dd863b657..623635ad4 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -13,7 +13,7 @@ from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM -from neural_lam.weather_dataset import WeatherDataset +from neural_lam.weather_dataset import WeatherDataModule, WeatherDataset from tests.conftest import init_datastore_example from tests.dummy_datastore import DummyDatastore, EnsembleDummyDatastore @@ -391,6 +391,7 @@ def test_ensemble_forcing_without_member_dim_is_shared(): init_states_0, _, forcing_0, target_times_0 = dataset[0] init_states_1, _, forcing_1, target_times_1 = dataset[1] + # Adjacent indices correspond to same time, different member assert torch.equal(target_times_0, target_times_1) assert not torch.equal(init_states_0, init_states_1) assert torch.equal(forcing_0, forcing_1) @@ -496,3 +497,28 @@ def get_dataarray(self, category, split, **kwargs): assert ( forcing.shape[-1] == 0 ), "Expected zero forcing features when forcing is None" + + +def test_datamodule_dataloaders_with_zero_workers(): + """`persistent_workers=True` is invalid when `num_workers=0`.""" + datastore = DummyDatastore(n_timesteps=10) + data_module = WeatherDataModule( + datastore=datastore, + ar_steps_train=2, + ar_steps_eval=2, + batch_size=2, + num_workers=0, + ) + + data_module.setup(stage=None) + + train_batch = next(iter(data_module.train_dataloader())) + val_batch = next(iter(data_module.val_dataloader())) + test_batch = next(iter(data_module.test_dataloader())) + + for batch in (train_batch, val_batch, test_batch): + init_states, target_states, forcing, target_times = batch + assert init_states.ndim == 4 + assert target_states.ndim == 4 + assert forcing.ndim == 4 + assert target_times.ndim == 2 diff --git a/tests/test_probabilistic_forecasting.py b/tests/test_probabilistic_forecasting.py new file mode 100644 index 000000000..46fdb9170 --- /dev/null +++ b/tests/test_probabilistic_forecasting.py @@ -0,0 +1,117 @@ +# Standard library +from pathlib import Path + +# Third-party +import torch +from torch.utils.data import DataLoader + +# First-party +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.metrics import get_metric +from neural_lam.models.graph_lam import GraphLAM +from neural_lam.weather_dataset import WeatherDataset +from tests.dummy_datastore import DummyDatastore + + +class ProbabilisticModelArgs: + output_std = True + loss = "nll" + restore_opt = False + n_example_pred = 0 + graph = "1level" + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 1 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + var_leads_metrics_watch = {} + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + +def test_graph_lam_probabilistic_step_produces_finite_positive_std(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + graph_dir_path = Path(datastore.root_path) / "graph" / "1level" + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path="" + ) + ) + model = GraphLAM( + args=ProbabilisticModelArgs(), + datastore=datastore, + config=config, + ) + + dataset = WeatherDataset(datastore=datastore, split="train", ar_steps=2) + batch = next(iter(DataLoader(dataset, batch_size=2))) + + prediction, target, pred_std, _ = model.common_step(batch) + loss = torch.mean( + model.loss( + prediction, + target, + pred_std, + mask=model.interior_mask_bool, + ) + ) + + assert prediction.shape == target.shape + assert pred_std.shape == target.shape + assert torch.all(pred_std > 0) + assert torch.isfinite(pred_std).all() + assert torch.isfinite(loss) + + +def test_probabilistic_metrics_are_available(): + assert callable(get_metric("nll")) + assert callable(get_metric("crps_gauss")) + +def test_base_graph_model_prevents_softplus_underflow_nans(): + # Because native softplus evaluates to exactly 0.0 at -100, + # it causes division by zero -> NaN loss crashes. + pred_std_raw = torch.tensor([-100.0, -200.0, -1000.0]) + pred_std = torch.clamp(torch.nn.functional.softplus(pred_std_raw), min=1e-6) + + assert torch.all(pred_std > 0) + assert not torch.any(pred_std == 0.0) + +def test_ar_model_ensemble_samples_from_pred_std(): + datastore = DummyDatastore(n_grid_points=16, n_timesteps=8) + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection(kind=datastore.SHORT_NAME, config_path=""), + training=nlconfig.TrainingConfig(output_mode="ensemble", ensemble_size=1000), + ) + + from neural_lam.models.ar_model import ARModel + class DummyArgs(ProbabilisticModelArgs): + loss = "mse" + output_std = True + + class DummyARModel(ARModel): + def predict_step(self, prev_state, prev_prev_state, forcing): + return prev_state, None + + model = DummyARModel(args=DummyArgs(), config=config, datastore=datastore) + + preds = torch.ones(2, 4, 5) + pred_std = torch.full((2, 4, 5), 5.0) + + ensemble_preds = model(preds, pred_std=pred_std) + + assert ensemble_preds.shape == (1000, 2, 4, 5) + + # Assert variance matches the model's predicted std, NOT the default 0.01 fallback + measured_std = torch.std(ensemble_preds, dim=0) + assert torch.allclose(measured_std, pred_std, atol=0.5) diff --git a/tests/test_training.py b/tests/test_training.py index 972740695..2a9a0d315 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -42,10 +42,7 @@ def run_simple_training(datastore, set_output_std): max_epochs=1, deterministic=True, accelerator=device_name, - # XXX: `devices` has to be set to 2 otherwise - # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails - # because it expects to aggregate over multiple devices - devices=2, + devices=1, log_every_n_steps=1, # use `detect_anomaly` to ensure that we don't have NaNs popping up # during training @@ -76,7 +73,7 @@ def run_simple_training(datastore, set_output_std): class ModelArgs: output_std = set_output_std - loss = "mse" + loss = "nll" if set_output_std else "mse" restore_opt = False n_example_pred = 1 # XXX: this should be superfluous when we have already defined the