Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6f82d43
feat: introduce mesh_layout argument with two-step coordinate/connect…
prajwal-tech07 Mar 2, 2026
63af996
test: add comprehensive tests for mesh_layout two-step architecture
prajwal-tech07 Mar 2, 2026
f417962
refactor: align API with Leif's final table — rename params, simplify…
prajwal-tech07 Mar 2, 2026
30ddb19
docs: add CHANGELOG entry for mesh_layout feature
prajwal-tech07 Mar 2, 2026
802b66a
Address all review feedback from PR #81
prajwal-tech07 Mar 4, 2026
21d33e8
refactor: add coords module with two-step mesh creation (primitives +…
prajwal-tech07 Mar 8, 2026
8d18530
refactor: use explicit keyword argument names in flat mesh functions
prajwal-tech07 Mar 8, 2026
4f5a8de
refactor: put default values in call signature for create_hierarchica…
prajwal-tech07 Mar 8, 2026
e7e3c03
refactor: use loguru logger.warning() and implement two-step mesh cre…
prajwal-tech07 Mar 8, 2026
ef478a2
refactor: update archetypes and tests to use mesh_layout parameter
prajwal-tech07 Mar 8, 2026
305e66a
test: update backward-compat tests to capture loguru warnings instead…
prajwal-tech07 Mar 8, 2026
b38fd64
Merge branch 'main' into feat/mesh-layout-rectilinear
prajwal-tech07 Mar 10, 2026
4372a3f
style: fix linting — run black, isort, remove empty f-strings
prajwal-tech07 Mar 10, 2026
4f48ed3
refactor: remove inline defaults and simplify m2m_connectivity_kwargs…
prajwal-tech07 Mar 20, 2026
0b6ff4a
refactor: rename mesh.kinds to mesh.connectivity, use **kwargs, updat…
prajwal-tech07 Mar 23, 2026
c7bdeac
Merge branch 'main' into feat/mesh-layout-rectilinear
leifdenby Mar 23, 2026
51ea189
fix: pre-commit linting + notebook max_num_refinement_levels
prajwal-tech07 Mar 23, 2026
220fa44
Merge branch 'main' into feat/mesh-layout-rectilinear
leifdenby Mar 24, 2026
edbce82
refactor: separate coordinate creation (Step 1) from connectivity cre…
prajwal-tech07 Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [unreleased](https://github.com/mllam/weather-model-graphs/compare/v0.3.0...HEAD)

### Added

- Add `mesh_layout` argument to mesh graph creation functions, with `rectilinear`
as the first supported layout. Uses a two-step architecture separating coordinate
creation from connectivity creation, enabling future alternative layouts (e.g. triangular).
[\#78](https://github.com/mllam/weather-model-graphs/issues/78), @prajwal-tech07

## [v0.3.0](https://github.com/mllam/weather-model-graphs/releases/tag/v0.3.0)

### Added
Expand Down
25 changes: 18 additions & 7 deletions src/weather_model_graphs/create/archetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def create_keisler_graph(
return create_all_graph_components(
coords=coords,
m2m_connectivity="flat",
m2m_connectivity_kwargs=dict(mesh_node_distance=mesh_node_distance),
mesh_layout="rectilinear",
mesh_layout_kwargs=dict(grid_spacing=mesh_node_distance),
m2m_connectivity_kwargs=dict(pattern="8-star"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on discussions in the last dev meeting there was preference for m2m connectivity not being given explicitly here, but instead we should just let the m2m_connectivity method define its own default (which for flat would be 8-star) and then we don't have to pass in m2m_connectivity_kwargs here

g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
Expand Down Expand Up @@ -133,10 +135,14 @@ def create_graphcast_graph(
return create_all_graph_components(
coords=coords,
m2m_connectivity="flat_multiscale",
mesh_layout="rectilinear",
mesh_layout_kwargs=dict(
grid_spacing=mesh_node_distance,
refinement_factor=level_refinement_factor,
max_num_refinement_levels=max_num_levels,
),
m2m_connectivity_kwargs=dict(
mesh_node_distance=mesh_node_distance,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
pattern="8-star",
),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
Expand Down Expand Up @@ -217,10 +223,15 @@ def create_oskarsson_hierarchical_graph(
return create_all_graph_components(
coords=coords,
m2m_connectivity="hierarchical",
mesh_layout="rectilinear",
mesh_layout_kwargs=dict(
grid_spacing=mesh_node_distance,
refinement_factor=level_refinement_factor,
max_num_refinement_levels=max_num_levels,
),
m2m_connectivity_kwargs=dict(
mesh_node_distance=mesh_node_distance,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
intra_level=dict(pattern="8-star"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again suggested that inter-level and intra-level connectivity parameters defined here should be the default assumed for hierarchical similar to what I suggested for flat here: https://github.com/mllam/weather-model-graphs/pull/81/changes#r2953045634

inter_level=dict(pattern="nearest", k=1),
),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
Expand Down
206 changes: 183 additions & 23 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
function uses `connect_nodes_across_graphs` to connect nodes across the component graphs.
"""


import warnings
from typing import Iterable

import networkx
Expand All @@ -25,18 +25,26 @@
)
from .grid import create_grid_graph_nodes
from .mesh.kinds.flat import (
create_flat_multiscale_mesh_graph,
create_flat_singlescale_mesh_graph,
create_flat_multiscale_from_coordinates,
create_flat_singlescale_from_coordinates,
)
from .mesh.kinds.hierarchical import (
create_hierarchical_from_coordinates,
)
from .mesh.mesh import (
create_multirange_2d_mesh_coordinates,
create_single_level_2d_mesh_coordinates,
)
from .mesh.kinds.hierarchical import create_hierarchical_multiscale_mesh_graph


def create_all_graph_components(
coords: np.ndarray,
m2m_connectivity: str,
m2g_connectivity: str,
g2m_connectivity: str,
m2m_connectivity_kwargs={},
mesh_layout: str = "rectilinear",
mesh_layout_kwargs: dict = None,
m2m_connectivity_kwargs: dict = None,
m2g_connectivity_kwargs={},
g2m_connectivity_kwargs={},
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m2g_connectivity_kwargs and g2m_connectivity_kwargs are still using {} as default values. These are mutable defaults and can lead to cross-call state leakage if they are ever mutated inside the function or by callees. Use None defaults and initialize/copy inside the function (as is already done for mesh_layout_kwargs and m2m_connectivity_kwargs).

Suggested change
m2g_connectivity_kwargs={},
g2m_connectivity_kwargs={},
m2g_connectivity_kwargs=None,
g2m_connectivity_kwargs=None,

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, also I don't think we should default to rectilinear for the mesh_layout argument. We want people to be explicit about the method they want to use

coords_crs: pyproj.crs.CRS | None = None,
Expand All @@ -49,6 +57,14 @@ def create_all_graph_components(
grid-to-mesh (g2m), mesh-to-mesh (m2m) and mesh-to-grid (m2g),
representing the encode-process-decode respectively.

The mesh graph creation follows a two-step process:
1. **Coordinate creation** (controlled by `mesh_layout` + `mesh_layout_kwargs`):
Creates an undirected graph (nx.Graph) with node positions and spatial
adjacency edges annotated with adjacency types.
2. **Connectivity creation** (controlled by `m2m_connectivity` + `m2m_connectivity_kwargs`):
Converts the coordinate graph to directed connectivity (nx.DiGraph)
based on the specified pattern and connectivity method.

For each graph component, the method for connecting nodes across graphs
should be specified (with the `*_connectivity` arguments, e.g. `m2g_connectivity`).
And the method-specific arguments should be passed as keyword arguments using
Expand All @@ -62,15 +78,23 @@ def create_all_graph_components(
- "within_radius": Find all neighbours in grid within an absolute distance
of `max_dist` or relative distance of `rel_max_dist` from each node in mesh

mesh_layout:
- "rectilinear": Regular rectilinear grid (default). Uses grid_spacing to
determine mesh node placement. Produces nodes with 4-star (cardinal) and
8-star (cardinal + diagonal) spatial adjacency edges.

mesh_layout_kwargs (for mesh_layout="rectilinear"):
- grid_spacing: float, distance between mesh nodes in coordinate units
- refinement_factor: int, refinement factor between levels (for multi-level)
- max_num_refinement_levels: int, maximum number of mesh levels (for multi-level)

m2m_connectivity:
- "flat": Create a single-level 2D mesh graph with `mesh_node_distance`,
similar to Keisler et al. (2022)
- "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`,
`mesh_node_distance` and `level_refinement_factor`,
similar to GraphCast, Lam et al. (2023)
- "hierarchical": Create a hierarchical mesh graph with `max_num_levels`,
`mesh_node_distance` and `level_refinement_factor`,
similar to Oskarsson et al. (2023)
- "flat": Create a single-level directed mesh graph.
m2m_connectivity_kwargs: pattern="4-star" or "8-star" (default: "8-star")
- "flat_multiscale": Create a flat multiscale mesh graph.
m2m_connectivity_kwargs: pattern="4-star" or "8-star" (default: "8-star")
- "hierarchical": Create a hierarchical mesh graph with up/down connections.
m2m_connectivity_kwargs: intra_level=dict(pattern=...), inter_level=dict(pattern=..., k=...)

m2g_connectivity:
- "nearest_neighbour": Find the nearest neighbour in mesh for each node in grid
Expand All @@ -97,6 +121,50 @@ def create_all_graph_components(
"""
graph_components: dict[networkx.DiGraph] = {}

# Initialize mutable default arguments (and copy to avoid mutating caller's dicts)
if mesh_layout_kwargs is None:
mesh_layout_kwargs = {}
else:
mesh_layout_kwargs = dict(mesh_layout_kwargs)
if m2m_connectivity_kwargs is None:
m2m_connectivity_kwargs = {}
else:
m2m_connectivity_kwargs = dict(m2m_connectivity_kwargs)

# Backward compatibility: migrate old-style kwargs where mesh_node_distance,
# level_refinement_factor, and max_num_levels were passed via
# m2m_connectivity_kwargs. In the new design these belong in mesh_layout_kwargs.
if "mesh_node_distance" in m2m_connectivity_kwargs and "grid_spacing" not in mesh_layout_kwargs:
warnings.warn(
"Passing 'mesh_node_distance' in m2m_connectivity_kwargs is deprecated. "
"Use mesh_layout_kwargs=dict(grid_spacing=...) instead.",
DeprecationWarning,
stacklevel=2,
)
mesh_layout_kwargs["grid_spacing"] = m2m_connectivity_kwargs.pop(
"mesh_node_distance"
)
if "level_refinement_factor" in m2m_connectivity_kwargs and "refinement_factor" not in mesh_layout_kwargs:
warnings.warn(
"Passing 'level_refinement_factor' in m2m_connectivity_kwargs is deprecated. "
"Use mesh_layout_kwargs=dict(refinement_factor=...) instead.",
DeprecationWarning,
stacklevel=2,
)
mesh_layout_kwargs["refinement_factor"] = (
m2m_connectivity_kwargs.pop("level_refinement_factor")
)
if "max_num_levels" in m2m_connectivity_kwargs and "max_num_refinement_levels" not in mesh_layout_kwargs:
warnings.warn(
"Passing 'max_num_levels' in m2m_connectivity_kwargs is deprecated. "
"Use mesh_layout_kwargs=dict(max_num_refinement_levels=...) instead.",
DeprecationWarning,
stacklevel=2,
)
mesh_layout_kwargs["max_num_refinement_levels"] = m2m_connectivity_kwargs.pop(
"max_num_levels"
)

assert (
len(coords.shape) == 2 and coords.shape[1] == 2
), "Grid node coordinates should be given as an array of shape [num_grid_nodes, 2]."
Expand Down Expand Up @@ -126,26 +194,118 @@ def create_all_graph_components(
xy = np.stack(xy_tuple, axis=1)

if m2m_connectivity == "flat":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because all m2m_connectivity all now follow the same pattern I think we can simplify the code here. I.e. 1) create the coordinates based on the mesh_layout selected and then 2) create the connectivity between the mesh nodes calling the appropriate function between create_hierarchical_from_coordinates(), create_flat_multiscale_from_coordinates() and create_flat_singlescale_from_coordinates(). What do you think?

graph_components["m2m"] = create_flat_singlescale_mesh_graph(
xy,
**m2m_connectivity_kwargs,
# --- Step 1: Coordinate creation based on mesh_layout ---
if mesh_layout == "rectilinear":
grid_spacing = mesh_layout_kwargs.get("grid_spacing")
if grid_spacing is None:
raise ValueError(
"mesh_layout='rectilinear' requires 'grid_spacing' in "
"mesh_layout_kwargs (or 'mesh_node_distance' in "
"m2m_connectivity_kwargs for backward compatibility)."
)
# Compute number of mesh nodes from grid_spacing
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rectilinear mesh_layout handling is duplicated across the flat, hierarchical, and flat_multiscale branches (grid_spacing extraction/validation + coordinate graph construction). This repetition makes it easy for defaults/validation to diverge between modes (e.g. the multi-level branches currently differ from the single-level branch). Consider factoring the layout dispatch into a small helper that returns either a single coordinate graph or a list of coordinate graphs, so behavior stays consistent as new layouts are added.

Copilot uses AI. Check for mistakes.
range_x, range_y = np.ptp(xy, axis=0)
nx_mesh = int(range_x / grid_spacing)
ny_mesh = int(range_y / grid_spacing)
if nx_mesh == 0 or ny_mesh == 0:
raise ValueError(
"The given `grid_spacing` is too large for the provided "
f"coordinates. Got grid_spacing={grid_spacing}, but the "
f"x-range is {range_x} and y-range is {range_y}. Maybe you "
"want to decrease the `grid_spacing` so that the mesh nodes "
"are spaced closer together?"
)
G_mesh_coords = create_single_level_2d_mesh_coordinates(
xy, nx_mesh, ny_mesh
)
else:
raise NotImplementedError(
f"mesh_layout='{mesh_layout}' is not yet supported. "
"Currently only 'rectilinear' is implemented."
)

# --- Step 2: Connectivity creation ---
pattern = m2m_connectivity_kwargs.get("pattern", "8-star")
graph_components["m2m"] = create_flat_singlescale_from_coordinates(
G_mesh_coords, pattern=pattern
)
grid_connect_graph = graph_components["m2m"]

elif m2m_connectivity == "hierarchical":
# --- Step 1: Coordinate creation based on mesh_layout ---
if mesh_layout == "rectilinear":
grid_spacing = mesh_layout_kwargs.get("grid_spacing")
refinement_factor = mesh_layout_kwargs.get("refinement_factor")
max_num_refinement_levels = mesh_layout_kwargs.get(
"max_num_refinement_levels"
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the hierarchical path, refinement_factor (and max_num_refinement_levels) are pulled from mesh_layout_kwargs without defaults/validation, but are then passed directly to create_multirange_2d_mesh_coordinates. If the caller omits refinement_factor, this will pass None as interlevel_refinement_factor and fail inside np.log(...). Either set explicit defaults here (e.g. refinement_factor=3, max_num_refinement_levels=None per the API spec) or raise a clear ValueError when they’re missing/invalid.

Suggested change
grid_spacing = mesh_layout_kwargs.get("grid_spacing")
refinement_factor = mesh_layout_kwargs.get("refinement_factor")
max_num_refinement_levels = mesh_layout_kwargs.get(
"max_num_refinement_levels"
# Ensure we have a dict to read from, even if the caller passed None.
mesh_layout_kwargs = mesh_layout_kwargs or {}
grid_spacing = mesh_layout_kwargs.get("grid_spacing")
# Default refinement_factor and max_num_refinement_levels per API spec.
refinement_factor = mesh_layout_kwargs.get("refinement_factor", 3)
max_num_refinement_levels = mesh_layout_kwargs.get(
"max_num_refinement_levels", None

Copilot uses AI. Check for mistakes.
)
if grid_spacing is None:
raise ValueError(
"mesh_layout='rectilinear' with m2m_connectivity='hierarchical' "
"requires 'grid_spacing' in mesh_layout_kwargs."
)
G_coords_list = create_multirange_2d_mesh_coordinates(
max_num_levels=max_num_refinement_levels,
xy=xy,
grid_spacing=grid_spacing,
interlevel_refinement_factor=refinement_factor,
)
else:
raise NotImplementedError(
f"mesh_layout='{mesh_layout}' is not yet supported. "
"Currently only 'rectilinear' is implemented."
)

# --- Step 2: Connectivity creation ---
intra_level = m2m_connectivity_kwargs.get(
"intra_level", {"pattern": "8-star"}
)
inter_level = m2m_connectivity_kwargs.get(
"inter_level", {"pattern": "nearest", "k": 1}
)
# hierarchical mesh graph have three sub-graphs:
# `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and `mesh_down` (down edge connections)
graph_components["m2m"] = create_hierarchical_multiscale_mesh_graph(
xy=xy,
**m2m_connectivity_kwargs,
# `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and
# `mesh_down` (down edge connections)
graph_components["m2m"] = create_hierarchical_from_coordinates(
G_coords_list,
intra_level=intra_level,
inter_level=inter_level,
)
# Only connect grid to bottom level of hierarchy
grid_connect_graph = split_graph_by_edge_attribute(
graph_components["m2m"], "level"
)[0]

elif m2m_connectivity == "flat_multiscale":
graph_components["m2m"] = create_flat_multiscale_mesh_graph(
xy=xy,
**m2m_connectivity_kwargs,
# --- Step 1: Coordinate creation based on mesh_layout ---
if mesh_layout == "rectilinear":
grid_spacing = mesh_layout_kwargs.get("grid_spacing")
refinement_factor = mesh_layout_kwargs.get("refinement_factor")
max_num_refinement_levels = mesh_layout_kwargs.get(
"max_num_refinement_levels"
)
if grid_spacing is None:
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the flat_multiscale path, refinement_factor = mesh_layout_kwargs.get("refinement_factor") can be None, but it’s passed as interlevel_refinement_factor into create_multirange_2d_mesh_coordinates, which will error. Add defaults (and/or validation) for refinement_factor and max_num_refinement_levels before calling the coordinate creation step.

Copilot uses AI. Check for mistakes.
raise ValueError(
"mesh_layout='rectilinear' with m2m_connectivity='flat_multiscale' "
"requires 'grid_spacing' in mesh_layout_kwargs."
)
G_coords_list = create_multirange_2d_mesh_coordinates(
max_num_levels=max_num_refinement_levels,
xy=xy,
grid_spacing=grid_spacing,
interlevel_refinement_factor=refinement_factor,
)
else:
raise NotImplementedError(
f"mesh_layout='{mesh_layout}' is not yet supported. "
"Currently only 'rectilinear' is implemented."
)

# --- Step 2: Connectivity creation ---
pattern = m2m_connectivity_kwargs.get("pattern", "8-star")
graph_components["m2m"] = create_flat_multiscale_from_coordinates(
G_coords_list,
pattern=pattern,
)
grid_connect_graph = graph_components["m2m"]
else:
Expand Down
7 changes: 6 additions & 1 deletion src/weather_model_graphs/create/mesh/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .mesh import create_single_level_2d_mesh_graph
from .mesh import (
create_directed_mesh_graph,
create_multirange_2d_mesh_coordinates,
create_single_level_2d_mesh_coordinates,
create_single_level_2d_mesh_graph,
)
Loading
Loading