Skip to content

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models#507

Open
Sir-Sloth-The-Lazy wants to merge 22 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/batch-fold-ensemble-prep
Open

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models#507
Sir-Sloth-The-Lazy wants to merge 22 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/batch-fold-ensemble-prep

Conversation

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy commented Mar 24, 2026

Describe your changes

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models, as outlined in #62.

It is integrated into the existing model hierarchy from #208 and can be enabled via the vertical_propnets flag.

Depends on #208.
For changes on top of #208 only, see:
Sir-Sloth-The-Lazy/neural-lam@refactor/model-class-hierarchy-issue-49...refactor/batch-fold-ensemble-prep

Issue Link

Contributes to #62

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

Sir-Sloth-The-Lazy and others added 22 commits February 21, 2026 17:42
- Update test_datasets.py to use ForecasterModule instead of GraphLAM
- Update test_plotting.py to use ForecasterModule instead of GraphLAM
- Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking
- Fix all_gather_cat to handle single-device runs without incorrect dim collapse
…r hierarchy

- Replace opaque argparse.Namespace with explicit keyword arguments in
  StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM,
  and HiLAMParallel __init__ methods
- Reorder methods in step_predictor.py: forward/expand_to_batch now
  appear before clamping methods
- Update all instantiation sites (train_model.py, test_training.py,
  test_prediction_model_classes.py) to pass explicit kwargs
- HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim
  and self.hidden_layers instead of args parameter

Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster
- Pass Forecaster object to ForecasterModule init instead of Predictor
- Remove inline imports in ForecasterModule
- Move loss-related pred_std logic fully into ForecasterModule
- Delete obsolete test_refactored_hierarchy.py
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
- Add predicts_std property to StepPredictor, Forecaster and ARForecaster
  so ForecasterModule can query the forecaster instead of taking output_std
  as a separate constructor argument
- Remove output_std parameter from ForecasterModule; use
  self._forecaster.predicts_std throughout
- Move fallback per_var_std logic out of forecast_for_batch into each
  step method so pred_std is None before fallback, enabling direct None
  checks instead of hparam checks
- Replace len(datastore.boundary_mask) with datastore.num_grid_points in
  StepPredictor to avoid relying on boundary_mask
- Move get_state_feature_weighting and ARForecaster inline imports to
  module-level imports in forecaster_module.py and train_model.py
- Fix statement ordering in StepPredictor.__init__ so register_buffer for
  grid_static_features appears directly after building the tensor
- Replace dict+loop pattern for registering state_mean/state_std buffers
  with two direct register_buffer calls
- Remove all internal Item N checklist references from comments
- Remove TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD env var hack; pass
  weights_only=False explicitly to load_from_checkpoint calls and
  weights_only=True to torch.load in test_graph_creation.py
- Add test_step_predictor_no_static_features to verify models initialise
  and run correctly when the datastore returns None for static features
- Fix graph= -> graph_name= and model.forecaster -> model._forecaster in
  tests to match current API
…r_batch

Makes the forecasting path tolerant to batch-folded execution so that
future ensemble generation can fold (S, B) into (S*B) before calling
ARForecaster, without any changes to ARForecaster or StepPredictor.

Prediction is kept folded through the existing deterministic logging and
aggregation paths so all dim assumptions in training_step, validation_step,
and test_step remain correct. Unfolding to (*leading, T, N, F) is deferred
to ensemble-specific subclasses (e.g. EnsForecasterModule).

Adds test_fold_unfold_equivalence to confirm ARForecaster's rollout is
rank-transparent under a pre-entry fold.
…stic models

- Port PropagationNet as InteractionNet subclass (mean aggr, sender residual
  in messages, aggregation residual in forward)
- Add --vertical_propnets CLI flag to select PropagationNet for grid-mesh
  and vertical message passing edges
- Wire flag through model hierarchy: BaseGraphModel (g2m/m2g),
  BaseHiGraphModel (mesh init), HiLAM (up GNNs)
- Add 13 tests covering unit behavior and backward compatibility
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Mar 24, 2026

@joeloskarsson @sadamov @observingClouds please have a look , if this qualifies as the next step in ensemble prep 😄 . Would be grateful for your feedback !

@Debadri-das
Copy link
Copy Markdown

@joeloskarsson @sadamov would request your further review on this PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants