Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e7eb060
refactor: add Forecaster and StepPredictor abstract base classes
Sir-Sloth-The-Lazy Feb 21, 2026
4aa3112
refactor: add ARForecaster implementing auto-regressive unrolling
Sir-Sloth-The-Lazy Feb 21, 2026
ea5f185
refactor: add ForecasterModule as pl.LightningModule wrapper
Sir-Sloth-The-Lazy Feb 21, 2026
6892997
refactor: change BaseGraphModel base class to StepPredictor
Sir-Sloth-The-Lazy Feb 21, 2026
0bd2f5e
refactor: update train_model.py and __init__.py for new hierarchy
Sir-Sloth-The-Lazy Feb 21, 2026
1781aba
test: add tests for refactored model hierarchy and checkpoint loading
Sir-Sloth-The-Lazy Feb 21, 2026
3e1011e
refactor: archive legacy monolithic ar_model.py
Sir-Sloth-The-Lazy Feb 21, 2026
4e2b722
fix: update tests for new model hierarchy and fix pre-existing bugs
Sir-Sloth-The-Lazy Feb 21, 2026
5305a65
Removed ar_model.py
Sir-Sloth-The-Lazy Feb 24, 2026
b013643
Merge branch 'main' into refactor/model-class-hierarchy-issue-49
Sir-Sloth-The-Lazy Feb 27, 2026
3ba80f9
refactor: replace args namespace with explicit params in StepPredicto…
Sir-Sloth-The-Lazy Mar 3, 2026
b086ad1
refactor: address remaining review comments on PR #208
Sir-Sloth-The-Lazy Mar 3, 2026
211a2ce
StepPredictor now allows static features
Sir-Sloth-The-Lazy Mar 6, 2026
be58b66
ForecasterModule now queries state_mean and state_std
Sir-Sloth-The-Lazy Mar 6, 2026
d02ca1d
StepPredictor and moved the datastore retrieval directly into BaseGra…
Sir-Sloth-The-Lazy Mar 6, 2026
00083c5
Broad namespace remap and added regression test
Sir-Sloth-The-Lazy Mar 7, 2026
c81958c
Update neural_lam/models/forecaster_module.py
Sir-Sloth-The-Lazy Mar 18, 2026
a65488b
refactor: address PR review comments on model class hierarchy
Sir-Sloth-The-Lazy Mar 18, 2026
e659f7a
refactor: fold extra leading dims into effective batch in forecast_fo…
Sir-Sloth-The-Lazy Mar 19, 2026
f78b3c4
fix: return folded target_in to match folded prediction in forecast_f…
Sir-Sloth-The-Lazy Mar 19, 2026
97b0cfc
Add PropagationNet GNN layer with optional integration into determini…
Sir-Sloth-The-Lazy Mar 22, 2026
b29a862
integration test for the propagation_net for deterministic models
Sir-Sloth-The-Lazy Mar 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,83 @@ def aggregate(self, inputs, index, ptr, dim_size):
return aggr, inputs


class PropagationNet(InteractionNet):
"""
Alternative version of InteractionNet that incentivizes the propagation
of information from sender nodes to receivers.
"""

# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass

def __init__(
self,
edge_index,
input_dim,
update_edges=True,
hidden_layers=1,
hidden_dim=None,
edge_chunk_sizes=None,
aggr_chunk_sizes=None,
aggr="sum",
):
# Use mean aggregation in propagation version to avoid instability
super().__init__(
edge_index,
input_dim,
update_edges=update_edges,
hidden_layers=hidden_layers,
hidden_dim=hidden_dim,
edge_chunk_sizes=edge_chunk_sizes,
aggr_chunk_sizes=aggr_chunk_sizes,
aggr="mean",
)

