Skip to content

Commit b62c14f

Browse files
committed
revert to naive branching resolution
1 parent 50adc02 commit b62c14f

File tree

11 files changed

+79
-50
lines changed

11 files changed

+79
-50
lines changed

requirements.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ matplotlib
77
tqdm
88
networkx
99
ninja
10-
jinja2
11-
class-resolver
10+
jinja2

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
"networkx",
4040
"ninja",
4141
"jinja2",
42-
"class-resolver",
4342
],
4443
python_requires=">=3.7,<3.9",
4544
classifiers=[

torchdrug/layers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \
44
NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv
55
from .pool import DiffPool, MinCutPool
6-
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout
6+
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort
77
from .flow import ConditionalFlow
88
from .sampler import NodeSampler, EdgeSampler
99
from . import distribution, functional
@@ -23,7 +23,7 @@
2323
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
2424
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
2525
"DiffPool", "MinCutPool",
26-
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout",
26+
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
2727
"ConditionalFlow",
2828
"NodeSampler", "EdgeSampler",
2929
"distribution", "functional",

torchdrug/layers/readout.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import torch
22
from torch import nn
33
from torch_scatter import scatter_mean, scatter_add, scatter_max
4-
from class_resolver import ClassResolver
54

65

7-
class Readout(nn.Module):
8-
"""A base class for readouts."""
9-
10-
11-
class MeanReadout(Readout):
6+
class MeanReadout(nn.Module):
127
"""Mean readout operator over graphs with variadic sizes."""
138

149
def forward(self, graph, input):
@@ -26,7 +21,7 @@ def forward(self, graph, input):
2621
return output
2722

2823

29-
class SumReadout(Readout):
24+
class SumReadout(nn.Module):
3025
"""Sum readout operator over graphs with variadic sizes."""
3126

3227
def forward(self, graph, input):
@@ -44,7 +39,7 @@ def forward(self, graph, input):
4439
return output
4540

4641

47-
class MaxReadout(Readout):
42+
class MaxReadout(nn.Module):
4843
"""Max readout operator over graphs with variadic sizes."""
4944

5045
def forward(self, graph, input):
@@ -62,12 +57,6 @@ def forward(self, graph, input):
6257
return output
6358

6459

65-
readout_resolver = ClassResolver.from_subclasses(
66-
Readout,
67-
default=SumReadout,
68-
)
69-
70-
7160
class Softmax(nn.Module):
7261
"""Softmax operator over graphs with variadic sizes."""
7362

torchdrug/models/chebnet.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65

76
from torchdrug import core, layers
87
from torchdrug.core import Registry as R
9-
from torchdrug.layers import Readout, readout_resolver
108

119

1210
@R.register("models.ChebNet")
@@ -31,7 +29,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
3129
"""
3230

3331
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
34-
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
32+
activation="relu", concat_hidden=False, readout="sum"):
3533
super(ChebyshevConvolutionalNetwork, self).__init__()
3634

3735
if not isinstance(hidden_dims, Sequence):
@@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
4745
self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k,
4846
batch_norm, activation))
4947

50-
self.readout = readout_resolver.make(readout)
48+
if readout == "sum":
49+
self.readout = layers.SumReadout()
50+
elif readout == "mean":
51+
self.readout = layers.MeanReadout()
52+
elif readout == "max":
53+
self.readout = layers.MaxReadout()
54+
else:
55+
raise ValueError("Unknown readout `%s`" % readout)
5156

5257
def forward(self, graph, input, all_loss=None, metric=None):
5358
"""

torchdrug/models/gat.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65

76
from torchdrug import core, layers
87
from torchdrug.core import Registry as R
9-
from torchdrug.layers import Readout, readout_resolver
108

119

1210
@R.register("models.GAT")
@@ -31,7 +29,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
3129
"""
3230

3331
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False,
34-
batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
32+
batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
3533
super(GraphAttentionNetwork, self).__init__()
3634

3735
if not isinstance(hidden_dims, Sequence):
@@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
4745
self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head,
4846
negative_slope, batch_norm, activation))
4947

50-
self.readout = readout_resolver.make(readout)
48+
if readout == "sum":
49+
self.readout = layers.SumReadout()
50+
elif readout == "mean":
51+
self.readout = layers.MeanReadout()
52+
elif readout == "max":
53+
self.readout = layers.MaxReadout()
54+
else:
55+
raise ValueError("Unknown readout `%s`" % readout)
5156

5257
def forward(self, graph, input, all_loss=None, metric=None):
5358
"""

torchdrug/models/gcn.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65

76
from torchdrug import core, layers
87
from torchdrug.core import Registry as R
9-
from torchdrug.layers import Readout, readout_resolver
108

119

1210
@R.register("models.GCN")
@@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2927
"""
3028

3129
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
32-
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
30+
activation="relu", concat_hidden=False, readout="sum"):
3331
super(GraphConvolutionalNetwork, self).__init__()
3432

3533
if not isinstance(hidden_dims, Sequence):
@@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4442
for i in range(len(self.dims) - 1):
4543
self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation))
4644

47-
self.readout = readout_resolver.make(readout)
45+
if readout == "sum":
46+
self.readout = layers.SumReadout()
47+
elif readout == "mean":
48+
self.readout = layers.MeanReadout()
49+
elif readout == "max":
50+
self.readout = layers.MaxReadout()
51+
else:
52+
raise ValueError("Unknown readout `%s`" % readout)
4853

4954
def forward(self, graph, input, all_loss=None, metric=None):
5055
"""
@@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
103108
"""
104109

105110
def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
106-
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
111+
activation="relu", concat_hidden=False, readout="sum"):
107112
super(RelationalGraphConvolutionalNetwork, self).__init__()
108113

109114
if not isinstance(hidden_dims, Sequence):
@@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
120125
self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim,
121126
batch_norm, activation))
122127

123-
self.readout = readout_resolver.make(readout)
128+
if readout == "sum":
129+
self.readout = layers.SumReadout()
130+
elif readout == "mean":
131+
self.readout = layers.MeanReadout()
132+
elif readout == "max":
133+
self.readout = layers.MaxReadout()
134+
else:
135+
raise ValueError("Unknown readout `%s`" % readout)
124136

125137
def forward(self, graph, input, all_loss=None, metric=None):
126138
"""

torchdrug/models/gin.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65

76
from torchdrug import core, layers
87
from torchdrug.core import Registry as R
9-
from torchdrug.layers import Readout, readout_resolver
108

119

1210
@R.register("models.GIN")
@@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
3230
"""
3331

3432
def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
35-
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
36-
readout: Hint[Readout] = "sum"):
33+
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
3734
super(GraphIsomorphismNetwork, self).__init__()
3835

3936
if not isinstance(hidden_dims, Sequence):
@@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
5047
self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim,
5148
layer_hidden_dims, eps, learn_eps, batch_norm, activation))
5249

53-
self.readout = readout_resolver.make(readout)
50+
if readout == "sum":
51+
self.readout = layers.SumReadout()
52+
elif readout == "mean":
53+
self.readout = layers.MeanReadout()
54+
elif readout == "max":
55+
self.readout = layers.MaxReadout()
56+
else:
57+
raise ValueError("Unknown readout `%s`" % readout)
5458

5559
def forward(self, graph, input, all_loss=None, metric=None):
5660
"""

torchdrug/models/neuralfp.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65
from torch.nn import functional as F
76

87
from torchdrug import core, layers
98
from torchdrug.core import Registry as R
10-
from torchdrug.layers import Readout, readout_resolver
119

1210

1311
@R.register("models.NeuralFP")
@@ -31,7 +29,7 @@ class NeuralFingerprint(nn.Module, core.Configurable):
3129
"""
3230

3331
def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
34-
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
32+
activation="relu", concat_hidden=False, readout="sum"):
3533
super(NeuralFingerprint, self).__init__()
3634

3735
if not isinstance(hidden_dims, Sequence):
@@ -49,7 +47,14 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor
4947
batch_norm, activation))
5048
self.linears.append(nn.Linear(self.dims[i + 1], output_dim))
5149

52-
self.readout = readout_resolver.make(readout)
50+
if readout == "sum":
51+
self.readout = layers.SumReadout()
52+
elif readout == "mean":
53+
self.readout = layers.MeanReadout()
54+
elif readout == "max":
55+
self.readout = layers.MaxReadout()
56+
else:
57+
raise ValueError("Unknown readout `%s`" % readout)
5358

5459
def forward(self, graph, input, all_loss=None, metric=None):
5560
"""

torchdrug/models/schnet.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import Sequence
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65

76
from torchdrug import core, layers
87
from torchdrug.core import Registry as R
9-
from torchdrug.layers import Readout, readout_resolver
108

119

1210
@R.register("models.SchNet")
@@ -31,7 +29,7 @@ class SchNet(nn.Module, core.Configurable):
3129
"""
3230

3331
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
34-
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"):
32+
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"):
3533
super(SchNet, self).__init__()
3634

3735
if not isinstance(hidden_dims, Sequence):
@@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga
4745
self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff,
4846
num_gaussian, batch_norm, activation))
4947

50-
self.readout = readout_resolver.make(readout)
48+
if readout == "sum":
49+
self.readout = layers.SumReadout()
50+
elif readout == "mean":
51+
self.readout = layers.MeanReadout()
52+
elif readout == "max":
53+
self.readout = layers.MaxReadout()
54+
else:
55+
raise ValueError("Unknown readout `%s`" % readout)
5156

5257
def forward(self, graph, input, all_loss=None, metric=None):
5358
"""

torchdrug/tasks/pretrain.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import copy
22

33
import torch
4-
from class_resolver import Hint
54
from torch import nn
65
from torch.nn import functional as F
76
from torch_scatter import scatter_max, scatter_min
87

98
from torchdrug import core, tasks, layers
109
from torchdrug.data import constant
11-
from torchdrug.layers import functional, readout_resolver, Readout
10+
from torchdrug.layers import functional
1211
from torchdrug.core import Registry as R
1312

1413

@@ -173,7 +172,7 @@ class ContextPrediction(tasks.Task, core.Configurable):
173172
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
174173
"""
175174

176-
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1):
175+
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
177176
super(ContextPrediction, self).__init__()
178177
self.model = model
179178
self.k = k
@@ -187,7 +186,14 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Rea
187186
else:
188187
self.context_model = context_model
189188

190-
self.readout = readout_resolver.make(readout)
189+
if readout == "sum":
190+
self.readout = layers.SumReadout()
191+
elif readout == "mean":
192+
self.readout = layers.MeanReadout()
193+
elif readout == "max":
194+
self.readout = layers.MaxReadout()
195+
else:
196+
raise ValueError("Unknown readout `%s`" % readout)
191197

192198
def substruct_and_context(self, graph):
193199
center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()

0 commit comments

Comments
 (0)