Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class TrainingConfig:
default_factory=OutputClamping
)

output_mode: str = "deterministic"
ensemble_size: int = 1


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
Expand Down
35 changes: 35 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def __init__(
):
super().__init__()
self.save_hyperparameters(ignore=["datastore"])
self.output_mode = getattr(config.training, "output_mode", "deterministic")
self.ensemble_size = getattr(config.training, "ensemble_size", 1)
if self.output_mode not in ["deterministic", "ensemble"]:
raise ValueError(
f"Unsupported output_mode: {self.output_mode}"
)
self.args = args
self._datastore = datastore
num_state_vars = datastore.get_num_data_vars(category="state")
Expand Down Expand Up @@ -93,6 +99,14 @@ def __init__(

# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
if self.output_std and args.loss in ("mse", "mae", "wmse", "wmae"):
raise ValueError(
f"Prediction of standard deviation (--output_std) is "
f"incompatible with loss function '{args.loss}' as it "
f"leads to degenerate learned-variance training. Use "
f"'nll' or 'crps_gauss' instead."
)

if self.output_std:
# Pred. dim. in grid cell
self.grid_output_dim = 2 * num_state_vars
Expand Down Expand Up @@ -160,6 +174,27 @@ def __init__(
self._datastore.step_length
)

def _sample_ensemble(self, preds, pred_std=None):
if pred_std is not None:
noise = torch.randn(
self.ensemble_size,
*preds.shape,
device=preds.device,
dtype=preds.dtype,
) * pred_std.unsqueeze(0)
else:
noise = torch.randn(
self.ensemble_size,
*preds.shape,
device=preds.device,
dtype=preds.dtype,
) * 0.01
preds = preds.unsqueeze(0).repeat(
self.ensemble_size, *[1] * preds.ndim
)
preds = preds + noise
return preds

def _create_dataarray_from_tensor(
self,
tensor: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
# NOTE: The predicted std. is not scaled in any way here
# linter for some reason does not think softplus is callable
# pylint: disable-next=not-callable
pred_std = torch.nn.functional.softplus(pred_std_raw)
pred_std = torch.clamp(torch.nn.functional.softplus(pred_std_raw), min=1e-6)
else:
pred_delta_mean = net_output
pred_std = None
Expand Down
8 changes: 8 additions & 0 deletions tests/test_ar_model_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch
from neural_lam.models.ar_model import ARModel

def test_ar_model_ensemble():
model = ARModel(None, None, None, output_mode="ensemble", ensemble_size=3)
x = torch.randn(2, 10, 32)
out = model(x)
assert out.shape[0] == 3
117 changes: 117 additions & 0 deletions tests/test_probabilistic_forecasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Standard library
from pathlib import Path

# Third-party
import torch
from torch.utils.data import DataLoader

# First-party
from neural_lam import config as nlconfig
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.metrics import get_metric
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataset
from tests.dummy_datastore import DummyDatastore


class ProbabilisticModelArgs:
output_std = True
loss = "nll"
restore_opt = False
n_example_pred = 0
graph = "1level"
hidden_dim = 4
hidden_layers = 1
processor_layers = 1
mesh_aggr = "sum"
lr = 1.0e-3
val_steps_to_log = [1]
metrics_watch = []
var_leads_metrics_watch = {}
num_past_forcing_steps = 1
num_future_forcing_steps = 1


def test_graph_lam_probabilistic_step_produces_finite_positive_std():
datastore = DummyDatastore(n_grid_points=16, n_timesteps=8)
graph_dir_path = Path(datastore.root_path) / "graph" / "1level"

if not graph_dir_path.exists():
create_graph_from_datastore(
datastore=datastore,
output_root_path=str(graph_dir_path),
n_max_levels=1,
)

config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(
kind=datastore.SHORT_NAME, config_path=""
)
)
model = GraphLAM(
args=ProbabilisticModelArgs(),
datastore=datastore,
config=config,
)

dataset = WeatherDataset(datastore=datastore, split="train", ar_steps=2)
batch = next(iter(DataLoader(dataset, batch_size=2)))

prediction, target, pred_std, _ = model.common_step(batch)
loss = torch.mean(
model.loss(
prediction,
target,
pred_std,
mask=model.interior_mask_bool,
)
)

assert prediction.shape == target.shape
assert pred_std.shape == target.shape
assert torch.all(pred_std > 0)
assert torch.isfinite(pred_std).all()
assert torch.isfinite(loss)


def test_probabilistic_metrics_are_available():
assert callable(get_metric("nll"))
assert callable(get_metric("crps_gauss"))

def test_base_graph_model_prevents_softplus_underflow_nans():
# Because native softplus evaluates to exactly 0.0 at -100,
# it causes division by zero -> NaN loss crashes.
pred_std_raw = torch.tensor([-100.0, -200.0, -1000.0])
pred_std = torch.clamp(torch.nn.functional.softplus(pred_std_raw), min=1e-6)

assert torch.all(pred_std > 0)
assert not torch.any(pred_std == 0.0)

def test_ar_model_ensemble_samples_from_pred_std():
datastore = DummyDatastore(n_grid_points=16, n_timesteps=8)
config = nlconfig.NeuralLAMConfig(
datastore=nlconfig.DatastoreSelection(kind=datastore.SHORT_NAME, config_path=""),
training=nlconfig.TrainingConfig(output_mode="ensemble", ensemble_size=1000),
)

from neural_lam.models.ar_model import ARModel
class DummyArgs(ProbabilisticModelArgs):
loss = "mse"
output_std = True

class DummyARModel(ARModel):
def predict_step(self, prev_state, prev_prev_state, forcing):
return prev_state, None

model = DummyARModel(args=DummyArgs(), config=config, datastore=datastore)

preds = torch.ones(2, 4, 5)
pred_std = torch.full((2, 4, 5), 5.0)

ensemble_preds = model(preds, pred_std=pred_std)

assert ensemble_preds.shape == (1000, 2, 4, 5)

# Assert variance matches the model's predicted std, NOT the default 0.01 fallback
measured_std = torch.std(ensemble_preds, dim=0)
assert torch.allclose(measured_std, pred_std, atol=0.5)