1
1
from collections .abc import Sequence
2
2
3
3
import torch
4
- from class_resolver import Hint
5
4
from torch import nn
6
5
7
6
from torchdrug import core , layers
8
7
from torchdrug .core import Registry as R
9
- from torchdrug .layers import Readout , readout_resolver
10
8
11
9
12
10
@R .register ("models.GCN" )
@@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
29
27
"""
30
28
31
29
def __init__ (self , input_dim , hidden_dims , edge_input_dim = None , short_cut = False , batch_norm = False ,
32
- activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
30
+ activation = "relu" , concat_hidden = False , readout = "sum" ):
33
31
super (GraphConvolutionalNetwork , self ).__init__ ()
34
32
35
33
if not isinstance (hidden_dims , Sequence ):
@@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
44
42
for i in range (len (self .dims ) - 1 ):
45
43
self .layers .append (layers .GraphConv (self .dims [i ], self .dims [i + 1 ], edge_input_dim , batch_norm , activation ))
46
44
47
- self .readout = readout_resolver .make (readout )
45
+ if readout == "sum" :
46
+ self .readout = layers .SumReadout ()
47
+ elif readout == "mean" :
48
+ self .readout = layers .MeanReadout ()
49
+ elif readout == "max" :
50
+ self .readout = layers .MaxReadout ()
51
+ else :
52
+ raise ValueError ("Unknown readout `%s`" % readout )
48
53
49
54
def forward (self , graph , input , all_loss = None , metric = None ):
50
55
"""
@@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
103
108
"""
104
109
105
110
def __init__ (self , input_dim , hidden_dims , num_relation , edge_input_dim = None , short_cut = False , batch_norm = False ,
106
- activation = "relu" , concat_hidden = False , readout : Hint [ Readout ] = "sum" ):
111
+ activation = "relu" , concat_hidden = False , readout = "sum" ):
107
112
super (RelationalGraphConvolutionalNetwork , self ).__init__ ()
108
113
109
114
if not isinstance (hidden_dims , Sequence ):
@@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120
125
self .layers .append (layers .RelationalGraphConv (self .dims [i ], self .dims [i + 1 ], num_relation , edge_input_dim ,
121
126
batch_norm , activation ))
122
127
123
- self .readout = readout_resolver .make (readout )
128
+ if readout == "sum" :
129
+ self .readout = layers .SumReadout ()
130
+ elif readout == "mean" :
131
+ self .readout = layers .MeanReadout ()
132
+ elif readout == "max" :
133
+ self .readout = layers .MaxReadout ()
134
+ else :
135
+ raise ValueError ("Unknown readout `%s`" % readout )
124
136
125
137
def forward (self , graph , input , all_loss = None , metric = None ):
126
138
"""
0 commit comments