Skip to content
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8a11527
Start looking into where xy has to be changed
Oct 2, 2024
2c7e4eb
Change shapes in docstrings
Oct 3, 2024
6f03908
Make flat mesh graphs work with new coordinate layout
Oct 3, 2024
d2a081a
Merge PR #26 into branch
Oct 13, 2024
9dbd2a5
Fix coordinate handling in multirange graph creation
Oct 13, 2024
eada6ea
Rename grid_refinement_factor to mesh_node_distance
Oct 13, 2024
01180a1
Fix existing tests to work with new coordinate format
Oct 13, 2024
d103ad6
Add test for irregularlygridded coordinates
Oct 13, 2024
7ba387e
Remove unneeded eps in mesh level calculation
Oct 14, 2024
9069c84
Change documentation to use new format and arguments for coordinates
Oct 14, 2024
d592c2b
Fix bug in coordinate order for flat graphs
Oct 14, 2024
fd695b1
Start working on allowing latlon coordinates
Oct 14, 2024
2c781ee
Introduce coords and projection
Oct 14, 2024
5cf3bbf
Merge branch 'main' into general_coordinates
Oct 14, 2024
0f17d2c
Fix linting
Oct 14, 2024
f145025
Fix tests with coords keyword argument
Oct 14, 2024
cc4cc5e
Implement lat-lon transformation through projection
Oct 16, 2024
2627e37
Add documentation page about graphs constructed using lat-lons
Oct 16, 2024
c764fd7
Adjust coords keyword arg in docs
Oct 16, 2024
f6ae35b
Add test for lat-lon coordinates
Oct 16, 2024
22caf65
Fix linting of docs
Oct 16, 2024
7ec34ff
Merge main into branch
Oct 17, 2024
f564c70
Add decode_mask for only including subset of grid nodes in m2g
Oct 21, 2024
2333d53
Add test for decode filtering
Oct 21, 2024
3ee25ae
Fix typos and clarifications as suggested from code review
joeloskarsson Oct 22, 2024
70eef3e
Change euclidean coordinates to Cartesian coordinates
Oct 23, 2024
3c4866b
Merge branch 'main' into general_coordinates
Oct 23, 2024
5f33bc5
Merge branch 'general_coordinates' into decoding_mask
Oct 23, 2024
63df482
Merge branch 'main' into decoding_mask
Nov 11, 2024
350e0c0
Sort nodes and subgraphs for saving
Nov 11, 2024
806b78f
Fix linting
Nov 11, 2024
4fb69dc
Merge branch 'main' into general_coordinates
Nov 18, 2024
8fcf182
Apply suggested documentation and code readability updates
joeloskarsson Nov 19, 2024
746966f
Update src/weather_model_graphs/create/mesh/kinds/hierarchical.py
joeloskarsson Nov 26, 2024
47b5dcc
Clarify comments and variable names around mesh level computation
Nov 26, 2024
05a0cc6
Add check for number of nodes in test with irregular coords
Nov 26, 2024
339feff
Update docs line on square meshes
Nov 26, 2024
6b281bf
Reference lat-lon notebook in coordinate section
Nov 26, 2024
a525818
Change projection spec to use pyproj crs:s
Nov 26, 2024
8733359
Adjust test to crs arguments
Nov 26, 2024
c95f023
Fix linting
Nov 26, 2024
bcaf1e1
Update docs to crs change
Nov 26, 2024
8e3c1cb
Add cartopy dependency to visualization group
Nov 26, 2024
541054e
Merge branch 'general_coordinates' into decoding_mask
Nov 26, 2024
169ea48
Sort nodes by id before pyg conversion
Nov 27, 2024
a6b6137
Introduce option to return graph components directly, used through kw…
Nov 27, 2024
7576b43
Merge branch 'main' into decoding_mask
Nov 29, 2024
ae90eb8
Add explanation of **kwargs to archetype docstrings
Nov 29, 2024
ee8601e
add test ensuring unchanged grid-indecies /w decode mask
leifdenby Dec 6, 2024
e178a0b
add test and example notebook
leifdenby Dec 9, 2024
f50a9cd
Add Iterable[bool] type hint for mask
Sep 24, 2025
a48fd9d
Merge branch 'main' into decoding_mask
Sep 24, 2025
68cf065
Complete decoder mask test
Sep 24, 2025
fe81b5d
Add comment explaining the return_components option
joeloskarsson Sep 24, 2025
92b9d41
Rename node id -> label
joeloskarsson Sep 24, 2025
127d54d
Linting
Sep 24, 2025
718b769
work on notebook
leifdenby Oct 6, 2025
958273c
Fix typo
joeloskarsson Oct 7, 2025
3a4185e
Fix g2m/m2g typo in test
Oct 10, 2025
801a6b1
Add clarifying comment
Oct 10, 2025
321eb98
Clarify comment about node ordering
Oct 10, 2025
99eb7da
Get rid of **kwarg arguments in archetypes
Oct 10, 2025
94413c7
remove saving/loading masked-decoding graphs example
leifdenby Oct 28, 2025
1ac0a94
clear all notebook outputs
leifdenby Oct 28, 2025
71ce0a4
Merge branch 'decoding_mask' of https://github.com/joeloskarsson/weat…
leifdenby Oct 28, 2025
69f7f84
Make some small tweaks to documentation notebook
joeloskarsson Oct 29, 2025
26177ed
Linting notebook
joeloskarsson Oct 29, 2025
8806ca1
Add changelog entry
joeloskarsson Oct 29, 2025
c377812
Merge branch 'main' into decoding_mask
joeloskarsson Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 473 additions & 0 deletions docs/decoding_mask.ipynb

