Skip to content

Commit d6b544d

Browse files
authored
Update gin.py
1 parent 2c5d0f7 commit d6b544d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchdrug/models/gin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
2626
batch_norm (bool, optional): apply batch normalization or not
2727
activation (str or function, optional): activation function
2828
concat_hidden (bool, optional): concat hidden representations from all layers as output
29-
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
29+
readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``.
3030
"""
3131

3232
def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
@@ -52,6 +52,8 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
5252
self.readout = layers.SumReadout()
5353
elif readout == "mean":
5454
self.readout = layers.MeanReadout()
55+
elif readout == "max":
56+
self.readout = layers.MaxReadout()
5557
else:
5658
raise ValueError("Unknown readout `%s`" % readout)
5759

@@ -88,4 +90,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
8890
return {
8991
"graph_feature": graph_feature,
9092
"node_feature": node_feature
91-
}
93+
}

0 commit comments

Comments
 (0)