From ccc452b63035c76e0933a522870634a31de406a5 Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:24:25 +0000 Subject: [PATCH] refactor: model class hierarchy into Forecaster/StepPredictor layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the monolithic ARModel class into a composable hierarchy: - ForecasterModule (pl.LightningModule): Training loop, metrics, plotting - ARForecaster (nn.Module): Auto-regressive unrolling - StepPredictor (nn.Module): Single-step prediction interface - BaseGraphModel inherits StepPredictor instead of ARModel This separation enables: - Non-autoregressive forecasters - New step predictor architectures (e.g. Vision Transformers) - Ensemble strategies without modifying training infrastructure Also fixes two pre-existing bugs: - interior_mask_bool shape (1,) → (N,) for correct loss masking - all_gather_cat dimension collapse on single-device runs Refs #49 Co-authored-by: Claude Sonnet 4.5 --- .github/workflows/ci-pypi-deploy.yml | 9 +- .gitignore | 12 +- AGENTS.md | 97 --- CHANGELOG.md | 54 +- README.md | 2 +- neural_lam/create_graph.py | 8 +- neural_lam/datastore/base.py | 90 +- neural_lam/datastore/mdp.py | 87 +- .../compute_standardization_stats.py | 18 +- neural_lam/datastore/npyfilesmeps/store.py | 7 +- neural_lam/interaction_net.py | 24 +- neural_lam/models/__init__.py | 9 +- neural_lam/models/ar_forecaster.py | 84 ++ neural_lam/models/ar_model.py | 779 ------------------ neural_lam/models/base_graph_model.py | 267 ++---- neural_lam/models/base_hi_graph_model.py | 43 +- neural_lam/models/forecaster.py | 36 + neural_lam/models/forecaster_module.py | 604 ++++++++++++++ neural_lam/models/graph_lam.py | 37 +- neural_lam/models/hi_lam.py | 53 +- neural_lam/models/hi_lam_parallel.py | 35 +- neural_lam/models/step_predictor.py | 286 +++++++ neural_lam/plot_graph.py | 192 ++--- neural_lam/train_model.py | 90 +- neural_lam/utils.py | 242 +----- neural_lam/vis.py | 281 ++----- neural_lam/weather_dataset.py | 120 +-- pyproject.toml | 30 +- tests/conftest.py | 20 +- tests/dummy_datastore.py | 304 +------ tests/test_clamping.py | 11 +- tests/test_cli.py | 107 --- tests/test_datasets.py | 278 +------ tests/test_datastores.py | 34 +- tests/test_graph_creation.py | 2 +- tests/test_plotting.py | 229 ++--- tests/test_prediction_model_classes.py | 328 ++++++++ tests/test_training.py | 136 +-- 38 files changed, 2013 insertions(+), 3032 deletions(-) delete mode 100644 AGENTS.md create mode 100644 neural_lam/models/ar_forecaster.py delete mode 100644 neural_lam/models/ar_model.py create mode 100644 neural_lam/models/forecaster.py create mode 100644 neural_lam/models/forecaster_module.py create mode 100644 neural_lam/models/step_predictor.py create mode 100644 tests/test_prediction_model_classes.py diff --git a/.github/workflows/ci-pypi-deploy.yml b/.github/workflows/ci-pypi-deploy.yml index b8c358140..926f8d6ff 100644 --- a/.github/workflows/ci-pypi-deploy.yml +++ b/.github/workflows/ci-pypi-deploy.yml @@ -16,14 +16,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v6 - with: - python-version: '3.14' - - name: Install uv - uses: astral-sh/setup-uv@v7 - - name: Build with uv - run: uv build + - uses: actions/setup-python@v5 - uses: casperdcl/deploy-pypi@v2 with: password: ${{ secrets.PYPI_TOKEN }} + pip: wheel -w dist/ --no-deps . upload: ${{ github.event_name == 'release' && github.event.action == 'published' }} diff --git a/.gitignore b/.gitignore index 4f0d3301f..358df4c25 100644 --- a/.gitignore +++ b/.gitignore @@ -81,14 +81,10 @@ tags .DS_Store __MACOSX -# Virtual environments +# pdm (https://pdm-project.org/en/stable/) +.pdm-python .venv -venv -activate - -# Build artifacts -dist/ -build/ -*.egg-info/ +# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm +pdm.lock tests/test_outputs/ diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index dcf05d563..000000000 --- a/AGENTS.md +++ /dev/null @@ -1,97 +0,0 @@ -# AGENTS.md - -Mandatory rules for AI coding agents. Violations will result in rejected PRs. - ---- - -## Codebase - -Neural-LAM: graph-based neural weather prediction for Limited Area Modeling. Models: `GraphLAM`, -`HiLAM`, `HiLAMParallel`. - -**Data flow:** Raw zarr/numpy → `Datastore` → `WeatherDataset` → `WeatherDataModule` → Model → -Predictions - -**Key modules:** -- `datastore/` — `BaseDatastore` (abstract), `MDPDatastore` (zarr via mllam-data-prep) -- `models/` — `ARModel` (autoregressive base, Lightning) → `BaseGraphModel` (encode-process-decode) - → `GraphLAM` / `HiLAM` / `HiLAMParallel` -- `weather_dataset.py` — `WeatherDataset` + `WeatherDataModule` -- `config.py` — YAML config via dataclass-wizard -- `create_graph.py` — builds mesh graphs (must run before training) -- `interaction_net.py` — `InteractionNet` GNN layer (PyG `MessagePassing`) -- `utils.py` — `make_mlp`, normalization helpers - -Config examples: `tests/datastore_examples/` - -## Commands - -These commands need to be prepended with `uv run` or the virtual env activated with `source .venv/bin/activate` first: - -```bash -# Install (PyTorch must be installed first for CUDA variant) -uv pip install --group dev -e . - -# Lint -pre-commit run --all-files # black, isort, flake8, mypy, codespell - -# Test -pytest -vv -s --doctest-modules # all -pytest tests/test_training.py -vv -s # single file -pytest tests/test_training.py::test_fn -vv # single function - -# Run -python -m neural_lam.create_graph --config_path --name -python -m neural_lam.train_model --config_path --model graph_lam --graph -python -m neural_lam.train_model --eval test --config_path --load -``` - -W&B auto-disabled in tests. `DummyDatastore` used; example data downloaded from S3 on first run. - ---- - -## Rules - -### Issues - -1. **Search before creating.** Use any of: GitHub UI search, `gh issue list --state all --search ""`, or `curl "https://api.github.com/search/issues?q=+repo:mllam/neural-lam+type:issue"`. Duplicate issues will be closed. -2. **Every PR requires an issue.** No exceptions. Open one first if none exists. -3. **Include minimal example.** Each issue should include a minimal, reproducible example on how to easily recreate a bug, including all necessary module imports and data. Include full traceback if it is a bug-report. - -### Pull Requests - -1. **Search before creating.** Use any of: GitHub UI search, `gh pr list --state all --search ""`, or `curl "https://api.github.com/search/issues?q=+repo:mllam/neural-lam+type:pr"`. If a PR exists for the same issue, contribute there. -2. **Link the issue.** PR body must contain `closes #` or `refs #`. Unlinked PRs will be - rejected. -3. **Use the PR template.** Fill in every section of `.github/pull_request_template.md`. Do not - delete or skip sections. -4. **Read the full issue thread before writing code.** Rejected approaches and prior decisions are - there. Ignoring them wastes everyone's time. -5. **Run pre-commit hooks locally.** Linting needs to be done locally before each new commit with e.g. `uvx pre-commit run --all` -6. **Testing Mandate.** Run `pytests tests/` before opening a PR and if tests fail do not open the PR , fix the failure first. - -### Communication - -- **Terse.** One sentence per point. No preamble. No summaries of visible diffs. -- **No filler.** Ban list: "Great question", "As mentioned above", "I hope this helps", "Let me know - if you have questions", "Happy to help". -- **No obvious narration.** Do not explain what self-explanatory code does. -- **PR descriptions: what changed and why.** Nothing else. -- **One question at a time.** No shotgun lists of open-ended questions. - -### Context - -- **Re-read the entire thread** before every comment and every push. No exceptions. -- **After a context gap**, reload the full thread (GitHub UI, `gh issue view ` / `gh pr view `, or `curl "https://api.github.com/repos/mllam/neural-lam/issues/"`) before acting. -- **Never repeat** a question already answered or an approach already rejected in the thread. - -### Commits - -- Imperative form, matching existing `git log` style. -- One concern per PR. No unrelated changes. -- AI attribution of tool names is mandatory if used and should be mentioned in the commit message trailer as `Co-authored-by ` - -### Changelog - -Every PR must add a line to `CHANGELOG.md` in the section matching the change type (`Added` / `Changed` / `Fixed` / `Maintenance`). -`maintenance`). diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..fd8b0bb43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,60 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/mllam/neural-lam/compare/v0.5.0...HEAD) -### Added - -- Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov - -- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar - -### Changed - -- Change the default ensemble-loading behavior in `WeatherDataset` / `WeatherDataModule` to use all ensemble members as independent samples for ensemble datastores (with matching ensemble-member selection for forcing when available); single-member behavior now requires explicitly opting in via `--load_single_member` [\#332](https://github.com/mllam/neural-lam/pull/332) @kshirajahere -- Refactor graph loading: move zero-indexing out of the model and update plotting to prepare using the research-branch graph I/O [\#184](https://github.com/mllam/neural-lam/pull/184) @zweihuehner -- Replace `print()`-based `rank_zero_print` with `loguru` `logger.info()` for structured log-level control ([#33](https://github.com/mllam/neural-lam/issues/33)) - ### Fixed -- Initialize `da_forcing_mean` and `da_forcing_std` to `None` when forcing data is absent, fixing `AttributeError` in `WeatherDataset` with `standardize=True` [\#369](https://github.com/mllam/neural-lam/issues/369) @Sir-Sloth-The-Lazy - -- Ensure proper sorting of `analysis_time` in `NpyFilesDatastoreMEPS._get_analysis_times` independent of the order in which files are processed with glob [\#386](https://github.com/mllam/neural-lam/pull/386) @Gopisokk - -- Switch to lat/lon-based plotting with `pcolormesh` and `cartopy` for accurate spatial visualisation regardless of underlying projection. [\#168](https://github.com/mllam/neural-lam/pull/168) @sadamov - -- Replace `shell=True` subprocess call in `compute_standardization_stats.py` with a safe argument list and Python-side hostname parsing to prevent command injection via `SLURM_JOB_NODELIST` [\#264](https://github.com/mllam/neural-lam/pull/264) @ashum9 - -- Avoid NaN when standardizing fields with zero std [#189](https://github.com/mllam/neural-lam/pull/189) @varunsiravuri -- Replaces multiple `assert` statements used for runtime input validation with explicit `ValueError` [\#279](https://github.com/mllam/neural-lam/pull/279) @Sir-Sloth-The-Lazy - - Fix README image paths to use absolute GitHub URLs so images display correctly on PyPI [\#188](https://github.com/mllam/neural-lam/pull/188) @bk-simon -- Fix typo in `ar_model.py` that causes `AttributeError` during evaluation [\#204](https://github.com/mllam/neural-lam/pull/204) @ritinikhil - - Changed the hardcoded True to a conditional check "persistent_workers=self.num_workers > 0" [\#235](https://github.com/mllam/neural-lam/pull/235) @santhil-cyber -- Avoid eager download of the MEPS example dataset during pytest collection by lazily initializing it in `tests/conftest.py`, allowing tests to run without triggering a dataset download at import time. [#391](https://github.com/mllam/neural-lam/pull/391) @Saptami191 - -- `fractional_plot_bundle` now correctly multiplies by fraction instead of dividing -[\#222](https://github.com/mllam/neural-lam/pull/222) @santhil-cyber - -- Fix `all_gather_cat` producing wrong shapes on single-device runs by only flattening when `all_gather` actually introduces a new leading dimension [\#424](https://github.com/mllam/neural-lam/pull/424) @RajdeepKushwaha5 - -### Added - -- Expose `--wandb_id` CLI argument to allow resuming an existing W&B run by - ID. When provided, `resume="allow"` is set automatically so the same job - script works for both the initial submission and all resubmissions, making - it suitable for HPC systems with limited job runtimes or that may crash. - [\#197](https://github.com/mllam/neural-lam/pull/197) @Mani212005 - - - -- Fix Slack domain link [\#288](https://github.com/mllam/neural-lam/pull/288) @sadamov - -### Fixed - -- Infer spatial coordinate names for MDPDatastore (rather than assuming names `x` and `y`), allows for e.g. lat/lon regular grids [\#169](https://github.com/mllam/neural-lam/pull/169) @leifdenby - +- `fractional_plot_bundle` now correctly multiplies by fraction instead of dividing [ +\#222](https://github.com/mllam/neural-lam/pull/222) @santhil-cyber ### Maintenance @@ -74,10 +28,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix caching of MEPS example data in CI/CD [\#181](https://github.com/mllam/neural-lam/pull/181) @observingClouds -- Migrated build backend from PDM to Hatchling with hatch-vcs and added uv build in deploy CI - -- Warn when running with `--eval` without `--load` to avoid accidentally evaluating randomly initialized weights [#190](https://github.com/mllam/neural-lam/pull/190) @varunsiravuri - ## [v0.5.0](https://github.com/mllam/neural-lam/releases/tag/v0.5.0) This release contains maintenance and fixes, preventing some unexpected crashes and improving CICD and testing. diff --git a/README.md b/README.md index 46a077b79..9056f0006 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![slack](https://img.shields.io/badge/slack-join-brightgreen.svg?logo=slack)](https://kutt.to/mllam) +[![slack](https://img.shields.io/badge/slack-join-brightgreen.svg?logo=slack)](https://kutt.it/mllam) [![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg?branch=main)](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml) [![CPU+GPU testing](https://github.com/mllam/neural-lam/actions/workflows/install-and-test.yml/badge.svg?branch=main)](https://github.com/mllam/neural-lam/actions/workflows/install-and-test.yml) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index b24b96c23..e0d81ead5 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -11,7 +11,6 @@ import scipy.spatial import torch import torch_geometric as pyg -from loguru import logger from torch_geometric.utils.convert import from_networkx # Local @@ -231,7 +230,7 @@ def create_graph( """ os.makedirs(graph_dir_path, exist_ok=True) - logger.info(f"Writing graph components to {graph_dir_path}") + print(f"Writing graph components to {graph_dir_path}") grid_xy = torch.tensor(xy) pos_max = torch.max(torch.abs(grid_xy)) @@ -594,8 +593,9 @@ def cli(input_args=None): ) args = parser.parse_args(input_args) - if args.config_path is None: - raise ValueError("Specify your config with --config_path") + assert ( + args.config_path is not None + ), "Specify your config with --config_path" # Load neural-lam configuration and datastore to use _, datastore = load_config_and_datastore(config_path=args.config_path) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index f6dddb007..fc096595c 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -35,14 +35,9 @@ class BaseDatastore(abc.ABC): dimensions (rather than just `time`). # Ensemble vs deterministic data - If the datastore is used to present an ensemble of state realisations, for - example for forecast ensembles, then the `is_ensemble` attribute should be - set to `True` and returned state data from `get_dataarray` is expected to - have an `ensemble_member` dimension. If each ensemble member has its own - forcing values, then `has_ensemble_forcing` should be set to `True`, and - returned forcing data from `get_dataarray` is expected to have an - `ensemble_member` dimension; otherwise forcing data is expected not to have - one. + If the datastore is used to represent ensemble data, then the `is_ensemble` + attribute should be set to True, and returned data from `get_dataarray` is + assumed to have an `ensemble_member` dimension. # Grid index All methods that return data specific to a grid point (like @@ -54,7 +49,6 @@ class BaseDatastore(abc.ABC): """ is_ensemble: bool = False - has_ensemble_forcing: bool = False is_forecast: bool = False @property @@ -219,8 +213,7 @@ def _standardize_datarray( mean = standard_da[f"{category}_mean"] std = standard_da[f"{category}_std"] - eps = np.finfo(std.dtype).eps - return (da - mean) / std.where(std > eps, other=eps) + return (da - mean) / std @abc.abstractmethod def get_dataarray( @@ -249,11 +242,8 @@ def get_dataarray( elapsed_forecast_duration)` dimensions if `is_forecast` is True, or `(time)` if `is_forecast` is False. - If we have multiple ensemble members of state data, the returned state - dataarray is expected to have an additional `ensemble_member` - dimension. If `has_ensemble_forcing=True`, the returned forcing - dataarray is expected to have an additional `ensemble_member` - dimension; otherwise it is expected not to have one. + If the data is ensemble data, the dataarray is expected to have an + additional `ensemble_member` dimension. Parameters ---------- @@ -351,23 +341,6 @@ def get_xy_extent(self, category: str) -> List[float]: ] return [float(v) for v in extent] - @functools.lru_cache - def get_lat_lon(self, category: str) -> np.ndarray: - """ - Return stacked longitude/latitude pairs for the requested category. - """ - - xy = self.get_xy(category=category, stacked=True) - if xy.size == 0: - return xy - - lon_lat = ccrs.PlateCarree().transform_points( - self.coords_projection, - xy[:, 0], - xy[:, 1], - )[:, :2] - return lon_lat - @property @abc.abstractmethod def num_grid_points(self) -> int: @@ -400,8 +373,7 @@ def state_feature_weights_values(self) -> List[float]: @functools.lru_cache def expected_dim_order( - self, - category: Optional[str] = None, + self, category: Optional[str] = None ) -> tuple[str, ...]: """ Return the expected dimension order for the dataarray or dataset @@ -427,6 +399,7 @@ def expected_dim_order( ---------- category : str The category of the dataset (state/forcing/static). + Returns ------- List[str] @@ -445,9 +418,8 @@ def expected_dim_order( elif not self.is_forecast: dim_order.append("time") - if category == "state" and self.is_ensemble: - dim_order.append("ensemble_member") - elif category == "forcing" and self.has_ensemble_forcing: + if self.is_ensemble and category == "state": + # XXX: for now we only assume ensemble data for state variables dim_order.append("ensemble_member") dim_order.append("grid_index") @@ -475,9 +447,7 @@ class BaseRegularGridDatastore(BaseDatastore): `BaseDatastore`) for regular-gridded source data each `grid_index` coordinate value is assumed to be associated with `x` and `y`-values that allow the processed data-arrays can be reshaped back into into 2D - xy-gridded arrays (to change the name of the spatial coordinates the - `spatial_coordinates` value should be changed from its default value of - `("x", "y")`). + xy-gridded arrays. The following methods and attributes must be implemented for datastore that represents regular-gridded data: @@ -486,8 +456,6 @@ class BaseRegularGridDatastore(BaseDatastore): - `get_xy` (method): Return the x, y coordinates of the dataset, with the option to not stack the coordinates (so that they are returned as a 2D grid). - - `get_lat_lon` (method): Return the latitude/longitude coordinates of - the dataset for convenience when plotting. The operation of going from (x,y)-indexed regular grid to `grid_index`-indexed data-array is called "stacking" and the reverse @@ -496,7 +464,7 @@ class BaseRegularGridDatastore(BaseDatastore): `stack_grid_coords` and `unstack_grid_coords` respectively). """ - spatial_coordinates = ("x", "y") + CARTESIAN_COORDS = ["x", "y"] @cached_property @abc.abstractmethod @@ -539,9 +507,8 @@ def unstack_grid_coords( ) -> Union[xr.DataArray, xr.Dataset]: """ Unstack the spatial grid coordinates from `grid_index` into separate `x` - and `y` dimensions to create a 2D grid (if the spatial coordinates have - different names, those are used instead). Only performs unstacking if - the data is currently stacked (has grid_index dimension). + and `y` dimensions to create a 2D grid. Only performs unstacking if the + data is currently stacked (has grid_index dimension). Parameters ---------- @@ -559,33 +526,16 @@ def unstack_grid_coords( # Check whether `grid_index` is a multi-index if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex): - da_or_ds = da_or_ds.set_index(grid_index=self.spatial_coordinates) + da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS) da_or_ds_unstacked = da_or_ds.unstack("grid_index") # Ensure that the x, y dimensions are in the correct order dims = da_or_ds_unstacked.dims - xy_dim_order = [d for d in dims if d in self.spatial_coordinates] - - if xy_dim_order != self.spatial_coordinates: - # work out where the first spatial coordinate is located - # so that we can insert the second spatial coordinate next to it in - # the correct order. Although this looks verbose, it ensures that - # we don't change the order of any other dimensions. - first_xy_dim_index = min( - dims.index(self.spatial_coordinates[0]), - dims.index(self.spatial_coordinates[1]), - ) - new_dim_order = list(dims) - new_dim_order.remove(self.spatial_coordinates[0]) - new_dim_order.remove(self.spatial_coordinates[1]) - new_dim_order.insert( - first_xy_dim_index, self.spatial_coordinates[0] - ) - new_dim_order.insert( - first_xy_dim_index + 1, self.spatial_coordinates[1] - ) - da_or_ds_unstacked = da_or_ds_unstacked.transpose(*new_dim_order) + xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS] + + if xy_dim_order != self.CARTESIAN_COORDS: + da_or_ds_unstacked = da_or_ds_unstacked.transpose("x", "y") return da_or_ds_unstacked @@ -611,7 +561,7 @@ def stack_grid_coords( if "grid_index" in da_or_ds.dims: return da_or_ds - da_or_ds_stacked = da_or_ds.stack(grid_index=self.spatial_coordinates) + da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) # infer what category of data the array represents by finding the # dimension named in the format `{category}_feature` diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 114550e6d..ad5150118 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,6 +1,5 @@ # Standard library import copy -import functools import warnings from datetime import timedelta from functools import cached_property @@ -10,13 +9,12 @@ # Third-party import cartopy.crs as ccrs import mllam_data_prep as mdp -import numpy as np import xarray as xr from loguru import logger from numpy import ndarray # Local -from ..utils import log_on_rank_zero +from ..utils import rank_zero_print from .base import BaseRegularGridDatastore, CartesianGridShape @@ -75,19 +73,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) self._n_boundary_points = n_boundary_points - self.is_ensemble = "ensemble_member" in self._ds["state"].dims - self.has_ensemble_forcing = ( - "forcing" in self._ds - and "ensemble_member" in self._ds["forcing"].dims - ) - log_on_rank_zero( - "The loaded datastore contains the following features:" - ) + rank_zero_print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: if len(self.get_vars_names(category)) > 0: var_names = self.get_vars_names(category) - log_on_rank_zero(f" {category:<8s}: {' '.join(var_names)}") + rank_zero_print(f" {category:<8s}: {' '.join(var_names)}") # check that all three train/val/test splits are available required_splits = ["train", "val", "test"] @@ -98,14 +89,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): f"splits: {available_splits}" ) - log_on_rank_zero("With the following splits (over time):") + rank_zero_print("With the following splits (over time):") for split in required_splits: da_split = self._ds.splits.sel(split_name=split) da_split_start = da_split.sel(split_part="start").load().item() da_split_end = da_split.sel(split_part="end").load().item() - log_on_rank_zero( - f" {split:<8s}: {da_split_start} to {da_split_end}" - ) + rank_zero_print(f" {split:<8s}: {da_split_start} to {da_split_end}") # find out the dimension order for the stacking to grid-index dim_order = None @@ -118,7 +107,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): dim_order == dim_order_ ), "all inputs must have the same dimension order" - self.spatial_coordinates = dim_order + self.CARTESIAN_COORDS = dim_order @property def root_path(self) -> Path: @@ -256,10 +245,8 @@ def get_dataarray( elapsed_forecast_duration)` dimensions if `is_forecast` is True, or `(time)` if `is_forecast` is False. - If we have multiple ensemble members of state data, the returned state - dataarray will have an additional `ensemble_member` dimension. If - `has_ensemble_forcing=True`, the returned forcing dataarray will also - have an additional `ensemble_member` dimension. + If the data is ensemble data, the dataarray will have an additional + `ensemble_member` dimension. Parameters ---------- @@ -288,7 +275,7 @@ def get_dataarray( da_category[coord].attrs["units"] = "m" # set multi-index for grid-index - da_category = da_category.set_index(grid_index=self.spatial_coordinates) + da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS) if "time" in da_category.dims: t_start = ( @@ -454,15 +441,12 @@ def grid_shape_state(self): """ ds_state = self.unstack_grid_coords(self._ds["state"]) - xdim, ydim = self.spatial_coordinates - da_x, da_y = ds_state[xdim], ds_state[ydim] + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) def get_xy(self, category: str, stacked: bool) -> ndarray: - """ - Return the x, y coordinates of the dataset (i.e. the Cartesian - coordinates of the regular gridded dataset). + """Return the x, y coordinates of the dataset. Parameters ---------- @@ -484,62 +468,25 @@ def get_xy(self, category: str, stacked: bool) -> ndarray: # assume variables are stored in dimensions [grid_index, ...] ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category]) - xdim, ydim = self.spatial_coordinates - - da_xs = ds_category[xdim] - da_ys = ds_category[ydim] + da_xs = ds_category.x + da_ys = ds_category.y - assert ( - da_xs.ndim == da_ys.ndim == 1 - ), f"{xdim} and {ydim} coordinates must be 1D" + assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D" da_x, da_y = xr.broadcast(da_xs, da_ys) da_xy = xr.concat([da_x, da_y], dim="grid_coord") if stacked: - da_xy = da_xy.stack(grid_index=self.spatial_coordinates).transpose( + da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose( "grid_index", "grid_coord", ) else: dims = [ - xdim, - ydim, + "x", + "y", "grid_coord", ] da_xy = da_xy.transpose(*dims) return da_xy.values - - @functools.lru_cache - def get_lat_lon(self, category: str) -> np.ndarray: - """ - Return the longitude, latitude coordinates of the dataset as numpy - array for a given category of data. - Override in MDP to use lat/lons directly from xr.Dataset, if available. - - Parameters - ---------- - category : str - The category of the dataset (state/forcing/static). - - Returns - ------- - np.ndarray - The longitude, latitude coordinates of the dataset - with shape `[n_grid_points, 2]`. - """ - # Check first if lat/lon saved in ds - lookup_ds = self._ds - if "latitude" in lookup_ds.coords and "longitude" in lookup_ds.coords: - lon = lookup_ds.longitude - lat = lookup_ds.latitude - elif "lat" in lookup_ds.coords and "lon" in lookup_ds.coords: - lon = lookup_ds.lon - lat = lookup_ds.lat - else: - # Not saved, use method from BaseDatastore to derive from x/y - return super().get_lat_lon(category) - - coords = np.stack((lon.values, lat.values), axis=1) - return coords diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index 7dbcc7ef8..813d7b8e0 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -65,18 +65,14 @@ def get_world_size(): def setup(rank, world_size): # pylint: disable=redefined-outer-name """Initialize the distributed group.""" if "SLURM_JOB_NODELIST" in os.environ: - nodelist = os.environ["SLURM_JOB_NODELIST"] - hostnames = subprocess.check_output( - ["scontrol", "show", "hostnames", nodelist], - ) - hostname_lines = hostnames.decode("utf-8").splitlines() - if not hostname_lines: - raise RuntimeError( - f"SLURM_JOB_NODELIST is set to {repr(nodelist)}, but " - "'scontrol show hostnames' returned no hostnames. " - "Please check your SLURM job configuration." + master_node = ( + subprocess.check_output( + "scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1", + shell=True, ) - master_node = hostname_lines[0].strip() + .strip() + .decode("utf-8") + ) else: print( "\033[91mCareful, you are running this script with --distributed " diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index f65fd8aef..26214e30c 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -142,6 +142,7 @@ class NpyFilesDatastoreMEPS(BaseRegularGridDatastore): """ SHORT_NAME = "npyfilesmeps" + is_ensemble = True is_forecast = True def __init__( @@ -172,8 +173,6 @@ def __init__( self._remove_state_features_with_index = ( self.config.dataset.remove_state_features_with_index ) - self.is_ensemble = self._num_ensemble_members > 1 - self.has_ensemble_forcing = False @property def root_path(self) -> Path: @@ -530,7 +529,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: Returns ------- List[dt.datetime] - The analysis times for the given split, sorted in ascending order. + The analysis times for the given split. """ pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) @@ -548,7 +547,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: f"No files found in {sample_dir} with pattern {pattern}" ) - return sorted(times) + return times def _calc_datetime_forcing_features(self, da_time: xr.DataArray): da_hour_angle = da_time.dt.hour / 12 * np.pi diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 9a0177ff3..2f45b03fa 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -30,8 +30,7 @@ def __init__( """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format, with both sender and receiver - node indices starting at 0 + edge_index: (2,M), Edges in pyg format input_dim: Dimensionality of input representations, for both nodes and edges update_edges: If new edge representations should be computed @@ -46,27 +45,20 @@ def __init__( (None = no chunking, same MLP) aggr: Message aggregation method (sum/mean) """ - if aggr not in ("sum", "mean"): - raise ValueError(f"Unknown aggregation method: {aggr}") + assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}" super().__init__(aggr=aggr) if hidden_dim is None: # Default to input dim if not explicitly given hidden_dim = input_dim + # Make both sender and receiver indices of edge_index start at 0 + edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] + # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - # edge_index is expected to be zero-based and local: - # edge_index[0]: sender indices in [0 .. num_snd-1] - # edge_index[1]: receiver indices in [0 .. num_rec-1] - # The edge indices used in this GNN layer are defined as: - # receivers → [0 .. num_rec-1] - # senders → [num_rec .. num_rec+num_snd-1] - # Hence, sender indices from the input edge_index are offset - # by num_rec to obtain the indices used in this layer. - edge_index = torch.stack( - (edge_index[0] + self.num_rec, edge_index[1]), dim=0 - ) - + edge_index[0] = ( + edge_index[0] + self.num_rec + ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) # Create MLPs diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py index f65387ab6..ad5897f8a 100644 --- a/neural_lam/models/__init__.py +++ b/neural_lam/models/__init__.py @@ -1,6 +1,11 @@ # Local -from .base_graph_model import BaseGraphModel -from .base_hi_graph_model import BaseHiGraphModel +from .forecaster_module import ForecasterModule from .graph_lam import GraphLAM from .hi_lam import HiLAM from .hi_lam_parallel import HiLAMParallel + +MODELS = { + "graph_lam": GraphLAM, + "hi_lam": HiLAM, + "hi_lam_parallel": HiLAMParallel, +} diff --git a/neural_lam/models/ar_forecaster.py b/neural_lam/models/ar_forecaster.py new file mode 100644 index 000000000..9a18d054a --- /dev/null +++ b/neural_lam/models/ar_forecaster.py @@ -0,0 +1,84 @@ +# Third-party +import torch + +# Local +from ..datastore import BaseDatastore +from .forecaster import Forecaster +from .step_predictor import StepPredictor + + +class ARForecaster(Forecaster): + """ + Subclass of Forecaster that uses an auto-regressive strategy to + unroll a forecast. Makes use of a StepPredictor at each AR step. + """ + + def __init__(self, predictor: StepPredictor, datastore: BaseDatastore): + super().__init__() + self.predictor = predictor + + # Register boundary/interior masks on the forecaster, not the predictor + boundary_mask = ( + torch.tensor(datastore.boundary_mask.values, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(-1) + ) + self.register_buffer("boundary_mask", boundary_mask, persistent=False) + self.register_buffer( + "interior_mask", 1.0 - self.boundary_mask, persistent=False + ) + + @property + def predicts_std(self) -> bool: + return self.predictor.predicts_std + + def forward( + self, + init_states: torch.Tensor, + forcing_features: torch.Tensor, + boundary_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Unroll the autoregressive model. + boundary_states is used ONLY to overwrite boundary nodes at each step. + The interior prediction at step i must not depend on + boundary_states[:, i] in any other way. + """ + + prev_prev_state = init_states[:, 0] + prev_state = init_states[:, 1] + prediction_list = [] + pred_std_list = [] + pred_steps = forcing_features.shape[1] + + for i in range(pred_steps): + forcing = forcing_features[:, i] + boundary_state = boundary_states[:, i] + + pred_state, pred_std = self.predictor( + prev_state, prev_prev_state, forcing + ) + + # Overwrite boundary with true state using ARForecaster's mask + new_state = ( + self.boundary_mask * boundary_state + + self.interior_mask * pred_state + ) + + prediction_list.append(new_state) + if pred_std is not None: + pred_std_list.append(pred_std) + + # Update conditioning states + prev_prev_state = prev_state + prev_state = new_state + + prediction = torch.stack(prediction_list, dim=1) + # If predictor outputs std, stack it; otherwise return None so + # ForecasterModule can substitute the constant per_var_std + if pred_std_list: + pred_std = torch.stack(pred_std_list, dim=1) + else: + pred_std = None + + return prediction, pred_std diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py deleted file mode 100644 index a411a3afc..000000000 --- a/neural_lam/models/ar_model.py +++ /dev/null @@ -1,779 +0,0 @@ -# Standard library -import os -import warnings -from typing import Any, Dict, List - -# Third-party -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -import xarray as xr - -# First-party -from neural_lam.utils import get_integer_time - -# Local -from .. import metrics, vis -from ..config import NeuralLAMConfig -from ..datastore import BaseDatastore -from ..loss_weighting import get_state_feature_weighting -from ..weather_dataset import WeatherDataset - - -class ARModel(pl.LightningModule): - """ - Generic auto-regressive weather model. - Abstract class that can be extended. - """ - - # pylint: disable=arguments-differ - # Disable to override args/kwargs from superclass - - def __init__( - self, - args, - config: NeuralLAMConfig, - datastore: BaseDatastore, - ): - super().__init__() - self.save_hyperparameters(ignore=["datastore"]) - self.args = args - self._datastore = datastore - num_state_vars = datastore.get_num_data_vars(category="state") - num_forcing_vars = datastore.get_num_data_vars(category="forcing") - # Load static features standardized - da_static_features = datastore.get_dataarray( - category="static", split=None, standardize=True - ) - if da_static_features is None: - raise ValueError("Static features are required for ARModel") - da_state_stats = datastore.get_standardization_dataarray( - category="state" - ) - da_boundary_mask = datastore.boundary_mask - num_past_forcing_steps = args.num_past_forcing_steps - num_future_forcing_steps = args.num_future_forcing_steps - - # Load static features for grid/data, - self.register_buffer( - "grid_static_features", - torch.tensor(da_static_features.values, dtype=torch.float32), - persistent=False, - ) - - state_stats = { - "state_mean": torch.tensor( - da_state_stats.state_mean.values, dtype=torch.float32 - ), - "state_std": torch.tensor( - da_state_stats.state_std.values, dtype=torch.float32 - ), - # Note that the one-step-diff stats (diff_mean and diff_std) are - # for differences computed on standardized data - "diff_mean": torch.tensor( - da_state_stats.state_diff_mean_standardized.values, - dtype=torch.float32, - ), - "diff_std": torch.tensor( - da_state_stats.state_diff_std_standardized.values, - dtype=torch.float32, - ), - } - - for key, val in state_stats.items(): - self.register_buffer(key, val, persistent=False) - - state_feature_weights = get_state_feature_weighting( - config=config, datastore=datastore - ) - self.feature_weights = torch.tensor( - state_feature_weights, dtype=torch.float32 - ) - - # Double grid output dim. to also output std.-dev. - self.output_std = bool(args.output_std) - if self.output_std: - # Pred. dim. in grid cell - self.grid_output_dim = 2 * num_state_vars - else: - # Pred. dim. in grid cell - self.grid_output_dim = num_state_vars - # Store constant per-variable std.-dev. weighting - # NOTE that this is the inverse of the multiplicative weighting - # in wMSE/wMAE - self.register_buffer( - "per_var_std", - self.diff_std / torch.sqrt(self.feature_weights), - persistent=False, - ) - - # grid_dim from data + static - ( - self.num_grid_nodes, - grid_static_dim, - ) = self.grid_static_features.shape - - self.grid_dim = ( - 2 * num_state_vars - + grid_static_dim - + num_forcing_vars - * (num_past_forcing_steps + num_future_forcing_steps + 1) - ) - - # Instantiate loss function - self.loss = metrics.get_metric(args.loss) - - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - - self.val_metrics: Dict[str, List] = { - "mse": [], - } - self.test_metrics: Dict[str, List] = { - "mse": [], - "mae": [], - } - if self.output_std: - self.test_metrics["output_std"] = [] # Treat as metric - - # For making restoring of optimizer state optional - self.restore_opt = args.restore_opt - - # For example plotting - self.n_example_pred = args.n_example_pred - self.plotted_examples = 0 - - # For storing spatial loss maps during evaluation - self.spatial_loss_maps: List[Any] = [] - - self.time_step_int, self.time_step_unit = get_integer_time( - self._datastore.step_length - ) - - def _create_dataarray_from_tensor( - self, - tensor: torch.Tensor, - time: torch.Tensor, - split: str, - category: str, - ) -> xr.DataArray: - """ - Create an `xr.DataArray` from a tensor, with the correct dimensions and - coordinates to match the datastore used by the model. This function in - in effect is the inverse of what is returned by - `WeatherDataset.__getitem__`. - - Parameters - ---------- - tensor : torch.Tensor - The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature]. The tensor will be copied to the CPU if it is - not already there. - time : torch.Tensor - The time index or indices for the data, given as tensor representing - epoch time in nanoseconds. The tensor will be - copied to the CPU memory if they are not already there. - split : str - The split of the data, either 'train', 'val', or 'test' - category : str - The category of the data, either 'state' or 'forcing' - """ - # TODO: creating an instance of WeatherDataset here on every call is - # not how this should be done but whether WeatherDataset should be - # provided to ARModel or where to put plotting still needs discussion - weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time.cpu(), dtype="datetime64[ns]") - da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category - ) - return da - - def configure_optimizers(self): - opt = torch.optim.AdamW( - self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) - ) - return opt - - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - - @staticmethod - def expand_to_batch(x, batch_size): - """ - Expand tensor with initial batch dimension - """ - return x.unsqueeze(0).expand(batch_size, -1, -1) - - def predict_step(self, prev_state, prev_prev_state, forcing): - """ - Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, - num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, - forcing_dim) - """ - raise NotImplementedError("No prediction step implemented") - - def unroll_prediction(self, init_states, forcing_features, true_states): - """ - Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) - """ - prev_prev_state = init_states[:, 0] - prev_state = init_states[:, 1] - prediction_list = [] - pred_std_list = [] - pred_steps = forcing_features.shape[1] - - for i in range(pred_steps): - forcing = forcing_features[:, i] - border_state = true_states[:, i] - - pred_state, pred_std = self.predict_step( - prev_state, prev_prev_state, forcing - ) - # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, - # d_f) or None - - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) - if self.output_std: - pred_std_list.append(pred_std) - - # Update conditioning states - prev_prev_state = prev_state - prev_state = new_state - - prediction = torch.stack( - prediction_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) - if self.output_std: - pred_std = torch.stack( - pred_std_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) - else: - pred_std = self.per_var_std # (d_f,) - - return prediction, pred_std - - def common_step(self, batch): - """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states - """ - (init_states, target_states, forcing_features, batch_times) = batch - - prediction, pred_std = self.unroll_prediction( - init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) - - return prediction, target_states, pred_std, batch_times - - def training_step(self, batch): - """ - Train on single batch - """ - prediction, target, pred_std, _ = self.common_step(batch) - - # Compute loss - batch_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ) - ) # mean over unrolled times and batch - - log_dict = {"train_loss": batch_loss} - self.log_dict( - log_dict, - prog_bar=True, - on_step=True, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - return batch_loss - - def all_gather_cat(self, tensor_to_gather): - """ - Gather tensors across all ranks, and concatenate across dim. 0 (instead - of stacking in new dim. 0) - - tensor_to_gather: (d1, d2, ...), distributed over K ranks - - returns: - - single-device strategies: (d1, d2, ...) - - multi-device strategies: (K*d1, d2, ...) - """ - gathered = self.all_gather(tensor_to_gather) - # all_gather adds a leading dim (K,) only on multi-device runs; - # on single-device it returns the tensor unchanged. - if gathered.dim() > tensor_to_gather.dim(): - return gathered.flatten(0, 1) - return gathered - - # newer lightning versions requires batch_idx argument, even if unused - # pylint: disable-next=unused-argument - def validation_step(self, batch, batch_idx): - """ - Run validation on single batch - """ - prediction, target, pred_std, _ = self.common_step(batch) - - time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), - dim=0, - ) # (time_steps-1) - mean_loss = torch.mean(time_step_loss) - - # Log loss per time step forward and mean - val_log_dict = { - f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_to_log - if step <= len(time_step_loss) - } - val_log_dict["val_mean_loss"] = mean_loss - self.log_dict( - val_log_dict, - on_step=False, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - - # Store MSEs - entry_mses = metrics.mse( - prediction, - target, - pred_std, - mask=self.interior_mask_bool, - sum_vars=False, - ) # (B, pred_steps, d_f) - self.val_metrics["mse"].append(entry_mses) - - def on_validation_epoch_end(self): - """ - Compute val metrics at the end of val epoch - """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") - - # Clear lists with validation metrics values - for metric_list in self.val_metrics.values(): - metric_list.clear() - - # pylint: disable-next=unused-argument - def test_step(self, batch, batch_idx): - """ - Run test on single batch - """ - # TODO Here batch_times can be used for plotting routines - prediction, target, pred_std, batch_times = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) - - time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), - dim=0, - ) # (time_steps-1,) - mean_loss = torch.mean(time_step_loss) - - # Log loss per time step forward and mean - test_log_dict = { - f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_to_log - } - test_log_dict["test_mean_loss"] = mean_loss - - self.log_dict( - test_log_dict, - on_step=False, - on_epoch=True, - sync_dist=True, - batch_size=batch[0].shape[0], - ) - - # Compute all evaluation metrics for error maps Note: explicitly list - # metrics here, as test_metrics can contain additional ones, computed - # differently, but that should be aggregated on_test_epoch_end - for metric_name in ("mse", "mae"): - metric_func = metrics.get_metric(metric_name) - batch_metric_vals = metric_func( - prediction, - target, - pred_std, - mask=self.interior_mask_bool, - sum_vars=False, - ) # (B, pred_steps, d_f) - self.test_metrics[metric_name].append(batch_metric_vals) - - if self.output_std: - # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) - self.test_metrics["output_std"].append(mean_pred_std) - - # Save per-sample spatial loss for specific times - spatial_loss = self.loss( - prediction, target, pred_std, average_grid=False - ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[ - :, [step - 1 for step in self.args.val_steps_to_log] - ] - self.spatial_loss_maps.append(log_spatial_losses) - # (B, N_log, num_grid_nodes) - - # Plot example predictions (on rank 0 only) - if ( - self.trainer.is_global_zero - and self.plotted_examples < self.n_example_pred - ): - # Need to plot more example predictions - n_additional_examples = min( - prediction.shape[0], - self.n_example_pred - self.plotted_examples, - ) - - self.plot_examples( - batch, - n_additional_examples, - prediction=prediction, - split="test", - ) - - def plot_examples(self, batch, n_examples, split, prediction=None): - """ - Plot the first n_examples forecasts from batch - - batch: batch with data to plot corresponding forecasts for n_examples: - number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes, - d_f), existing prediction. - Generate if None. - """ - if prediction is None: - prediction, target, _, _ = self.common_step(batch) - - target = batch[1] - time = batch[3] - - # Rescale to original data scale - prediction_rescaled = prediction * self.state_std + self.state_mean - target_rescaled = target * self.state_std + self.state_mean - - # Iterate over the examples - for pred_slice, target_slice, time_slice in zip( - prediction_rescaled[:n_examples], - target_rescaled[:n_examples], - time[:n_examples], - ): - # Each slice is (pred_steps, num_grid_nodes, d_f) - self.plotted_examples += 1 # Increment already here - - da_prediction = self._create_dataarray_from_tensor( - tensor=pred_slice, - time=time_slice, - split=split, - category="state", - ).unstack("grid_index") - da_target = self._create_dataarray_from_tensor( - tensor=target_slice, - time=time_slice, - split=split, - category="state", - ).unstack("grid_index") - - var_vmin = ( - torch.minimum( - pred_slice.flatten(0, 1).min(dim=0)[0], - target_slice.flatten(0, 1).min(dim=0)[0], - ) - .cpu() - .numpy() - ) # (d_f,) - var_vmax = ( - torch.maximum( - pred_slice.flatten(0, 1).max(dim=0)[0], - target_slice.flatten(0, 1).max(dim=0)[0], - ) - .cpu() - .numpy() - ) # (d_f,) - var_vranges = list(zip(var_vmin, var_vmax)) - - # Iterate over prediction horizon time steps - for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): - # Create one figure per variable at this time step - var_figs = [ - vis.plot_prediction( - datastore=self._datastore, - title=f"{var_name}, t={t_i}" - f" ({self.time_step_int * t_i}" - f"{self.time_step_unit})", - colorbar_label=var_unit, - vrange=var_vrange, - da_prediction=da_prediction.isel( - state_feature=var_i, time=t_i - 1 - ).squeeze(), - da_target=da_target.isel( - state_feature=var_i, time=t_i - 1 - ).squeeze(), - ) - for var_i, (var_name, var_unit, var_vrange) in enumerate( - zip( - self._datastore.get_vars_names("state"), - self._datastore.get_vars_units("state"), - var_vranges, - ) - ) - ] - - example_i = self.plotted_examples - - for var_name, fig in zip( - self._datastore.get_vars_names("state"), var_figs - ): - - # We need treat logging images differently for different - # loggers. WANDB can log multiple images to the same key, - # while other loggers, as MLFlow, need unique keys for - # each image. - if isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{var_name}_example_{example_i}" - else: - key = f"{var_name}_example" - - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[fig], step=t_i) - else: - warnings.warn( - f"{self.logger} does not support image logging." - ) - - plt.close( - "all" - ) # Close all figs for this time step, saves memory - - # Save pred and target as .pt files - torch.save( - pred_slice.cpu(), - os.path.join( - self.logger.save_dir, - f"example_pred_{self.plotted_examples}.pt", - ), - ) - torch.save( - target_slice.cpu(), - os.path.join( - self.logger.save_dir, - f"example_target_{self.plotted_examples}.pt", - ), - ) - - def create_metric_log_dict(self, metric_tensor, prefix, metric_name): - """ - Put together a dict with everything to log for one metric. Also saves - plots as pdf and csv if using test prefix. - - metric_tensor: (pred_steps, d_f), metric values per time and variable - prefix: string, prefix to use for logging metric_name: string, name of - the metric - - Return: log_dict: dict with everything to log for given metric - """ - log_dict = {} - metric_fig = vis.plot_error_map( - errors=metric_tensor, - datastore=self._datastore, - ) - full_log_name = f"{prefix}_{metric_name}" - log_dict[full_log_name] = metric_fig - - if prefix == "test": - # Save pdf - metric_fig.savefig( - os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") - ) - # Save errors also as csv - np.savetxt( - os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), - metric_tensor.cpu().numpy(), - delimiter=",", - ) - - # Check if metrics are watched, log exact values for specific vars - var_names = self._datastore.get_vars_names(category="state") - if full_log_name in self.args.metrics_watch: - for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var_name = var_names[var_i] - for step in timesteps: - key = f"{full_log_name}_{var_name}_step_{step}" - log_dict[key] = metric_tensor[step - 1, var_i] - - return log_dict - - def aggregate_and_plot_metrics(self, metrics_dict, prefix): - """ - Aggregate and create error map plots for all metrics in metrics_dict - - metrics_dict: dictionary with metric_names and list of tensors - with step-evals. - prefix: string, prefix to use for logging - """ - log_dict = {} - for metric_name, metric_val_list in metrics_dict.items(): - metric_tensor = self.all_gather_cat( - torch.cat(metric_val_list, dim=0) - ) # (N_eval, pred_steps, d_f) - - if self.trainer.is_global_zero: - metric_tensor_averaged = torch.mean(metric_tensor, dim=0) - # (pred_steps, d_f) - - # Take square root after all averaging to change MSE to RMSE - if "mse" in metric_name: - metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) - metric_name = metric_name.replace("mse", "rmse") - - # NOTE: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.state_std - # (pred_steps, d_f) - log_dict.update( - self.create_metric_log_dict( - metric_rescaled, prefix, metric_name - ) - ) - - # Ensure that log_dict has structure for - # logging as dict(str, plt.Figure) - assert all( - isinstance(key, str) and isinstance(value, plt.Figure) - for key, value in log_dict.items() - ) - - if self.trainer.is_global_zero and not self.trainer.sanity_checking: - - current_epoch = self.trainer.current_epoch - - for key, figure in log_dict.items(): - # For other loggers than wandb, add epoch to key. - # Wandb can log multiple images to the same key, while other - # loggers, such as MLFlow need unique keys for each image. - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}-{current_epoch}" - - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[figure]) - - plt.close("all") # Close all figs - - def on_test_epoch_end(self): - """ - Compute test metrics and make plots at the end of test epoch. Will - gather stored tensors and perform plotting and logging on rank 0. - """ - # Create error maps for all test metrics - self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") - - # Plot spatial loss maps - spatial_loss_tensor = self.all_gather_cat( - torch.cat(self.spatial_loss_maps, dim=0) - ) # (N_test, N_log, num_grid_nodes) - if self.trainer.is_global_zero: - mean_spatial_loss = torch.mean( - spatial_loss_tensor, dim=0 - ) # (N_log, num_grid_nodes) - - loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, - datastore=self._datastore, - title=f"Test loss, t={t_i} " - f"({(self.time_step_int * t_i)} {self.time_step_unit})", - ) - for t_i, loss_map in zip( - self.args.val_steps_to_log, mean_spatial_loss - ) - ] - - # log all to same key, sequentially - for i, fig in enumerate(loss_map_figs): - key = "test_loss" - if not isinstance(self.logger, pl.loggers.WandbLogger): - key = f"{key}_{i}" - if hasattr(self.logger, "log_image"): - self.logger.log_image(key=key, images=[fig]) - - # also make without title and save as pdf - pdf_loss_map_figs = [ - vis.plot_spatial_error( - error=loss_map, datastore=self._datastore - ) - for loss_map in mean_spatial_loss - ] - pdf_loss_maps_dir = os.path.join( - self.logger.save_dir, "spatial_loss_maps" - ) - os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): - fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) - # save mean spatial loss as .pt file also - torch.save( - mean_spatial_loss.cpu(), - os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), - ) - - self.spatial_loss_maps.clear() - - def on_load_checkpoint(self, checkpoint): - """ - Perform any changes to state dict before loading checkpoint - """ - loaded_state_dict = checkpoint["state_dict"] - - # Fix for loading older models after IneractionNet refactoring, where - # the grid MLP was moved outside the encoder InteractionNet class - if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: - replace_keys = list( - filter( - lambda key: key.startswith("g2m_gnn.grid_mlp"), - loaded_state_dict.keys(), - ) - ) - for old_key in replace_keys: - new_key = old_key.replace( - "g2m_gnn.grid_mlp", "encoding_grid_mlp" - ) - loaded_state_dict[new_key] = loaded_state_dict[old_key] - del loaded_state_dict[old_key] - if not self.restore_opt: - opt = self.configure_optimizers() - checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fd38a2e67..9348aa6fb 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -6,22 +6,63 @@ from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..interaction_net import InteractionNet -from .ar_model import ARModel +from .step_predictor import StepPredictor -class BaseGraphModel(ARModel): +class BaseGraphModel(StepPredictor): """ Base (abstract) class for graph-based models building on the encode-process-decode idea. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + output_std=output_std, + ) + + # Retrieve difference statistics for rescaling in forward pass + da_state_stats = datastore.get_standardization_dataarray("state") + self.register_buffer( + "diff_mean", + torch.tensor( + da_state_stats.state_diff_mean_standardized.values, + dtype=torch.float32, + ), + persistent=False, + ) + self.register_buffer( + "diff_std", + torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ), + persistent=False, + ) + + # Store architecture hyperparameters for subclass use + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + self.processor_layers = processor_layers + self.mesh_aggr = mesh_aggr # Load graph with static features # NOTE: (IMPORTANT!) mesh nodes MUST have the first # num_mesh_nodes indices, - graph_dir_path = datastore.root_path / "graph" / args.graph + graph_dir_path = datastore.root_path / "graph" / graph_name self.hierarchical, graph_ldict = utils.load_graph( graph_dir_path=graph_dir_path ) @@ -34,20 +75,30 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # Specify dimensions of data self.num_mesh_nodes, _ = self.get_num_mesh() - utils.log_on_rank_zero( + utils.rank_zero_print( f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" ) - # grid_dim from data + static + # Compute grid_input_dim: total input dimensionality on the grid + num_state_vars = datastore.get_num_data_vars(category="state") + num_forcing_vars = datastore.get_num_data_vars(category="forcing") + grid_static_dim = self.grid_static_features.shape[1] + self.grid_input_dim = ( + 2 * num_state_vars + + grid_static_dim + + num_forcing_vars + * (num_past_forcing_steps + num_future_forcing_steps + 1) + ) + self.g2m_edges, g2m_dim = self.g2m_features.shape self.m2g_edges, m2g_dim = self.m2g_features.shape # Define sub-models # Feature embedders for grid - self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) + self.mlp_blueprint_end = [hidden_dim] * (hidden_layers + 1) self.grid_embedder = utils.make_mlp( - [self.grid_dim] + self.mlp_blueprint_end + [self.grid_input_dim] + self.mlp_blueprint_end ) self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -56,215 +107,31 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # encoder self.g2m_gnn = InteractionNet( self.g2m_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) self.encoding_grid_mlp = utils.make_mlp( - [args.hidden_dim] + self.mlp_blueprint_end + [hidden_dim] + self.mlp_blueprint_end ) # decoder self.m2g_gnn = InteractionNet( self.m2g_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) # Output mapping (hidden_dim -> output_dim) self.output_map = utils.make_mlp( - [args.hidden_dim] * (args.hidden_layers + 1) - + [self.grid_output_dim], + [hidden_dim] * (hidden_layers + 1) + [self.grid_output_dim], layer_norm=False, ) # No layer norm on this one # Compute indices and define clamping functions self.prepare_clamping_params(config, datastore) - def prepare_clamping_params( - self, config: NeuralLAMConfig, datastore: BaseDatastore - ): - """ - Prepare parameters for clamping predicted values to valid range - """ - - # Read configs - state_feature_names = datastore.get_vars_names(category="state") - lower_lims = config.training.output_clamping.lower - upper_lims = config.training.output_clamping.upper - - # Check that limits in config are for valid features - unknown_features_lower = set(lower_lims.keys()) - set( - state_feature_names - ) - unknown_features_upper = set(upper_lims.keys()) - set( - state_feature_names - ) - if unknown_features_lower or unknown_features_upper: - raise ValueError( - "State feature limits were provided for unknown features: " - f"{unknown_features_lower.union(unknown_features_upper)}" - ) - - # Constant parameters for clamping - sigmoid_sharpness = 1 - softplus_sharpness = 1 - sigmoid_center = 0 - softplus_center = 0 - - normalize_clamping_lim = ( - lambda x, feature_idx: (x - self.state_mean[feature_idx]) - / self.state_std[feature_idx] - ) - - # Check which clamping functions to use for each feature - sigmoid_lower_upper_idx = [] - sigmoid_lower_lims = [] - sigmoid_upper_lims = [] - - softplus_lower_idx = [] - softplus_lower_lims = [] - - softplus_upper_idx = [] - softplus_upper_lims = [] - - for feature_idx, feature in enumerate(state_feature_names): - if feature in lower_lims and feature in upper_lims: - assert ( - lower_lims[feature] < upper_lims[feature] - ), f'Invalid clamping limits for feature "{feature}",\ - lower: {lower_lims[feature]}, larger than\ - upper: {upper_lims[feature]}' - sigmoid_lower_upper_idx.append(feature_idx) - sigmoid_lower_lims.append( - normalize_clamping_lim(lower_lims[feature], feature_idx) - ) - sigmoid_upper_lims.append( - normalize_clamping_lim(upper_lims[feature], feature_idx) - ) - elif feature in lower_lims and feature not in upper_lims: - softplus_lower_idx.append(feature_idx) - softplus_lower_lims.append( - normalize_clamping_lim(lower_lims[feature], feature_idx) - ) - elif feature not in lower_lims and feature in upper_lims: - softplus_upper_idx.append(feature_idx) - softplus_upper_lims.append( - normalize_clamping_lim(upper_lims[feature], feature_idx) - ) - - self.register_buffer( - "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) - ) - self.register_buffer( - "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) - ) - self.register_buffer( - "softplus_lower_lims", torch.tensor(softplus_lower_lims) - ) - self.register_buffer( - "softplus_upper_lims", torch.tensor(softplus_upper_lims) - ) - - self.register_buffer( - "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) - ) - self.register_buffer( - "clamp_lower_idx", torch.tensor(softplus_lower_idx) - ) - self.register_buffer( - "clamp_upper_idx", torch.tensor(softplus_upper_idx) - ) - - # Define clamping functions - self.clamp_lower_upper = lambda x: ( - self.sigmoid_lower_lims - + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) - * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) - ) - self.clamp_lower = lambda x: ( - self.softplus_lower_lims - + torch.nn.functional.softplus( - x - softplus_center, beta=softplus_sharpness - ) - ) - self.clamp_upper = lambda x: ( - self.softplus_upper_lims - - torch.nn.functional.softplus( - softplus_center - x, beta=softplus_sharpness - ) - ) - - self.inverse_clamp_lower_upper = lambda x: ( - sigmoid_center - + utils.inverse_sigmoid( - (x - self.sigmoid_lower_lims) - / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) - ) - / sigmoid_sharpness - ) - self.inverse_clamp_lower = lambda x: ( - utils.inverse_softplus( - x - self.softplus_lower_lims, beta=softplus_sharpness - ) - + softplus_center - ) - self.inverse_clamp_upper = lambda x: ( - -utils.inverse_softplus( - self.softplus_upper_lims - x, beta=softplus_sharpness - ) - + softplus_center - ) - - def get_clamped_new_state(self, state_delta, prev_state): - """ - Clamp prediction to valid range supplied in config - Returns the clamped new state after adding delta to original state - - Instead of the new state being computed as - $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ - The clamped values will be - $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ - Which means the model will learn to output values in the range of the - inverse clamping function - - state_delta: (B, num_grid_nodes, feature_dim) - prev_state: (B, num_grid_nodes, feature_dim) - """ - - # Assign new state, but overwrite clamped values of each type later - new_state = prev_state + state_delta - - # Sigmoid/logistic clamps between ]a,b[ - if self.clamp_lower_upper_idx.numel() > 0: - idx = self.clamp_lower_upper_idx - - new_state[:, :, idx] = self.clamp_lower_upper( - self.inverse_clamp_lower_upper(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - # Softplus clamps between ]a,infty[ - if self.clamp_lower_idx.numel() > 0: - idx = self.clamp_lower_idx - - new_state[:, :, idx] = self.clamp_lower( - self.inverse_clamp_lower(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - # Softplus clamps between ]-infty,b[ - if self.clamp_upper_idx.numel() > 0: - idx = self.clamp_upper_idx - - new_state[:, :, idx] = self.clamp_upper( - self.inverse_clamp_upper(prev_state[:, :, idx]) - + state_delta[:, :, idx] - ) - - return new_state - def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, @@ -289,7 +156,7 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def forward(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, num_grid_nodes, feature_dim), X_t diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index b556cbd73..a15370147 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -14,8 +14,31 @@ class BaseHiGraphModel(BaseGraphModel): Base class for hierarchical graph models. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + ) # Track number of nodes, edges on each level # Flatten lists for efficient embedding @@ -27,10 +50,10 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): ] # Needs as python list for later # Print some useful info - utils.log_on_rank_zero("Loaded hierarchical graph with structure:") + utils.rank_zero_print("Loaded hierarchical graph with structure:") for level_index, level_mesh_size in enumerate(self.level_mesh_sizes): same_level_edges = self.m2m_features[level_index].shape[0] - utils.log_on_rank_zero( + utils.rank_zero_print( f"level {level_index} - {level_mesh_size} nodes, " f"{same_level_edges} same-level edges" ) @@ -38,8 +61,8 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): if level_index < (self.num_levels - 1): up_edges = self.mesh_up_features[level_index].shape[0] down_edges = self.mesh_down_features[level_index].shape[0] - utils.log_on_rank_zero(f" {level_index}<->{level_index + 1}") - utils.log_on_rank_zero( + utils.rank_zero_print(f" {level_index}<->{level_index + 1}") + utils.rank_zero_print( f" - {up_edges} up edges, {down_edges} down edges" ) # Embedders @@ -81,8 +104,8 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, ) for edge_index in self.mesh_up_edge_index ] @@ -93,8 +116,8 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, update_edges=False, ) for edge_index in self.mesh_down_edge_index diff --git a/neural_lam/models/forecaster.py b/neural_lam/models/forecaster.py new file mode 100644 index 000000000..a2f268c3a --- /dev/null +++ b/neural_lam/models/forecaster.py @@ -0,0 +1,36 @@ +# Standard library +from abc import ABC, abstractmethod + +# Third-party +import torch +from torch import nn + + +class Forecaster(nn.Module, ABC): + """ + Generic forecaster capable of mapping from a set of initial states, + forcing and forces and previous states into a full forecast of the + requested length. + """ + + @property + @abstractmethod + def predicts_std(self) -> bool: + """Whether this forecaster outputs a predicted standard deviation.""" + + @abstractmethod + def forward( + self, + init_states: torch.Tensor, + forcing_features: torch.Tensor, + boundary_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + boundary_states: (B, pred_steps, num_grid_nodes, d_f) + Returns: + prediction: (B, pred_steps, num_grid_nodes, d_f) + pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + """ + pass diff --git a/neural_lam/models/forecaster_module.py b/neural_lam/models/forecaster_module.py new file mode 100644 index 000000000..1d92b8082 --- /dev/null +++ b/neural_lam/models/forecaster_module.py @@ -0,0 +1,604 @@ +# Standard library +import os +import warnings +from typing import Any, Dict, List, Optional + +# Third-party +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import xarray as xr + +# First-party +from neural_lam.utils import get_integer_time + +# Local +from .. import metrics, vis +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset +from .forecaster import Forecaster + + +class ForecasterModule(pl.LightningModule): + """ + Lightning module handling training, validation and testing loops. + Wraps a Forecaster instance which performs the actual prediction. + """ + + # pylint: disable=arguments-differ + + def __init__( + self, + forecaster: Forecaster, + config: NeuralLAMConfig, + datastore: BaseDatastore, + loss: str = "wmse", + lr: float = 1e-3, + restore_opt: bool = False, + n_example_pred: int = 1, + val_steps_to_log: Optional[List[int]] = None, + metrics_watch: Optional[List[str]] = None, + var_leads_metrics_watch: Optional[Dict[int, List[int]]] = None, + ): + super().__init__() + # Resolve mutable defaults + if val_steps_to_log is None: + val_steps_to_log = [ + 1, + ] + if metrics_watch is None: + metrics_watch = [] + if var_leads_metrics_watch is None: + var_leads_metrics_watch = {} + + # Note: datastore is excluded from saved hparams and must be provided + # explicitly when calling load_from_checkpoint(path, + # datastore=datastore) + self.save_hyperparameters(ignore=["datastore", "forecaster"]) + self.datastore = datastore + self.forecaster = forecaster + + # Compute interior_mask_bool directly from datastore + boundary_mask = ( + torch.tensor(datastore.boundary_mask.values, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(-1) + ) # (1, num_grid_nodes, 1) + interior_mask = 1.0 - boundary_mask + self.register_buffer( + "interior_mask_bool", + interior_mask[0, :, 0].to(torch.bool), + persistent=False, + ) + + # Store per_var_std here if predictor does not output std + if not self.forecaster.predicts_std: + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + state_feature_weights = get_state_feature_weighting( + config=config, datastore=datastore + ) + diff_std = torch.tensor( + da_state_stats.state_diff_std_standardized.values, + dtype=torch.float32, + ) + feature_weights_t = torch.tensor( + state_feature_weights, dtype=torch.float32 + ) + self.register_buffer( + "per_var_std", + diff_std / torch.sqrt(feature_weights_t), + persistent=False, + ) + else: + self.per_var_std = None + + # Instantiate loss function + self.loss = metrics.get_metric(loss) + + self.val_metrics: Dict[str, List] = { + "mse": [], + } + self.test_metrics: Dict[str, List] = { + "mse": [], + "mae": [], + } + if self.forecaster.predicts_std: + self.test_metrics["output_std"] = [] # Treat as metric + + # For making restoring of optimizer state optional + self.restore_opt = restore_opt + + # For example plotting + self.n_example_pred = n_example_pred + self.plotted_examples = 0 + + # For storing spatial loss maps during evaluation + self.spatial_loss_maps: List[Any] = [] + + self.time_step_int, self.time_step_unit = get_integer_time( + self.datastore.step_length + ) + + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: torch.Tensor, + split: str, + category: str, + ) -> xr.DataArray: + weather_dataset = WeatherDataset(datastore=self.datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + + def configure_optimizers(self): + opt = torch.optim.AdamW( + self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.95) + ) + return opt + + def training_step(self, batch): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + if pred_std is None: + pred_std = self.per_var_std + + batch_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ) + ) + + log_dict = {"train_loss": batch_loss} + self.log_dict( + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + return batch_loss + + def all_gather_cat(self, tensor_to_gather): + gathered = self.all_gather(tensor_to_gather) + # all_gather adds dim 0 only on multi-device; on single + # device it returns the same tensor unchanged. + if gathered.dim() > tensor_to_gather.dim(): + gathered = gathered.flatten(0, 1) + return gathered + + # pylint: disable-next=unused-argument + def validation_step(self, batch, batch_idx): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + if pred_std is None: + pred_std = self.per_var_std + + time_step_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ), + dim=0, + ) + mean_loss = torch.mean(time_step_loss) + + val_log_dict = { + f"val_loss_unroll{step}": time_step_loss[step - 1] + for step in self.hparams.val_steps_to_log + if step <= len(time_step_loss) + } + val_log_dict["val_mean_loss"] = mean_loss + self.log_dict( + val_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + entry_mses = metrics.mse( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) + self.val_metrics["mse"].append(entry_mses) + + def on_validation_epoch_end(self): + self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") + for metric_list in self.val_metrics.values(): + metric_list.clear() + + # pylint: disable-next=unused-argument + def test_step(self, batch, batch_idx): + (init_states, target_states, forcing_features, _batch_times) = batch + prediction, pred_std = self.forecaster( + init_states, forcing_features, target_states + ) + + if pred_std is not None: + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) + self.test_metrics["output_std"].append(mean_pred_std) + + if pred_std is None: + pred_std = self.per_var_std + + time_step_loss = torch.mean( + self.loss( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + ), + dim=0, + ) + mean_loss = torch.mean(time_step_loss) + + test_log_dict = { + f"test_loss_unroll{step}": time_step_loss[step - 1] + for step in self.hparams.val_steps_to_log + if step <= len(time_step_loss) + } + test_log_dict["test_mean_loss"] = mean_loss + + self.log_dict( + test_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], + ) + + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target_states, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) + self.test_metrics[metric_name].append(batch_metric_vals) + + spatial_loss = self.loss( + prediction, target_states, pred_std, average_grid=False + ) + log_spatial_losses = spatial_loss[ + :, + [ + step - 1 + for step in self.hparams.val_steps_to_log + if step <= spatial_loss.shape[1] + ], + ] + self.spatial_loss_maps.append(log_spatial_losses) + + if ( + self.trainer.is_global_zero + and self.plotted_examples < self.n_example_pred + ): + n_additional_examples = min( + prediction.shape[0], + self.n_example_pred - self.plotted_examples, + ) + + self.plot_examples( + batch, + n_additional_examples, + prediction=prediction, + split="test", + ) + + def plot_examples(self, batch, n_examples, split, prediction): + + target = batch[1] + time = batch[3] + + da_state_stats = self.datastore.get_standardization_dataarray("state") + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=prediction.device, + ) + state_mean = torch.tensor( + da_state_stats.state_mean.values, + dtype=torch.float32, + device=prediction.device, + ) + + prediction_rescaled = prediction * state_std + state_mean + target_rescaled = target * state_std + state_mean + + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], + ): + self.plotted_examples += 1 + + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + + var_vmin = ( + torch.minimum( + pred_slice.flatten(0, 1).min(dim=0)[0], + target_slice.flatten(0, 1).min(dim=0)[0], + ) + .cpu() + .numpy() + ) + var_vmax = ( + torch.maximum( + pred_slice.flatten(0, 1).max(dim=0)[0], + target_slice.flatten(0, 1).max(dim=0)[0], + ) + .cpu() + .numpy() + ) + var_vranges = list(zip(var_vmin, var_vmax)) + + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): + var_figs = [ + vis.plot_prediction( + datastore=self.datastore, + title=f"{var_name} ({var_unit}), " + f"t={t_i} ({(self.time_step_int * t_i)}" + f"{self.time_step_unit})", + vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip( + self.datastore.get_vars_names("state"), + self.datastore.get_vars_units("state"), + var_vranges, + ) + ) + ] + + example_i = self.plotted_examples + + for var_name, fig in zip( + self.datastore.get_vars_names("state"), var_figs + ): + if isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{var_name}_example_{example_i}" + else: + key = f"{var_name}_example" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig], step=t_i) + else: + warnings.warn( + f"{self.logger} does not support image logging." + ) + + plt.close("all") + + torch.save( + pred_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_pred_{self.plotted_examples}.pt", + ), + ) + torch.save( + target_slice.cpu(), + os.path.join( + self.logger.save_dir, + f"example_target_{self.plotted_examples}.pt", + ), + ) + + def create_metric_log_dict(self, metric_tensor, prefix, metric_name): + log_dict = {} + metric_fig = vis.plot_error_map( + errors=metric_tensor, + datastore=self.datastore, + ) + full_log_name = f"{prefix}_{metric_name}" + log_dict[full_log_name] = metric_fig + + if prefix == "test": + metric_fig.savefig( + os.path.join(self.logger.save_dir, f"{full_log_name}.pdf") + ) + np.savetxt( + os.path.join(self.logger.save_dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) + + var_names = self.datastore.get_vars_names(category="state") + if full_log_name in self.hparams.metrics_watch: + for ( + var_i, + timesteps, + ) in self.hparams.var_leads_metrics_watch.items(): + var_name = var_names[var_i] + for step in timesteps: + key = f"{full_log_name}_{var_name}_step_{step}" + log_dict[key] = metric_tensor[step - 1, var_i] + + return log_dict + + def aggregate_and_plot_metrics(self, metrics_dict, prefix): + log_dict = {} + for metric_name, metric_val_list in metrics_dict.items(): + metric_tensor = self.all_gather_cat( + torch.cat(metric_val_list, dim=0) + ) + + if self.trainer.is_global_zero: + metric_tensor_averaged = torch.mean(metric_tensor, dim=0) + + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + + da_state_stats = self.datastore.get_standardization_dataarray( + "state" + ) + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=metric_tensor_averaged.device, + ) + metric_rescaled = metric_tensor_averaged * state_std + + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name + ) + ) + + figure_dict = { + k: v for k, v in log_dict.items() if isinstance(v, plt.Figure) + } + scalar_dict = { + k: v for k, v in log_dict.items() if not isinstance(v, plt.Figure) + } + + if self.trainer.is_global_zero and not self.trainer.sanity_checking: + # Log scalars via Lightning's built-in mechanism + if scalar_dict: + self.log_dict(scalar_dict, sync_dist=True) + + current_epoch = self.trainer.current_epoch + + for key, figure in figure_dict.items(): + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}-{current_epoch}" + + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[figure]) + + plt.close("all") + + def on_test_epoch_end(self): + self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") + + spatial_loss_tensor = self.all_gather_cat( + torch.cat(self.spatial_loss_maps, dim=0) + ) + if self.trainer.is_global_zero: + mean_spatial_loss = torch.mean(spatial_loss_tensor, dim=0) + + loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, + datastore=self.datastore, + title=f"Test loss, t={t_i} " + f"({(self.time_step_int * t_i)} {self.time_step_unit})", + ) + for t_i, loss_map in zip( + self.hparams.val_steps_to_log, mean_spatial_loss + ) + ] + + for i, fig in enumerate(loss_map_figs): + key = "test_loss" + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}_{i}" + if hasattr(self.logger, "log_image"): + self.logger.log_image(key=key, images=[fig]) + + pdf_loss_map_figs = [ + vis.plot_spatial_error( + error=loss_map, datastore=self.datastore + ) + for loss_map in mean_spatial_loss + ] + pdf_loss_maps_dir = os.path.join( + self.logger.save_dir, "spatial_loss_maps" + ) + os.makedirs(pdf_loss_maps_dir, exist_ok=True) + for t_i, fig in zip( + self.hparams.val_steps_to_log, pdf_loss_map_figs + ): + fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) + + torch.save( + mean_spatial_loss.cpu(), + os.path.join(self.logger.save_dir, "mean_spatial_loss.pt"), + ) + + self.spatial_loss_maps.clear() + + def on_load_checkpoint(self, checkpoint): + loaded_state_dict = checkpoint["state_dict"] + + # 1. Broad namespace remap: for pre-refactor checkpoints + # The old ARModel was a flat LightningModule. Everything that belonged + # to the predictor needs to be moved to 'forecaster.predictor.' + old_keys = list(loaded_state_dict.keys()) + for key in old_keys: + if not key.startswith("forecaster.") and key not in ( + "interior_mask_bool", + "per_var_std", + ): + new_key = f"forecaster.predictor.{key}" + loaded_state_dict[new_key] = loaded_state_dict.pop(key) + + # 2. Specific rename from g2m_gnn.grid_mlp -> encoding_grid_mlp + # Will be under forecaster.predictor due to the remap above, or + # already there if from a recent checkpoint before this rename. + if ( + "forecaster.predictor.g2m_gnn.grid_mlp.0.weight" + in loaded_state_dict + ): + replace_keys = list( + filter( + lambda key: key.startswith( + "forecaster.predictor.g2m_gnn.grid_mlp" + ), + loaded_state_dict.keys(), + ) + ) + for old_key in replace_keys: + new_key = old_key.replace( + "forecaster.predictor.g2m_gnn.grid_mlp", + "forecaster.predictor.encoding_grid_mlp", + ) + loaded_state_dict[new_key] = loaded_state_dict[old_key] + del loaded_state_dict[old_key] + + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index 1b48dd94f..81b77f1da 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -17,8 +17,31 @@ class GraphLAM(BaseGraphModel): Oskarsson et al. (2023). """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + ) assert ( not self.hierarchical @@ -27,7 +50,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape - utils.log_on_rank_zero( + utils.rank_zero_print( f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " f"m2g={self.m2g_edges}" ) @@ -42,11 +65,11 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): processor_nets = [ InteractionNet( self.m2m_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, - aggr=args.mesh_aggr, + hidden_dim, + hidden_layers=hidden_layers, + aggr=mesh_aggr, ) - for _ in range(args.processor_layers) + for _ in range(processor_layers) ] self.processor = pyg.nn.Sequential( "mesh_rep, edge_rep", diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index c340c95da..46a2cece3 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -15,26 +15,49 @@ class HiLAM(BaseHiGraphModel): The Hi-LAM model from Oskarsson et al. (2023) """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + ) # Make down GNNs, both for down edges and same level self.mesh_down_gnns = nn.ModuleList( - [self.make_down_gnns(args) for _ in range(args.processor_layers)] + [self.make_down_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels-1) self.mesh_down_same_gnns = nn.ModuleList( - [self.make_same_gnns(args) for _ in range(args.processor_layers)] + [self.make_same_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels) # Make up GNNs, both for up edges and same level self.mesh_up_gnns = nn.ModuleList( - [self.make_up_gnns(args) for _ in range(args.processor_layers)] + [self.make_up_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels-1) self.mesh_up_same_gnns = nn.ModuleList( - [self.make_same_gnns(args) for _ in range(args.processor_layers)] + [self.make_same_gnns() for _ in range(processor_layers)] ) # Nested lists (proc_steps, num_levels) - def make_same_gnns(self, args): + def make_same_gnns(self): """ Make intra-level GNNs. """ @@ -42,14 +65,14 @@ def make_same_gnns(self, args): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.m2m_edge_index ] ) - def make_up_gnns(self, args): + def make_up_gnns(self): """ Make GNNs for processing steps up through the hierarchy. """ @@ -57,14 +80,14 @@ def make_up_gnns(self, args): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.mesh_up_edge_index ] ) - def make_down_gnns(self, args): + def make_down_gnns(self): """ Make GNNs for processing steps down through the hierarchy. """ @@ -72,8 +95,8 @@ def make_down_gnns(self, args): [ InteractionNet( edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + self.hidden_dim, + hidden_layers=self.hidden_layers, ) for edge_index in self.mesh_down_edge_index ] diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index a0a84d293..da7d50df2 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -18,8 +18,31 @@ class HiLAMParallel(BaseHiGraphModel): of Hi-LAM. """ - def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): - super().__init__(args, config=config, datastore=datastore) + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + graph_name: str = "multiscale", + hidden_dim: int = 64, + hidden_layers: int = 1, + processor_layers: int = 4, + mesh_aggr: str = "sum", + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + output_std: bool = False, + ): + super().__init__( + config=config, + datastore=datastore, + graph_name=graph_name, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + processor_layers=processor_layers, + mesh_aggr=mesh_aggr, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + output_std=output_std, + ) # Processor GNNs # Create the complete edge_index combining all edges for processing @@ -31,18 +54,18 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): total_edge_index = torch.cat(total_edge_index_list, dim=1) self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] - if args.processor_layers == 0: + if processor_layers == 0: self.processor = lambda x, edge_attr: (x, edge_attr) else: processor_nets = [ InteractionNet( total_edge_index, - args.hidden_dim, - hidden_layers=args.hidden_layers, + hidden_dim, + hidden_layers=hidden_layers, edge_chunk_sizes=self.edge_split_sections, aggr_chunk_sizes=self.level_mesh_sizes, ) - for _ in range(args.processor_layers) + for _ in range(processor_layers) ] self.processor = pyg.nn.Sequential( "mesh_rep, edge_rep", diff --git a/neural_lam/models/step_predictor.py b/neural_lam/models/step_predictor.py new file mode 100644 index 000000000..4ba7a8393 --- /dev/null +++ b/neural_lam/models/step_predictor.py @@ -0,0 +1,286 @@ +# Standard library +from abc import ABC, abstractmethod +from typing import Optional + +# Third-party +import torch +from torch import nn + +# Local +from .. import utils +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore + + +class StepPredictor(nn.Module, ABC): + """ + Abstract base class for step predictors mapping from the two previous + time steps plus forcing into a prediction of the next state. + """ + + def __init__( + self, + config: NeuralLAMConfig, + datastore: BaseDatastore, + output_std: bool = False, + ): + super().__init__() + + num_state_vars = datastore.get_num_data_vars(category="state") + + # Load static features standardized + da_static_features = datastore.get_dataarray( + category="static", split=None, standardize=True + ) + if da_static_features is None: + # Create empty static features of the correct shape + num_grid_nodes = datastore.num_grid_points + grid_static_features = torch.empty( + (num_grid_nodes, 0), dtype=torch.float32 + ) + else: + grid_static_features = torch.tensor( + da_static_features.values, dtype=torch.float32 + ) + + self.register_buffer( + "grid_static_features", + grid_static_features, + persistent=False, + ) + + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + + self.register_buffer( + "state_mean", + torch.tensor(da_state_stats.state_mean.values, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "state_std", + torch.tensor(da_state_stats.state_std.values, dtype=torch.float32), + persistent=False, + ) + + self.output_std = bool(output_std) + if self.output_std: + self.grid_output_dim = 2 * num_state_vars + else: + self.grid_output_dim = num_state_vars + + (self.num_grid_nodes, _) = self.grid_static_features.shape + + @property + def predicts_std(self) -> bool: + """Whether this predictor outputs a predicted standard deviation.""" + return self.output_std + + def expand_to_batch(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: + """ + Expand tensor with shape (N, d) to (B, N, d) + """ + return x.unsqueeze(0).expand(batch_size, -1, -1) + + @abstractmethod + def forward( + self, + prev_state: torch.Tensor, + prev_prev_state: torch.Tensor, + forcing: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) + + Returns: + pred_state: (B, num_grid_nodes, d_f) + pred_std: (B, num_grid_nodes, d_f) or None + """ + pass + + def prepare_clamping_params( + self, config: NeuralLAMConfig, datastore: BaseDatastore + ): + """ + Prepare parameters for clamping predicted values to valid range + """ + + # Read configs + state_feature_names = datastore.get_vars_names(category="state") + lower_lims = config.training.output_clamping.lower + upper_lims = config.training.output_clamping.upper + + # Check that limits in config are for valid features + unknown_features_lower = set(lower_lims.keys()) - set( + state_feature_names + ) + unknown_features_upper = set(upper_lims.keys()) - set( + state_feature_names + ) + if unknown_features_lower or unknown_features_upper: + raise ValueError( + "State feature limits were provided for unknown features: " + f"{unknown_features_lower.union(unknown_features_upper)}" + ) + + # Constant parameters for clamping + sigmoid_sharpness = 1 + softplus_sharpness = 1 + sigmoid_center = 0 + softplus_center = 0 + + normalize_clamping_lim = ( + lambda x, feature_idx: (x - self.state_mean[feature_idx]) + / self.state_std[feature_idx] + ) + + # Check which clamping functions to use for each feature + sigmoid_lower_upper_idx = [] + sigmoid_lower_lims = [] + sigmoid_upper_lims = [] + + softplus_lower_idx = [] + softplus_lower_lims = [] + + softplus_upper_idx = [] + softplus_upper_lims = [] + + for feature_idx, feature in enumerate(state_feature_names): + if feature in lower_lims and feature in upper_lims: + assert ( + lower_lims[feature] < upper_lims[feature] + ), f'Invalid clamping limits for feature "{feature}",\ + lower: {lower_lims[feature]}, larger than\ + upper: {upper_lims[feature]}' + sigmoid_lower_upper_idx.append(feature_idx) + sigmoid_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + sigmoid_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + elif feature in lower_lims and feature not in upper_lims: + softplus_lower_idx.append(feature_idx) + softplus_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + elif feature not in lower_lims and feature in upper_lims: + softplus_upper_idx.append(feature_idx) + softplus_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + + self.register_buffer( + "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) + ) + self.register_buffer( + "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) + ) + self.register_buffer( + "softplus_lower_lims", torch.tensor(softplus_lower_lims) + ) + self.register_buffer( + "softplus_upper_lims", torch.tensor(softplus_upper_lims) + ) + + self.register_buffer( + "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) + ) + self.register_buffer( + "clamp_lower_idx", torch.tensor(softplus_lower_idx) + ) + self.register_buffer( + "clamp_upper_idx", torch.tensor(softplus_upper_idx) + ) + + # Define clamping functions + self.clamp_lower_upper = lambda x: ( + self.sigmoid_lower_lims + + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) + ) + self.clamp_lower = lambda x: ( + self.softplus_lower_lims + + torch.nn.functional.softplus( + x - softplus_center, beta=softplus_sharpness + ) + ) + self.clamp_upper = lambda x: ( + self.softplus_upper_lims + - torch.nn.functional.softplus( + softplus_center - x, beta=softplus_sharpness + ) + ) + + self.inverse_clamp_lower_upper = lambda x: ( + sigmoid_center + + utils.inverse_sigmoid( + (x - self.sigmoid_lower_lims) + / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + ) + / sigmoid_sharpness + ) + self.inverse_clamp_lower = lambda x: ( + utils.inverse_softplus( + x - self.softplus_lower_lims, beta=softplus_sharpness + ) + + softplus_center + ) + self.inverse_clamp_upper = lambda x: ( + -utils.inverse_softplus( + self.softplus_upper_lims - x, beta=softplus_sharpness + ) + + softplus_center + ) + + def get_clamped_new_state(self, state_delta, prev_state): + """ + Clamp prediction to valid range supplied in config + Returns the clamped new state after adding delta to original state + + Instead of the new state being computed as + $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ + The clamped values will be + $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ + Which means the model will learn to output values in the range of the + inverse clamping function + + state_delta: (B, num_grid_nodes, feature_dim) + prev_state: (B, num_grid_nodes, feature_dim) + """ + + # Assign new state, but overwrite clamped values of each type later + new_state = prev_state + state_delta + + # Sigmoid/logistic clamps between ]a,b[ + if self.clamp_lower_upper_idx.numel() > 0: + idx = self.clamp_lower_upper_idx + + new_state[:, :, idx] = self.clamp_lower_upper( + self.inverse_clamp_lower_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]a,infty[ + if self.clamp_lower_idx.numel() > 0: + idx = self.clamp_lower_idx + + new_state[:, :, idx] = self.clamp_lower( + self.inverse_clamp_lower(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]-infty,b[ + if self.clamp_upper_idx.numel() > 0: + idx = self.clamp_upper_idx + + new_state[:, :, idx] = self.clamp_upper( + self.inverse_clamp_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + return new_state diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py index 6ed7d0268..39b8639f3 100644 --- a/neural_lam/plot_graph.py +++ b/neural_lam/plot_graph.py @@ -16,33 +16,47 @@ GRID_HEIGHT = 0 -def plot_graph( - grid_pos, - hierarchical, - graph_ldict, - show_axis=False, - save=None, -): - """Build a 3D plotly figure of the graph structure. +def main(): + """Plot graph structure in 3D using plotly.""" + parser = ArgumentParser( + description="Plot graph", + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--datastore_config_path", + type=str, + default="tests/datastore_examples/mdp/config.yaml", + help="Path for the datastore config", + ) + parser.add_argument( + "--graph", + type=str, + default="multiscale", + help="Graph to plot", + ) + parser.add_argument( + "--save", + type=str, + help="Name of .html file to save interactive plot to", + ) + parser.add_argument( + "--show_axis", + action="store_true", + help="If the axis should be displayed", + ) - Parameters - ---------- - grid_pos : np.ndarray - Grid node positions, shape (N_grid, 2). - hierarchical : bool - Whether the loaded graph is hierarchical. - graph_ldict : dict - Graph dict as returned by ``utils.load_graph``. - show_axis : bool - If True, show the 3D axis. - save : str or None - If given, save the figure as an HTML file at this path. + args = parser.parse_args() + _, datastore = load_config_and_datastore( + config_path=args.datastore_config_path + ) - Returns - ------- - go.Figure - The plotly figure object. - """ + xy = datastore.get_xy("state", stacked=True) # (N_grid, 2) + pos_max = np.max(np.abs(xy)) + grid_pos = xy / pos_max # Divide by maximum coordinate + + # Load graph data + graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph) + hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) ( g2m_edge_index, m2g_edge_index, @@ -64,22 +78,6 @@ def plot_graph( (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 ) - # Normalize mesh_static_features to a list of tensors for zero_index - # functions: hierarchical -> BufferList, non-hierarchical -> single tensor - msf_as_list = ( - list(mesh_static_features) if hierarchical else [mesh_static_features] - ) - - # The plotting requires the edges to be non-zero-indexed - # with grid indices following mesh indices - m2g_edge_index = utils.zero_index_m2g( - m2g_edge_index, msf_as_list, mesh_first=True, restore=True - ) - - g2m_edge_index = utils.zero_index_g2m( - g2m_edge_index, msf_as_list, mesh_first=True, restore=True - ) - # List of edges to plot, (edge_index, color, line_width, label) edge_plot_list = [ (m2g_edge_index.numpy(), "black", 0.4, "M2G"), @@ -105,37 +103,22 @@ def plot_graph( ] mesh_pos = np.concatenate(mesh_level_pos, axis=0) - # Compute cumulative node offsets per level (in the concatenated - # mesh_pos array, level-l nodes start at level_offsets[l]) - # This is needed as the zero-indexing is applied to each level in - # in load_graph() - level_sizes = [msf.shape[0] for msf in mesh_static_features] - level_offsets = np.cumsum([0] + level_sizes[:-1]) - - # Add intra-level mesh edges (m2m per level) - # Edge indices are zero-indexed per level, so shift by level offset - for level, level_ei in enumerate(m2m_edge_index): - ei_shifted = level_ei.numpy() + level_offsets[level] - edge_plot_list.append((ei_shifted, "blue", 1, f"M2M Level {level}")) - - # Add inter-level mesh edges (up/down connect adjacent levels) - for level, level_up_ei in enumerate(mesh_up_edge_index): - ei_up = level_up_ei.numpy().copy() - # Row 0: source in level l, Row 1: target in level l+1 - ei_up[0] += level_offsets[level] - ei_up[1] += level_offsets[level + 1] - edge_plot_list.append( - (ei_up, "green", 1, f"Mesh up {level}->{level + 1}") - ) + # Add inter-level mesh edges + edge_plot_list += [ + (level_ei.numpy(), "blue", 1, f"M2M Level {level}") + for level, level_ei in enumerate(m2m_edge_index) + ] - for level, level_down_ei in enumerate(mesh_down_edge_index): - ei_down = level_down_ei.numpy().copy() - # Row 0: source in level l+1, Row 1: target in level l - ei_down[0] += level_offsets[level + 1] - ei_down[1] += level_offsets[level] - edge_plot_list.append( - (ei_down, "green", 1, f"Mesh down {level + 1}->{level}") - ) + # Add intra-level mesh edges + up_edges_ei = np.concatenate( + [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + ) + down_edges_ei = np.concatenate( + [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], + axis=1, + ) + edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) + edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) mesh_node_size = 2.5 else: @@ -162,8 +145,8 @@ def plot_graph( width, label, ) in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 3) - edge_end = node_pos[ei[1]] # (M, 3) + edge_start = node_pos[ei[0]] # (M, 2) + edge_end = node_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] x_edges = np.stack( @@ -214,7 +197,8 @@ def plot_graph( fig.update_layout(scene_aspectmode="data") fig.update_traces(connectgaps=False) - if not show_axis: + if not args.show_axis: + # Hide axis fig.update_layout( scene={ "xaxis": {"visible": False}, @@ -223,63 +207,9 @@ def plot_graph( } ) - if save: - fig.write_html(save, include_plotlyjs="cdn") - - return fig - - -def main(): - """Plot graph structure in 3D using plotly.""" - parser = ArgumentParser( - description="Plot graph", - formatter_class=ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--datastore_config_path", - type=str, - default="tests/datastore_examples/mdp/config.yaml", - help="Path for the datastore config", - ) - parser.add_argument( - "--graph", - type=str, - default="multiscale", - help="Graph to plot", - ) - parser.add_argument( - "--save", - type=str, - help="Name of .html file to save interactive plot to", - ) - parser.add_argument( - "--show_axis", - action="store_true", - help="If the axis should be displayed", - ) - - args = parser.parse_args() - _, datastore = load_config_and_datastore( - config_path=args.datastore_config_path - ) - - xy = datastore.get_xy("state", stacked=True) # (N_grid, 2) - pos_max = np.max(np.abs(xy)) - grid_pos = xy / pos_max # Divide by maximum coordinate - - # Load graph data - graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph) - hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) - - fig = plot_graph( - grid_pos=grid_pos, - hierarchical=hierarchical, - graph_ldict=graph_ldict, - show_axis=args.show_axis, - save=args.save, - ) - - if not args.save: + if args.save: + fig.write_html(args.save, include_plotlyjs="cdn") + else: fig.show() diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 06cd608a6..ca42ce96a 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -14,15 +14,10 @@ # Local from . import utils from .config import load_config_and_datastore -from .models import GraphLAM, HiLAM, HiLAMParallel +from .models import MODELS, ForecasterModule +from .models.ar_forecaster import ARForecaster from .weather_dataset import WeatherDataModule -MODELS = { - "graph_lam": GraphLAM, - "hi_lam": HiLAM, - "hi_lam_parallel": HiLAMParallel, -} - @logger.catch def main(input_args=None): @@ -192,18 +187,6 @@ def main(input_args=None): help="""Logger run name, for e.g. MLFlow (with default value `None` neural-lam's default format string is used)""", ) - - # Wandb-specific settings - parser.add_argument( - "--wandb_id", - type=str, - default=None, - help="Wandb run ID to use. If the run ID already exists in the " - "project, W&B resumes that run. If it does not exist, W&B creates " - "a new run with that ID. Useful on HPC systems with limited job " - "runtimes or that may crash, allowing training to be continued " - "across multiple job submissions.", - ) parser.add_argument( "--val_steps_to_log", nargs="+", @@ -236,14 +219,6 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) - parser.add_argument( - "--load_single_member", - action="store_true", - help=( - "If set, only use ensemble member 0 instead of treating all " - "ensemble members as independent samples." - ), - ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -252,27 +227,19 @@ def main(input_args=None): # Check that config only specifies logging for lead times that exist # Check --val_steps_to_log for step in args.val_steps_to_log: - if step > args.ar_steps_eval: - raise ValueError( - f"Can not log validation step {step} when validation is " - f"only unrolled {args.ar_steps_eval} steps. Adjust " - "--val_steps_to_log." - ) + assert 0 < step <= args.ar_steps_eval, ( + f"Can not log validation step {step} when validation is " + f"only unrolled {args.ar_steps_eval} steps. Adjust " + "--val_steps_to_log." + ) # Check --var_leads_metric_watch for var_i, leads in args.var_leads_metrics_watch.items(): for step in leads: - if step > args.ar_steps_eval: - raise ValueError( - f"Can not log validation step {step} for variable " - f"{var_i} when validation is only unrolled " - f"{args.ar_steps_eval} steps. Adjust " - "--var_leads_metric_watch." - ) - - if args.eval and not args.load: - logger.warning( - "Evaluation (--eval) without --load: no checkpoint will be loaded.", - ) + assert 0 < step <= args.ar_steps_eval, ( + f"Can not log validation step {step} for variable {var_i} when " + f"validation is only unrolled {args.ar_steps_eval} steps. " + "Adjust --var_leads_metric_watch." + ) # Get an (actual) random run id as a unique identifier random_run_id = random.randint(0, 9999) @@ -291,7 +258,6 @@ def main(input_args=None): standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, - load_single_member=args.load_single_member, batch_size=args.batch_size, num_workers=args.num_workers, eval_split=args.eval or "test", @@ -315,9 +281,35 @@ def main(input_args=None): except ValueError: raise ValueError("devices should be 'auto' or a list of integers") - # Load model parameters Use new args for model - ModelClass = MODELS[args.model] - model = ModelClass(args, config=config, datastore=datastore) + # Build predictor and forecaster externally, then inject into + # ForecasterModule + predictor_class = MODELS[args.model] + predictor = predictor_class( + config=config, + datastore=datastore, + graph=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss=args.loss, + lr=args.lr, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + ) if args.eval: prefix = f"eval-{args.eval}-" diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 742ef9823..6e6aa6a9b 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -10,7 +10,6 @@ # Third-party import pytorch_lightning as pl import torch -from loguru import logger from pytorch_lightning.loggers import MLFlowLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_only from torch import nn @@ -44,128 +43,6 @@ def __len__(self): def __iter__(self): return (self[i] for i in range(len(self))) - def __itruediv__(self, other): - """Divide each element in list with other""" - return self.__imul__(1.0 / other) - - def __imul__(self, other): - """Multiply each element in list with other""" - for buffer_tensor in self: - buffer_tensor *= other - - return self - - -def zero_index_edge_index(edge_index): - """ - Make both sender and receiver indices of edge_index start at 0 - """ - return edge_index - edge_index.min(dim=1, keepdim=True)[0] - - -def zero_index_m2g( - m2g_edge_index: torch.Tensor, - mesh_static_features: list[torch.Tensor], - mesh_first: bool, - restore: bool = False, -) -> torch.Tensor: - """ - Zero-index the m2g (mesh-to-grid) edge index, or undo this operation. - - Special handling is needed since not all mesh nodes may be present. - - Parameters - ---------- - m2g_edge_index : torch.Tensor - Edge index tensor of shape (2, N_edges). - mesh_static_features : list of torch.Tensor - Mesh node feature tensors. - mesh_first : bool - If True, mesh nodes are indexed before grid nodes. - restore : bool - If True, undo zero-indexing (restore original indices). - - Returns - ------- - torch.Tensor - Edge index tensor with zero-based or restored indices. - """ - - sign = 1 if restore else -1 - - if mesh_first: - # Mesh has the first indices, adjust grid indices (row 1) - num_mesh_nodes = mesh_static_features[0].shape[0] - return torch.stack( - ( - m2g_edge_index[0], - m2g_edge_index[1] + sign * num_mesh_nodes, - ), - dim=0, - ) - else: - # Grid (interior) has the first indices, adjust mesh indices (row 0) - num_interior_nodes = m2g_edge_index[1].max() + 1 - return torch.stack( - ( - m2g_edge_index[0] + sign * num_interior_nodes, - m2g_edge_index[1], - ), - dim=0, - ) - - -def zero_index_g2m( - g2m_edge_index: torch.Tensor, - mesh_static_features: list[torch.Tensor], - mesh_first: bool, - restore: bool = False, -) -> torch.Tensor: - """ - Zero-index the g2m (grid-to-mesh) edge index, or undo this operation. - - Special handling is needed since not all mesh nodes may be present. - - Parameters - ---------- - g2m_edge_index : torch.Tensor - Edge index tensor of shape (2, N_edges). - mesh_static_features : list of torch.Tensor - Mesh node feature tensors. - mesh_first : bool - If True, mesh nodes are indexed before grid nodes. - restore : bool - If True, undo zero-indexing (restore original indices). - - Returns - ------- - torch.Tensor - Edge index tensor with zero-based or restored indices. - """ - - sign = 1 if restore else -1 - - if mesh_first: - # Mesh has the first indices, adjust grid indices (row 0) - num_mesh_nodes = mesh_static_features[0].shape[0] - return torch.stack( - ( - g2m_edge_index[0] + sign * num_mesh_nodes, - g2m_edge_index[1], - ), - dim=0, - ) - else: - # Grid has the first indices, adjust mesh indices (row 1) - num_grid_nodes = g2m_edge_index[0].max() + 1 - return torch.stack( - ( - g2m_edge_index[0], - g2m_edge_index[1] + sign * num_grid_nodes, - ), - dim=0, - ) - def load_graph(graph_dir_path, device="cpu"): """Load all tensors representing the graph from `graph_dir_path`. @@ -219,34 +96,13 @@ def loads_file(fn): weights_only=True, ) - # Load static node features - mesh_static_features = loads_file( - "mesh_features.pt" - ) # List of (N_mesh[l], d_mesh_static) - # Load edges (edge_index) m2m_edge_index = BufferList( - [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], - persistent=False, + loads_file("m2m_edge_index.pt"), persistent=False ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) - # Change first indices to 0 - # m2g and g2m has to be handled specially as not all mesh nodes - # might be indexed - m2g_min_indices = m2g_edge_index.min(dim=1, keepdim=True)[0] - mesh_first = m2g_min_indices[0] < m2g_min_indices[1] - g2m_edge_index = zero_index_g2m( - g2m_edge_index, mesh_static_features, mesh_first=mesh_first - ) - m2g_edge_index = zero_index_m2g( - m2g_edge_index, mesh_static_features, mesh_first=mesh_first - ) - - assert m2g_edge_index.min() >= 0, "Negative node index in m2g" - assert g2m_edge_index.min() >= 0, "Negative node index in g2m" - n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph @@ -260,12 +116,18 @@ def loads_file(fn): longest_edge = max( torch.max(level_features[:, 0]) for level_features in m2m_features ) # Col. 0 is length - - m2m_features = BufferList(m2m_features, persistent=False) - m2m_features /= longest_edge + m2m_features = BufferList( + [level_features / longest_edge for level_features in m2m_features], + persistent=False, + ) g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge + # Load static node features + mesh_static_features = loads_file( + "mesh_features.pt" + ) # List of (N_mesh[l], d_mesh_static) + # Some checks for consistency assert ( len(m2m_features) == n_levels @@ -277,18 +139,10 @@ def loads_file(fn): if hierarchical: # Load up and down edges and features mesh_up_edge_index = BufferList( - [ - zero_index_edge_index(ei) - for ei in loads_file("mesh_up_edge_index.pt") - ], - persistent=False, + loads_file("mesh_up_edge_index.pt"), persistent=False ) # List of (2, M_up[l]) mesh_down_edge_index = BufferList( - [ - zero_index_edge_index(ei) - for ei in loads_file("mesh_down_edge_index.pt") - ], - persistent=False, + loads_file("mesh_down_edge_index.pt"), persistent=False ) # List of (2, M_down[l]) mesh_up_features = loads_file( @@ -299,10 +153,20 @@ def loads_file(fn): ) # List of (M_down[l], d_edge_f) # Rescale - mesh_up_features = BufferList(mesh_up_features, persistent=False) - mesh_up_features /= longest_edge - mesh_down_features = BufferList(mesh_down_features, persistent=False) - mesh_down_features /= longest_edge + mesh_up_features = BufferList( + [ + edge_features / longest_edge + for edge_features in mesh_up_features + ], + persistent=False, + ) + mesh_down_features = BufferList( + [ + edge_features / longest_edge + for edge_features in mesh_down_features + ], + persistent=False, + ) mesh_static_features = BufferList( mesh_static_features, persistent=False @@ -442,19 +306,9 @@ def fractional_plot_bundle(fraction): @rank_zero_only -def log_on_rank_zero(msg: str, level: str = "info", *args, **kwargs): - """Log a message only on rank zero using loguru logger. - - Parameters - ---------- - msg : str - The message to log. - level : str, optional - The logging level (e.g. "info", "warning", "error"). Default is "info". - """ - if rank_zero_only.rank == 0: - log_fn = getattr(logger, level, logger.info) - log_fn(msg, *args, **kwargs) +def rank_zero_print(*args, **kwargs): + """Print only from rank 0 process""" + print(*args, **kwargs) def init_training_logger_metrics(training_logger, val_steps): @@ -477,7 +331,7 @@ def init_training_logger_metrics(training_logger, val_steps): @rank_zero_only def setup_training_logger(datastore, args, run_name): - """Set up the training logger (WandB or MLFlow). + """ Parameters ---------- @@ -492,56 +346,32 @@ def setup_training_logger(datastore, args, run_name): Returns ------- - training_logger : pytorch_lightning.loggers.base + logger : pytorch_lightning.loggers.base Logger object. - - Notes - ----- - When ``--wandb_id`` is given, ``resume="allow"`` is set automatically: - W&B resumes the run if it exists, or creates it with that ID otherwise. - This allows the same job script to be safely resubmitted on HPC systems. - The run name is set to ``None`` when resuming to preserve the existing name. """ - if args.wandb_id and args.logger != "wandb": - logger.warning( - f"--wandb_id is set but logger is {args.logger!r}; " - "the wandb_id will have no effect." - ) - if args.logger == "wandb": - wandb_resume = "allow" if args.wandb_id else None - logger.info( - f"Wandb resume mode: {wandb_resume!r} (id: {args.wandb_id!r})" - ) - return pl.loggers.WandbLogger( + logger = pl.loggers.WandbLogger( project=args.logger_project, - name=None if args.wandb_id else run_name, + name=run_name, config=dict(training=vars(args), datastore=datastore._config), - resume=wandb_resume, - id=args.wandb_id, ) elif args.logger == "mlflow": - if args.wandb_id is not None: - warnings.warn( - "--wandb_id is only used with --logger=wandb and will be " - "ignored." - ) url = os.getenv("MLFLOW_TRACKING_URI") if url is None: raise ValueError( "MLFlow logger requires setting MLFLOW_TRACKING_URI in env." ) - training_logger = CustomMLFlowLogger( + logger = CustomMLFlowLogger( experiment_name=args.logger_project, tracking_uri=url, run_name=run_name, ) - training_logger.log_hyperparams( + logger.log_hyperparams( dict(training=vars(args), datastore=datastore._config) ) - return training_logger + return logger def inverse_softplus(x, beta=1, threshold=20): diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 06f2f6d35..3db4365a4 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,158 +1,13 @@ # Third-party -import cartopy.crs as ccrs -import cartopy.feature as cfeature import matplotlib -import matplotlib.colors import matplotlib.pyplot as plt import numpy as np -import torch import xarray as xr # Local from . import utils from .datastore.base import BaseRegularGridDatastore -# Font sizes shared across all plot functions for visual consistency. -_TITLE_SIZE = 13 # suptitle and per-axes titles -_LABEL_SIZE = 11 # axis / colorbar labels -_TICK_SIZE = 11 # tick labels - - -def _tex_safe(s: str) -> str: - """Escape TeX special characters in s if TeX rendering is currently active. - - Needed because % is a TeX comment character; without escaping it would - silently truncate any text that follows it (e.g. the title for r2m (%)). - """ - if plt.rcParams.get("text.usetex", False): - s = s.replace("%", r"\%") - return s - - -def plot_on_axis( - ax, - da, - datastore, - vmin=None, - vmax=None, - ax_title=None, - cmap="plasma", - boundary_alpha=None, - crop_to_interior=False, -): - """Plot weather state on given axis using datastore metadata. - - Parameters - ---------- - ax : matplotlib.axes.Axes - The axis to plot on. Should have a cartopy projection. - da : xarray.DataArray - The data to plot. Should have shape (N_grid,). - datastore : BaseRegularGridDatastore - The datastore containing metadata about the grid. - vmin : float, optional - Minimum value for color scale. - vmax : float, optional - Maximum value for color scale. - ax_title : str, optional - Title for the axis. - cmap : str or matplotlib.colors.Colormap, optional - Colormap to use for plotting. - boundary_alpha : float, optional - If provided, overlay boundary mask with given alpha transparency. - crop_to_interior : bool, optional - If True, crop the plot to the interior region. - - Returns - ------- - matplotlib.collections.QuadMesh - The mesh object created by pcolormesh. - - """ - - ax.coastlines(resolution="50m") - ax.add_feature(cfeature.BORDERS, linestyle="-", alpha=0.5) - - gl = ax.gridlines( - draw_labels=True, - dms=True, - x_inline=False, - y_inline=False, - ) - gl.top_labels = False - gl.right_labels = False - gl.xlabel_style = {"size": _TICK_SIZE} - gl.ylabel_style = {"size": _TICK_SIZE} - - lats_lons = datastore.get_lat_lon("state") - grid_shape = ( - datastore.grid_shape_state.x, - datastore.grid_shape_state.y, - ) - lons = lats_lons[:, 0].reshape(grid_shape) - lats = lats_lons[:, 1].reshape(grid_shape) - - if isinstance(da, xr.DataArray) and "x" in da.dims and "y" in da.dims: - da = da.transpose("x", "y") - - values = da.values.reshape(grid_shape) - - mesh = ax.pcolormesh( - lons, - lats, - values, - transform=ccrs.PlateCarree(), - vmin=vmin, - vmax=vmax, - cmap=cmap, - shading="auto", - ) - - if boundary_alpha is not None: - # Overlay boundary mask - mask_da = datastore.boundary_mask - mask_values = mask_da.values - if mask_values.ndim == 2 and mask_values.shape[1] == 1: - mask_values = mask_values[:, 0] - mask_2d = mask_values.reshape(grid_shape) - - # Create overlay: 1 where boundary, NaN where interior - overlay = np.where(mask_2d == 1, 1.0, np.nan) - - ax.pcolormesh( - lons, - lats, - overlay, - transform=ccrs.PlateCarree(), - cmap=matplotlib.colors.ListedColormap([(1, 1, 1, boundary_alpha)]), - shading="auto", - ) - - if crop_to_interior: - # Calculate extent of interior - mask_da = datastore.boundary_mask - mask_values = mask_da.values - if mask_values.ndim == 2 and mask_values.shape[1] == 1: - mask_values = mask_values[:, 0] - mask_2d = mask_values.reshape(grid_shape) - - interior_points = mask_2d == 0 - if np.any(interior_points): - interior_lons = lons[interior_points] - interior_lats = lats[interior_points] - - min_lon, max_lon = interior_lons.min(), interior_lons.max() - min_lat, max_lat = interior_lats.min(), interior_lats.max() - - ax.set_extent( - [min_lon, max_lon, min_lat, max_lat], crs=ccrs.PlateCarree() - ) - - if ax_title: - ax.set_title(ax_title, size=_TITLE_SIZE) - - return mesh - @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @@ -190,23 +45,23 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): ax.text(i, j, formatted_error, ha="center", va="center", usetex=False) # Ticks and labels + label_size = 15 ax.set_xticks(np.arange(pred_steps)) pred_hor_i = np.arange(pred_steps) + 1 pred_hor_h = time_step_int * pred_hor_i - ax.set_xticklabels(pred_hor_h, size=_TICK_SIZE) - ax.set_xlabel(f"Lead time ({time_step_unit[0]})", size=_LABEL_SIZE) + ax.set_xticklabels(pred_hor_h, size=label_size) + ax.set_xlabel(f"Lead time ({time_step_unit[0]})", size=label_size) ax.set_yticks(np.arange(d_f)) var_names = datastore.get_vars_names(category="state") var_units = datastore.get_vars_units(category="state") y_ticklabels = [ - _tex_safe(f"{name} ({unit})") - for name, unit in zip(var_names, var_units) + f"{name} ({unit})" for name, unit in zip(var_names, var_units) ] - ax.set_yticklabels(y_ticklabels, rotation=30, size=_TICK_SIZE) + ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) if title: - ax.set_title(title, size=_TITLE_SIZE) + ax.set_title(title, size=15) return fig @@ -218,9 +73,6 @@ def plot_prediction( da_target: xr.DataArray, title=None, vrange=None, - boundary_alpha=0.7, - crop_to_interior=True, - colorbar_label: str = "", ): """ Plot example prediction and grond truth. @@ -228,102 +80,105 @@ def plot_prediction( Each has shape (N_grid,) """ + # Get common scale for values if vrange is None: - vmin = float(min(da_prediction.min(), da_target.min())) - vmax = float(max(da_prediction.max(), da_target.max())) - else: + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) + elif vrange is not None: vmin, vmax = vrange + extent = datastore.get_xy_extent("state") + + # Set up masking of border region + da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region + fig, axes = plt.subplots( 1, 2, - figsize=(13, 6), + figsize=(13, 7), subplot_kw={"projection": datastore.coords_projection}, ) - for ax, da, subtitle in zip( - axes, (da_target, da_prediction), ("Ground Truth", "Prediction") - ): - plot_on_axis( + # Plot pred and target + for ax, da in zip(axes, (da_target, da_prediction)): + ax.coastlines() # Add coastline outlines + da.plot.imshow( ax=ax, - da=da, - datastore=datastore, + origin="lower", + x="x", + extent=extent, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, - ax_title=subtitle, - cmap="viridis", - boundary_alpha=boundary_alpha, - crop_to_interior=crop_to_interior, + cmap="plasma", + transform=datastore.coords_projection, ) + # Ticks and labels + axes[0].set_title("Ground Truth", size=15) + axes[1].set_title("Prediction", size=15) + if title: - fig.suptitle(title, size=_TITLE_SIZE) - - cbar = fig.colorbar( - axes[0].collections[0], - ax=axes, - orientation="horizontal", - location="bottom", - shrink=0.6, - pad=0.02, - ) - cbar.ax.tick_params(labelsize=_TICK_SIZE) - if colorbar_label: - cbar.set_label(_tex_safe(colorbar_label), size=_LABEL_SIZE) + fig.suptitle(title, size=20) return fig @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_spatial_error( - error: torch.Tensor, - datastore: BaseRegularGridDatastore, - title=None, - vrange=None, - boundary_alpha=0.7, - crop_to_interior=True, - colorbar_label: str = "", + error, datastore: BaseRegularGridDatastore, title=None, vrange=None ): - """Plot spatial error with projection-aware axes.""" - - error_np = error.detach().cpu().numpy() - + """ + Plot errors over spatial map + Error and obs_mask has shape (N_grid,) + """ + # Get common scale for values if vrange is None: - vmin = float(np.nanmin(error_np)) - vmax = float(np.nanmax(error_np)) + vmin = error.min().cpu().item() + vmax = error.max().cpu().item() else: vmin, vmax = vrange + extent = datastore.get_xy_extent("state") + + # Set up masking of border region + da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) + mask_reshaped = da_mask.values + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region + fig, ax = plt.subplots( - figsize=(6.5, 6), + figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, ) - mesh = plot_on_axis( - ax=ax, - da=xr.DataArray(error_np), - datastore=datastore, + ax.coastlines() # Add coastline outlines + error_grid = ( + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() + .numpy() + ) + + im = ax.imshow( + error_grid, + origin="lower", + extent=extent, + alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", - boundary_alpha=boundary_alpha, - crop_to_interior=crop_to_interior, ) - cbar = fig.colorbar( - mesh, - ax=ax, - orientation="horizontal", - location="bottom", - shrink=0.8, - pad=0.02, - ) - cbar.ax.tick_params(labelsize=_TICK_SIZE) + # Ticks and labels + cbar = fig.colorbar(im, aspect=30) + cbar.ax.tick_params(labelsize=10) + cbar.ax.yaxis.get_offset_text().set_fontsize(10) cbar.formatter.set_powerlimits((-3, 3)) - if colorbar_label: - cbar.set_label(_tex_safe(colorbar_label), size=_LABEL_SIZE) if title: - fig.suptitle(title, size=_TITLE_SIZE) + fig.suptitle(title, size=10) return fig diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5547fdd4e..0ec85a31e 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -8,7 +8,6 @@ import pytorch_lightning as pl import torch import xarray as xr -from loguru import logger # First-party from neural_lam.datastore.base import BaseDatastore @@ -37,11 +36,6 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. - load_single_member : bool, optional - If `False` and the datastore returns an ensemble of state - realisations, treat each state ensemble member as an independent - sample. If `True`, only ensemble member 0 is used. Default is False, - so all members are used when available. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -53,7 +47,6 @@ def __init__( ar_steps: int = 3, num_past_forcing_steps: int = 1, num_future_forcing_steps: int = 1, - load_single_member: bool = False, standardize: bool = True, ): super().__init__() @@ -63,7 +56,6 @@ def __init__( self.datastore = datastore self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps - self.load_single_member = load_single_member self.da_state = self.datastore.get_dataarray( category="state", split=self.split @@ -71,19 +63,6 @@ def __init__( self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) - if self.da_state is None: - raise ValueError( - "The datastore must provide state data for the WeatherDataset." - ) - - if self.datastore.is_ensemble and self.load_single_member: - warnings.warn( - "only using first ensemble member, so dataset size is " - "effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", - UserWarning, - stacklevel=2, - ) # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -104,10 +83,10 @@ def __init__( parts["forcing"] = self.da_forcing for part, da in parts.items(): + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) if da is not None: - expected_dim_order = self.datastore.expected_dim_order( - category=part - ) if da.dims != expected_dim_order: raise ValueError( f"The dimension order of the `{part}` data ({da.dims}) " @@ -135,37 +114,22 @@ def __init__( ) self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std - else: - self.da_forcing_mean = None - self.da_forcing_std = None - - self.state_std_safe = self._compute_std_safe( - self.da_state_std, "state" - ) - - if self.da_forcing_std is not None: - self.forcing_std_safe = self._compute_std_safe( - self.da_forcing_std, "forcing" - ) - else: - self.forcing_std_safe = None - - def _compute_std_safe(self, std: xr.DataArray, feature: str): - eps = np.finfo(std.dtype).eps - if bool((std <= eps).any()): - logger.warning( - f"Some {feature} features have near-zero std and will be " - "standardized using machine epsilon to avoid NaN." - ) - return std.where(std > eps, other=eps) def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time - # and then take the first (2 + ar_steps) forecast times. - # If the datastore returns an ensemble of state realisations and - # `load_single_member=False`, each ensemble member is exposed as an - # independent sample by scaling the base dataset length below. + # and then take the first (2 + ar_steps) forecast times. In + # addition we only use the first ensemble member (if ensemble data + # has been provided). + # This means that for each analysis time we get a single sample + + if self.datastore.is_ensemble: + warnings.warn( + "only using first ensemble member, so dataset size is " + " effectively reduced by the number of ensemble members " + f"({self.da_state.ensemble_member.size})", + UserWarning, + ) # check that there are enough forecast steps available to create # samples given the number of autoregressive steps requested @@ -178,7 +142,7 @@ def __len__(self): "creating a sample with initial and target states." ) - base_len = self.da_state.analysis_time.size + return self.da_state.analysis_time.size else: # Calculate the number of samples in the dataset n_samples = total # time steps - (autoregressive steps + past forcing + future @@ -190,15 +154,12 @@ def __len__(self): # - past forcing: max(2, self.num_past_forcing_steps) (at least 2 # time steps are required for the initial state) # - future forcing: self.num_future_forcing_steps - base_len = ( + return ( len(self.da_state.time) - self.ar_steps - max(2, self.num_past_forcing_steps) - self.num_future_forcing_steps ) - if self.datastore.is_ensemble and not self.load_single_member: - return base_len * self.da_state.ensemble_member.size - return base_len def _slice_state_time(self, da_state, idx, n_steps: int): """ @@ -391,35 +352,40 @@ def _build_item_dataarrays(self, idx): da_target_times : xr.DataArray The dataarray for the target times. """ - # Handle indexing over state ensemble members. If forcing data also - # has an ensemble dimension, we select the same member below. - sample_idx = idx - i_ensemble = 0 - + # handling ensemble data if self.datastore.is_ensemble: - n_ensemble_members = self.da_state.ensemble_member.size - if not self.load_single_member: - sample_idx, i_ensemble = divmod(idx, n_ensemble_members) + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 da_state = self.da_state.isel(ensemble_member=i_ensemble) else: da_state = self.da_state if self.da_forcing is not None: - if self.datastore.has_ensemble_forcing: - da_forcing = self.da_forcing.isel(ensemble_member=i_ensemble) - else: - da_forcing = self.da_forcing + if "ensemble_member" in self.da_forcing.dims: + raise NotImplementedError( + "Ensemble member not yet supported for forcing data" + ) + da_forcing = self.da_forcing else: da_forcing = None # handle time sampling in a way that is compatible with both analysis # and forecast data da_state = self._slice_state_time( - da_state=da_state, idx=sample_idx, n_steps=self.ar_steps + da_state=da_state, idx=idx, n_steps=self.ar_steps ) if da_forcing is not None: da_forcing_windowed = self._slice_forcing_time( - da_forcing=da_forcing, idx=sample_idx, n_steps=self.ar_steps + da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps ) # load the data into memory @@ -434,10 +400,10 @@ def _build_item_dataarrays(self, idx): if self.standardize: da_init_states = ( da_init_states - self.da_state_mean - ) / self.state_std_safe + ) / self.da_state_std da_target_states = ( da_target_states - self.da_state_mean - ) / self.state_std_safe + ) / self.da_state_std if da_forcing is not None: # XXX: Here we implicitly assume that the last dimension of the @@ -446,7 +412,7 @@ def _build_item_dataarrays(self, idx): # tensor with repeated means and stds for each "windowed" time.) da_forcing_windowed = ( da_forcing_windowed - self.da_forcing_mean - ) / self.forcing_std_safe + ) / self.da_forcing_std if da_forcing is not None: # stack the `forcing_feature` and `window_sample` dimensions into a @@ -646,7 +612,6 @@ def __init__( standardize: bool = True, num_past_forcing_steps: int = 1, num_future_forcing_steps: int = 1, - load_single_member: bool = False, batch_size: int = 4, num_workers: int = 16, eval_split: str = "test", @@ -658,7 +623,6 @@ def __init__( self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize - self.load_single_member = load_single_member self.batch_size = batch_size self.num_workers: int = num_workers self.train_dataset = None @@ -680,7 +644,6 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, - load_single_member=self.load_single_member, ) self.val_dataset = WeatherDataset( datastore=self._datastore, @@ -689,7 +652,6 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, - load_single_member=self.load_single_member, ) if stage == "test" or stage is None: @@ -700,7 +662,6 @@ def setup(self, stage=None): standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, - load_single_member=self.load_single_member, ) def train_dataloader(self): @@ -712,7 +673,6 @@ def train_dataloader(self): shuffle=True, multiprocessing_context=self.multiprocessing_context, persistent_workers=self.num_workers > 0, - pin_memory=torch.cuda.is_available(), ) def val_dataloader(self): @@ -724,7 +684,6 @@ def val_dataloader(self): shuffle=False, multiprocessing_context=self.multiprocessing_context, persistent_workers=self.num_workers > 0, - pin_memory=torch.cuda.is_available(), ) def test_dataloader(self): @@ -736,5 +695,4 @@ def test_dataloader(self): shuffle=False, multiprocessing_context=self.multiprocessing_context, persistent_workers=self.num_workers > 0, - pin_memory=torch.cuda.is_available(), ) diff --git a/pyproject.toml b/pyproject.toml index 0a7e8ec1a..55cd7642f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ authors = [ { name = "Daniel Holmberg", email = "daniel.holmberg@helsinki.fi" }, ] readme = "README.md" -license = { text = "MIT" } # PEP 621 project metadata # See https://www.python.org/dev/peps/pep-0621/ @@ -43,6 +42,9 @@ requires-python = ">=3.10" [dependency-groups] dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +[tool.setuptools] +py-modules = ["neural_lam"] + [tool.black] line-length = 80 @@ -107,25 +109,11 @@ allow-any-import-level = "neural_lam" [tool.pylint.SIMILARITIES] min-similarity-lines = 10 -[build-system] -requires = ["hatchling>=1.27.0", "hatch-vcs"] -build-backend = "hatchling.build" - -[tool.hatch.metadata] -core-metadata-version = "2.4" -[tool.hatch.version] -source = "vcs" -fallback-version = "0.0.0" +[tool.pdm.version] +source = "scm" +fallback_version = "0.0.0" -[tool.hatch.build.targets.sdist] -exclude = [ - ".venv/", - "venv/", -] - -[tool.hatch.build.targets.wheel] -exclude = [ - ".venv/", - "venv/", -] +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" diff --git a/tests/conftest.py b/tests/conftest.py index 47237ed55..3edd8a862 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,7 @@ # Third-party import pooch -import pytest import yaml -from pytorch_lightning.utilities import rank_zero_only # First-party from neural_lam.datastore import DATASTORES, init_datastore @@ -22,14 +20,6 @@ # and to avoid having to deal with authentication os.environ["WANDB_MODE"] = "disabled" - -@pytest.fixture(autouse=True) -def ensure_rank_zero(monkeypatch): - """Ensure rank_zero_only.rank == 0 so @rank_zero_only-decorated functions - execute their body regardless of state left by prior training tests.""" - monkeypatch.setattr(rank_zero_only, "rank", 0, raising=False) - - DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples") # Initializing variables for the s3 client @@ -101,7 +91,7 @@ def download_meps_example_reduced_dataset(): / "danra_100m_winds" / "danra.datastore.yaml" ), - npyfilesmeps=None, + npyfilesmeps=download_meps_example_reduced_dataset(), dummydata=None, ) @@ -109,14 +99,6 @@ def download_meps_example_reduced_dataset(): def init_datastore_example(datastore_kind): - if ( - datastore_kind == "npyfilesmeps" - and DATASTORES_EXAMPLES["npyfilesmeps"] is None - ): - DATASTORES_EXAMPLES["npyfilesmeps"] = ( - download_meps_example_reduced_dataset() - ) - datastore = init_datastore( datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 3a844d6d9..b29e4c217 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -14,7 +14,6 @@ # First-party from neural_lam.datastore.base import ( - BaseDatastore, BaseRegularGridDatastore, CartesianGridShape, ) @@ -31,7 +30,7 @@ class DummyDatastore(BaseRegularGridDatastore): SHORT_NAME = "dummydata" T0 = isodate.parse_datetime("2021-01-01T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) - spatial_coordinates = ("x", "y") + CARTESIAN_COORDS = ["x", "y"] # center the domain on Denmark latlon_center = [56, 10] # latitude, longitude @@ -166,9 +165,7 @@ def __init__( ) # Stack the spatial dimensions into grid_index - self.ds = self.ds.stack(grid_index=self.spatial_coordinates) - self.is_ensemble = "ensemble_member" in self.ds["state"].dims - self.has_ensemble_forcing = "ensemble_member" in self.ds["forcing"].dims + self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) # Create temporary directory for storing derived files self._tempdir = tempfile.TemporaryDirectory() @@ -337,11 +334,8 @@ def get_dataarray( elapsed_forecast_duration)` dimensions if `is_forecast` is True, or `(time)` if `is_forecast` is False. - If we have multiple ensemble members of state data, the returned state - dataarray is expected to have an additional `ensemble_member` - dimension. If `has_ensemble_forcing=True`, the returned forcing - dataarray is expected to have an additional `ensemble_member` - dimension. + If the data is ensemble data, the dataarray is expected to have an + additional `ensemble_member` dimension. Parameters ---------- @@ -359,6 +353,7 @@ def get_dataarray( """ dim_order = self.expected_dim_order(category=category) + da_category = self.ds[category].transpose(*dim_order) if standardize: @@ -414,7 +409,7 @@ def get_xy(self, category: str, stacked: bool) -> ndarray: da_xy = xr.concat([da_x, da_y], dim="grid_coord") if stacked: - da_xy = da_xy.stack(grid_index=self.spatial_coordinates).transpose( + da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose( "grid_index", "grid_coord", ) @@ -471,290 +466,3 @@ def grid_shape_state(self) -> CartesianGridShape: n_points_1d = int(np.sqrt(self.num_grid_points)) return CartesianGridShape(x=n_points_1d, y=n_points_1d) - - -class EnsembleDummyDatastore(BaseDatastore): - """Small offline datastore for ensemble WeatherDataset tests. - - Generates synthetic ensemble data (state + forcing) for both analysis - and forecast modes, with configurable ensemble-member count. Values - are deterministic functions of the axis indices so that tests can - assert exact numeric expectations. - """ - - T0 = np.datetime64("2021-01-01T00:00:00") - - def __init__( - self, - *, - is_forecast: bool = False, - forcing_has_ensemble: bool = False, - n_ensemble_members: int = 3, - n_timesteps: int = 10, - n_analysis_times: int = 4, - n_forecast_steps: int = 6, - ): - self.is_forecast = is_forecast - self._forcing_has_ensemble = forcing_has_ensemble - self._step_length = timedelta(hours=1) - self._root_path = Path(".") - - self._state_feature = np.array(["state_feat_0"], dtype=object) - self._forcing_feature = np.array(["forcing_feat_0"], dtype=object) - self._grid_index = np.array([0], dtype=int) - self._ensemble_member = np.arange(n_ensemble_members, dtype=int) - - step_ns = np.timedelta64(int(self._step_length.total_seconds()), "s") - - if is_forecast: - self._init_forecast_data( - n_analysis_times, - n_forecast_steps, - n_ensemble_members, - step_ns, - forcing_has_ensemble, - ) - else: - self._init_analysis_data( - n_timesteps, - n_ensemble_members, - step_ns, - forcing_has_ensemble, - ) - self.is_ensemble = "ensemble_member" in self._da_state.dims - self.has_ensemble_forcing = "ensemble_member" in self._da_forcing.dims - - # ---- data initialisation helpers ---------------------------------------- - - def _init_forecast_data( - self, - n_analysis_times, - n_forecast_steps, - n_ensemble_members, - step_ns, - forcing_has_ensemble, - ): - analysis_time = ( - self.T0 + np.arange(n_analysis_times) * step_ns - ).astype("datetime64[ns]") - elapsed = (np.arange(n_forecast_steps) * step_ns).astype( - "timedelta64[ns]" - ) - - analysis_axis = np.arange(n_analysis_times).reshape(-1, 1, 1, 1, 1) - forecast_axis = np.arange(n_forecast_steps).reshape(1, -1, 1, 1, 1) - ensemble_axis = np.arange(n_ensemble_members).reshape(1, 1, -1, 1, 1) - - state_values = ( - analysis_axis * 1000 + forecast_axis * 10 + ensemble_axis - ).astype(np.float32) - self._da_state = xr.DataArray( - state_values, - dims=( - "analysis_time", - "elapsed_forecast_duration", - "ensemble_member", - "grid_index", - "state_feature", - ), - coords={ - "analysis_time": analysis_time, - "elapsed_forecast_duration": elapsed, - "ensemble_member": self._ensemble_member, - "grid_index": self._grid_index, - "state_feature": self._state_feature, - }, - ) - - if forcing_has_ensemble: - forcing_values = ( - 10000 - + analysis_axis * 1000 - + forecast_axis * 10 - + ensemble_axis - ).astype(np.float32) - self._da_forcing = xr.DataArray( - forcing_values, - dims=( - "analysis_time", - "elapsed_forecast_duration", - "ensemble_member", - "grid_index", - "forcing_feature", - ), - coords={ - "analysis_time": analysis_time, - "elapsed_forecast_duration": elapsed, - "ensemble_member": self._ensemble_member, - "grid_index": self._grid_index, - "forcing_feature": self._forcing_feature, - }, - ) - else: - analysis_axis_ne = np.arange(n_analysis_times).reshape(-1, 1, 1, 1) - forecast_axis_ne = np.arange(n_forecast_steps).reshape(1, -1, 1, 1) - forcing_values = ( - 20000 + analysis_axis_ne * 1000 + forecast_axis_ne * 10 - ).astype(np.float32) - self._da_forcing = xr.DataArray( - forcing_values, - dims=( - "analysis_time", - "elapsed_forecast_duration", - "grid_index", - "forcing_feature", - ), - coords={ - "analysis_time": analysis_time, - "elapsed_forecast_duration": elapsed, - "grid_index": self._grid_index, - "forcing_feature": self._forcing_feature, - }, - ) - - def _init_analysis_data( - self, - n_timesteps, - n_ensemble_members, - step_ns, - forcing_has_ensemble, - ): - time = (self.T0 + np.arange(n_timesteps) * step_ns).astype( - "datetime64[ns]" - ) - time_axis = np.arange(n_timesteps).reshape(-1, 1, 1, 1) - ensemble_axis = np.arange(n_ensemble_members).reshape(1, -1, 1, 1) - - state_values = (time_axis * 100 + ensemble_axis).astype(np.float32) - self._da_state = xr.DataArray( - state_values, - dims=("time", "ensemble_member", "grid_index", "state_feature"), - coords={ - "time": time, - "ensemble_member": self._ensemble_member, - "grid_index": self._grid_index, - "state_feature": self._state_feature, - }, - ) - - if forcing_has_ensemble: - forcing_values = (10000 + time_axis * 100 + ensemble_axis).astype( - np.float32 - ) - self._da_forcing = xr.DataArray( - forcing_values, - dims=( - "time", - "ensemble_member", - "grid_index", - "forcing_feature", - ), - coords={ - "time": time, - "ensemble_member": self._ensemble_member, - "grid_index": self._grid_index, - "forcing_feature": self._forcing_feature, - }, - ) - else: - time_axis_ne = np.arange(n_timesteps).reshape(-1, 1, 1) - forcing_values = (20000 + time_axis_ne * 100).astype(np.float32) - self._da_forcing = xr.DataArray( - forcing_values, - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": time, - "grid_index": self._grid_index, - "forcing_feature": self._forcing_feature, - }, - ) - - # ---- BaseDatastore interface -------------------------------------------- - - @property - def root_path(self) -> Path: - return self._root_path - - @property - def config(self) -> dict: - return {} - - @property - def step_length(self) -> timedelta: - return self._step_length - - def get_vars_units(self, category: str) -> list[str]: - return ["-"] - - def get_vars_names(self, category: str) -> list[str]: - if category == "state": - return self._state_feature.tolist() - if category == "forcing": - return self._forcing_feature.tolist() - if category == "static": - return ["static_feat_0"] - raise NotImplementedError(category) - - def get_vars_long_names(self, category: str) -> list[str]: - return self.get_vars_names(category=category) - - def get_num_data_vars(self, category: str) -> int: - return len(self.get_vars_names(category=category)) - - def get_standardization_dataarray(self, category: str) -> xr.Dataset: - ds = xr.Dataset() - feat_name = f"{category}_feature" - coords = {feat_name: self.get_vars_names(category=category)} - ds[f"{category}_mean"] = xr.DataArray( - [0.0], dims=[feat_name], coords=coords - ) - ds[f"{category}_std"] = xr.DataArray( - [1.0], dims=[feat_name], coords=coords - ) - if category == "state": - ds["state_diff_mean_standardized"] = xr.DataArray( - [0.0], - dims=["state_feature"], - coords={"state_feature": self._state_feature}, - ) - ds["state_diff_std_standardized"] = xr.DataArray( - [1.0], - dims=["state_feature"], - coords={"state_feature": self._state_feature}, - ) - return ds - - def get_dataarray( - self, category: str, split: Optional[str], standardize: bool = False - ) -> Union[xr.DataArray, None]: - if category == "state": - da = self._da_state - elif category == "forcing": - da = self._da_forcing - else: - return None - - if standardize: - return self._standardize_datarray(da=da, category=category) - return da - - @property - def boundary_mask(self) -> xr.DataArray: - return xr.DataArray( - [0], dims=("grid_index",), coords={"grid_index": [0]} - ) - - def get_xy(self, category: str, stacked: bool) -> np.ndarray: - return np.array([[0.0, 0.0]]) - - @property - def coords_projection(self): - return None - - @property - def num_grid_points(self) -> int: - return 1 - - @property - def state_feature_weights_values(self) -> list[float]: - return [1.0] diff --git a/tests/test_clamping.py b/tests/test_clamping.py index f3f9365d0..09a46ffdf 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -57,9 +57,16 @@ class ModelArgs: ) model = GraphLAM( - args=model_args, - datastore=datastore, config=config, + datastore=datastore, + graph_name=model_args.graph, + hidden_dim=model_args.hidden_dim, + hidden_layers=model_args.hidden_layers, + processor_layers=model_args.processor_layers, + mesh_aggr=model_args.mesh_aggr, + num_past_forcing_steps=model_args.num_past_forcing_steps, + num_future_forcing_steps=model_args.num_future_forcing_steps, + output_std=model_args.output_std, ) features = datastore.get_vars_names(category="state") diff --git a/tests/test_cli.py b/tests/test_cli.py index 1be4db149..0dbd04a11 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,3 @@ -# Standard library -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from unittest.mock import MagicMock, patch - -# Third-party -import pytest - # First-party import neural_lam import neural_lam.create_graph @@ -17,103 +10,3 @@ def test_import(): assert neural_lam is not None assert neural_lam.create_graph is not None assert neural_lam.train_model is not None - - -# --- Argument parsing tests --------------------------------------------------- - - -def _make_parser(): - """Minimal parser mirroring train_model's --config_path and --wandb_id.""" - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument("--config_path", type=str, required=True) - parser.add_argument("--wandb_id", type=str, default=None) - return parser - - -@pytest.mark.parametrize( - "extra_args, expected_wandb_id", - [ - ([], None), - (["--wandb_id", "abc123xyz"], "abc123xyz"), - ], -) -def test_wandb_id_parsed(extra_args, expected_wandb_id): - """--wandb_id defaults to None and is parsed correctly when provided.""" - args = _make_parser().parse_args( - ["--config_path", "dummy.yaml"] + extra_args - ) - assert args.wandb_id == expected_wandb_id - - -def test_wandb_resume_not_exposed(): - """--wandb_resume must not exist as a CLI argument in train_model.""" - with pytest.raises(SystemExit): - _make_parser().parse_args( - ["--config_path", "dummy.yaml", "--wandb_resume", "allow"] - ) - - -# --- setup_training_logger tests ---------------------------------------------- - - -def _make_args(wandb_id=None): - args = MagicMock() - args.logger = "wandb" - args.logger_project = "neural_lam" - args.wandb_id = wandb_id - return args - - -@pytest.mark.parametrize( - "wandb_id, expected_resume, expected_id, expected_name", - [ - (None, None, None, "my-run"), - ("abc123", "allow", "abc123", None), - ], -) -@patch("neural_lam.utils.pl.loggers.WandbLogger") -def test_wandb_logger_kwargs( - mock_wandb, wandb_id, expected_resume, expected_id, expected_name -): - """WandbLogger is called with the correct resume, id, and name kwargs.""" - # First-party - from neural_lam.utils import setup_training_logger - - args = _make_args(wandb_id=wandb_id) - datastore = MagicMock() - datastore._config = {} - - setup_training_logger(datastore, args, run_name="my-run") - - _, kwargs = mock_wandb.call_args - assert kwargs["resume"] == expected_resume - assert kwargs["id"] == expected_id - assert kwargs["name"] == expected_name - - -def test_wandb_id_ignored_with_mlflow_warns(): - """--wandb_id is ignored when logger=mlflow and a warning is emitted.""" - # First-party - from neural_lam.utils import setup_training_logger - - args = MagicMock() - args.logger = "mlflow" - args.logger_project = "neural_lam" - args.wandb_id = "abc123" - - datastore = MagicMock() - datastore._config = {} - - with ( - patch("neural_lam.utils.CustomMLFlowLogger"), - patch.dict( - "os.environ", {"MLFLOW_TRACKING_URI": "http://localhost:5000"} - ), - patch("neural_lam.utils.logger") as mock_log, - ): - setup_training_logger(datastore, args, run_name="my-run") - - mock_log.warning.assert_called_once() - warning_msg = mock_log.warning.call_args[0][0] - assert "--wandb_id is set but logger is" in warning_msg - assert "mlflow" in warning_msg diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dd863b657..e4f3ad113 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,10 +12,10 @@ from neural_lam.create_graph import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataset from tests.conftest import init_datastore_example -from tests.dummy_datastore import DummyDatastore, EnsembleDummyDatastore +from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) @@ -182,6 +182,10 @@ class ModelArgs: hidden_layers = 1 processor_layers = 2 mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + var_leads_metrics_watch = {} num_past_forcing_steps = 1 num_future_forcing_steps = 1 @@ -212,13 +216,42 @@ def _create_graph(): dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) - model = GraphLAM(args=args, datastore=datastore, config=config) # noqa + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + ) + forecaster = ARForecaster(predictor, datastore=datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss=args.loss, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + lr=args.lr, + ) model_device = model.to(device_name) data_loader = DataLoader(dataset, batch_size=2) batch = next(iter(data_loader)) batch_device = [part.to(device_name) for part in batch] - model_device.common_step(batch_device) model_device.training_step(batch_device) @@ -259,240 +292,3 @@ def test_dataset_length(dataset_config): # Check that we can actually get last and first sample dataset[0] dataset[expected_len - 1] - - -def test_ensemble_len_scales_with_default_all_members(): - datastore = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=False, - n_ensemble_members=3, - n_timesteps=10, - ) - - dataset_all = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - standardize=False, - ) - - dataset_single = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - load_single_member=True, - standardize=False, - ) - - assert len(dataset_all) == len(dataset_single) * 3 - - -def test_expected_dim_order_handles_optional_ensemble_forcing(): - datastore_with_ensemble_forcing = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=True, - n_ensemble_members=3, - n_timesteps=10, - ) - - datastore_without_ensemble_forcing = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=False, - n_ensemble_members=3, - n_timesteps=10, - ) - - assert datastore_with_ensemble_forcing.is_ensemble is True - assert datastore_with_ensemble_forcing.has_ensemble_forcing is True - assert datastore_without_ensemble_forcing.is_ensemble is True - assert datastore_without_ensemble_forcing.has_ensemble_forcing is False - - assert datastore_with_ensemble_forcing.expected_dim_order( - category="forcing" - ) == ("time", "ensemble_member", "grid_index", "forcing_feature") - assert datastore_without_ensemble_forcing.expected_dim_order( - category="forcing" - ) == ("time", "grid_index", "forcing_feature") - assert datastore_with_ensemble_forcing.expected_dim_order( - category="static" - ) == ("grid_index", "static_feature") - - -def test_ensemble_index_mapping_is_time_major(): - datastore = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=False, - n_ensemble_members=3, - n_timesteps=10, - ) - dataset = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - load_single_member=False, - standardize=False, - ) - - init_states_0, _, _, target_times_0 = dataset[0] - init_states_1, _, _, target_times_1 = dataset[1] - - # Adjacent flat indices correspond to same sample_idx and different member. - assert torch.equal(target_times_0, target_times_1) - assert not torch.equal(init_states_0, init_states_1) - - -def test_ensemble_forcing_uses_same_member_when_available(): - datastore = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=True, - n_ensemble_members=3, - n_timesteps=10, - ) - dataset = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - load_single_member=False, - standardize=False, - ) - - _, _, forcing_0, target_times_0 = dataset[0] - _, _, forcing_1, target_times_1 = dataset[1] - - assert torch.equal(target_times_0, target_times_1) - assert not torch.equal(forcing_0, forcing_1) - - -def test_ensemble_forcing_without_member_dim_is_shared(): - datastore = EnsembleDummyDatastore( - is_forecast=False, - forcing_has_ensemble=False, - n_ensemble_members=3, - n_timesteps=10, - ) - dataset = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - load_single_member=False, - standardize=False, - ) - - init_states_0, _, forcing_0, target_times_0 = dataset[0] - init_states_1, _, forcing_1, target_times_1 = dataset[1] - - assert torch.equal(target_times_0, target_times_1) - assert not torch.equal(init_states_0, init_states_1) - assert torch.equal(forcing_0, forcing_1) - - -def test_forecast_ensemble_len_scales_with_default_all_members(): - datastore = EnsembleDummyDatastore( - is_forecast=True, - forcing_has_ensemble=True, - n_ensemble_members=3, - n_analysis_times=4, - n_forecast_steps=6, - ) - - dataset_all = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - standardize=False, - ) - - with pytest.warns(UserWarning, match="only using first ensemble member"): - dataset_single = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=2, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - load_single_member=True, - standardize=False, - ) - - assert len(dataset_all) == len(dataset_single) * 3 - - -def test_standardization_with_zero_std(): - """Regression test for https://github.com/mllam/neural-lam/issues/136 - - When all values of a field are identical (std = 0), WeatherDataset - must not produce NaN via division-by-zero during standardization. - """ - # Third-party - import xarray as xr - - std_da = xr.DataArray( - np.array([0.0, 1.0, 2.0], dtype=np.float32), dims=["feature"] - ) - - dataset = WeatherDataset.__new__(WeatherDataset) - result = dataset._compute_std_safe(std_da, "state") - - eps = np.finfo(std_da.dtype).eps - - assert ( - float(result[0]) == eps - ), "Zero std was not clamped to machine epsilon" - assert float(result[1]) == 1.0 - assert float(result[2]) == 2.0 - assert not np.isnan( - result.values - ).any(), "NaN found after _compute_std_safe" - - -def test_weather_dataset_no_forcing_standardize(): - """Regression test: WeatherDataset must not raise AttributeError when the - datastore has no forcing data and standardize=True (the default). - - Before the fix, self.da_forcing_std was accessed at line 123 of - weather_dataset.py without ever being assigned when da_forcing is None, - causing: - AttributeError: 'WeatherDataset' object has no attribute - 'da_forcing_std' - """ - - class NoForcingDatastore(DummyDatastore): - """DummyDatastore that returns None for the forcing category.""" - - def get_dataarray(self, category, split, **kwargs): - if category == "forcing": - return None - return super().get_dataarray( - category=category, split=split, **kwargs - ) - - datastore = NoForcingDatastore(n_grid_points=100, n_timesteps=20) - - # Should not raise AttributeError - dataset = WeatherDataset( - datastore=datastore, - split="train", - ar_steps=3, - standardize=True, - ) - - assert dataset.forcing_std_safe is None - assert dataset.da_forcing_mean is None - assert dataset.da_forcing_std is None - - # Ensure we can still retrieve a sample (forcing tensor should be empty) - init_states, target_states, forcing, target_times = dataset[0] - assert ( - forcing.shape[-1] == 0 - ), "Expected zero forcing features when forcing is None" diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4ce3875ea..5d85c4d73 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -198,10 +198,7 @@ def test_get_dataarray(datastore_name): ] if datastore.is_ensemble and category == "state": - expected_dims.append("ensemble_member") - elif category == "forcing" and getattr( - datastore, "has_ensemble_forcing", False - ): + # assume that only state variables change with ensemble members expected_dims.append("ensemble_member") # XXX: for now we only have a single attribute to get the shape of @@ -292,35 +289,6 @@ def test_get_xy(datastore_name): assert xy_unstacked.shape[1] == ny assert xy_unstacked.shape[2] == 2 - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) - def test_get_lat_lon(datastore_name): - """Check that the `datastore.get_lat_lon` helper returns valid - shapes and values.""" - datastore = init_datastore_example(datastore_name) - - if not isinstance(datastore, BaseRegularGridDatastore): - pytest.skip( - "Datastore does not implement `BaseCartesianDatastore`" - ) - - nx, ny = datastore.grid_shape_state.x, datastore.grid_shape_state.y - - lonlat = datastore.get_lat_lon(category="state") - lonlat_grid = lonlat.reshape(nx, ny, 2) - - assert lonlat.shape == (nx * ny, 2) - assert lonlat_grid.shape == (nx, ny, 2) - assert np.isfinite(lonlat).all() - assert np.isfinite(lonlat_grid).all() - - lon = lonlat[:, 0] - lat = lonlat[:, 1] - - assert np.all((lat >= -90.0) & (lat <= 90.0)) - lon_in_180 = (lon >= -180.0) & (lon <= 180.0) - lon_in_360 = (lon >= 0.0) & (lon <= 360.0) - assert np.all(lon_in_180 | lon_in_360) - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_projection(datastore_name): diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 93a7a55f4..db0597aec 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -84,7 +84,7 @@ def test_graph_creation(datastore_name, graph_name): # try to load each and ensure they have the right shape for file_name in required_graph_files: file_id = Path(file_name).stem # remove the extension - result = torch.load(graph_dir_path / file_name) + result = torch.load(graph_dir_path / file_name, weights_only=True) if file_id.startswith("g2m") or file_id.startswith("m2g"): assert isinstance(result, torch.Tensor) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 84b35a308..c4549566e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1,25 +1,19 @@ # Standard library from datetime import timedelta from pathlib import Path -from typing import Iterator -from unittest.mock import patch # Third-party -import matplotlib.figure import matplotlib.pyplot as plt import numpy as np import pytest import torch -import xarray as xr -from cartopy import crs as ccrs # First-party from neural_lam import config as nlconfig from neural_lam import vis from neural_lam.create_graph import create_graph_from_datastore -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example from tests.dummy_datastore import DummyDatastore # Create output directory for test figures @@ -27,171 +21,6 @@ TEST_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) -@pytest.fixture(autouse=True) -def mock_cartopy_downloads(monkeypatch: pytest.MonkeyPatch) -> None: - """ - Prevent cartopy from downloading Natural Earth map data during tests. - Monkeypatches the GeoAxes methods used in vis.plot_on_axis. - """ - # Third-party - from cartopy.mpl.geoaxes import GeoAxes - - monkeypatch.setattr(GeoAxes, "coastlines", lambda *args, **kwargs: None) - monkeypatch.setattr(GeoAxes, "add_feature", lambda *args, **kwargs: None) - - -@pytest.fixture(autouse=True) -def close_all_figures_after_test() -> Iterator[None]: - """Ensure test-created matplotlib figures are always cleaned up.""" - yield - plt.close("all") - - -def test_plot_prediction() -> None: - """Check prediction plot structure, titles and shared color scaling.""" - datastore = init_datastore_example("dummydata") - n_grid = datastore.num_grid_points - - da_pred = xr.DataArray(np.linspace(0.0, 1.0, n_grid)) - da_target = xr.DataArray(np.linspace(1.0, 2.0, n_grid)) - - expected_vmin = float(np.nanmin([da_pred.values, da_target.values])) - expected_vmax = float(np.nanmax([da_pred.values, da_target.values])) - - fig = vis.plot_prediction( - datastore=datastore, - da_prediction=da_pred, - da_target=da_target, - title="Test Prediction", - vrange=(expected_vmin, expected_vmax), - boundary_alpha=None, - crop_to_interior=False, - ) - - assert isinstance(fig, matplotlib.figure.Figure) - assert len(fig.axes) == 3 - - ground_truth_ax, prediction_ax, _ = fig.axes - assert ground_truth_ax.get_title() == "Ground Truth" - assert prediction_ax.get_title() == "Prediction" - assert fig._suptitle.get_text() == "Test Prediction" - - assert len(ground_truth_ax.collections) == 1 - assert len(prediction_ax.collections) == 1 - - assert ground_truth_ax.collections[0].norm.vmin == expected_vmin - assert ground_truth_ax.collections[0].norm.vmax == expected_vmax - assert prediction_ax.collections[0].norm.vmin == expected_vmin - assert prediction_ax.collections[0].norm.vmax == expected_vmax - - -def test_plot_error_map() -> None: - """Check error heatmap content, labels and annotations.""" - datastore = init_datastore_example("dummydata") - d_f = len(datastore.get_vars_names(category="state")) - pred_steps = 4 - - errors = torch.arange(1, pred_steps * d_f + 1, dtype=torch.float32).reshape( - pred_steps, d_f - ) - - fig = vis.plot_error_map( - errors=errors, - datastore=datastore, - title="Test Error Map", - ) - - assert isinstance(fig, matplotlib.figure.Figure) - assert len(fig.axes) == 1 - - ax = fig.axes[0] - assert len(ax.images) == 1 - assert ax.images[0].get_array().shape == (d_f, pred_steps) - assert ax.get_xlabel() == "Lead time (h)" - assert ax.get_title() == "Test Error Map" - - expected_x_ticklabels = [str(step) for step in range(1, pred_steps + 1)] - actual_x_ticklabels = [tick.get_text() for tick in ax.get_xticklabels()] - assert actual_x_ticklabels == expected_x_ticklabels - - var_names = datastore.get_vars_names(category="state") - var_units = datastore.get_vars_units(category="state") - expected_y_ticklabels = [ - f"{name} ({unit})" for name, unit in zip(var_names, var_units) - ] - actual_y_ticklabels = [tick.get_text() for tick in ax.get_yticklabels()] - assert actual_y_ticklabels == expected_y_ticklabels - - assert len(ax.texts) == pred_steps * d_f - - -def test_plot_spatial_error() -> None: - """Check that plot_spatial_error runs without error and returns a Figure.""" - datastore = init_datastore_example("dummydata") - n_grid = datastore.num_grid_points - - error = torch.linspace(0.0, 1.0, n_grid) - - fig = vis.plot_spatial_error( - error=error, - datastore=datastore, - title="Test Spatial Error", - boundary_alpha=None, - crop_to_interior=False, - ) - - assert isinstance(fig, matplotlib.figure.Figure) - # GeoAxes + colorbar axes - assert len(fig.axes) == 2 - assert fig.texts[0].get_text() == "Test Spatial Error" - - -def test_plot_spatial_error_crop_to_interior_changes_extent() -> None: - """Check interior cropping forwards interior lon/lat bounds to - set_extent.""" - datastore = init_datastore_example("dummydata") - n_grid = datastore.num_grid_points - grid_shape = (datastore.grid_shape_state.x, datastore.grid_shape_state.y) - - boundary_mask = np.ones(grid_shape, dtype=int) - boundary_mask[2:-2, 2:-2] = 0 - datastore.ds["boundary_mask"] = xr.DataArray( - boundary_mask.reshape(n_grid), dims=["grid_index"] - ) - datastore.__dict__.pop("boundary_mask", None) - - lats_lons = datastore.get_lat_lon("state") - lons = lats_lons[:, 0].reshape(grid_shape) - lats = lats_lons[:, 1].reshape(grid_shape) - interior = boundary_mask == 0 - - expected_min_lon = float(lons[interior].min()) - expected_max_lon = float(lons[interior].max()) - expected_min_lat = float(lats[interior].min()) - expected_max_lat = float(lats[interior].max()) - - error = torch.linspace(0.0, 1.0, n_grid) - with patch( - "cartopy.mpl.geoaxes.GeoAxes.set_extent", autospec=True - ) as set_extent_mock: - vis.plot_spatial_error( - error=error, - datastore=datastore, - boundary_alpha=None, - crop_to_interior=True, - ) - - assert set_extent_mock.call_count == 1 - called_extent = set_extent_mock.call_args.args[1] - called_crs = set_extent_mock.call_args.kwargs["crs"] - - assert called_extent[0] == pytest.approx(expected_min_lon) - assert called_extent[1] == pytest.approx(expected_max_lon) - assert called_extent[2] == pytest.approx(expected_min_lat) - assert called_extent[3] == pytest.approx(expected_max_lat) - assert isinstance(called_crs, ccrs.PlateCarree) - - @pytest.fixture def model_and_batch(tmp_path, time_step, time_unit): """Setup a model and dataset for testing plot_examples""" @@ -237,10 +66,37 @@ class ModelArgs: ) # Create model - model = GraphLAM( - args=ModelArgs(), + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + args = ModelArgs() + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( config=config, datastore=datastore, + graph_name=args.graph, + hidden_dim=args.hidden_dim, + hidden_layers=args.hidden_layers, + processor_layers=args.processor_layers, + mesh_aggr=args.mesh_aggr, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + output_std=args.output_std, + ) + forecaster = ARForecaster(predictor, datastore=datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss=args.loss, + restore_opt=args.restore_opt, + n_example_pred=args.n_example_pred, + val_steps_to_log=args.val_steps_to_log, + metrics_watch=args.metrics_watch, + var_leads_metrics_watch=args.var_leads_metrics_watch, + lr=args.lr, ) # Create dataset to get a sample batch @@ -287,11 +143,26 @@ def test_plot_examples_integration_saves_figure( ), f"Expected time_step_unit={time_unit}, got {model.time_step_unit}" # Generate prediction - prediction, target, _, _ = model.common_step(batch) + (init_states, target, forcing_features, _batch_times) = batch + prediction, _ = model.forecaster( + init_states, forcing_features, target + ) # Rescale to original data scale - prediction_rescaled = prediction * model.state_std + model.state_mean - target_rescaled = target * model.state_std + model.state_mean + da_state_stats = datastore.get_standardization_dataarray("state") + state_std = torch.tensor( + da_state_stats.state_std.values, + dtype=torch.float32, + device=prediction.device, + ) + state_mean = torch.tensor( + da_state_stats.state_mean.values, + dtype=torch.float32, + device=prediction.device, + ) + + prediction_rescaled = prediction * state_std + state_mean + target_rescaled = target * state_std + state_mean # Get first example pred_slice = prediction_rescaled[0].detach() # Detach from graph @@ -336,8 +207,8 @@ def test_plot_examples_integration_saves_figure( fig = vis.plot_prediction( datastore=datastore, - title=f"{var_names[0]}, t={t_i} ({time_step * t_i} {time_unit})", - colorbar_label=var_units[0], + title=f"{var_names[0]} ({var_units[0]}), t={t_i}" + f"({(time_step * t_i)} {time_unit})", vrange=var_vranges[0], da_prediction=da_prediction.isel( state_feature=0, time=t_i - 1 diff --git a/tests/test_prediction_model_classes.py b/tests/test_prediction_model_classes.py new file mode 100644 index 000000000..4e7860e7a --- /dev/null +++ b/tests/test_prediction_model_classes.py @@ -0,0 +1,328 @@ +# Third-party +import pytorch_lightning as pl +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.models.ar_forecaster import ARForecaster +from neural_lam.models.forecaster_module import ForecasterModule +from neural_lam.models.step_predictor import StepPredictor +from tests.conftest import init_datastore_example +from tests.dummy_datastore import DummyDatastore + + +class NoStaticDummyDatastore(DummyDatastore): + """DummyDatastore variant that returns None for static features.""" + + def get_dataarray(self, category, split, standardize=False): + if category == "static": + return None + return super().get_dataarray(category, split, standardize=standardize) + + +class MockStepPredictor(StepPredictor): + def __init__(self, config, datastore, **kwargs): + super().__init__(config, datastore, **kwargs) + + def forward(self, prev_state, prev_prev_state, forcing): + # Return zeros for state + # The true state will be mixed in at boundaries + pred_state = torch.zeros_like(prev_state) + pred_std = torch.zeros_like(prev_state) if self.output_std else None + return pred_state, pred_std + + +def test_ar_forecaster_unroll(): + datastore = init_datastore_example("mdp") + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + predictor = MockStepPredictor( + config=config, + datastore=datastore, + output_std=False, + ) + + forecaster = ARForecaster(predictor, datastore) + + # Override masks to test boundary masking behaviour + forecaster.interior_mask = torch.zeros_like(forecaster.interior_mask) + forecaster.interior_mask[0, 0] = 1 # One node is interior + forecaster.boundary_mask = 1 - forecaster.interior_mask + + B, num_grid_nodes = 2, predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + pred_steps = 3 + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, pred_steps, num_grid_nodes, d_forcing) + true_states = torch.ones(B, pred_steps, num_grid_nodes, d_state) * 5.0 + + prediction, pred_std = forecaster( + init_states, forcing_features, true_states + ) + + assert prediction.shape == (B, pred_steps, num_grid_nodes, d_state) + + # Boundary (where interior_mask == 0) should equal true_state (5.0) + # Interior (where interior_mask == 1) should equal predictor output (0.0) + assert torch.all(prediction[:, :, 0, :] == 0.0) + assert torch.all(prediction[:, :, 1:, :] == 5.0) + + +def test_forecaster_module_checkpoint(tmp_path): + datastore = init_datastore_example("mdp") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + # Build predictor and forecaster externally, then inject into + # ForecasterModule + # First-party + from neural_lam.models import MODELS + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss="mse", + lr=1e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1], + metrics_watch=[], + ) + + ckpt_path = tmp_path / "test.ckpt" + trainer = pl.Trainer( + max_epochs=1, + accelerator="cpu", + logger=False, + enable_checkpointing=False, + ) + trainer.strategy.connect(model) + trainer.save_checkpoint(ckpt_path) + + # Build a fresh forecaster structure for loading weights into + load_predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + load_forecaster = ARForecaster(load_predictor, datastore) + + # Load from checkpoint + loaded_model = ForecasterModule.load_from_checkpoint( + ckpt_path, + datastore=datastore, + forecaster=load_forecaster, + weights_only=False, + ) + + # Validate the correct internal hierarchy has been constructed + assert loaded_model.forecaster.predictor.__class__.__name__ == "GraphLAM" + + # Verify that outputs match (checkpoint successfully restored weights) + B, num_grid_nodes = 2, model.forecaster.predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.ones(B, 1, num_grid_nodes, d_state) * 5.0 + + with torch.no_grad(): + out_before = model.forecaster( + init_states, forcing_features, boundary_states + ) + out_after = loaded_model.forecaster( + init_states, forcing_features, boundary_states + ) + + assert torch.allclose(out_before[0], out_after[0]) + + +def test_forecaster_module_old_checkpoint(tmp_path): + datastore = init_datastore_example("mdp") + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + # First-party + from neural_lam.models import MODELS + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, + config=config, + datastore=datastore, + loss="mse", + lr=1e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1], + metrics_watch=[], + ) + + ckpt_path = tmp_path / "test_old.ckpt" + trainer = pl.Trainer( + max_epochs=1, + accelerator="cpu", + logger=False, + enable_checkpointing=False, + ) + trainer.strategy.connect(model) + trainer.save_checkpoint(ckpt_path) + + # Manually hack the checkpoint to emulate a pre-refactor state dict + ckpt = torch.load(ckpt_path, weights_only=False) + old_state_dict = {} + for k, v in ckpt["state_dict"].items(): + if k.startswith("forecaster.predictor."): + # Revert structural rename to emulate old flat keys + new_k = k.replace("forecaster.predictor.", "") + if "encoding_grid_mlp" in new_k: + new_k = new_k.replace("encoding_grid_mlp", "g2m_gnn.grid_mlp") + old_state_dict[new_k] = v + else: + old_state_dict[k] = v + + ckpt["state_dict"] = old_state_dict + torch.save(ckpt, ckpt_path) + + # Build a fresh forecaster structure for loading weights into + load_predictor = predictor_class( + config=config, + datastore=datastore, + graph_name="1level", + hidden_dim=4, + hidden_layers=1, + processor_layers=1, + mesh_aggr="sum", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + output_std=False, + ) + load_forecaster = ARForecaster(load_predictor, datastore) + + # Load from hacked old checkpoint + loaded_model = ForecasterModule.load_from_checkpoint( + ckpt_path, + datastore=datastore, + forecaster=load_forecaster, + weights_only=False, + ) + + # Validate the correct internal hierarchy has been constructed + assert loaded_model.forecaster.predictor.__class__.__name__ == "GraphLAM" + + # Verify that outputs match (checkpoint successfully restored weights) + B, num_grid_nodes = 2, model.forecaster.predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + d_forcing = datastore.get_num_data_vars(category="forcing") * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + init_states = torch.ones(B, 2, num_grid_nodes, d_state) + forcing_features = torch.ones(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.ones(B, 1, num_grid_nodes, d_state) * 5.0 + + with torch.no_grad(): + out_before = model.forecaster( + init_states, forcing_features, boundary_states + ) + out_after = loaded_model.forecaster( + init_states, forcing_features, boundary_states + ) + + assert torch.allclose(out_before[0], out_after[0]) + + +def test_step_predictor_no_static_features(): + """Model should run correctly when the datastore has no static features, + using an empty (N, 0) tensor in place of static features.""" + datastore = NoStaticDummyDatastore() + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + predictor = MockStepPredictor( + config=config, + datastore=datastore, + output_std=False, + ) + + # Static features buffer should exist but be empty (zero width) + assert predictor.grid_static_features.shape == ( + datastore.num_grid_points, + 0, + ) + + # Verify a forward pass works end-to-end via ARForecaster + forecaster = ARForecaster(predictor, datastore) + B, num_grid_nodes = 2, predictor.num_grid_nodes + d_state = datastore.get_num_data_vars(category="state") + d_forcing = datastore.get_num_data_vars(category="forcing") + init_states = torch.zeros(B, 2, num_grid_nodes, d_state) + forcing_features = torch.zeros(B, 1, num_grid_nodes, d_forcing) + boundary_states = torch.zeros(B, 1, num_grid_nodes, d_state) + + prediction, pred_std = forecaster( + init_states, forcing_features, boundary_states + ) + assert prediction.shape == (B, 1, num_grid_nodes, d_state) + assert pred_std is None diff --git a/tests/test_training.py b/tests/test_training.py index 972740695..bfab37e75 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -12,11 +12,19 @@ from neural_lam.create_graph import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore -from neural_lam.models.ar_model import ARModel -from neural_lam.models.graph_lam import GraphLAM +from neural_lam.models.forecaster_module import ForecasterModule from neural_lam.weather_dataset import WeatherDataModule from tests.conftest import init_datastore_example +# Model architecture defaults for tests +GRAPH = "1level" +HIDDEN_DIM = 4 +HIDDEN_LAYERS = 1 +PROCESSOR_LAYERS = 2 +MESH_AGGR = "sum" +NUM_PAST_FORCING_STEPS = 1 +NUM_FUTURE_FORCING_STEPS = 1 + def run_simple_training(datastore, set_output_std): """ @@ -42,9 +50,6 @@ def run_simple_training(datastore, set_output_std): max_epochs=1, deterministic=True, accelerator=device_name, - # XXX: `devices` has to be set to 2 otherwise - # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails - # because it expects to aggregate over multiple devices devices=2, log_every_n_steps=1, # use `detect_anomaly` to ensure that we don't have NaNs popping up @@ -74,39 +79,50 @@ def run_simple_training(datastore, set_output_std): num_future_forcing_steps=1, ) - class ModelArgs: - output_std = set_output_std - loss = "mse" - restore_opt = False - n_example_pred = 1 - # XXX: this should be superfluous when we have already defined the - # model object no? - graph = graph_name - hidden_dim = 4 - hidden_layers = 1 - processor_layers = 2 - mesh_aggr = "sum" - lr = 1.0e-3 - val_steps_to_log = [1, 3] - metrics_watch = [] - num_past_forcing_steps = 1 - num_future_forcing_steps = 1 - - model_args = ModelArgs() - config = nlconfig.NeuralLAMConfig( datastore=nlconfig.DatastoreSelection( kind=datastore.SHORT_NAME, config_path=datastore.root_path ) ) - model = GraphLAM( # noqa - args=model_args, + # Build predictor and forecaster externally, then inject into + # ForecasterModule + # First-party + from neural_lam.models import MODELS + from neural_lam.models.ar_forecaster import ARForecaster + + predictor_class = MODELS["graph_lam"] + predictor = predictor_class( + config=config, datastore=datastore, + graph_name=graph_name, + hidden_dim=HIDDEN_DIM, + hidden_layers=HIDDEN_LAYERS, + processor_layers=PROCESSOR_LAYERS, + mesh_aggr=MESH_AGGR, + num_past_forcing_steps=NUM_PAST_FORCING_STEPS, + num_future_forcing_steps=NUM_FUTURE_FORCING_STEPS, + output_std=set_output_std, + ) + forecaster = ARForecaster(predictor, datastore) + + model = ForecasterModule( + forecaster=forecaster, config=config, + datastore=datastore, + loss="mse", + lr=1.0e-3, + restore_opt=False, + n_example_pred=1, + val_steps_to_log=[1, 3], + metrics_watch=[], + var_leads_metrics_watch={}, ) wandb.init() - trainer.fit(model=model, datamodule=data_module) + try: + trainer.fit(model=model, datamodule=data_module) + finally: + wandb.finish() @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) @@ -125,67 +141,3 @@ def test_training(datastore_name): def test_training_output_std(): datastore = init_datastore_example("mdp") # Test only with mdp datastore run_simple_training(datastore, set_output_std=True) - - -def test_all_gather_cat_single_device(): - """ - Test that all_gather_cat preserves tensor shape on single-device runs. - On a single device, all_gather returns the tensor unchanged (no new - leading dim), so all_gather_cat should not flatten any existing dims. - """ - - class MockModule: - """Minimal object with mocked single-device all_gather.""" - - def all_gather(self, tensor_to_gather, sync_grads=False): - # Single-device behavior: return tensor unchanged - return tensor_to_gather - - module = MockModule() - # Bind the real ARModel.all_gather_cat to our mock - module.all_gather_cat = ARModel.all_gather_cat.__get__(module, MockModule) - - # Simulate a 3D metric tensor: (N_eval, pred_steps, d_f) - tensor = torch.randn(4, 3, 5) - result = module.all_gather_cat(tensor) - - # On single device, shape must be preserved - assert result.shape == tensor.shape, ( - f"all_gather_cat changed shape on single device: " - f"{tensor.shape} -> {result.shape}" - ) - assert torch.equal(result, tensor) - - -def test_all_gather_cat_multi_device_simulation(): - """ - Test that all_gather_cat correctly flattens when all_gather adds a - leading dimension (simulating multi-device behavior). - """ - - class MockModule: - """Object with mocked multi-device all_gather.""" - - def all_gather(self, tensor, sync_grads=False): - # Simulate 2-GPU all_gather: prepend a dim of size 2 - return torch.stack([tensor, tensor], dim=0) - - module = MockModule() - # Bind the real ARModel.all_gather_cat to our mock - module.all_gather_cat = ARModel.all_gather_cat.__get__(module, MockModule) - - tensor = torch.randn(4, 3, 5) # (N_eval, pred_steps, d_f) - result = module.all_gather_cat(tensor) - - # Should flatten (2, 4, 3, 5) -> (8, 3, 5) - assert result.shape == ( - 8, - 3, - 5, - ), f"all_gather_cat wrong shape on multi-device: {result.shape}" - # Validate values match expected concatenation along dim 0 - expected = torch.cat([tensor, tensor], dim=0) - assert torch.equal(result, expected), ( - "all_gather_cat produced incorrectly ordered/combined values " - "on multi-device simulation" - )