@@ -861,6 +861,117 @@ def test_domain_normalizes_ai_onnx(self):
861861 node .domain = "ai.onnx"
862862 self .assertEqual (node .domain , "" )
863863
864+ def test_attributes_add (self ):
865+ node = _core .Node ("ai.onnx" , "TestOp" , inputs = ())
866+ node .attributes .add (_core .AttrInt64 ("test_attr" , 1 ))
867+ self .assertIn ("test_attr" , node .attributes )
868+ self .assertEqual (node .attributes ["test_attr" ].value , 1 )
869+
870+ def test_attributes_set_raise_with_type_error (self ):
871+ node = _core .Node ("ai.onnx" , "TestOp" , inputs = ())
872+ with self .assertRaises (TypeError ):
873+ node .attributes ["test_attr" ] = 1
874+ with self .assertRaises (TypeError ):
875+ node .attributes [1 ] = _core .AttrInt64 ("test_attr" , 1 )
876+
877+ def test_init_accepts_attribute_mapping (self ):
878+ node = _core .Node (
879+ "ai.onnx" , "TestOp" , inputs = (), attributes = [_core .AttrInt64 ("test_attr" , 1 )]
880+ )
881+ new_node = _core .Node ("" , "OtherOp" , inputs = (), attributes = node .attributes )
882+ self .assertEqual (new_node .attributes , node .attributes )
883+
884+ def test_attributes_get_int (self ):
885+ node = _core .Node (
886+ "ai.onnx" , "TestOp" , inputs = (), attributes = [_core .AttrInt64 ("test_attr" , 1 )]
887+ )
888+ self .assertEqual (node .attributes .get_int ("test_attr" ), 1 )
889+ self .assertIsNone (node .attributes .get_int ("non_existent_attr" ))
890+ self .assertEqual (node .attributes .get_int ("non_existent_attr" , 42 ), 42 )
891+
892+ def test_attributes_get_float (self ):
893+ node = _core .Node (
894+ "ai.onnx" , "TestOp" , inputs = (), attributes = [_core .AttrFloat32 ("test_attr" , 1.0 )]
895+ )
896+ self .assertEqual (node .attributes .get_float ("test_attr" ), 1.0 )
897+ self .assertIsNone (node .attributes .get_float ("non_existent_attr" ))
898+ self .assertEqual (node .attributes .get_float ("non_existent_attr" , 42.0 ), 42.0 )
899+
900+ def test_attributes_get_string (self ):
901+ node = _core .Node (
902+ "ai.onnx" , "TestOp" , inputs = (), attributes = [_core .AttrString ("test_attr" , "value" )]
903+ )
904+ self .assertEqual (node .attributes .get_string ("test_attr" ), "value" )
905+ self .assertIsNone (node .attributes .get_string ("non_existent_attr" ))
906+ self .assertEqual (node .attributes .get_string ("non_existent_attr" , "default" ), "default" )
907+
908+ def test_attributes_get_tensor (self ):
909+ tensor = ir .Tensor (np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 ))
910+ node = _core .Node (
911+ "ai.onnx" , "TestOp" , inputs = (), attributes = [_core .AttrTensor ("test_attr" , tensor )]
912+ )
913+ np .testing .assert_equal (
914+ node .attributes .get_tensor ("test_attr" ).numpy (), tensor .numpy ()
915+ )
916+ self .assertIsNone (node .attributes .get_tensor ("non_existent_attr" ))
917+ np .testing .assert_equal (
918+ node .attributes .get_tensor ("non_existent_attr" , tensor ).numpy (), tensor .numpy ()
919+ )
920+
921+ def test_attributes_get_ints (self ):
922+ node = _core .Node (
923+ "ai.onnx" ,
924+ "TestOp" ,
925+ inputs = (),
926+ attributes = [_core .AttrInt64s ("test_attr" , [1 , 2 , 3 ])],
927+ )
928+ self .assertEqual (node .attributes .get_ints ("test_attr" ), [1 , 2 , 3 ])
929+ self .assertIsNone (node .attributes .get_ints ("non_existent_attr" ))
930+ self .assertEqual (node .attributes .get_ints ("non_existent_attr" , [42 ]), [42 ])
931+
932+ def test_attributes_get_floats (self ):
933+ node = _core .Node (
934+ "ai.onnx" ,
935+ "TestOp" ,
936+ inputs = (),
937+ attributes = [_core .AttrFloat32s ("test_attr" , [1.0 , 2.0 , 3.0 ])],
938+ )
939+ self .assertEqual (node .attributes .get_floats ("test_attr" ), [1.0 , 2.0 , 3.0 ])
940+ self .assertIsNone (node .attributes .get_floats ("non_existent_attr" ))
941+ self .assertEqual (node .attributes .get_floats ("non_existent_attr" , [42.0 ]), [42.0 ])
942+
943+ def test_attributes_get_strings (self ):
944+ node = _core .Node (
945+ "ai.onnx" ,
946+ "TestOp" ,
947+ inputs = (),
948+ attributes = [_core .AttrStrings ("test_attr" , ["a" , "b" , "c" ])],
949+ )
950+ self .assertEqual (node .attributes .get_strings ("test_attr" ), ["a" , "b" , "c" ])
951+ self .assertIsNone (node .attributes .get_strings ("non_existent_attr" ))
952+ self .assertEqual (
953+ node .attributes .get_strings ("non_existent_attr" , ["default" ]), ["default" ]
954+ )
955+
956+ def test_attributes_get_tensors (self ):
957+ tensor1 = ir .Tensor (np .array ([1.0 , 2.0 ], dtype = np .float32 ))
958+ tensor2 = ir .Tensor (np .array ([3.0 , 4.0 ], dtype = np .float32 ))
959+ node = _core .Node (
960+ "ai.onnx" ,
961+ "TestOp" ,
962+ inputs = (),
963+ attributes = [_core .AttrTensors ("test_attr" , [tensor1 , tensor2 ])],
964+ )
965+ tensors = node .attributes .get_tensors ("test_attr" )
966+ self .assertIsNotNone (tensors )
967+ self .assertEqual (len (tensors ), 2 )
968+ np .testing .assert_equal (tensors [0 ].numpy (), tensor1 .numpy ())
969+ np .testing .assert_equal (tensors [1 ].numpy (), tensor2 .numpy ())
970+ self .assertIsNone (node .attributes .get_tensors ("non_existent_attr" ))
971+ np .testing .assert_equal (
972+ node .attributes .get_tensors ("non_existent_attr" , [tensor1 ]), [tensor1 ]
973+ )
974+
864975 # TODO(justinchuby): Test all methods
865976
866977
@@ -1453,7 +1564,7 @@ def test_outputs_copy(self):
14531564 self .assertNotIn (self .value3 , self .graph .outputs )
14541565 self .assertIn (self .value3 , outputs_copy )
14551566
1456- def test_set_initializers (self ):
1567+ def test_initializers_setitem (self ):
14571568 self .graph .initializers ["initializer1" ] = self .value3
14581569 self .assertIn ("initializer1" , self .graph .initializers )
14591570 self .assertTrue (self .value3 .is_initializer ())
@@ -1467,11 +1578,11 @@ def test_set_initializers(self):
14671578 self .assertFalse (self .value3 .is_initializer ())
14681579 self .assertIsNone (self .value3 .graph )
14691580
1470- def test_set_initializers_raises_when_key_does_not_match (self ):
1581+ def test_initializers_setitem_raises_when_key_does_not_match (self ):
14711582 with self .assertRaisesRegex (ValueError , "does not match the name of the value" ):
14721583 self .graph .initializers ["some_key" ] = self .value3
14731584
1474- def test_set_initializers_raises_when_it_belongs_to_another_graph (self ):
1585+ def test_initializers_setitem_raises_when_it_belongs_to_another_graph (self ):
14751586 other_graph = _core .Graph (inputs = (), outputs = (), nodes = ())
14761587 other_graph .initializers ["initializer1" ] = self .value3
14771588 with self .assertRaisesRegex (
@@ -1485,11 +1596,51 @@ def test_set_initializers_raises_when_it_belongs_to_another_graph(self):
14851596 self .assertTrue (self .value3 .is_initializer ())
14861597 self .assertIs (self .value3 .graph , self .graph )
14871598
1488- def test_set_initializers_raises_when_value_does_not_have_a_name (self ):
1599+ def test_initializers_setitem_raises_when_value_does_not_have_a_name (self ):
14891600 self .value3 .name = None
14901601 with self .assertRaises (TypeError ):
14911602 self .graph .initializers [None ] = self .value3
14921603
1604+ with self .assertRaisesRegex (ValueError , "cannot be an empty string" ):
1605+ self .graph .initializers ["" ] = _core .Value (name = "" )
1606+
1607+ def test_initializers_setitem_checks_value_name_match (self ):
1608+ with self .assertRaisesRegex (ValueError , "does not match" ):
1609+ self .graph .initializers ["some_name" ] = _core .Value (name = "some_other_name" )
1610+
1611+ def test_initializers_setitem_assigns_key_to_value_name_if_not_set (self ):
1612+ value = _core .Value (name = None )
1613+ self .graph .initializers ["some_name" ] = value
1614+ self .assertEqual (value .name , "some_name" )
1615+ self .assertIs (value , self .graph .initializers ["some_name" ])
1616+
1617+ value = _core .Value (name = "" )
1618+ self .graph .initializers ["some_other_name" ] = value
1619+ self .assertEqual (value .name , "some_other_name" )
1620+ self .assertIs (value , self .graph .initializers ["some_other_name" ])
1621+
1622+ def test_initializers_setitem_checks_value_type (self ):
1623+ with self .assertRaisesRegex (TypeError , "must be a Value object" ):
1624+ self .graph .initializers ["some_name" ] = ir .tensor ([1 , 2 , 3 ], name = "some_tensor" )
1625+
1626+ def test_initializers_setitem_raises_when_value_is_node_output (self ):
1627+ node = ir .node ("SomeOp" , inputs = [])
1628+ with self .assertRaisesRegex (ValueError , "produced by a node" ):
1629+ self .graph .initializers ["some_name" ] = node .outputs [0 ]
1630+
1631+ def test_initializers_add_checks_value_name (self ):
1632+ # Initializers should always have a name
1633+ with self .assertRaisesRegex (ValueError , "cannot be an empty string" ):
1634+ self .graph .initializers .add (_core .Value (name = "" ))
1635+
1636+ with self .assertRaisesRegex (TypeError , "must be a string" ):
1637+ self .graph .initializers .add (_core .Value (name = None ))
1638+
1639+ def test_initializers_add_checks_value_type (self ):
1640+ # Initializers should be of type Value
1641+ with self .assertRaisesRegex (TypeError , "must be a Value object" ):
1642+ self .graph .initializers .add (ir .tensor ([1 , 2 , 3 ], name = "some_tensor" ))
1643+
14931644 def test_delete_initializer (self ):
14941645 self .graph .initializers ["initializer1" ] = self .value3
14951646 del self .graph .initializers ["initializer1" ]
0 commit comments