@@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
336
336
return values
337
337
338
338
339
+ def _update_graph_or_function_outputs (
340
+ graph_or_function : _core .Graph | _core .Function ,
341
+ old_values : Sequence [_core .Value ],
342
+ new_values : Sequence [_core .Value ],
343
+ ):
344
+ """Update graph/function outputs."""
345
+ replacement_mapping = dict (zip (old_values , new_values ))
346
+ for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
347
+ if graph_or_function_output in replacement_mapping :
348
+ graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
349
+
350
+
339
351
def replace_nodes_and_values (
340
352
graph_or_function : _core .Graph | _core .Function ,
341
353
/ ,
@@ -367,10 +379,7 @@ def replace_nodes_and_values(
367
379
# Reconnect the users of the deleted values to use the new values
368
380
replace_all_uses_with (old_values , new_values )
369
381
# Update graph/function outputs if the node generates output
370
- replacement_mapping = dict (zip (old_values , new_values ))
371
- for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
372
- if graph_or_function_output in replacement_mapping :
373
- graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
382
+ _update_graph_or_function_outputs (graph_or_function , old_values , new_values )
374
383
375
384
# insert new nodes after the index node
376
385
graph_or_function .insert_after (insertion_point , new_nodes )
0 commit comments