Skip to content

Commit ca4d788

Browse files
authored
Bugfix for add_alias v2 (#143)
There could be cases where some of the users of a node were the output node, which wouldn't be taken into account. This fixes it. This was caught by autoparallel/graph_utils.py when trying to set the v2 flag on by default
1 parent 273f54c commit ca4d788

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

autoparallel/graph_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _add_alias(gm, version="v1"):
8080
"""
8181
graph = gm.graph
8282

83-
nodes = [n for n in graph.nodes if n.op == "call_function"]
83+
nodes = list(graph.nodes)
8484
node_map = {node: idx for idx, node in enumerate(nodes)}
8585

8686
def _insert_alias(node):
@@ -94,10 +94,9 @@ def delete_user_cb(n):
9494

9595
node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb)
9696

97-
inputs = graph.find_nodes(op="placeholder")
9897
if version == "v1":
9998
# only on inputs
100-
for node in inputs:
99+
for node in graph.find_nodes(op="placeholder"):
101100
if len(node.users) == 0:
102101
# node is not used, don't add alias for it
103102
continue
@@ -110,7 +109,7 @@ def delete_user_cb(n):
110109
_insert_alias(node)
111110
elif version == "v2":
112111
# for every node that has more than one user
113-
for node in inputs + nodes:
112+
for node in nodes:
114113
if len(node.users) < 2:
115114
continue
116115
# don't add alias for ops which return tuple for now
@@ -121,6 +120,7 @@ def delete_user_cb(n):
121120
raise ValueError(f"Unknown version {version}")
122121

123122
"""
123+
nodes = [n for n in graph.nodes if n.op == "call_function"]
124124
for node in nodes:
125125
# skip ops which return tuple
126126
if not isinstance(node.meta["val"], torch.Tensor):

0 commit comments

Comments
 (0)