Skip to content

Commit 8ac954f

Browse files
committed
Refactor: update_graph_outputs in a helper (#62)
Signed-off-by: Johansmm <[email protected]>
1 parent d19f998 commit 8ac954f

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
336336
return values
337337

338338

339+
def _update_graph_or_function_outputs(
340+
graph_or_function: _core.Graph | _core.Function,
341+
old_values: Sequence[_core.Value],
342+
new_values: Sequence[_core.Value],
343+
):
344+
"""Update graph/function outputs."""
345+
replacement_mapping = dict(zip(old_values, new_values))
346+
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
347+
if graph_or_function_output in replacement_mapping:
348+
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
349+
350+
339351
def replace_nodes_and_values(
340352
graph_or_function: _core.Graph | _core.Function,
341353
/,
@@ -367,10 +379,7 @@ def replace_nodes_and_values(
367379
# Reconnect the users of the deleted values to use the new values
368380
replace_all_uses_with(old_values, new_values)
369381
# Update graph/function outputs if the node generates output
370-
replacement_mapping = dict(zip(old_values, new_values))
371-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
372-
if graph_or_function_output in replacement_mapping:
373-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
382+
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)
374383

375384
# insert new nodes after the index node
376385
graph_or_function.insert_after(insertion_point, new_nodes)

0 commit comments

Comments
 (0)