-
Notifications
You must be signed in to change notification settings - Fork 255
Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models
#507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Sir-Sloth-The-Lazy
wants to merge
22
commits into
mllam:main
Choose a base branch
from
Sir-Sloth-The-Lazy:refactor/batch-fold-ensemble-prep
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
e7eb060
refactor: add Forecaster and StepPredictor abstract base classes
Sir-Sloth-The-Lazy 4aa3112
refactor: add ARForecaster implementing auto-regressive unrolling
Sir-Sloth-The-Lazy ea5f185
refactor: add ForecasterModule as pl.LightningModule wrapper
Sir-Sloth-The-Lazy 6892997
refactor: change BaseGraphModel base class to StepPredictor
Sir-Sloth-The-Lazy 0bd2f5e
refactor: update train_model.py and __init__.py for new hierarchy
Sir-Sloth-The-Lazy 1781aba
test: add tests for refactored model hierarchy and checkpoint loading
Sir-Sloth-The-Lazy 3e1011e
refactor: archive legacy monolithic ar_model.py
Sir-Sloth-The-Lazy 4e2b722
fix: update tests for new model hierarchy and fix pre-existing bugs
Sir-Sloth-The-Lazy 5305a65
Removed ar_model.py
Sir-Sloth-The-Lazy b013643
Merge branch 'main' into refactor/model-class-hierarchy-issue-49
Sir-Sloth-The-Lazy 3ba80f9
refactor: replace args namespace with explicit params in StepPredicto…
Sir-Sloth-The-Lazy b086ad1
refactor: address remaining review comments on PR #208
Sir-Sloth-The-Lazy 211a2ce
StepPredictor now allows static features
Sir-Sloth-The-Lazy be58b66
ForecasterModule now queries state_mean and state_std
Sir-Sloth-The-Lazy d02ca1d
StepPredictor and moved the datastore retrieval directly into BaseGra…
Sir-Sloth-The-Lazy 00083c5
Broad namespace remap and added regression test
Sir-Sloth-The-Lazy c81958c
Update neural_lam/models/forecaster_module.py
Sir-Sloth-The-Lazy a65488b
refactor: address PR review comments on model class hierarchy
Sir-Sloth-The-Lazy e659f7a
refactor: fold extra leading dims into effective batch in forecast_fo…
Sir-Sloth-The-Lazy f78b3c4
fix: return folded target_in to match folded prediction in forecast_f…
Sir-Sloth-The-Lazy 97b0cfc
Add PropagationNet GNN layer with optional integration into determini…
Sir-Sloth-The-Lazy b29a862
integration test for the propagation_net for deterministic models
Sir-Sloth-The-Lazy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,11 @@ | ||
| # Local | ||
| from .base_graph_model import BaseGraphModel | ||
| from .base_hi_graph_model import BaseHiGraphModel | ||
| from .forecaster_module import ForecasterModule | ||
| from .graph_lam import GraphLAM | ||
| from .hi_lam import HiLAM | ||
| from .hi_lam_parallel import HiLAMParallel | ||
|
|
||
| MODELS = { | ||
| "graph_lam": GraphLAM, | ||
| "hi_lam": HiLAM, | ||
| "hi_lam_parallel": HiLAMParallel, | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Third-party | ||
| import torch | ||
|
|
||
| # Local | ||
| from ..datastore import BaseDatastore | ||
| from .forecaster import Forecaster | ||
| from .step_predictor import StepPredictor | ||
|
|
||
|
|
||
| class ARForecaster(Forecaster): | ||
| """ | ||
| Subclass of Forecaster that uses an auto-regressive strategy to | ||
| unroll a forecast. Makes use of a StepPredictor at each AR step. | ||
| """ | ||
|
|
||
| def __init__(self, predictor: StepPredictor, datastore: BaseDatastore): | ||
| super().__init__() | ||
| self.predictor = predictor | ||
|
|
||
| # Register boundary/interior masks on the forecaster, not the predictor | ||
| boundary_mask = ( | ||
| torch.tensor(datastore.boundary_mask.values, dtype=torch.float32) | ||
| .unsqueeze(0) | ||
| .unsqueeze(-1) | ||
| ) | ||
| self.register_buffer("boundary_mask", boundary_mask, persistent=False) | ||
| self.register_buffer( | ||
| "interior_mask", 1.0 - self.boundary_mask, persistent=False | ||
| ) | ||
|
|
||
| @property | ||
| def predicts_std(self) -> bool: | ||
| return self.predictor.predicts_std | ||
|
|
||
| def forward( | ||
| self, | ||
| init_states: torch.Tensor, | ||
| forcing_features: torch.Tensor, | ||
| boundary_states: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Unroll the autoregressive model. | ||
| boundary_states is used ONLY to overwrite boundary nodes at each step. | ||
| The interior prediction at step i must not depend on | ||
| boundary_states[:, i] in any other way. | ||
| """ | ||
|
|
||
| prev_prev_state = init_states[:, 0] | ||
| prev_state = init_states[:, 1] | ||
| prediction_list = [] | ||
| pred_std_list = [] | ||
| pred_steps = forcing_features.shape[1] | ||
|
|
||
| for i in range(pred_steps): | ||
| forcing = forcing_features[:, i] | ||
| boundary_state = boundary_states[:, i] | ||
|
|
||
| pred_state, pred_std = self.predictor( | ||
| prev_state, prev_prev_state, forcing | ||
| ) | ||
|
|
||
| # Overwrite boundary with true state using ARForecaster's mask | ||
| new_state = ( | ||
| self.boundary_mask * boundary_state | ||
| + self.interior_mask * pred_state | ||
| ) | ||
|
|
||
| prediction_list.append(new_state) | ||
| if pred_std is not None: | ||
| pred_std_list.append(pred_std) | ||
|
|
||
| # Update conditioning states | ||
| prev_prev_state = prev_state | ||
| prev_state = new_state | ||
|
|
||
| prediction = torch.stack(prediction_list, dim=1) | ||
| # If predictor outputs std, stack it; otherwise return None so | ||
| # ForecasterModule can substitute the constant per_var_std | ||
| if pred_std_list: | ||
| pred_std = torch.stack(pred_std_list, dim=1) | ||
| else: | ||
| pred_std = None | ||
|
|
||
| return prediction, pred_std |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is quite a lot of repeated code from the InteractionNets forward here. Could this be refactored to avoid as much repetition, while keeping some clarity of the differences between the two classes?