Skip to content

Commit c139242

Browse files
authored
Create add methods on initializers and attributes (#33)
1. It's now possible to do `node.attributes.add(ir.Attr(...))` to avoid having to get the name as a redundant step. 2. Also allow node attributes to be initialized with a dictionary - This way users do not have to use ```py ir.Node(attributes=node.attributes.values()) ``` when copying the node. Instead, it is possible to ```py ir.Node(attributes=node.attributes) ``` 3. Implement get_* methods on attributes to make getting default attributes easier. Before ```py attr = node.attributes.get("attr") if attr is not None: value = attr.as_int() else: value = 42 ``` now ```py value = node.attributes.get_int("attr", 42) ``` --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 5506fe6 commit c139242

File tree

4 files changed

+299
-53
lines changed

4 files changed

+299
-53
lines changed

src/onnx_ir/_core.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222
import sys
2323
import textwrap
2424
import typing
25-
from collections import OrderedDict
2625
from collections.abc import (
2726
Collection,
2827
Hashable,
2928
Iterable,
3029
Iterator,
31-
MutableMapping,
30+
Mapping,
3231
MutableSequence,
3332
Sequence,
3433
)
@@ -1325,7 +1324,7 @@ def __init__(
13251324
domain: str,
13261325
op_type: str,
13271326
inputs: Iterable[Value | None],
1328-
attributes: Iterable[Attr] = (),
1327+
attributes: Iterable[Attr] | Mapping[str, Attr] = (),
13291328
*,
13301329
overload: str = "",
13311330
num_outputs: int | None = None,
@@ -1371,15 +1370,10 @@ def __init__(
13711370
self._inputs: tuple[Value | None, ...] = tuple(inputs)
13721371
# Values belong to their defining nodes. The values list is immutable
13731372
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1374-
attributes = tuple(attributes)
1375-
if attributes and not isinstance(attributes[0], Attr):
1376-
raise TypeError(
1377-
f"Expected the attributes to be Attr, got {type(attributes[0])}. "
1378-
"If you are copying the attributes from another node, make sure you call "
1379-
"node.attributes.values() because it is a dictionary."
1380-
)
1381-
self._attributes: OrderedDict[str, Attr] = OrderedDict(
1382-
(attr.name, attr) for attr in attributes
1373+
if isinstance(attributes, Mapping):
1374+
attributes = tuple(attributes.values())
1375+
self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
1376+
attributes
13831377
)
13841378
self._overload: str = overload
13851379
# TODO(justinchuby): Potentially support a version range
@@ -1637,7 +1631,7 @@ def outputs(self, _: Sequence[Value]) -> None:
16371631
raise AttributeError("outputs is immutable. Please create a new node instead.")
16381632

16391633
@property
1640-
def attributes(self) -> OrderedDict[str, Attr]:
1634+
def attributes(self) -> _graph_containers.Attributes:
16411635
"""The attributes of the node."""
16421636
return self._attributes
16431637

@@ -2201,17 +2195,9 @@ def __init__(
22012195
# Private fields that are not to be accessed by any other classes
22022196
self._inputs = _graph_containers.GraphInputs(self, inputs)
22032197
self._outputs = _graph_containers.GraphOutputs(self, outputs)
2204-
self._initializers = _graph_containers.GraphInitializers(self)
2205-
for initializer in initializers:
2206-
if isinstance(initializer, str):
2207-
raise TypeError(
2208-
"Initializer must be a Value, not a string. "
2209-
"If you are copying the initializers from another graph, "
2210-
"make sure you call graph.initializers.values() because it is a dictionary."
2211-
)
2212-
if initializer.name is None:
2213-
raise ValueError(f"Initializer must have a name: {initializer}")
2214-
self._initializers[initializer.name] = initializer
2198+
self._initializers = _graph_containers.GraphInitializers(
2199+
self, {initializer.name: initializer for initializer in initializers}
2200+
)
22152201
self._doc_string = doc_string
22162202
self._opset_imports = opset_imports or {}
22172203
self._metadata: _metadata.MetadataStore | None = None
@@ -2234,7 +2220,19 @@ def outputs(self) -> MutableSequence[Value]:
22342220
return self._outputs
22352221

22362222
@property
2237-
def initializers(self) -> MutableMapping[str, Value]:
2223+
def initializers(self) -> _graph_containers.GraphInitializers:
2224+
"""The initializers of the graph as a ``MutableMapping[str, Value]``.
2225+
2226+
The keys are the names of the initializers. The values are the :class:`Value` objects.
2227+
2228+
This property additionally supports the ``add`` method, which takes a :class:`Value`
2229+
and adds it to the initializers if it is not already present.
2230+
2231+
.. note::
2232+
When setting an initializer with ``graph.initializers[key] = value``,
2233+
if the value does not have a name, it will be assigned ``key`` as its name.
2234+
2235+
"""
22382236
return self._initializers
22392237

22402238
def register_initializer(self, value: Value) -> None:
@@ -2263,15 +2261,11 @@ def register_initializer(self, value: Value) -> None:
22632261
" it is not the same object: existing={self._initializers[value.name]!r},"
22642262
f" new={value!r}"
22652263
)
2266-
if value.producer() is not None:
2267-
raise ValueError(
2268-
f"Value '{value!r}' is produced by a node and cannot be an initializer."
2269-
)
22702264
if value.const_value is None:
22712265
raise ValueError(
22722266
f"Value '{value!r}' must have its const_value set to be an initializer."
22732267
)
2274-
self._initializers[value.name] = value
2268+
self._initializers.add(value)
22752269

22762270
@property
22772271
def doc_string(self) -> str | None:
@@ -2701,7 +2695,7 @@ def __init__(
27012695
outputs: Sequence[Value],
27022696
*,
27032697
nodes: Iterable[Node],
2704-
initializers: Sequence[_protocols.ValueProtocol] = (),
2698+
initializers: Sequence[Value] = (),
27052699
doc_string: str | None = None,
27062700
opset_imports: dict[str, int] | None = None,
27072701
name: str | None = None,
@@ -2710,10 +2704,7 @@ def __init__(
27102704
self.name = name
27112705
self.inputs = tuple(inputs)
27122706
self.outputs = tuple(outputs)
2713-
for initializer in initializers:
2714-
if initializer.name is None:
2715-
raise ValueError(f"Initializer must have a name: {initializer}")
2716-
self.initializers = {tensor.name: tensor for tensor in initializers}
2707+
self.initializers = {initializer.name: initializer for initializer in initializers}
27172708
self.doc_string = doc_string
27182709
self.opset_imports = opset_imports or {}
27192710
self._metadata: _metadata.MetadataStore | None = None
@@ -2927,13 +2918,15 @@ def __init__(
29272918
# Ensure the inputs and outputs of the function belong to a graph
29282919
# and not from an outer scope
29292920
graph: Graph,
2930-
attributes: Sequence[Attr],
2921+
attributes: Iterable[Attr] | Mapping[str, Attr],
29312922
) -> None:
29322923
self._domain = domain
29332924
self._name = name
29342925
self._overload = overload
29352926
self._graph = graph
2936-
self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2927+
if isinstance(attributes, Mapping):
2928+
attributes = tuple(attributes.values())
2929+
self._attributes = _graph_containers.Attributes(attributes)
29372930

29382931
def identifier(self) -> _protocols.OperatorIdentifier:
29392932
return self.domain, self.name, self.overload
@@ -2971,7 +2964,7 @@ def outputs(self) -> MutableSequence[Value]:
29712964
return self._graph.outputs
29722965

29732966
@property
2974-
def attributes(self) -> OrderedDict[str, Attr]:
2967+
def attributes(self) -> _graph_containers.Attributes:
29752968
return self._attributes
29762969

29772970
@typing.overload

src/onnx_ir/_core_test.py

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,117 @@ def test_domain_normalizes_ai_onnx(self):
861861
node.domain = "ai.onnx"
862862
self.assertEqual(node.domain, "")
863863

864+
def test_attributes_add(self):
865+
node = _core.Node("ai.onnx", "TestOp", inputs=())
866+
node.attributes.add(_core.AttrInt64("test_attr", 1))
867+
self.assertIn("test_attr", node.attributes)
868+
self.assertEqual(node.attributes["test_attr"].value, 1)
869+
870+
def test_attributes_set_raise_with_type_error(self):
871+
node = _core.Node("ai.onnx", "TestOp", inputs=())
872+
with self.assertRaises(TypeError):
873+
node.attributes["test_attr"] = 1
874+
with self.assertRaises(TypeError):
875+
node.attributes[1] = _core.AttrInt64("test_attr", 1)
876+
877+
def test_init_accepts_attribute_mapping(self):
878+
node = _core.Node(
879+
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)]
880+
)
881+
new_node = _core.Node("", "OtherOp", inputs=(), attributes=node.attributes)
882+
self.assertEqual(new_node.attributes, node.attributes)
883+
884+
def test_attributes_get_int(self):
885+
node = _core.Node(
886+
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrInt64("test_attr", 1)]
887+
)
888+
self.assertEqual(node.attributes.get_int("test_attr"), 1)
889+
self.assertIsNone(node.attributes.get_int("non_existent_attr"))
890+
self.assertEqual(node.attributes.get_int("non_existent_attr", 42), 42)
891+
892+
def test_attributes_get_float(self):
893+
node = _core.Node(
894+
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrFloat32("test_attr", 1.0)]
895+
)
896+
self.assertEqual(node.attributes.get_float("test_attr"), 1.0)
897+
self.assertIsNone(node.attributes.get_float("non_existent_attr"))
898+
self.assertEqual(node.attributes.get_float("non_existent_attr", 42.0), 42.0)
899+
900+
def test_attributes_get_string(self):
901+
node = _core.Node(
902+
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrString("test_attr", "value")]
903+
)
904+
self.assertEqual(node.attributes.get_string("test_attr"), "value")
905+
self.assertIsNone(node.attributes.get_string("non_existent_attr"))
906+
self.assertEqual(node.attributes.get_string("non_existent_attr", "default"), "default")
907+
908+
def test_attributes_get_tensor(self):
909+
tensor = ir.Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32))
910+
node = _core.Node(
911+
"ai.onnx", "TestOp", inputs=(), attributes=[_core.AttrTensor("test_attr", tensor)]
912+
)
913+
np.testing.assert_equal(
914+
node.attributes.get_tensor("test_attr").numpy(), tensor.numpy()
915+
)
916+
self.assertIsNone(node.attributes.get_tensor("non_existent_attr"))
917+
np.testing.assert_equal(
918+
node.attributes.get_tensor("non_existent_attr", tensor).numpy(), tensor.numpy()
919+
)
920+
921+
def test_attributes_get_ints(self):
922+
node = _core.Node(
923+
"ai.onnx",
924+
"TestOp",
925+
inputs=(),
926+
attributes=[_core.AttrInt64s("test_attr", [1, 2, 3])],
927+
)
928+
self.assertEqual(node.attributes.get_ints("test_attr"), [1, 2, 3])
929+
self.assertIsNone(node.attributes.get_ints("non_existent_attr"))
930+
self.assertEqual(node.attributes.get_ints("non_existent_attr", [42]), [42])
931+
932+
def test_attributes_get_floats(self):
933+
node = _core.Node(
934+
"ai.onnx",
935+
"TestOp",
936+
inputs=(),
937+
attributes=[_core.AttrFloat32s("test_attr", [1.0, 2.0, 3.0])],
938+
)
939+
self.assertEqual(node.attributes.get_floats("test_attr"), [1.0, 2.0, 3.0])
940+
self.assertIsNone(node.attributes.get_floats("non_existent_attr"))
941+
self.assertEqual(node.attributes.get_floats("non_existent_attr", [42.0]), [42.0])
942+
943+
def test_attributes_get_strings(self):
944+
node = _core.Node(
945+
"ai.onnx",
946+
"TestOp",
947+
inputs=(),
948+
attributes=[_core.AttrStrings("test_attr", ["a", "b", "c"])],
949+
)
950+
self.assertEqual(node.attributes.get_strings("test_attr"), ["a", "b", "c"])
951+
self.assertIsNone(node.attributes.get_strings("non_existent_attr"))
952+
self.assertEqual(
953+
node.attributes.get_strings("non_existent_attr", ["default"]), ["default"]
954+
)
955+
956+
def test_attributes_get_tensors(self):
957+
tensor1 = ir.Tensor(np.array([1.0, 2.0], dtype=np.float32))
958+
tensor2 = ir.Tensor(np.array([3.0, 4.0], dtype=np.float32))
959+
node = _core.Node(
960+
"ai.onnx",
961+
"TestOp",
962+
inputs=(),
963+
attributes=[_core.AttrTensors("test_attr", [tensor1, tensor2])],
964+
)
965+
tensors = node.attributes.get_tensors("test_attr")
966+
self.assertIsNotNone(tensors)
967+
self.assertEqual(len(tensors), 2)
968+
np.testing.assert_equal(tensors[0].numpy(), tensor1.numpy())
969+
np.testing.assert_equal(tensors[1].numpy(), tensor2.numpy())
970+
self.assertIsNone(node.attributes.get_tensors("non_existent_attr"))
971+
np.testing.assert_equal(
972+
node.attributes.get_tensors("non_existent_attr", [tensor1]), [tensor1]
973+
)
974+
864975
# TODO(justinchuby): Test all methods
865976

