diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index f922b2d..f2adfe4 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -260,7 +260,12 @@ def connect_nodes_across_graphs( target_nodes_list = list(G_target.nodes) # build kd tree for source nodes (e.g. the mesh nodes when constructing m2g) - xy_source = np.array([G_source.nodes[node]["pos"] for node in G_source.nodes]) + source_nodes_list = list(G_source.nodes) + + xy_source = np.array( + [G_source.nodes[node]["pos"] for node in source_nodes_list] + ) + kdt_s = scipy.spatial.KDTree(xy_source) # Determine method and perform checks once @@ -402,7 +407,7 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target): G_connect.add_nodes_from(sorted(G_target.nodes(data=True))) # sort nodes by index - source_nodes_list = sorted(G_source.nodes) + # add edges for target_node in target_nodes_list: diff --git a/tests/test_connect_nodes_across_graphs.py b/tests/test_connect_nodes_across_graphs.py new file mode 100644 index 0000000..7dad6e5 --- /dev/null +++ b/tests/test_connect_nodes_across_graphs.py @@ -0,0 +1,35 @@ +import numpy as np +import networkx as nx + +from weather_model_graphs.create.base import connect_nodes_across_graphs + + +def test_kdtree_node_mapping_order(): + G_source = nx.Graph() + + # intentionally unsorted insertion order + G_source.add_node("z", pos=np.array([0.0, 0.0])) + G_source.add_node("a", pos=np.array([10.0, 0.0])) + G_source.add_node("m", pos=np.array([20.0, 0.0])) + + G_target = nx.Graph() + + G_target.add_node("t0", pos=np.array([0.2, 0.0])) + G_target.add_node("t1", pos=np.array([9.9, 0.0])) + G_target.add_node("t2", pos=np.array([19.7, 0.0])) + + G = connect_nodes_across_graphs( + G_source, + G_target, + method="nearest_neighbour" + ) + + edges = sorted(G.edges()) + + expected = [ + ("z", "t0"), + ("a", "t1"), + ("m", "t2") + ] + + assert edges == expected \ No newline at end of file