Large diffs are not rendered by default.

45 changes: 14 additions & 31 deletions src/weather_model_graphs/create/archetype.py
Comment thread
leifdenby marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from .base import create_all_graph_components


def create_keisler_graph(
coords,
mesh_node_distance=3,
coords_crs=None,
graph_crs=None,
):
def create_keisler_graph(coords, mesh_node_distance=3, **kwargs):
"""
Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global graph used by Keisler (2022, https://arxiv.org/abs/2202.07575).
Expand All @@ -28,11 +23,8 @@ def create_keisler_graph(
mesh_node_distance: float
Distance (in x- and y-direction) between created mesh nodes,
in coordinate system of coords
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -51,8 +43,7 @@ def create_keisler_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)


Expand All @@ -61,8 +52,7 @@ def create_graphcast_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
coords_crs=None,
graph_crs=None,
**kwargs,
):
"""
Create a multiscale LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand All @@ -89,11 +79,8 @@ def create_graphcast_graph(
NOTE: Must be an odd integer >1 to create proper multiscale graph
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -116,8 +103,7 @@ def create_graphcast_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)


Expand All @@ -126,8 +112,7 @@ def create_oskarsson_hierarchical_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
coords_crs=None,
graph_crs=None,
**kwargs,
):
"""
Create a LAM graph following Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand Down Expand Up @@ -157,11 +142,10 @@ def create_oskarsson_hierarchical_graph(
in coordinate system of coords
level_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -184,6 +168,5 @@ def create_oskarsson_hierarchical_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)
30 changes: 29 additions & 1 deletion src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"""


from typing import Iterable

import networkx
import networkx as nx
import numpy as np
Expand Down Expand Up @@ -39,6 +41,8 @@ def create_all_graph_components(
g2m_connectivity_kwargs={},
coords_crs: pyproj.crs.CRS | None = None,
graph_crs: pyproj.crs.CRS | None = None,
decode_mask: Iterable[bool] | None = None,
return_components: bool = False,
):
"""
Create all graph components used in creating the message-passing graph,
Expand Down Expand Up @@ -82,6 +86,11 @@ def create_all_graph_components(
will be transformed from their original Coordinate Reference System (`coords_crs`)
to the CRS where the graph creation should take place (`graph_crs`).
If any one of them is None the graph creation is carried out using the original coords.

`decode_mask` should be an Iterable of booleans, masking which grid positions should be
decoded to (included in the m2g subgraph), i.e. which positions should be output. It should have the same length as the number of
grid position coordinates given in `coords`. The mask being set to True means that corresponding
grid nodes should be included in g2m. If `decode_mask=None` (default), all grid nodes are included.
"""
graph_components: dict[networkx.DiGraph] = {}

Expand Down Expand Up @@ -149,9 +158,19 @@ def create_all_graph_components(
)
graph_components["g2m"] = G_g2m

if decode_mask is None:
# decode to all grid nodes
decode_grid = G_grid
else:
# Select subset of grid nodes to decode to, where m2g should connect
filter_nodes = [
n for n, include in zip(G_grid.nodes, decode_mask, strict=True) if include
]
decode_grid = G_grid.subgraph(filter_nodes)

G_m2g = connect_nodes_across_graphs(
G_source=grid_connect_graph,
G_target=G_grid,
G_target=decode_grid,
method=m2g_connectivity,
**m2g_connectivity_kwargs,
)
Expand All @@ -162,6 +181,15 @@ def create_all_graph_components(
for edge in graph.edges:
graph.edges[edge]["component"] = name

if return_components:
Comment thread
joeloskarsson marked this conversation as resolved.
# Because merging to a single graph and then splitting again leads to changes in node indexing when converting to `pyg.Data` objects (this in part is due to the to `m2g` and `g2m` having a different set of grid nodes) the ability to return the graph components (`g2m`, `m2m` and `m2g`) has been added here. See https://github.com/mllam/weather-model-graphs/pull/34#issuecomment-2507980752 for details
# Give each component unique ids
graph_components = {
comp_name: replace_node_labels_with_unique_ids(subgraph)
for comp_name, subgraph in graph_components.items()
}
return graph_components

# merge to single graph
G_tot = networkx.compose_all(graph_components.values())
# only keep graph attributes that are the same for all components
Expand Down
49 changes: 19 additions & 30 deletions src/weather_model_graphs/networkx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,40 +98,29 @@ def split_graph_by_edge_attribute(graph, attr):
f"No subgraphs were created. Check the edge attribute '{attr}'."
)

