Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ repos:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
additional_dependencies: [Flake8-pyproject]
- repo: https://github.com/econchick/interrogate
rev: 1.7.0
hooks:
- id: interrogate
description: Ensure documentation coverage stays perfect
pass_filenames: false
args: ["--fail-under=100", "neural_lam"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.0
hooks:
- id: mypy
- id: mypy
additional_dependencies: [types-PyYAML, types-Pillow, types-tqdm]
description: Check for type errors
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,23 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w
These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels).
Entries 0 in these lists describe edges between the lowest levels 1 and 2.

## Dimension Glossary

Canonical dimension names used in tensor shape annotations throughout the codebase:

- `B` — batch size
- `pred_steps` — number of autoregressive prediction steps
- `num_grid_nodes` — number of nodes in the flattened spatial grid
- `num_mesh_nodes` — number of mesh nodes; indexed as `num_mesh_nodes[l]` for hierarchical level `l`
- `num_state_vars` — number of atmospheric state variables
- `num_forcing_vars` — number of forcing input variables
- `num_variables` — generic variable dimension used in metric functions
- `hidden_dim` — internal hidden representation size in GNN layers and MLPs
- `input_dim` — input feature dimensionality to a layer before transformation
- `num_edges` — number of edges in a graph (g2m, m2g, same-level, up, down)
- `num_send` — number of sender nodes in a message-passing step
- `num_rec` — number of receiver nodes in a message-passing step

# Development and Contributing
Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks.
These hooks will run a series of checks on the code, like formatting and linting.
Expand Down
2 changes: 2 additions & 0 deletions neural_lam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Neural-LAM: graph-based neural weather prediction models."""

# Standard library
import importlib.metadata

Expand Down
61 changes: 47 additions & 14 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Configuration utilities and dataclasses for neural-lam.

This module defines configuration structures used to initialize
datastores, training settings, and model-related parameters.
"""

# Standard library
import dataclasses
from pathlib import Path
Expand All @@ -18,21 +25,26 @@
@dataclasses.dataclass
class DatastoreSelection:
"""
Configuration for selecting a datastore to use with neural-lam.
Configuration for selecting a datastore to use with Neural-LAM.

Attributes
----------
kind : str
The kind of datastore to use, currently `mdp` or `npyfilesmeps` are
implemented.
Identifier of the datastore to use (e.g. "mdp", "npyfilesmeps").
config_path : str
The path to the configuration file for the selected datastore, this is
assumed to be relative to the configuration file for neural-lam.
Path to the datastore-specific configuration file, relative to the
main Neural-LAM configuration file.
"""

kind: str

def __post_init__(self):
"""
Validate datastore configuration after initialization.

Ensures that the datastore kind is correctly
specified and supported.
"""
if self.kind not in DATASTORES:
raise ValueError(f"Datastore kind {self.kind} is not implemented")

Expand Down Expand Up @@ -84,15 +96,21 @@ class OutputClamping:
@dataclasses.dataclass
class TrainingConfig:
"""
Configuration related to training neural-lam
Configuration related to training Neural-LAM.

This includes:
- how state features are weighted in the loss function
- optional output clamping behaviour

Attributes
----------
state_feature_weighting : Union[ManualStateFeatureWeighting,
UnformFeatureWeighting]
The method to use for weighting the state features in the loss
function. Defaults to uniform weighting (`UnformFeatureWeighting`, i.e.
all features are weighted equally).
UniformFeatureWeighting]
Strategy used to weight state features in the loss function.
Defaults to uniform weighting (all features contribute equally).

output_clamping : OutputClamping
Optional configuration to clamp model outputs within specified bounds.
"""

