Skip to content

Commit ea8a41e

Browse files
committed
update layers.py
1 parent 6a45aa1 commit ea8a41e

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ More detailed information can be found [here](https://github.com/runningoat/hgmn
3636
## Run
3737
Just execuate the following command for graph-graph classification task:
3838
```
39-
python main_classification.py --datasets openssl_min50
39+
python main_classification.py --dataset openssl_min50
4040
```
4141

4242
Similarly, execuate the following command for graph-graph regression task:
4343
```
44-
python main_regression.py --datasets AIDS700nef
44+
python main_regression.py --dataset AIDS700nef
4545
```
4646

4747
## Citing

layers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.nn.parameter import Parameter
66
from torch_geometric.nn import global_add_pool, global_mean_pool, HypergraphConv
77
from torch_geometric.nn.pool.topk_pool import topk
8-
from torch_geometric.utils import dense_to_sparse
98
from torch_scatter import scatter_add
109
from torch_scatter import scatter
1110
from torch_geometric.utils import degree
@@ -272,7 +271,7 @@ def forward(self, x_left, batch_left, x_right, batch_right):
272271
# Construct batch fully connected graph in block diagonal matirx format
273272
for idx_i, idx_j, idx_x, idx_y in zip(shift_cum_num_nodes_x_left, cum_num_nodes_x_left, shift_cum_num_nodes_x_right, cum_num_nodes_x_right):
274273
adj[idx_i:idx_j, idx_x:idx_y] = 1.0
275-
new_edge_index, _ = dense_to_sparse(adj)
274+
new_edge_index, _ = self.dense_to_sparse(adj)
276275
row, col = new_edge_index
277276

278277
assign_index1 = torch.stack([col, row], dim=0)
@@ -283,6 +282,12 @@ def forward(self, x_left, batch_left, x_right, batch_right):
283282

284283
return out1, out2
285284

285+
def dense_to_sparse(self, adj):
286+
assert adj.dim() == 2
287+
index = adj.nonzero(as_tuple=False).t().contiguous()
288+
value = adj[index[0], index[1]]
289+
return index, value
290+
286291

287292
class ReadoutModule(torch.nn.Module):
288293
def __init__(self, args):

0 commit comments

Comments
 (0)