# copy node attributes
Comment thread
joeloskarsson marked this conversation as resolved.
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
return subgraphs

# copy node attributes
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
def sort_nodes_in_graph(graph):
"""
Creates a new networkx.DiGraph that is a copy of input, but with nodes
sorted according to their label value

# copy node attributes
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])
Parameters
----------
graph : networkx.DiGraph
Graph to sort nodes from

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
Returns
-------
networkx.DiGraph
Graph with sorted nodes
"""
sorted_graph = networkx.DiGraph()
sorted_graph.add_nodes_from(sorted(graph.nodes(data=True)))
sorted_graph.add_edges_from(graph.edges(data=True))

return subgraphs
return sorted_graph


def replace_node_labels_with_unique_ids(graph):
Comment thread
joeloskarsson marked this conversation as resolved.
Expand All @@ -149,7 +138,7 @@ def replace_node_labels_with_unique_ids(graph):
Graph with node labels renamed
"""
return networkx.relabel_nodes(
graph, {node: i for i, node in enumerate(graph.nodes)}, copy=True
graph, {node: i for i, node in enumerate(sorted(graph.nodes))}, copy=True
)


Expand Down
27 changes: 19 additions & 8 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import networkx
from loguru import logger

from .networkx_utils import MissingEdgeAttributeError, split_graph_by_edge_attribute
from .networkx_utils import (
MissingEdgeAttributeError,
sort_nodes_in_graph,
split_graph_by_edge_attribute,
)

try:
import torch
Expand Down Expand Up @@ -106,18 +110,25 @@ def _concat_pyg_features(
if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
try:
sub_graphs = list(
split_graph_by_edge_attribute(
graph=graph, attr=list_from_attribute
).values()
)
sub_graphs = [
value
for key, value in sorted(
split_graph_by_edge_attribute(
graph=graph, attr=list_from_attribute
).items()
)
]
Comment thread
joeloskarsson marked this conversation as resolved.
except MissingEdgeAttributeError:
# neural-lam still expects a list of graphs, so if the attribute is missing
# we just return the original graph as a list
sub_graphs = [graph]
pyg_graphs = [pyg_convert.from_networkx(g) for g in sub_graphs]
# Nodes must be sorted if we want to preserve any ordering
# when converted to pyg
Comment thread
leifdenby marked this conversation as resolved.
Outdated
pyg_graphs = [
pyg_convert.from_networkx(sort_nodes_in_graph(g)) for g in sub_graphs
]
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]
pyg_graphs = [pyg_convert.from_networkx(sort_nodes_in_graph(graph))]

edge_features_values = [
_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs
Expand Down
3 changes: 2 additions & 1 deletion src/weather_model_graphs/visualise/plot_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def nx_draw_with_pos_and_attr(
node_zorder_attr=None,
node_size=100,
connectionstyle="arc3, rad=0.1",
with_labels=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -171,7 +172,7 @@ def nx_draw_with_pos_and_attr(
graph,
ax=ax,
arrows=True,
with_labels=False,
with_labels=with_labels,
node_size=node_size,
connectionstyle=connectionstyle,
**kwargs,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def test_create_lat_lon(kind):
)


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
def test_create_decode_mask(kind):
"""
Tests that the decode mask for m2g works, resulting in less edges than
no filtering.
"""
xy = test_utils.create_fake_irregular_coords(100)
fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)
# ~= 20 mesh nodes in bottom layer in each direction
mesh_node_distance = 0.05

unfiltered_graph = fn(coords=xy, mesh_node_distance=mesh_node_distance)

# Filter to only 20 / 100 grid nodes
decode_mask = np.concatenate((np.ones(20), np.zeros(80))).astype(bool)
filtered_graph = fn(
coords=xy, mesh_node_distance=mesh_node_distance, decode_mask=decode_mask
)

# Check that some filtering has been performed
assert len(filtered_graph.edges) < len(unfiltered_graph.edges)


@pytest.mark.parametrize("kind", ["graphcast", "oskarsson_hierarchical"])
def test_create_many_levels(kind):
"""Test that mesh graph creation methods that work with many levels
Expand Down
Loading
Loading