state_feature_weighting: Union[
Expand Down Expand Up @@ -150,15 +168,25 @@ class _(dataclass_wizard.JSONWizard.Meta):


class InvalidConfigError(Exception):
"""Raised when the configuration file is invalid or cannot be parsed."""

pass


def load_config_and_datastore(
config_path: str,
) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]:
) -> tuple[
NeuralLAMConfig,
Union[MDPDatastore, NpyFilesDatastoreMEPS],
]:
"""
Load the neural-lam configuration and the datastore specified in the
configuration.
Load the Neural-LAM configuration file and initialize
the corresponding datastore.

This function:
- parses the configuration file into a ``NeuralLAMConfig`` object
- resolves the datastore configuration path relative to the config file
- initializes the datastore using the selected backend

Parameters
----------
Expand All @@ -168,7 +196,12 @@ def load_config_and_datastore(
Returns
-------
tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]
The Neural-LAM configuration and the loaded datastore.
Loaded configuration object and initialized datastore instance.

Raises
------
InvalidConfigError
If the configuration file contains unknown or invalid fields.
"""
try:
config = NeuralLAMConfig.from_yaml_file(config_path)
Expand Down
48 changes: 48 additions & 0 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Graph construction utilities for Neural-LAM meshes and grids."""

# Standard library
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
Expand All @@ -20,6 +22,21 @@


def plot_graph(graph, title=None):
"""
Render a PyTorch Geometric graph using stored node coordinates.

Parameters
----------
graph : torch_geometric.data.Data
Graph containing ``edge_index`` and ``pos`` attributes.
title : str or None, optional
Optional subplot title.

Returns
-------
tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]
Figure and axis handles for further customization.
"""
fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H
edge_index = graph.edge_index
pos = graph.pos
Expand Down Expand Up @@ -71,6 +88,7 @@ def plot_graph(graph, title=None):


def sort_nodes_internally(nx_graph):
"""Return a copy of ``nx_graph`` with deterministically ordered nodes."""
# For some reason the networkx .nodes() return list can not be sorted,
# but this is the ordering used by pyg when converting.
# This function fixes this.
Expand All @@ -81,6 +99,7 @@ def sort_nodes_internally(nx_graph):


def save_edges(graph, name, base_path):
"""Persist edge indices/features for a PyG graph under ``base_path``."""
torch.save(
graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")
)
Expand All @@ -91,6 +110,7 @@ def save_edges(graph, name, base_path):


def save_edges_list(graphs, name, base_path):
"""Persist edge indices/features for a list of graphs."""
torch.save(
[graph.edge_index for graph in graphs],
os.path.join(base_path, f"{name}_edge_index.pt"),
Expand All @@ -105,12 +125,14 @@ def save_edges_list(graphs, name, base_path):


def from_networkx_with_start_index(nx_graph, start_index):
"""Convert a NetworkX graph to PyG and offset node indices."""
pyg_graph = from_networkx(nx_graph)
pyg_graph.edge_index += start_index
return pyg_graph


def mk_2d_graph(xy, nx, ny):
"""Create a diagonal 2-D grid graph over the ``xy`` positions."""
xm, xM = np.amin(xy[:, :, 0][:, 0]), np.amax(xy[:, :, 0][:, 0])
ym, yM = np.amin(xy[:, :, 1][0, :]), np.amax(xy[:, :, 1][0, :])

Expand Down Expand Up @@ -150,6 +172,7 @@ def mk_2d_graph(xy, nx, ny):


def prepend_node_index(graph, new_index):
"""Relabel each node by prepending ``new_index`` to its tuple identifier."""
# Relabel node indices in graph, insert (graph_level, i, j)
ijk = [tuple((new_index,) + x) for x in graph.nodes]
to_mapping = dict(zip(graph.nodes, ijk))
Expand Down Expand Up @@ -544,6 +567,22 @@ def create_graph_from_datastore(
hierarchical: bool = False,
create_plot: bool = False,
):
"""
Generate graph components for ``datastore`` and persist them on disk.

