Skip to content

Commit 518670e

Browse files
committed
Introduce insert_nodes_before_value (#62)
Convenience function to insert a set of nodes in value(s). Signed-off-by: Johansmm <[email protected]>
1 parent 8ac954f commit 518670e

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"replace_all_uses_with",
1515
"create_value_mapping",
1616
"replace_nodes_and_values",
17+
"insert_nodes_in_value",
1718
]
1819

1920
from collections.abc import Mapping, Sequence
@@ -384,3 +385,108 @@ def replace_nodes_and_values(
384385
# insert new nodes after the index node
385386
graph_or_function.insert_after(insertion_point, new_nodes)
386387
graph_or_function.remove(old_nodes, safe=True)
388+
389+
390+
def _find_inputs_outputs(
391+
nodes: Sequence[_core.Node],
392+
) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]:
393+
"""Find the values that are considered as inputs and outputs in a sequence of nodes."""
394+
# Search the unique inputs/outputs in new_nodes, keeping the order.
395+
all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ()))
396+
all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ()))
397+
# A value is considered as input if it is not any output.
398+
inputs = tuple(val for val in all_inputs if val not in all_outputs)
399+
# A value is considered as output if it is not any input.
400+
outputs = tuple(val for val in all_outputs if val not in all_inputs)
401+
return inputs, outputs
402+
403+
404+
def insert_nodes_in_value(
405+
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node]
406+
) -> None:
407+
"""Inserts a sequence of nodes into the provided value(s).
408+
409+
This allows to insert a list of LINKED nodes (over the same context) at
410+
a specific point in the graph.
411+
412+
For example, suppose we have the following graph::
413+
414+
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output
415+
416+
We want to insert [node_M, node_N] at B value::
417+
418+
>>> import onnx_ir as ir
419+
>>> input = ir.Input("input")
420+
>>> node_A = ir.node("op_A", [input])
421+
>>> B = ir.Value(name="B")
422+
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B])
423+
>>> node_C = ir.node("op_C", node_B.outputs)
424+
>>> # Create a new sequence to insert
425+
>>> input_2 = ir.Input("input_2")
426+
>>> node_M = ir.node("op_M", [input_2])
427+
>>> node_N = ir.node("op_N", node_M.outputs)
428+
>>> # Insert nodes in B
429+
>>> insert_nodes_before_value(node_B.outputs, [node_M, node_N])
430+
>>> len(node_B.outputs)
431+
1
432+
>>> node_B.outputs[0].consumers()[0].op_type
433+
'op_M'
434+
>>> len(node_C.inputs)
435+
1
436+
>>> node_C.inputs[0].producer().op_type
437+
'op_N'
438+
>>> node_C.inputs[0].name
439+
'B'
440+
441+
When values is a sequence, the set of nodes must have the same number
442+
of inputs and outputs, then they are zipped into pairs: first value is
443+
replaced with the first input/output, and so on.
444+
445+
Args:
446+
values: The value(s) where to insert the nodes.
447+
new_nodes: The nodes to insert in the graph.
448+
"""
449+
if not isinstance(values, Sequence):
450+
values = (values,)
451+
452+
# Search the unique inputs/outputs in new_nodes, keeping the order.
453+
inputs, outputs = _find_inputs_outputs(new_nodes)
454+
455+
# Sanity check.
456+
if len(values) != len(inputs):
457+
raise ValueError(
458+
f"The number of values and inputs ({inputs}) in new_nodes must match."
459+
)
460+
if len(values) != len(outputs):
461+
raise ValueError(
462+
f"The number of values and outputs ({outputs}) in new_nodes must match."
463+
)
464+
465+
# Propagate relevant info.
466+
for val, in_val, out_val in zip(values, inputs, outputs):
467+
# Propagate relevant info from value to out_value.
468+
# TODO(Rama): Perhaps this should be a separate utility function.
469+
out_val.type = val.type
470+
out_val.shape = val.shape
471+
out_val.name = val.name
472+
# Propagate relevant info from value to in_value.
473+
# TODO(Rama): Perhaps this should be a separate utility function.
474+
in_val.type = val.type
475+
in_val.shape = val.shape
476+
# Rename each value, following each input.
477+
val.name = in_val.name
478+
479+
# Insert the new nodes in two steps:
480+
# 1. Reconnect the users of values to the outputs
481+
replace_all_uses_with(values, outputs)
482+
# 2. Reconnect the users of inputs to values
483+
replace_all_uses_with(inputs, values)
484+
485+
# Update graph if there is one:
486+
if (graph := values[-1].graph) is not None:
487+
# Update graph/function outputs if the node generates output
488+
_update_graph_or_function_outputs(graph, values, outputs)
489+
490+
# Insert new nodes if there is a graph
491+
graph.extend(new_nodes)
492+
graph.sort()
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Unit tests for the _convenience module."""
4+
5+
import onnx
6+
7+
import unittest
8+
9+
import onnx_ir as ir
10+
from onnx_ir._convenience import insert_nodes_in_value
11+
12+
13+
def _create_model(model_text: str) -> ir.Model:
14+
model = onnx.parser.parse_model(model_text)
15+
return ir.serde.deserialize_model(model)
16+
17+
18+
class ConvenienceTest(unittest.TestCase):
19+
def test_insert_nodes_in_value(self):
20+
# Main graph
21+
input = ir.Input("input")
22+
node_A = ir.node("op_A", [input])
23+
node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")])
24+
node_C = ir.node("op_C", node_B.outputs)
25+
26+
# New sequence to insert
27+
input_2 = ir.Input("input_2")
28+
node_M = ir.node("op_M", [input_2])
29+
node_N = ir.node("op_N", node_M.outputs)
30+
31+
# Insert nodes in B
32+
insert_nodes_in_value(node_B.outputs[0], [node_M, node_N])
33+
self.assertEqual(len(node_B.outputs), 1)
34+
self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M")
35+
self.assertEqual(len(node_C.inputs), 1)
36+
self.assertEqual(node_C.inputs[0].producer().op_type, "op_N")
37+
self.assertEqual(node_C.inputs[0].name, "B")
38+
39+
def test_insert_nodes_in_value_in_graph(self):
40+
ir_model = _create_model(
41+
"""
42+
<ir_version: 10, opset_import: [ "" : 17]>
43+
agraph (float[N] x) => (float[N] z) {
44+
two = Constant<value_float=2.0>()
45+
a, b = SplitNode(x)
46+
z = MergeNode(a, b, two)
47+
}
48+
"""
49+
)
50+
51+
# Sequence to insert.
52+
# Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]].
53+
i1, i2 = ir.Input("i1"), ir.Input("i2")
54+
a = ir.node("op_1", [i1, i2])
55+
b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2)
56+
c = ir.node("op_3", [i2, b.outputs[0]])
57+
58+
# Insert nodes in SplitNode.outputs
59+
target_node = ir_model.graph[1]
60+
insert_nodes_in_value(target_node.outputs, [a, b, c])
61+
62+
# Check target_node outputs have been renamed
63+
new_i1, new_i2 = target_node.outputs
64+
self.assertEqual(new_i1.name, "i1")
65+
self.assertEqual(new_i2.name, "i2")
66+
67+
# Check i1 and i2 have new users
68+
self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2"))
69+
self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3"))
70+
71+
# Check outputs have been correctly renamed as previous values
72+
self.assertEqual(b.outputs[1].name, "a")
73+
self.assertEqual(c.outputs[0].name, "b")
74+
75+
# Check nodes have been inserted in the graph
76+
self.assertEqual(len(ir_model.graph), 6)
77+
78+
def test_insert_nodes_in_input(self):
79+
ir_model = _create_model(
80+
"""
81+
<ir_version: 10, opset_import: [ "" : 17]>
82+
agraph (float[N] x) => (float[N] z) {
83+
two = Constant<value_float=2.0>()
84+
z = Add(x, two)
85+
}
86+
"""
87+
)
88+
89+
# Sequence to insert.
90+
x = ir.Input("new_x")
91+
node = ir.node("Mul", [x, x])
92+
93+
# Insert nodes in graph.inputs
94+
insert_nodes_in_value(ir_model.graph[1].inputs[0], [node])
95+
self.assertEqual(node.outputs[0].name, "x")
96+
97+
# Check input has been renamed
98+
self.assertEqual(ir_model.graph.inputs[0].name, "new_x")
99+
100+
# Finally, check new graph is valid
101+
proto = ir.to_proto(ir_model)
102+
onnx.checker.check_model(proto, full_check=True)
103+
104+
def test_insert_nodes_in_output(self):
105+
ir_model = _create_model(
106+
"""
107+
<ir_version: 10, opset_import: [ "" : 17]>
108+
agraph (float[N] x) => (float[N] z) {
109+
two = Constant<value_float=2.0>()
110+
z = Add(x, two)
111+
}
112+
"""
113+
)
114+
115+
# Sequence to insert.
116+
x = ir.Input("new_z")
117+
node = ir.node("Mul", [x, x])
118+
119+
# Insert nodes in graph.inputs
120+
insert_nodes_in_value(ir_model.graph.outputs[0], [node])
121+
self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z")
122+
123+
# Check output name is preserved
124+
self.assertEqual(ir_model.graph.outputs[0].name, "z")
125+
126+
def test_value_error_for_wrong_number_of_points(self):
127+
ir_model = _create_model(
128+
"""
129+
<ir_version: 10, opset_import: [ "" : 17]>
130+
agraph (float[N] x) => (float[N] z) {
131+
two = Constant<value_float=2.0>()
132+
a, b = SplitNode(x)
133+
z = MergeNode(a, b, two)
134+
}
135+
"""
136+
)
137+
node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")])
138+
with self.assertRaisesRegex(ValueError, "The number of values and inputs"):
139+
insert_nodes_in_value(ir_model.graph[0].outputs, [node])
140+
141+
with self.assertRaisesRegex(ValueError, "The number of values and outputs"):
142+
insert_nodes_in_value(ir_model.graph[1].outputs, [node])
143+
144+
145+
if __name__ == "__main__":
146+
unittest.main()

src/onnx_ir/convenience.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"replace_all_uses_with",
1111
"replace_nodes_and_values",
1212
"create_value_mapping",
13+
"insert_nodes_in_value",
1314
]
1415

1516
from onnx_ir._convenience import (
@@ -18,6 +19,7 @@
1819
create_value_mapping,
1920
replace_all_uses_with,
2021
replace_nodes_and_values,
22+
insert_nodes_in_value,
2123
)
2224

2325
# NOTE: Do not implement any other functions in this module.

0 commit comments

Comments
 (0)