Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ def connect_nodes_across_graphs(
target_nodes_list = list(G_target.nodes)

# build kd tree for source nodes (e.g. the mesh nodes when constructing m2g)
xy_source = np.array([G_source.nodes[node]["pos"] for node in G_source.nodes])
source_nodes_list = list(G_source.nodes)

xy_source = np.array(
[G_source.nodes[node]["pos"] for node in source_nodes_list]
)

kdt_s = scipy.spatial.KDTree(xy_source)

# Determine method and perform checks once
Expand Down Expand Up @@ -402,7 +407,7 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target):
G_connect.add_nodes_from(sorted(G_target.nodes(data=True)))

# sort nodes by index
source_nodes_list = sorted(G_source.nodes)


# add edges
for target_node in target_nodes_list:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_connect_nodes_across_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import networkx as nx

from weather_model_graphs.create.base import connect_nodes_across_graphs


def test_kdtree_node_mapping_order():
G_source = nx.Graph()

# intentionally unsorted insertion order
G_source.add_node("z", pos=np.array([0.0, 0.0]))
G_source.add_node("a", pos=np.array([10.0, 0.0]))
G_source.add_node("m", pos=np.array([20.0, 0.0]))

G_target = nx.Graph()

G_target.add_node("t0", pos=np.array([0.2, 0.0]))
G_target.add_node("t1", pos=np.array([9.9, 0.0]))
G_target.add_node("t2", pos=np.array([19.7, 0.0]))

G = connect_nodes_across_graphs(
G_source,
G_target,
method="nearest_neighbour"
)

edges = sorted(G.edges())

expected = [
("z", "t0"),
("a", "t1"),
("m", "t2")
]

assert edges == expected