Skip to content

Commit 95228e0

Browse files
committed
Init Submit
1 parent 6c6aa87 commit 95228e0

7 files changed

+1238
-0
lines changed

datasets.zip

11.5 MB
Binary file not shown.

layers.py

+344
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch_geometric.nn import MessagePassing
4+
import torch.nn.functional as F
5+
from torch.nn.parameter import Parameter
6+
from torch_geometric.nn import global_add_pool, global_mean_pool, HypergraphConv
7+
from torch_geometric.nn.pool.topk_pool import topk
8+
from torch_geometric.utils import dense_to_sparse
9+
from torch_scatter import scatter_add
10+
from torch_scatter import scatter
11+
from torch_geometric.utils import degree
12+
from utils import zeros, glorot, hyperedge_representation
13+
14+
15+
class HypergraphConvolution(MessagePassing):
16+
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
17+
super(HypergraphConvolution, self).__init__(aggr='add', node_dim=0, **kwargs)
18+
19+
self.in_channels = in_channels
20+
self.out_channels = out_channels
21+
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
22+
23+
if bias:
24+
self.bias = Parameter(torch.Tensor(out_channels))
25+
else:
26+
self.register_parameter('bias', None)
27+
28+
self.reset_parameters()
29+
30+
def reset_parameters(self):
31+
glorot(self.weight)
32+
zeros(self.bias)
33+
34+
def message(self, x_j, edge_index_i, norm):
35+
out = norm[edge_index_i].view(-1, 1) * x_j.view(-1, self.out_channels)
36+
37+
return out
38+
39+
def forward(self, x, hyperedge_index, hyperedge_weight=None):
40+
x = torch.matmul(x, self.weight)
41+
42+
if hyperedge_weight is None:
43+
D = degree(hyperedge_index[0], x.size(0), x.dtype)
44+
else:
45+
D = scatter_add(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0], dim=0, dim_size=x.size(0))
46+
D = 1.0 / D
47+
D[D == float("inf")] = 0
48+
49+
if hyperedge_index.numel() == 0:
50+
num_edges = 0
51+
else:
52+
num_edges = hyperedge_index[1].max().item() + 1
53+
B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype)
54+
B[B == float("inf")] = 0
55+
if hyperedge_weight is not None:
56+
B = B * hyperedge_weight
57+
58+
self.flow = 'source_to_target'
59+
out = self.propagate(hyperedge_index, x=x, norm=B)
60+
self.flow = 'target_to_source'
61+
out = self.propagate(hyperedge_index, x=out, norm=D)
62+
63+
if self.bias is not None:
64+
out = out + self.bias
65+
66+
return out
67+
68+
def __repr__(self):
69+
return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)
70+
71+
72+
class HyperedgeConv(MessagePassing):
73+
def __init__(self, in_channels, out_channels, bias=True, **kwargs):
74+
super(HyperedgeConv, self).__init__(aggr='add', node_dim=0, **kwargs)
75+
76+
self.in_channels = in_channels
77+
self.out_channels = out_channels
78+
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
79+
80+
if bias:
81+
self.bias = Parameter(torch.Tensor(out_channels))
82+
else:
83+
self.register_parameter('bias', None)
84+
85+
self.reset_parameters()
86+
87+
def reset_parameters(self):
88+
glorot(self.weight)
89+
zeros(self.bias)
90+
91+
def message(self, x_j, edge_index_i, norm):
92+
out = norm[edge_index_i].view(-1, 1) * x_j.view(-1, self.out_channels)
93+
94+
return out
95+
96+
def forward(self, x, hyperedge_index, hyperedge_weight=None):
97+
x = torch.matmul(x, self.weight)
98+
99+
num_nodes = hyperedge_index[0].max().item() + 1
100+
if hyperedge_weight is None:
101+
D = degree(hyperedge_index[0], num_nodes, x.dtype)
102+
else:
103+
D = scatter_add(hyperedge_weight[hyperedge_index[1]],
104+
hyperedge_index[0], dim=0, dim_size=num_nodes)
105+
D = 1.0 / D
106+
D[D == float("inf")] = 0
107+
108+
if hyperedge_index.numel() == 0:
109+
num_edges = 0
110+
else:
111+
num_edges = hyperedge_index[1].max().item() + 1
112+
B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype)
113+
B[B == float("inf")] = 0
114+
if hyperedge_weight is not None:
115+
B = B * hyperedge_weight
116+
117+
out = B.view(-1, 1) * x
118+
self.flow = 'target_to_source'
119+
out = self.propagate(hyperedge_index, x=out, norm=D, size=(num_edges, num_nodes))
120+
121+
if self.bias is not None:
122+
out = out + self.bias
123+
124+
return out
125+
126+
def __repr__(self):
127+
return "{}({}, {})".format(self.__class__.__name__, self.in_channels, self.out_channels)
128+
129+
130+
class HyperedgePool(MessagePassing):
131+
def __init__(self, nhid, ratio):
132+
super(HyperedgePool, self).__init__()
133+
self.ratio = ratio
134+
self.nhid = nhid
135+
self.alpha = 0.1
136+
self.K = 10
137+
self.hypergnn = HypergraphConv(self.nhid, 1)
138+
139+
def message(self, x_j, edge_index_i, norm):
140+
out = norm[edge_index_i].view(-1, 1) * x_j.view(-1, 1)
141+
142+
return out
143+
144+
def forward(self, x, batch, edge_index, edge_weight):
145+
# Init pagerank values
146+
pr = torch.sigmoid(self.hypergnn(x, edge_index, edge_weight))
147+
148+
if edge_weight is None:
149+
D = degree(edge_index[0], x.size(0), x.dtype)
150+
else:
151+
D = scatter_add(edge_weight[edge_index[1]], edge_index[0], dim=0, dim_size=x.size(0))
152+
D = 1.0 / D
153+
D[D == float("inf")] = 0
154+
155+
if edge_index.numel() == 0:
156+
num_edges = 0
157+
else:
158+
num_edges = edge_index[1].max().item() + 1
159+
B = 1.0 / degree(edge_index[1], num_edges, x.dtype)
160+
B[B == float("inf")] = 0
161+
if edge_weight is not None:
162+
B = B * edge_weight
163+
164+
hidden = pr
165+
for k in range(self.K):
166+
self.flow = 'source_to_target'
167+
out = self.propagate(edge_index, x=pr, norm=B)
168+
self.flow = 'target_to_source'
169+
pr = self.propagate(edge_index, x=out, norm=D)
170+
pr = pr * (1 - self.alpha)
171+
pr += self.alpha * hidden
172+
173+
score = self.calc_hyperedge_score(pr, edge_index)
174+
score = score.view(-1)
175+
perm = topk(score, self.ratio, batch)
176+
177+
x_hyperedge = hyperedge_representation(x, edge_index)
178+
x_hyperedge = x_hyperedge[perm] * score[perm].view(-1, 1)
179+
batch = batch[perm]
180+
edge_index, edge_attr = self.filter_hyperedge(edge_index, edge_weight, perm, num_nodes=score.size(0))
181+
182+
return x_hyperedge, edge_index, edge_attr, batch
183+
184+
def calc_hyperedge_score(self, x, edge_index):
185+
x = x[edge_index[0]]
186+
score = scatter(x, edge_index[1], dim=0, reduce='mean')
187+
188+
return score
189+
190+
def filter_hyperedge(self, edge_index, edge_attr, perm, num_nodes):
191+
mask = perm.new_full((num_nodes, ), -1)
192+
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
193+
mask[perm] = i
194+
195+
row, col = edge_index
196+
mask = (mask[col] >= 0)
197+
row, col = row[mask], col[mask]
198+
199+
# ID re-mapping operation, which makes the ids become continuous
200+
unique_row = torch.unique(row)
201+
unique_col = torch.unique(col)
202+
combined = torch.cat((unique_row, unique_col))
203+
uniques, counts = combined.unique(return_counts=True)
204+
difference = uniques[counts == 1]
205+
206+
new_perm = torch.cat((unique_col, difference))
207+
max_id = new_perm.max().item() + 1
208+
new_mask = new_perm.new_full((max_id,), -1)
209+
j = torch.arange(new_perm.size(0), dtype=torch.long, device=new_perm.device)
210+
new_mask[new_perm] = j
211+
212+
row, col = new_mask[row], new_mask[col]
213+
214+
if edge_attr is not None:
215+
edge_attr = edge_attr[mask]
216+
217+
return torch.stack([row, col], dim=0), edge_attr
218+
219+
220+
class CrossGraphConvolutionOperator(MessagePassing):
221+
def __init__(self, out_nhid, in_nhid):
222+
super(CrossGraphConvolutionOperator, self).__init__('add')
223+
self.out_nhid = out_nhid
224+
self.in_nhid = in_nhid
225+
self.weight = torch.nn.Parameter(torch.Tensor(self.out_nhid, self.in_nhid))
226+
nn.init.xavier_uniform_(self.weight.data)
227+
228+
def forward(self, x, assign_index, N, M):
229+
global_x = self.propagate(assign_index, size=(N, M), x=x)
230+
target_x = x[1]
231+
target_x = torch.unsqueeze(target_x, dim=1)
232+
global_x = torch.unsqueeze(global_x, dim=1)
233+
weight = torch.unsqueeze(self.weight, dim=0)
234+
target_x = target_x * weight
235+
global_x = global_x * weight
236+
numerator = torch.sum(target_x * global_x, dim=-1)
237+
target_x_denominator = torch.sqrt(torch.sum(torch.square(target_x), dim=-1) + 1e-6)
238+
global_x_denominator = torch.sqrt(torch.sum(torch.square(global_x), dim=-1) + 1e-6)
239+
denominator = torch.clamp(target_x_denominator * global_x_denominator, min=1e-6)
240+
241+
return numerator / denominator
242+
243+
def message(self, x_i, x_j, edge_index):
244+
x_i_norm = torch.norm(x_i, dim=-1, keepdim=True)
245+
x_j_norm = torch.norm(x_j, dim=-1, keepdim=True)
246+
x_norm = torch.clamp(x_i_norm * x_j_norm, min=1e-6)
247+
x_product = torch.sum(x_i * x_j, dim=1, keepdim=True)
248+
coef = F.relu(x_product / x_norm)
249+
coef_sum = scatter(coef + 1e-6, edge_index[1], dim=0, reduce='sum')
250+
normalized_coef = coef / coef_sum[edge_index[1]]
251+
252+
return normalized_coef * x_j
253+
254+
255+
class CrossGraphConvolution(torch.nn.Module):
256+
def __init__(self, out_nhid, in_nhid):
257+
super(CrossGraphConvolution, self).__init__()
258+
self.out_nhid = out_nhid
259+
self.in_nhid = in_nhid
260+
self.cross_conv = CrossGraphConvolutionOperator(self.out_nhid, self.in_nhid)
261+
262+
def forward(self, x_left, batch_left, x_right, batch_right):
263+
num_nodes_x_left = scatter_add(batch_left.new_ones(x_left.size(0)), batch_left, dim=0)
264+
shift_cum_num_nodes_x_left = torch.cat([num_nodes_x_left.new_zeros(1), num_nodes_x_left.cumsum(dim=0)[:-1]], dim=0)
265+
cum_num_nodes_x_left = num_nodes_x_left.cumsum(dim=0)
266+
267+
num_nodes_x_right = scatter_add(batch_right.new_ones(x_right.size(0)), batch_right, dim=0)
268+
shift_cum_num_nodes_x_right = torch.cat([num_nodes_x_right.new_zeros(1), num_nodes_x_right.cumsum(dim=0)[:-1]], dim=0)
269+
cum_num_nodes_x_right = num_nodes_x_right.cumsum(dim=0)
270+
271+
adj = torch.zeros((x_left.size(0), x_right.size(0)), dtype=torch.float, device=x_left.device)
272+
# Construct batch fully connected graph in block diagonal matirx format
273+
for idx_i, idx_j, idx_x, idx_y in zip(shift_cum_num_nodes_x_left, cum_num_nodes_x_left, shift_cum_num_nodes_x_right, cum_num_nodes_x_right):
274+
adj[idx_i:idx_j, idx_x:idx_y] = 1.0
275+
new_edge_index, _ = dense_to_sparse(adj)
276+
row, col = new_edge_index
277+
278+
assign_index1 = torch.stack([col, row], dim=0)
279+
out1 = self.cross_conv((x_right, x_left), assign_index1, N=x_right.size(0), M=x_left.size(0))
280+
281+
assign_index2 = torch.stack([row, col], dim=0)
282+
out2 = self.cross_conv((x_left, x_right), assign_index2, N=x_left.size(0), M=x_right.size(0))
283+
284+
return out1, out2
285+
286+
287+
class ReadoutModule(torch.nn.Module):
288+
def __init__(self, args):
289+
"""
290+
:param args: Arguments object.
291+
"""
292+
super(ReadoutModule, self).__init__()
293+
self.args = args
294+
295+
self.weight = torch.nn.Parameter(torch.Tensor(self.args.nhid, self.args.nhid))
296+
nn.init.xavier_uniform_(self.weight.data)
297+
298+
def forward(self, x, batch):
299+
"""
300+
Making a forward propagation pass to create a graph level representation.
301+
:param x: Result of the GNN.
302+
:param batch: Batch vector, which assigns each node to a specific example
303+
:param size: Size
304+
:return representation: A graph level representation matrix.
305+
"""
306+
mean_pool = global_mean_pool(x, batch)
307+
transformed_global = torch.tanh(torch.mm(mean_pool, self.weight))
308+
coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
309+
weighted = coefs.unsqueeze(-1) * x
310+
311+
return global_add_pool(weighted, batch)
312+
313+
314+
class MLPModule(torch.nn.Module):
315+
def __init__(self, args):
316+
super(MLPModule, self).__init__()
317+
self.args = args
318+
319+
self.lin0 = torch.nn.Linear(self.args.nhid * 2 * 4, self.args.nhid * 2)
320+
nn.init.xavier_uniform_(self.lin0.weight.data)
321+
nn.init.zeros_(self.lin0.bias.data)
322+
323+
self.lin1 = torch.nn.Linear(self.args.nhid * 2, self.args.nhid)
324+
nn.init.xavier_uniform_(self.lin1.weight.data)
325+
nn.init.zeros_(self.lin1.bias.data)
326+
327+
self.lin2 = torch.nn.Linear(self.args.nhid, self.args.nhid // 2)
328+
nn.init.xavier_uniform_(self.lin2.weight.data)
329+
nn.init.zeros_(self.lin2.bias.data)
330+
331+
self.lin3 = torch.nn.Linear(self.args.nhid // 2, 1)
332+
nn.init.xavier_uniform_(self.lin3.weight.data)
333+
nn.init.zeros_(self.lin3.bias.data)
334+
335+
def forward(self, scores):
336+
scores = F.relu(self.lin0(scores))
337+
scores = F.dropout(scores, p=self.args.dropout, training=self.training)
338+
scores = F.relu(self.lin1(scores))
339+
scores = F.dropout(scores, p=self.args.dropout, training=self.training)
340+
scores = F.relu(self.lin2(scores))
341+
scores = F.dropout(scores, p=self.args.dropout, training=self.training)
342+
scores = torch.sigmoid(self.lin3(scores)).view(-1)
343+
344+
return scores

0 commit comments

Comments
 (0)