forked from wengong-jin/icml18-jtnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnnutils.py
35 lines (28 loc) · 982 Bytes
/
nnutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
return Variable(tensor).cuda()
else:
return Variable(tensor, requires_grad=requires_grad).cuda()
def index_select_ND(source, dim, index):
index_size = index.size()
suffix_dim = source.size()[1:]
final_size = index_size + suffix_dim
target = source.index_select(dim, index.view(-1))
return target.view(final_size)
def GRU(x, h_nei, W_z, W_r, U_r, W_h):
hidden_size = x.size()[-1]
sum_h = h_nei.sum(dim=1)
z_input = torch.cat([x,sum_h], dim=1)
z = nn.Sigmoid()(W_z(z_input))
r_1 = W_r(x).view(-1,1,hidden_size)
r_2 = U_r(h_nei)
r = nn.Sigmoid()(r_1 + r_2)
gated_h = r * h_nei
sum_gated_h = gated_h.sum(dim=1)
h_input = torch.cat([x,sum_gated_h], dim=1)
pre_h = nn.Tanh()(W_h(h_input))
new_h = (1.0 - z) * sum_h + z * pre_h
return new_h