5
5
from torch .nn .parameter import Parameter
6
6
from torch_geometric .nn import global_add_pool , global_mean_pool , HypergraphConv
7
7
from torch_geometric .nn .pool .topk_pool import topk
8
- from torch_geometric .utils import dense_to_sparse
9
8
from torch_scatter import scatter_add
10
9
from torch_scatter import scatter
11
10
from torch_geometric .utils import degree
@@ -272,7 +271,7 @@ def forward(self, x_left, batch_left, x_right, batch_right):
272
271
# Construct batch fully connected graph in block diagonal matirx format
273
272
for idx_i , idx_j , idx_x , idx_y in zip (shift_cum_num_nodes_x_left , cum_num_nodes_x_left , shift_cum_num_nodes_x_right , cum_num_nodes_x_right ):
274
273
adj [idx_i :idx_j , idx_x :idx_y ] = 1.0
275
- new_edge_index , _ = dense_to_sparse (adj )
274
+ new_edge_index , _ = self . dense_to_sparse (adj )
276
275
row , col = new_edge_index
277
276
278
277
assign_index1 = torch .stack ([col , row ], dim = 0 )
@@ -283,6 +282,12 @@ def forward(self, x_left, batch_left, x_right, batch_right):
283
282
284
283
return out1 , out2
285
284
285
+ def dense_to_sparse (self , adj ):
286
+ assert adj .dim () == 2
287
+ index = adj .nonzero (as_tuple = False ).t ().contiguous ()
288
+ value = adj [index [0 ], index [1 ]]
289
+ return index , value
290
+
286
291
287
292
class ReadoutModule (torch .nn .Module ):
288
293
def __init__ (self , args ):
0 commit comments