Skip to content

Commit 4d74e5d

Browse files
committed
Add class-resolver
1 parent d6b544d commit 4d74e5d

File tree

7 files changed

+27
-40
lines changed

7 files changed

+27
-40
lines changed

torchdrug/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \
44
NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv
55
from .pool import DiffPool, MinCutPool
6-
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort
6+
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout
77
from .flow import ConditionalFlow
88
from .sampler import NodeSampler, EdgeSampler
99
from . import distribution, functional

torchdrug/layers/readout.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import torch
22
from torch import nn
33
from torch_scatter import scatter_mean, scatter_add, scatter_max
4+
from class_resolver import ClassResolver
45

56

6-
class MeanReadout(nn.Module):
7+
class Readout(nn.Module):
8+
"""A base class for readouts."""
9+
10+
11+
class MeanReadout(Readout):
712
"""Mean readout operator over graphs with variadic sizes."""
813

914
def forward(self, graph, input):
@@ -21,7 +26,7 @@ def forward(self, graph, input):
2126
return output
2227

2328

24-
class SumReadout(nn.Module):
29+
class SumReadout(Readout):
2530
"""Sum readout operator over graphs with variadic sizes."""
2631

2732
def forward(self, graph, input):
@@ -39,7 +44,7 @@ def forward(self, graph, input):
3944
return output
4045

4146

42-
class MaxReadout(nn.Module):
47+
class MaxReadout(Readout):
4348
"""Max readout operator over graphs with variadic sizes."""
4449

4550
def forward(self, graph, input):
@@ -57,6 +62,12 @@ def forward(self, graph, input):
5762
return output
5863

5964

65+
readout_resolver = ClassResolver.from_subclasses(
66+
Readout,
67+
default=SumReadout,
68+
)
69+
70+
6071
class Softmax(nn.Module):
6172
"""Softmax operator over graphs with variadic sizes."""
6273

torchdrug/models/gat.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
2525
batch_norm (bool, optional): apply batch normalization or not
2626
activation (str or function, optional): activation function
2727
concat_hidden (bool, optional): concat hidden representations from all layers as output
28-
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
28+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
2929
"""
3030

3131
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False,
32-
batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
32+
batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
3333
super(GraphAttentionNetwork, self).__init__()
3434

3535
if not isinstance(hidden_dims, Sequence):
@@ -45,14 +45,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
4545
self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head,
4646
negative_slope, batch_norm, activation))
4747

48-
if readout == "sum":
49-
self.readout = layers.SumReadout()
50-
elif readout == "mean":
51-
self.readout = layers.MeanReadout()
52-
elif readout == "max":
53-
self.readout = layers.MaxReadout()
54-
else:
55-
raise ValueError("Unknown readout `%s`" % readout)
48+
self.readout = readout_resolver.make(readout)
5649

5750
def forward(self, graph, input, all_loss=None, metric=None):
5851
"""

torchdrug/models/gcn.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from collections.abc import Sequence
22

33
import torch
4+
from class_resolver import Hint
45
from torch import nn
56

67
from torchdrug import core, layers
8+
from torchdrug.layers import readout_resolver, Readout
79
from torchdrug.core import Registry as R
810

911

@@ -23,11 +25,11 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
2325
batch_norm (bool, optional): apply batch normalization or not
2426
activation (str or function, optional): activation function
2527
concat_hidden (bool, optional): concat hidden representations from all layers as output
26-
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
28+
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
2729
"""
2830

2931
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
30-
activation="relu", concat_hidden=False, readout="sum"):
32+
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
3133
super(GraphConvolutionalNetwork, self).__init__()
3234

3335
if not isinstance(hidden_dims, Sequence):
@@ -42,14 +44,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
4244
for i in range(len(self.dims) - 1):
4345
self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation))
4446

45-
if readout == "sum":
46-
self.readout = layers.SumReadout()
47-
elif readout == "mean":
48-
self.readout = layers.MeanReadout()
49-
elif readout == "max":
50-
self.readout = layers.MaxReadout()
51-
else:
52-
raise ValueError("Unknown readout `%s`" % readout)
47+
self.readout = readout_resolver.make(readout)
5348

5449
def forward(self, graph, input, all_loss=None, metric=None):
5550
"""

torchdrug/models/gin.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,7 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
4848
self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim,
4949
layer_hidden_dims, eps, learn_eps, batch_norm, activation))
5050

51-
if readout == "sum":
52-
self.readout = layers.SumReadout()
53-
elif readout == "mean":
54-
self.readout = layers.MeanReadout()
55-
elif readout == "max":
56-
self.readout = layers.MaxReadout()
57-
else:
58-
raise ValueError("Unknown readout `%s`" % readout)
51+
self.readout = readout_resolver.make(readout)
5952

6053
def forward(self, graph, input, all_loss=None, metric=None):
6154
"""

torchdrug/models/neuralfp.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,7 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor
4747
batch_norm, activation))
4848
self.linears.append(nn.Linear(self.dims[i + 1], output_dim))
4949

50-
if readout == "sum":
51-
self.readout = layers.SumReadout()
52-
elif readout == "mean":
53-
self.readout = layers.MeanReadout()
54-
else:
55-
raise ValueError("Unknown readout `%s`" % readout)
50+
self.readout = readout_resolver.make(readout)
5651

5752
def forward(self, graph, input, all_loss=None, metric=None):
5853
"""

torchdrug/models/schnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class SchNet(nn.Module, core.Configurable):
2828
"""
2929

3030
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
31-
batch_norm=False, activation="shifted_softplus", concat_hidden=False):
31+
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"):
3232
super(SchNet, self).__init__()
3333

3434
if not isinstance(hidden_dims, Sequence):
@@ -44,7 +44,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga
4444
self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff,
4545
num_gaussian, batch_norm, activation))
4646

47-
self.readout = layers.SumReadout()
47+
self.readout = readout_resolver.make(readout)
4848

4949
def forward(self, graph, input, all_loss=None, metric=None):
5050
"""

0 commit comments

Comments
 (0)