Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/weather_model_graphs/networkx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def prepend_node_index(graph, new_index):


def sort_nodes_internally(nx_graph, node_attr=None, edge_attr=None):
# For some reason the networkx .nodes() return list can not be sorted,
# but this is the ordering used by pyg when converting.
# This function fixes this.
H = networkx.DiGraph()
# Dynamically use the same class (Graph, DiGraph, etc.) instead of hardcoding DiGraph
H = type(nx_graph)()

# Preserve graph-level metadata (like CRS)
H.graph.update(nx_graph.graph)

if node_attr is not None:
H.add_nodes_from(
sorted(nx_graph.nodes(data=True), key=lambda x: x[1][node_attr])
Expand Down Expand Up @@ -103,20 +105,15 @@ def split_graph_by_edge_attribute(graph, attr):

def sort_nodes_in_graph(graph):
"""
Creates a new networkx.DiGraph that is a copy of input, but with nodes
sorted according to their label value
Creates a new networkx graph that is a copy of input, but with nodes
sorted according to their label value. Preserves graph type and metadata.
"""
# Dynamically use the same class (Graph, DiGraph, etc.)
sorted_graph = type(graph)()

Parameters
----------
graph : networkx.DiGraph
Graph to sort nodes from
# Preserve graph-level metadata (like CRS)
sorted_graph.graph.update(graph.graph)

Returns
-------
networkx.DiGraph
Graph with sorted nodes
"""
sorted_graph = networkx.DiGraph()
sorted_graph.add_nodes_from(sorted(graph.nodes(data=True)))
sorted_graph.add_edges_from(graph.edges(data=True))

Expand Down
Loading