866977

@@ -1453,7 +1564,7 @@ def test_outputs_copy(self):
14531564
self.assertNotIn(self.value3, self.graph.outputs)
14541565
self.assertIn(self.value3, outputs_copy)
14551566

1456-
def test_set_initializers(self):
1567+
def test_initializers_setitem(self):
14571568
self.graph.initializers["initializer1"] = self.value3
14581569
self.assertIn("initializer1", self.graph.initializers)
14591570
self.assertTrue(self.value3.is_initializer())
@@ -1467,11 +1578,11 @@ def test_set_initializers(self):
14671578
self.assertFalse(self.value3.is_initializer())
14681579
self.assertIsNone(self.value3.graph)
14691580

1470-
def test_set_initializers_raises_when_key_does_not_match(self):
1581+
def test_initializers_setitem_raises_when_key_does_not_match(self):
14711582
with self.assertRaisesRegex(ValueError, "does not match the name of the value"):
14721583
self.graph.initializers["some_key"] = self.value3
14731584

1474-
def test_set_initializers_raises_when_it_belongs_to_another_graph(self):
1585+
def test_initializers_setitem_raises_when_it_belongs_to_another_graph(self):
14751586
other_graph = _core.Graph(inputs=(), outputs=(), nodes=())
14761587
other_graph.initializers["initializer1"] = self.value3
14771588
with self.assertRaisesRegex(
@@ -1485,11 +1596,51 @@ def test_set_initializers_raises_when_it_belongs_to_another_graph(self):
14851596
self.assertTrue(self.value3.is_initializer())
14861597
self.assertIs(self.value3.graph, self.graph)
14871598

1488-
def test_set_initializers_raises_when_value_does_not_have_a_name(self):
1599+
def test_initializers_setitem_raises_when_value_does_not_have_a_name(self):
14891600
self.value3.name = None
14901601
with self.assertRaises(TypeError):
14911602
self.graph.initializers[None] = self.value3
14921603

1604+
with self.assertRaisesRegex(ValueError, "cannot be an empty string"):
1605+
self.graph.initializers[""] = _core.Value(name="")
1606+
1607+
def test_initializers_setitem_checks_value_name_match(self):
1608+
with self.assertRaisesRegex(ValueError, "does not match"):
1609+
self.graph.initializers["some_name"] = _core.Value(name="some_other_name")
1610+
1611+
def test_initializers_setitem_assigns_key_to_value_name_if_not_set(self):
1612+
value = _core.Value(name=None)
1613+
self.graph.initializers["some_name"] = value
1614+
self.assertEqual(value.name, "some_name")
1615+
self.assertIs(value, self.graph.initializers["some_name"])
1616+
1617+
value = _core.Value(name="")
1618+
self.graph.initializers["some_other_name"] = value
1619+
self.assertEqual(value.name, "some_other_name")
1620+
self.assertIs(value, self.graph.initializers["some_other_name"])
1621+
1622+
def test_initializers_setitem_checks_value_type(self):
1623+
with self.assertRaisesRegex(TypeError, "must be a Value object"):
1624+
self.graph.initializers["some_name"] = ir.tensor([1, 2, 3], name="some_tensor")
1625+
1626+
def test_initializers_setitem_raises_when_value_is_node_output(self):
1627+
node = ir.node("SomeOp", inputs=[])
1628+
with self.assertRaisesRegex(ValueError, "produced by a node"):
1629+
self.graph.initializers["some_name"] = node.outputs[0]
1630+
1631+
def test_initializers_add_checks_value_name(self):
1632+
# Initializers should always have a name
1633+
with self.assertRaisesRegex(ValueError, "cannot be an empty string"):
1634+
self.graph.initializers.add(_core.Value(name=""))
1635+
1636+
with self.assertRaisesRegex(TypeError, "must be a string"):
1637+
self.graph.initializers.add(_core.Value(name=None))
1638+
1639+
def test_initializers_add_checks_value_type(self):
1640+
# Initializers should be of type Value
1641+
with self.assertRaisesRegex(TypeError, "must be a Value object"):
1642+
self.graph.initializers.add(ir.tensor([1, 2, 3], name="some_tensor"))
1643+
14931644
def test_delete_initializer(self):
14941645
self.graph.initializers["initializer1"] = self.value3
14951646
del self.graph.initializers["initializer1"]

0 commit comments

Comments
 (0)