Parameters
----------
datastore : BaseRegularGridDatastore
Datastore providing ``get_xy`` for state nodes.
output_root_path : str
Directory where the resulting ``*.pt`` graph files are stored.
n_max_levels : int or None, optional
Optional limit of hierarchical mesh levels to build.
hierarchical : bool, optional
If ``True``, create multi-level hierarchical graphs. Default ``False``.
create_plot : bool, optional
If ``True``, display matplotlib previews of the generated graphs.
"""
if isinstance(datastore, BaseRegularGridDatastore):
xy = datastore.get_xy(category="state", stacked=False)
else:
Expand All @@ -561,6 +600,15 @@ def create_graph_from_datastore(


def cli(input_args=None):
"""
Parse CLI arguments and call :func:`create_graph_from_datastore`.

Parameters
----------
input_args : list[str] or None, optional
Argument list forwarded to :class:`argparse.ArgumentParser`. When
``None``, ``sys.argv`` is used.
"""
parser = ArgumentParser(
description="Graph generation for neural-lam",
formatter_class=ArgumentDefaultsHelpFormatter,
Expand Down
35 changes: 28 additions & 7 deletions neural_lam/custom_loggers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom logging utilities (e.g., MLFlow wrappers) used in Neural-LAM."""

# Standard library
import sys

Expand All @@ -16,6 +18,18 @@ class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""

def __init__(self, experiment_name, tracking_uri, run_name):
"""
Initialize the logger and start an MLflow run.

Parameters
----------
experiment_name : str
Target MLflow experiment.
tracking_uri : str
MLflow tracking server URI.
run_name : str
Human-readable run name stored as ``mlflow.runName``.
"""
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)
Expand All @@ -39,14 +53,21 @@ def save_dir(self):

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow
Log one or more Matplotlib figures as images in MLflow.

Parameters
----------
key : str
Identifier under which to log the image.
images : Sequence[matplotlib.figure.Figure]
Figures to export; only the first element is logged.
step : int or None, optional
Optional training step index appended to ``key``.

key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
Raises
------
SystemExit
If AWS credentials for the MLflow artifact store are missing.
"""
# Third-party
from botocore.exceptions import NoCredentialsError
Expand Down
22 changes: 22 additions & 0 deletions neural_lam/datastore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Datastore backends for loading and serving weather model data."""

# Local
from .base import BaseDatastore # noqa
from .mdp import MDPDatastore # noqa
Expand All @@ -15,6 +17,26 @@


def init_datastore(datastore_kind, config_path):
"""
Instantiate a datastore based on its short-name identifier.

Parameters
----------
datastore_kind : str
Key corresponding to one of :data:`DATASTORES`.
config_path : str | pathlib.Path
Path to the datastore-specific configuration file.

Returns
-------
BaseDatastore
Concrete datastore instance configured for ``config_path``.

Raises
------
NotImplementedError
If ``datastore_kind`` is not registered.
"""
DatastoreClass = DATASTORES.get(datastore_kind)

if DatastoreClass is None:
Expand Down
14 changes: 10 additions & 4 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Abstract base classes describing Neural-LAM datastore APIs."""

# Standard library
import abc
import collections
Expand Down Expand Up @@ -85,8 +87,10 @@ def config(self) -> collections.abc.Mapping:
def step_length(self) -> timedelta:
"""The step length of the dataset as a time interval.

Returns:
timedelta: The step length as a datetime.timedelta object.
Returns
-------
datetime.timedelta
The step length as a ``datetime.timedelta`` object.

"""
pass
Expand Down Expand Up @@ -384,8 +388,10 @@ def state_feature_weights_values(self) -> List[float]:
the loss function for each state variable (e.g. via the standard
deviation of the 1-step differences of the state variables).

Returns:
List[float]: The weights for each state feature.
Returns
-------
List[float]
The weights for each state feature.
"""
pass

Expand Down
2 changes: 2 additions & 0 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Datastore implementation wrapping ``mllam-data-prep`` outputs."""

# Standard library
import copy
import functools
Expand Down
Loading