Summary
When to_pyg() exports a NetworkX DiGraph to PyTorch tensors, all original node labels are discarded. This proposal adds a companion mapping file alongside the .pt outputs so that tensor indices can be traced back to the original wmg node identities — enabling round-trip validation, latent-space visualization, and debugging.
Problem
PyG's from_networkx() re-indexes all nodes to sequential integers [0, N), discarding the original node labels that weather-model-graphs uses as globally unique identifiers during graph construction.
This causes three concrete problems:
1. Round-trip is lossy
@leifdenby noted in neural-lam#339:
"I think the neural-lam disk format is actually in a way lossy (information that is in the networkx.DiGraph object is lost), as in the node-label is not saved I believe"
@joeloskarsson confirmed:
"the conversion between networkx and pyg throws away pretty much any structure put in place by networkx, which makes all the bookkeeping of indices in wmg not very useful when we then use the graphs in pyg and neural-lam"
This means you cannot reconstruct the original graph from the saved .pt files.
2. Latent-space visualization is blocked
@leifdenby stated in #20:
"I do see the merit in being able to take for example latent representations and visualise them on the graph in a downstream application in neural-lam. To facilitate that I would suggest that we need the functionality to take information derived in neural-lam and inject it back into the graph by setting new attributes on an existing graph from wmg."
@Joltsy10 elaborated in #20:
"By the time a latent representation exists inside neural-lam it has been through at least two permutations with no record of either."
Without knowing which tensor index corresponds to which wmg node, you cannot map neural-lam's latent tensors back onto the graph for visualization.
3. Debugging is opaque
When a tensor index in neural-lam produces an anomalous value, there is currently no way to determine which physical grid/mesh node it corresponds to in the wmg construction graph. This makes diagnosing issues like #82 (wrong KDTree mapping) much harder.
Proposal
A. Save a mapping file during export
When to_pyg() saves .pt files, also save a companion mapping:
# In save.py, after from_networkx() is called:
sorted_nodes = sorted(graph.nodes()) # matches from_networkx() ordering
node_id_map = list(sorted_nodes) # index i -> original wmg node label
torch.save(
node_id_map,
Path(output_directory) / f"{name}_node_id_map.pt"
)
This produces files like g2m_node_id_map.pt, m2m_node_id_map.pt, m2g_node_id_map.pt alongside the existing outputs.
B. Provide a utility to inject tensor data back into the graph
def map_tensor_to_graph(graph, tensor, node_id_map, attr_name):
"""
Set tensor values as node attributes on the original wmg NetworkX graph.
Parameters
----------
graph : networkx.DiGraph
The original wmg graph (with original node labels).
tensor : torch.Tensor
Tensor of shape [N, ...] where N matches len(node_id_map).
node_id_map : list
Ordered list of wmg node labels, as saved by to_pyg().
attr_name : str
Name of the attribute to set on each node.
"""
for idx, node_id in enumerate(node_id_map):
graph.nodes[node_id][attr_name] = tensor[idx].detach().cpu().numpy()
C. Provide the reverse lookup
def map_graph_attr_to_tensor(graph, node_id_map, attr_name):
"""
Extract a node attribute from the wmg graph in the tensor-index order.
Returns a numpy array of shape [N, ...] aligned with neural-lam tensors.
"""
import numpy as np
return np.stack([graph.nodes[nid][attr_name] for nid in node_id_map])
Use cases enabled
Latent-space visualization (#20)
latent = model.get_mesh_latent(batch) # shape [N_mesh, D]
node_id_map = torch.load("graph/m2m_node_id_map.pt")
map_tensor_to_graph(wmg_graph, latent, node_id_map, "latent")
wmg.visualise.plot_graph(wmg_graph, edge_color_attr="latent")
Round-trip validation (#88, neural-lam#339)
node_id_map = torch.load("graph/g2m_node_id_map.pt")
edge_index = torch.load("graph/g2m_edge_index.pt")
for i in range(edge_index.shape[1]):
src_wmg = node_id_map[edge_index[0, i].item()]
dst_wmg = node_id_map[edge_index[1, i].item()]
assert wmg_graph.has_edge(src_wmg, dst_wmg)
Debugging neural-lam anomalies
node_id_map = torch.load("graph/m2m_node_id_map.pt")
print(f"Tensor index 42 = wmg node {node_id_map[42]}")
# → e.g., "Tensor index 42 = wmg node (3, 7)"
Scope
- ~20 lines added to save.py to save the mapping
- ~30 lines for the two utility functions
- No breaking changes: mapping files are additional outputs; existing code works unchanged
- No new dependencies: uses only torch.save() and standard Python
Relationship to existing work
| Issue/PR |
Relationship |
| neural-lam#339 |
@leifdenby's RFC identifies the wmg→neural-lam format gap. This mapping makes the bridge non-lossy. |
| #20 (3D plots) |
@leifdenby's requirement to inject latent representations back into the graph depends on this mapping. |
| #88 (strengthen tests) |
Round-trip validation tests become possible once tensor indices can be traced back to node IDs. |
| PR #47 (DataTree format) |
If the project moves to xr.DataTree, node labels would naturally be preserved as coordinates. This mapping serves the same purpose for the current .pt format. |
Summary
When
to_pyg()exports a NetworkX DiGraph to PyTorch tensors, all original node labels are discarded. This proposal adds a companion mapping file alongside the.ptoutputs so that tensor indices can be traced back to the originalwmgnode identities — enabling round-trip validation, latent-space visualization, and debugging.Problem
PyG's
from_networkx()re-indexes all nodes to sequential integers[0, N), discarding the original node labels thatweather-model-graphsuses as globally unique identifiers during graph construction.This causes three concrete problems:
1. Round-trip is lossy
@leifdenby noted in neural-lam#339:
@joeloskarsson confirmed:
This means you cannot reconstruct the original graph from the saved
.ptfiles.2. Latent-space visualization is blocked
@leifdenby stated in #20:
@Joltsy10 elaborated in #20:
Without knowing which tensor index corresponds to which
wmgnode, you cannot map neural-lam's latent tensors back onto the graph for visualization.3. Debugging is opaque
When a tensor index in neural-lam produces an anomalous value, there is currently no way to determine which physical grid/mesh node it corresponds to in the
wmgconstruction graph. This makes diagnosing issues like #82 (wrong KDTree mapping) much harder.Proposal
A. Save a mapping file during export
When
to_pyg()saves.ptfiles, also save a companion mapping:This produces files like g2m_node_id_map.pt, m2m_node_id_map.pt, m2g_node_id_map.pt alongside the existing outputs.
B. Provide a utility to inject tensor data back into the graph
C. Provide the reverse lookup
Use cases enabled
Latent-space visualization (#20)
Round-trip validation (#88, neural-lam#339)
Debugging neural-lam anomalies
Scope
Relationship to existing work
xr.DataTree, node labels would naturally be preserved as coordinates. This mapping serves the same purpose for the current.ptformat.