26
26
27
27
# PL callbacks
28
28
from pytorch_lightning .callbacks import ModelCheckpoint
29
+ from torch import Tensor
29
30
30
31
AVAIL_GPUS = min (1 , torch .cuda .device_count ())
31
32
BATCH_SIZE = 256 if AVAIL_GPUS else 64
@@ -183,7 +184,7 @@ def forward(self, node_feats, adj_matrix):
183
184
184
185
# %%
185
186
node_feats = torch .arange (8 , dtype = torch .float32 ).view (1 , 4 , 2 )
186
- adj_matrix = torch . Tensor ([[[1 , 1 , 0 , 0 ], [1 , 1 , 1 , 1 ], [0 , 1 , 1 , 1 ], [0 , 1 , 1 , 1 ]]])
187
+ adj_matrix = Tensor ([[[1 , 1 , 0 , 0 ], [1 , 1 , 1 , 1 ], [0 , 1 , 1 , 1 ], [0 , 1 , 1 , 1 ]]])
187
188
188
189
print ("Node features:\n " , node_feats )
189
190
print ("\n Adjacency matrix:\n " , adj_matrix )
@@ -195,8 +196,8 @@ def forward(self, node_feats, adj_matrix):
195
196
196
197
# %%
197
198
layer = GCNLayer (c_in = 2 , c_out = 2 )
198
- layer .projection .weight .data = torch . Tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]])
199
- layer .projection .bias .data = torch . Tensor ([0.0 , 0.0 ])
199
+ layer .projection .weight .data = Tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]])
200
+ layer .projection .bias .data = Tensor ([0.0 , 0.0 ])
200
201
201
202
with torch .no_grad ():
202
203
out_feats = layer (node_feats , adj_matrix )
@@ -308,7 +309,7 @@ def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):
308
309
309
310
# Sub-modules and parameters needed in the layer
310
311
self .projection = nn .Linear (c_in , c_out * num_heads )
311
- self .a = nn .Parameter (torch . Tensor (num_heads , 2 * c_out )) # One per head
312
+ self .a = nn .Parameter (Tensor (num_heads , 2 * c_out )) # One per head
312
313
self .leakyrelu = nn .LeakyReLU (alpha )
313
314
314
315
# Initialization from the original implementation
@@ -376,9 +377,9 @@ def forward(self, node_feats, adj_matrix, print_attn_probs=False):
376
377
377
378
# %%
378
379
layer = GATLayer (2 , 2 , num_heads = 2 )
379
- layer .projection .weight .data = torch . Tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]])
380
- layer .projection .bias .data = torch . Tensor ([0.0 , 0.0 ])
381
- layer .a .data = torch . Tensor ([[- 0.2 , 0.3 ], [0.1 , - 0.1 ]])
380
+ layer .projection .weight .data = Tensor ([[1.0 , 0.0 ], [0.0 , 1.0 ]])
381
+ layer .projection .bias .data = Tensor ([0.0 , 0.0 ])
382
+ layer .a .data = Tensor ([[- 0.2 , 0.3 ], [0.1 , - 0.1 ]])
382
383
383
384
with torch .no_grad ():
384
385
out_feats = layer (node_feats , adj_matrix , print_attn_probs = True )
0 commit comments