Skip to content

Commit 7392070

Browse files
committed
release v0.1.3
1 parent cb117f8 commit 7392070

File tree

6 files changed

+49
-43
lines changed

6 files changed

+49
-43
lines changed

conda/torchdrug/meta.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
package:
22
name: torchdrug
3-
version: 0.1.2
3+
version: 0.1.3
44

55
source:
66
path: ../..
77

88
requirements:
99
host:
10-
- python >=3.7,<3.9
10+
- python >=3.7,<3.10
1111
- pip
1212
run:
13-
- python >=3.7,<3.9
13+
- python >=3.7,<3.10
1414
- pytorch >=1.8.0
1515
- pytorch-scatter >=2.0.8
1616
- decorator
1717
- numpy >=1.11
18-
- rdkit
18+
- rdkit >=2020.09
1919
- matplotlib
2020
- tqdm
2121
- networkx

setup.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
long_description_content_type="text/markdown",
1414
url="https://torchdrug.ai/",
1515
author="TorchDrug Team",
16-
version="0.1.2",
16+
version="0.1.3",
1717
license="Apache-2.0",
1818
keywords=["deep-learning", "pytorch", "drug-discovery"],
1919
packages=setuptools.find_packages(),
@@ -24,23 +24,23 @@
2424
"layers/functional/extension/*.cpp",
2525
"layers/functional/extension/*.cu",
2626
"utils/extension/*.cpp",
27-
"utils/template/*.html"
28-
]},
27+
"utils/template/*.html",
28+
]
29+
},
2930
test_suite="nose.collector",
30-
install_requires=
31-
[
32-
"torch>=1.8.0",
33-
"torch-scatter>=2.0.8",
34-
"decorator",
35-
"numpy>=1.11",
36-
"rdkit-pypi",
37-
"matplotlib",
38-
"tqdm",
39-
"networkx",
40-
"ninja",
41-
"jinja2",
42-
],
43-
python_requires=">=3.7,<3.9",
31+
install_requires=[
32+
"torch>=1.8.0",
33+
"torch-scatter>=2.0.8",
34+
"decorator",
35+
"numpy>=1.11",
36+
"rdkit-pypi>=2020.9",
37+
"matplotlib",
38+
"tqdm",
39+
"networkx",
40+
"ninja",
41+
"jinja2",
42+
],
43+
python_requires=">=3.7,<3.10",
4444
classifiers=[
4545
"Development Status :: 4 - Beta",
4646
'Intended Audience :: Developers',

test/layers/test_conv.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_graph_conv(self):
3838
adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
3939
x = adjacency.t() @ self.input
4040
truth = conv.activation(conv.linear(x))
41-
self.assertTrue(torch.allclose(result, truth, rtol=1e-4, atol=1e-7), "Incorrect graph convolution")
41+
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph convolution")
4242

4343
num_head = 2
4444
conv = layers.GraphAttentionConv(self.input_dim, self.output_dim, num_head=num_head).cuda()
@@ -55,15 +55,15 @@ def test_graph_conv(self):
5555
outputs.append(output)
5656
truth = torch.cat(outputs, dim=-1)
5757
truth = conv.activation(truth)
58-
self.assertTrue(torch.allclose(result, truth), "Incorrect graph attention convolution")
58+
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph attention convolution")
5959

6060
eps = 1
6161
conv = layers.GraphIsomorphismConv(self.input_dim, self.output_dim, eps=eps).cuda()
6262
result = conv(self.graph, self.input)
6363
adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
6464
x = (1 + eps) * self.input + adjacency.t() @ self.input
6565
truth = conv.activation(conv.mlp(x))
66-
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect graph isomorphism convolution")
66+
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-2), "Incorrect graph isomorphism convolution")
6767

6868
conv = layers.RelationalGraphConv(self.input_dim, self.output_dim, self.num_relation).cuda()
6969
result = conv(self.graph, self.input)
@@ -72,7 +72,7 @@ def test_graph_conv(self):
7272
x = torch.einsum("htr, hd -> trd", adjacency, self.input)
7373
x = conv.linear(x.flatten(1)) + conv.self_loop(self.input)
7474
truth = conv.activation(x)
75-
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect relational graph convolution")
75+
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect relational graph convolution")
7676

