Skip to content

Commit 50adc02

Browse files
committed
Cleanup
1 parent 4d74e5d commit 50adc02

File tree

10 files changed

+32
-33
lines changed

10 files changed

+32
-33
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ matplotlib
77
tqdm
88
networkx
99
ninja
10-
jinja2
10+
jinja2
11+
class-resolver

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"networkx",
4040
"ninja",
4141
"jinja2",
42+
"class-resolver",
4243
],
4344
python_requires=">=3.7,<3.9",
4445
classifiers=[

torchdrug/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
2424
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
2525
"DiffPool", "MinCutPool",
26-
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
26+
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout",
2727
"ConditionalFlow",
2828
"NodeSampler", "EdgeSampler",
2929
"distribution", "functional",

torchdrug/models/chebnet.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56

67
from torchdrug import core, layers
78
from torchdrug.core import Registry as R
9+
from torchdrug.layers import Readout, readout_resolver
810

911

1012
@R.register("models.ChebNet")
@@ -25,11 +27,11 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
2527
batch_norm (bool, optional): apply batch normalization or not
2628
activation (str or function, optional): activation function
2729
concat_hidden (bool, optional): concat hidden representations from all layers as output
28-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
30+
readout: readout function. Available functions are ``sum`` and ``mean``.
2931
"""
3032

3133
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
32-
activation="relu", concat_hidden=False, readout="sum"):
34+
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
3335
super(ChebyshevConvolutionalNetwork, self).__init__()
3436

3537
if not isinstance(hidden_dims, Sequence):
@@ -45,12 +47,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
4547
self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k,
4648
batch_norm, activation))
4749

48-
if readout == "sum":
49-
self.readout = layers.SumReadout()
50-
elif readout == "mean":
51-
self.readout = layers.MeanReadout()
52-
else:
53-
raise ValueError("Unknown readout `%s`" % readout)
50+
self.readout = readout_resolver.make(readout)
5451

5552
def forward(self, graph, input, all_loss=None, metric=None):
5653
"""

torchdrug/models/gat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56

67
from torchdrug import core, layers
78
from torchdrug.core import Registry as R
9+
from torchdrug.layers import Readout, readout_resolver
810

911

1012
@R.register("models.GAT")

torchdrug/models/gcn.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torch import nn
66

77
from torchdrug import core, layers
8-
from torchdrug.layers import readout_resolver, Readout
98
from torchdrug.core import Registry as R
9+
from torchdrug.layers import Readout, readout_resolver
1010

1111

1212
@R.register("models.GCN")
@@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
9999
batch_norm (bool, optional): apply batch normalization or not
100100
activation (str or function, optional): activation function
101101
concat_hidden (bool, optional): concat hidden representations from all layers as output
102-
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
102+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
103103
"""
104104

105105
def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
106-
activation="relu", concat_hidden=False, readout="sum"):
106+
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
107107
super(RelationalGraphConvolutionalNetwork, self).__init__()
108108

109109
if not isinstance(hidden_dims, Sequence):
@@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120120
self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim,
121121
batch_norm, activation))
122122

123-
if readout == "sum":
124-
self.readout = layers.SumReadout()
125-
elif readout == "mean":
126-
self.readout = layers.MeanReadout()
127-
elif readout == "max":
128-
self.readout = layers.MaxReadout()
129-
else:
130-
raise ValueError("Unknown readout `%s`" % readout)
123+
self.readout = readout_resolver.make(readout)
131124

132125
def forward(self, graph, input, all_loss=None, metric=None):
133126
"""

torchdrug/models/gin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56

67
from torchdrug import core, layers
78
from torchdrug.core import Registry as R
9+
from torchdrug.layers import Readout, readout_resolver
810

911

1012
@R.register("models.GIN")
@@ -26,12 +28,12 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
2628
batch_norm (bool, optional): apply batch normalization or not
2729
activation (str or function, optional): activation function
2830
concat_hidden (bool, optional): concat hidden representations from all layers as output
29-
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
31+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
3032
"""
3133

3234
def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
3335
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
34-
readout="sum"):
36+
readout: Hint[Readout] = "sum"):
3537
super(GraphIsomorphismNetwork, self).__init__()
3638

3739
if not isinstance(hidden_dims, Sequence):

torchdrug/models/neuralfp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56
from torch.nn import functional as F
67

78
from torchdrug import core, layers
89
from torchdrug.core import Registry as R
10+
from torchdrug.layers import Readout, readout_resolver
911

1012

1113
@R.register("models.NeuralFP")
@@ -25,11 +27,11 @@ class NeuralFingerprint(nn.Module, core.Configurable):
2527
batch_norm (bool, optional): apply batch normalization or not
2628
activation (str or function, optional): activation function
2729
concat_hidden (bool, optional): concat hidden representations from all layers as output
28-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
30+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
2931
"""
3032

3133
def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
32-
activation="relu", concat_hidden=False, readout="sum"):
34+
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
3335
super(NeuralFingerprint, self).__init__()
3436

3537
if not isinstance(hidden_dims, Sequence):

torchdrug/models/schnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56

67
from torchdrug import core, layers
78
from torchdrug.core import Registry as R
9+
from torchdrug.layers import Readout, readout_resolver
810

911

1012
@R.register("models.SchNet")
@@ -25,6 +27,7 @@ class SchNet(nn.Module, core.Configurable):
2527
batch_norm (bool, optional): apply batch normalization or not
2628
activation (str or function, optional): activation function
2729
concat_hidden (bool, optional): concat hidden representations from all layers as output
30+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
2831
"""
2932

3033
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,

torchdrug/tasks/pretrain.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import copy
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56
from torch.nn import functional as F
67
from torch_scatter import scatter_max, scatter_min
78

89
from torchdrug import core, tasks, layers
910
from torchdrug.data import constant
10-
from torchdrug.layers import functional
11+
from torchdrug.layers import functional, readout_resolver, Readout
1112
from torchdrug.core import Registry as R
1213

1314

@@ -169,9 +170,10 @@ class ContextPrediction(tasks.Task, core.Configurable):
169170
r2 (int, optional): outer radius for context graphs
170171
readout (nn.Module, optional): readout function over context anchor nodes
171172
num_negative (int, optional): number of negative samples per positive sample
173+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
172174
"""
173175

174-
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
176+
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1):
175177
super(ContextPrediction, self).__init__()
176178
self.model = model
177179
self.k = k
@@ -184,12 +186,8 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n
184186
self.context_model = copy.deepcopy(model)
185187
else:
186188
self.context_model = context_model
187-
if readout == "sum":
188-
self.readout = layers.SumReadout()
189-
elif readout == "mean":
190-
self.readout = layers.MeanReadout()
191-
else:
192-
raise ValueError("Unknown readout `%s`" % readout)
189+
190+
self.readout = readout_resolver.make(readout)
193191

194192
def substruct_and_context(self, graph):
195193
center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()

0 commit comments

Comments
 (0)