def forward(self, send_rep, rec_rep, edge_rep):
"""
Apply propagation network to update the representations of receiver
nodes, and optionally the edge representations.

send_rep: (N_send, d_h), vector representations of sender nodes
rec_rep: (N_rec, d_h), vector representations of receiver nodes
edge_rep: (M, d_h), vector representations of edges used

Returns:
rec_rep: (N_rec, d_h), updated vector representations of receiver
nodes
(optionally) edge_rep: (M, d_h), updated vector representations
of edges
"""
# Always concatenate to [rec_nodes, send_nodes] for propagation,
# but only aggregate to rec_nodes
node_reps = torch.cat((rec_rep, send_rep), dim=-2)
edge_rep_aggr, edge_diff = self.propagate(
self.edge_index, x=node_reps, edge_attr=edge_rep
)
rec_diff = self.aggr_mlp(
torch.cat((rec_rep, edge_rep_aggr), dim=-1)
)

# Residual connections
rec_rep = edge_rep_aggr + rec_diff # residual is to aggregation

if self.update_edges:
edge_rep = edge_rep + edge_diff
return rec_rep, edge_rep

return rec_rep
Comment on lines +181 to +198
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is quite a lot of repeated code from the InteractionNets forward here. Could this be refactored to avoid as much repetition, while keeping some clarity of the differences between the two classes?


def message(self, x_j, x_i, edge_attr):
"""
Compute messages from node j to node i.
"""
# Residual connection is to sender node, propagating information
# to edge
return x_j + self.edge_mlp(
torch.cat((edge_attr, x_j, x_i), dim=-1)
)


class SplitMLPs(nn.Module):
"""
Module that feeds chunks of input through different MLPs.
Expand Down
9 changes: 7 additions & 2 deletions neural_lam/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Local
from .base_graph_model import BaseGraphModel
from .base_hi_graph_model import BaseHiGraphModel
from .forecaster_module import ForecasterModule
from .graph_lam import GraphLAM
from .hi_lam import HiLAM
from .hi_lam_parallel import HiLAMParallel

MODELS = {
"graph_lam": GraphLAM,
"hi_lam": HiLAM,
"hi_lam_parallel": HiLAMParallel,
}
84 changes: 84 additions & 0 deletions neural_lam/models/ar_forecaster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Third-party
import torch

# Local
from ..datastore import BaseDatastore
from .forecaster import Forecaster
from .step_predictor import StepPredictor


class ARForecaster(Forecaster):
"""
Subclass of Forecaster that uses an auto-regressive strategy to
unroll a forecast. Makes use of a StepPredictor at each AR step.
"""

def __init__(self, predictor: StepPredictor, datastore: BaseDatastore):
super().__init__()
self.predictor = predictor

# Register boundary/interior masks on the forecaster, not the predictor
boundary_mask = (
torch.tensor(datastore.boundary_mask.values, dtype=torch.float32)
.unsqueeze(0)
.unsqueeze(-1)
)
self.register_buffer("boundary_mask", boundary_mask, persistent=False)
self.register_buffer(
"interior_mask", 1.0 - self.boundary_mask, persistent=False
)

@property
def predicts_std(self) -> bool:
return self.predictor.predicts_std

def forward(
self,
init_states: torch.Tensor,
forcing_features: torch.Tensor,
boundary_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Unroll the autoregressive model.
boundary_states is used ONLY to overwrite boundary nodes at each step.
The interior prediction at step i must not depend on
boundary_states[:, i] in any other way.
"""

prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
prediction_list = []
pred_std_list = []
pred_steps = forcing_features.shape[1]

for i in range(pred_steps):
forcing = forcing_features[:, i]
boundary_state = boundary_states[:, i]

pred_state, pred_std = self.predictor(
prev_state, prev_prev_state, forcing
)

# Overwrite boundary with true state using ARForecaster's mask
new_state = (
self.boundary_mask * boundary_state
+ self.interior_mask * pred_state
)

prediction_list.append(new_state)
if pred_std is not None:
pred_std_list.append(pred_std)

# Update conditioning states
prev_prev_state = prev_state
prev_state = new_state

prediction = torch.stack(prediction_list, dim=1)
# If predictor outputs std, stack it; otherwise return None so
# ForecasterModule can substitute the constant per_var_std
if pred_std_list:
pred_std = torch.stack(pred_std_list, dim=1)
else:
pred_std = None

return prediction, pred_std
Loading