@@ -38,7 +38,7 @@ def test_graph_conv(self):
3838 adjacency /= adjacency .sum (dim = 0 , keepdim = True ).sqrt () * adjacency .sum (dim = 1 , keepdim = True ).sqrt ()
3939 x = adjacency .t () @ self .input
4040 truth = conv .activation (conv .linear (x ))
41- self .assertTrue (torch .allclose (result , truth , rtol = 1e-4 , atol = 1e-7 ), "Incorrect graph convolution" )
41+ self .assertTrue (torch .allclose (result , truth , rtol = 1e-2 , atol = 1e-3 ), "Incorrect graph convolution" )
4242
4343 num_head = 2
4444 conv = layers .GraphAttentionConv (self .input_dim , self .output_dim , num_head = num_head ).cuda ()
@@ -55,15 +55,15 @@ def test_graph_conv(self):
5555 outputs .append (output )
5656 truth = torch .cat (outputs , dim = - 1 )
5757 truth = conv .activation (truth )
58- self .assertTrue (torch .allclose (result , truth ), "Incorrect graph attention convolution" )
58+ self .assertTrue (torch .allclose (result , truth , rtol = 1e-2 , atol = 1e-3 ), "Incorrect graph attention convolution" )
5959
6060 eps = 1
6161 conv = layers .GraphIsomorphismConv (self .input_dim , self .output_dim , eps = eps ).cuda ()
6262 result = conv (self .graph , self .input )
6363 adjacency = self .graph .adjacency .to_dense ().sum (dim = - 1 )
6464 x = (1 + eps ) * self .input + adjacency .t () @ self .input
6565 truth = conv .activation (conv .mlp (x ))
66- self .assertTrue (torch .allclose (result , truth , atol = 1e-4 , rtol = 1e-7 ), "Incorrect graph isomorphism convolution" )
66+ self .assertTrue (torch .allclose (result , truth , rtol = 1e-2 , atol = 1e-2 ), "Incorrect graph isomorphism convolution" )
6767
6868 conv = layers .RelationalGraphConv (self .input_dim , self .output_dim , self .num_relation ).cuda ()
6969 result = conv (self .graph , self .input )
@@ -72,7 +72,7 @@ def test_graph_conv(self):
7272 x = torch .einsum ("htr, hd -> trd" , adjacency , self .input )
7373 x = conv .linear (x .flatten (1 )) + conv .self_loop (self .input )
7474 truth = conv .activation (x )
75- self .assertTrue (torch .allclose (result , truth , atol = 1e-4 , rtol = 1e-7 ), "Incorrect relational graph convolution" )
75+ self .assertTrue (torch .allclose (result , truth , rtol = 1e-2 , atol = 1e-3 ), "Incorrect relational graph convolution" )
7676
7777 conv = layers .ChebyshevConv (self .input_dim , self .output_dim , k = 2 ).cuda ()
7878 result = conv (self .graph , self .input )
@@ -83,7 +83,7 @@ def test_graph_conv(self):
8383 bases = [self .input , laplacian .t () @ self .input , (2 * laplacian .t () @ laplacian .t () - identity ) @ self .input ]
8484 x = conv .linear (torch .cat (bases , dim = - 1 ))
8585 truth = conv .activation (x )
86- self .assertTrue (torch .allclose (result , truth , atol = 1e-4 , rtol = 1e-7 ), "Incorrect chebyshev graph convolution" )
86+ self .assertTrue (torch .allclose (result , truth , rtol = 1e-2 , atol = 1e-3 ), "Incorrect chebyshev graph convolution" )
8787
8888
8989if __name__ == "__main__" :
0 commit comments