7777
conv = layers.ChebyshevConv(self.input_dim, self.output_dim, k=2).cuda()
7878
result = conv(self.graph, self.input)
@@ -83,7 +83,7 @@ def test_graph_conv(self):
8383
bases = [self.input, laplacian.t() @ self.input, (2 * laplacian.t() @ laplacian.t() - identity) @ self.input]
8484
x = conv.linear(torch.cat(bases, dim=-1))
8585
truth = conv.activation(x)
86-
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect chebyshev graph convolution")
86+
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect chebyshev graph convolution")
8787

8888

8989
if __name__ == "__main__":

test/layers/test_pool.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def test_pool(self):
4343
truth_adj = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
4444
index = torch.arange(self.output_node, device=truth.device)
4545
truth_adj[:, index, index] = 0
46-
self.assertTrue(torch.allclose(result, truth), "Incorrect diffpool node feature")
47-
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect diffpool adjacency")
46+
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
47+
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
4848

4949
graph = self.graph[0]
5050
rng_state = torch.get_rng_state()
@@ -60,8 +60,8 @@ def test_pool(self):
6060
truth_adj = torch.einsum("na, nm, mc -> ac", assignment, adjacency, assignment)
6161
index = torch.arange(self.output_node, device=truth.device)
6262
truth_adj[index, index] = 0
63-
self.assertTrue(torch.allclose(result, truth), "Incorrect diffpool node feature")
64-
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect diffpool adjacency")
63+
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
64+
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")
6565

6666
pool = layers.MinCutPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer).cuda()
6767
all_loss = torch.tensor(0, dtype=torch.float32, device="cuda")
@@ -89,10 +89,10 @@ def test_pool(self):
8989
x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5)
9090
regularization = x.flatten(-2).norm(dim=-1).mean()
9191
truth_metric = {"normalized cut loss": cut_loss, "orthogonal regularization": regularization}
92-
self.assertTrue(torch.allclose(result, truth), "Incorrect min cut pool feature")
93-
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect min cut pool adjcency")
92+
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect min cut pool feature")
93+
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect min cut pool adjcency")
9494
for key in result_metric:
95-
self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], atol=1e-4, rtol=1e-7),
95+
self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], rtol=1e-3, atol=1e-4),
9696
"Incorrect min cut pool metric")
9797

9898

torchdrug/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
handler.setFormatter(format)
1313
logger.addHandler(handler)
1414

15-
__version__ = "0.1.2"
15+
__version__ = "0.1.3"

torchdrug/utils/decorator.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import warnings
3+
import functools
34

45
from decorator import decorator
56

@@ -100,14 +101,19 @@ def deprecated_alias(**alias):
100101
Handle argument alias for a function and output deprecated warnings.
101102
"""
102103

103-
def wrapper(func, *args, **kwargs):
104-
for key, value in alias.items():
105-
if key in kwargs:
106-
if value in kwargs:
107-
raise TypeError("%s() got values for both `%s` and `%s`" % (func.__name__, value, key))
108-
warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (func.__name__, key, value))
109-
kwargs[value] = kwargs.pop(key)
104+
def decorate(func):
110105

111-
return func(*args, **kwargs)
106+
@functools.wraps(func)
107+
def wrapper(*args, **kwargs):
108+
for key, value in alias.items():
109+
if key in kwargs:
110+
if value in kwargs:
111+
raise TypeError("%s() got values for both `%s` and `%s`" % (func.__name__, value, key))
112+
warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (func.__name__, key, value))
113+
kwargs[value] = kwargs.pop(key)
112114

113-
return decorator(wrapper, kwsyntax=True)
115+
return func(*args, **kwargs)
116+
117+
return wrapper
118+
119+
return decorate

0 commit comments

Comments
 (0)