Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]:
}

# NOTE: THE SDFG IS ASSUMED TO BE FROZEN (not change) FROM THIS POINT ONWARDS
sdfg.sort_sdfg_alphabetically()

# Generate frame code (and the rest of the code)
(global_code, frame_code, used_targets, used_environments) = frame.generate_code(sdfg, None)
Expand Down
39 changes: 39 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3006,3 +3006,42 @@ def recheck_using_explicit_control_flow(self) -> bool:
break
self.root_sdfg.using_explicit_control_flow = found_explicit_cf_block
return found_explicit_cf_block

def sort_sdfg_alphabetically(self, visited: Optional[Set[int]] = None) -> None:
"""
Forces all internal dictionaries and graph structures into a deterministic,
lexicographical order to guarantee stable code generation.

This method operates in-place and recursively processes all internal
dataflow states and nested SDFGs.


:param visited: A set of memory addresses (IDs) of already processed SDFGs.
Used internally to prevent infinite recursion in the event
of cyclic nested SDFG references.
"""
if visited is None:
visited = set()

if id(self) in visited:
return
visited.add(id(self))

# Avoid import loops
from dace.sdfg.utils import sort_graph_dicts_alphabetically

for attr in ['_arrays', 'symbols', 'constants_prop']:
if hasattr(self, attr):
val = getattr(self, attr)
if val and hasattr(val, 'keys') and hasattr(val, 'pop'):
for k in sorted(list(val.keys())):
val[k] = val.pop(k)

sort_graph_dicts_alphabetically(self)

for state in self.nodes():
sort_graph_dicts_alphabetically(state)

for node in state.nodes():
if hasattr(node, 'sdfg') and node.sdfg is not None:
node.sdfg.sort_sdfg_alphabetically(visited)
90 changes: 90 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import networkx as nx
import time
import re

import dace.sdfg.nodes
from dace.codegen import compiled_sdfg as csdfg
Expand Down Expand Up @@ -2754,3 +2755,92 @@ def expand_nodes(sdfg: SDFG, predicate: Callable[[nd.Node], bool]):

if expanded_something:
states.append(state)


def get_deterministic_node_key(node: Any) -> str:
"""
Generates a stable string key for Graph nodes to ensure deterministic sorting.
Strips memory addresses, sequential IDs, partial hashes, and UUIDs.


:param node: The DaCe graph node object to be evaluated.
:return: A stable string representation of the node.
"""
node_type = type(node).__name__
raw_identifier = getattr(node, 'label', getattr(node, 'name', getattr(node, 'data', str(node))))

# 1. Strip memory addresses (e.g., <... object at 0x...>)
identifier = re.sub(r' at 0x[0-9a-fA-F]+', '', raw_identifier)

# 2. Strip full UUIDs (supports hyphens, underscores, or flat 32-char hex)
# Catches: 550e8400-e29b-41d4-a716-446655440000, 550e8400_e29b... or 550e8400e29b...
identifier = re.sub(
r'_?[0-9a-fA-F]{8}[-_]?[0-9a-fA-F]{4}[-_]?[0-9a-fA-F]{4}[-_]?[0-9a-fA-F]{4}[-_]?[0-9a-fA-F]{12}', '',
identifier)

return f"{node_type}_{identifier}"


def get_deterministic_edge_key(edge: Any) -> str:
"""
Generates a stable string key for Graph edges to ensure deterministic sorting.


:param edge: The DaCe graph edge object (or InterstateEdge) to be evaluated.
:return: A stable string representation of the edge.
"""
src_conn = getattr(edge, 'src_conn', '')
dst_conn = getattr(edge, 'dst_conn', '')
data_str = str(getattr(edge, 'data', ''))

return f"{get_deterministic_node_key(edge.src)}:{src_conn}->{get_deterministic_node_key(edge.dst)}:{dst_conn}_{data_str}"


def sort_graph_dicts_alphabetically(graph: Any) -> None:
"""
Sorts internal graph nodes, edge dictionaries, and NetworkX backends in-place.

This function performs three critical phases:
1. Alphabetizes the master `_nodes` dictionary and its nested adjacency lists.
2. Alphabetizes the master `_edges` dictionary.
3. Tears down and sequentially rebuilds the underlying NetworkX graph (`_nx`)
using the newly sorted nodes and edges.


:param graph: The DaCe graph structure (SDFG, SDFGState, or generic Graph)
whose internal structures need to be stabilized.
"""

if hasattr(graph, '_nodes'):
for k in sorted(list(graph._nodes.keys()), key=get_deterministic_node_key):
graph._nodes[k] = graph._nodes.pop(k)

# Sort the nested edge dictionaries inside _nodes in-place
for node, (in_edges, out_edges) in graph._nodes.items():
for e_key in sorted(list(in_edges.keys()), key=lambda k: get_deterministic_edge_key(in_edges[k])):
in_edges[e_key] = in_edges.pop(e_key)
for e_key in sorted(list(out_edges.keys()), key=lambda k: get_deterministic_edge_key(out_edges[k])):
out_edges[e_key] = out_edges.pop(e_key)

if hasattr(graph, '_edges'):
for e_key in sorted(list(graph._edges.keys()), key=lambda k: get_deterministic_edge_key(graph._edges[k])):
graph._edges[e_key] = graph._edges.pop(e_key)

if hasattr(graph, '_nx'):
old_nx = graph._nx
graph._nx = type(old_nx)()

for n in graph._nodes.keys():
graph._nx.add_node(n, **old_nx.nodes.get(n, {}))

for e_obj in graph._edges.values():
edge_attrs = {'data': e_obj.data}

if hasattr(e_obj, 'src_conn'):
edge_attrs['src_conn'] = e_obj.src_conn
edge_attrs['dst_conn'] = e_obj.dst_conn

if hasattr(e_obj, 'key'):
graph._nx.add_edge(e_obj.src, e_obj.dst, key=e_obj.key, **edge_attrs)
else:
graph._nx.add_edge(e_obj.src, e_obj.dst, **edge_attrs)
53 changes: 53 additions & 0 deletions tests/sdfg/deterministic_sort_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import dace
import random
from dace.sdfg.utils import get_deterministic_node_key, get_deterministic_edge_key


def test_sdfg_alphabetical_sorting():
"""
Tests that the SDFG and its internal states can be forced into a strictly
deterministic topological order, regardless of dictionary insertion history.
"""
# 1. Create a simple SDFG
sdfg = dace.SDFG('deterministic_test')
sdfg.add_array('A', [10], dace.float64)
sdfg.add_array('B', [10], dace.float64)

state = sdfg.add_state('state0')
a = state.add_read('A')
b = state.add_write('B')
tasklet = state.add_tasklet('compute', {'a'}, {'b'}, 'b = a + 1')

state.add_edge(a, None, tasklet, 'a', dace.Memlet.from_array('A', sdfg.arrays['A']))
state.add_edge(tasklet, 'b', b, None, dace.Memlet.from_array('B', sdfg.arrays['B']))

# 2. Intentionally scramble the internal dictionaries to simulate non-determinism
# Scramble top-level arrays
array_items = list(sdfg._arrays.items())
random.shuffle(array_items)

# Scramble state nodes
node_items = list(state._nodes.items())
random.shuffle(node_items)
state._nodes = dict(node_items)

# 3. Apply the canonicalizer
sdfg.sort_sdfg_alphabetically()

# 4. Assert that the underlying dictionaries are now strictly ordered
node_keys = list(state._nodes.keys())
expected_node_keys = sorted(node_keys, key=get_deterministic_node_key)

edge_keys = list(state._edges.keys())
expected_edge_keys = sorted(edge_keys, key=lambda k: get_deterministic_edge_key(state._edges[k]))

assert node_keys == expected_node_keys, "Graph nodes were not deterministically sorted!"
assert edge_keys == expected_edge_keys, "Graph edges were not deterministically sorted!"

# Ensure networkx backend was also rebuilt deterministically
nx_nodes = list(state._nx.nodes())
assert nx_nodes == expected_node_keys, "NetworkX nodes do not match DaCe dict order!"


if __name__ == "__main__":
test_sdfg_alphabetical_sorting()
Loading