Skip to content

Commit 489677c

Browse files
committed
release v0.1.2
1 parent 3ab63a7 commit 489677c

File tree

13 files changed

+246
-21
lines changed

13 files changed

+246
-21
lines changed

conda/torchdrug/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package:
22
name: torchdrug
3-
version: 0.1.1
3+
version: 0.1.2
44

55
source:
66
path: ../..

diff.txt

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
diff --git a/conda/torchdrug/meta.yaml b/conda/torchdrug/meta.yaml
2+
index b366902..55604f0 100644
3+
--- a/conda/torchdrug/meta.yaml
4+
+++ b/conda/torchdrug/meta.yaml
5+
@@ -1,6 +1,6 @@
6+
package:
7+
name: torchdrug
8+
- version: 0.1.1
9+
+ version: 0.1.2
10+
11+
source:
12+
path: ../..
13+
diff --git a/doc/source/paper.rst b/doc/source/paper.rst
14+
index ab8489c..c22a7ae 100644
15+
--- a/doc/source/paper.rst
16+
+++ b/doc/source/paper.rst
17+
@@ -86,9 +86,9 @@ Readout Layers
18+
19+
1. `Order Matters: Sequence to sequence for sets <Set2Set_>`_
20+
21+
- Oriol Vinyals, Samy Bengio, Manjunath Kudlur
22+
+ Oriol Vinyals, Samy Bengio, Manjunath Kudlur
23+
24+
- :class:`Set2Set <torchdrug.layers.Set2Set>`
25+
+ :class:`Set2Set <torchdrug.layers.Set2Set>`
26+
27+
Normalization Layers
28+
^^^^^^^^^^^^^^^^^^^^
29+
diff --git a/setup.py b/setup.py
30+
index ddf3cdb..9da3d27 100644
31+
--- a/setup.py
32+
+++ b/setup.py
33+
@@ -13,7 +13,7 @@ if __name__ == "__main__":
34+
long_description_content_type="text/markdown",
35+
url="https://torchdrug.ai/",
36+
author="TorchDrug Team",
37+
- version="0.1.1",
38+
+ version="0.1.2",
39+
license="Apache-2.0",
40+
keywords=["deep-learning", "pytorch", "drug-discovery"],
41+
packages=setuptools.find_packages(),
42+
diff --git a/torchdrug/__init__.py b/torchdrug/__init__.py
43+
index 7058780..7dca7a0 100644
44+
--- a/torchdrug/__init__.py
45+
+++ b/torchdrug/__init__.py
46+
@@ -12,4 +12,4 @@ handler = logging.StreamHandler(sys.stdout)
47+
handler.setFormatter(format)
48+
logger.addHandler(handler)
49+
50+
-__version__ = "0.1.1"
51+
\ No newline at end of file
52+
+__version__ = "0.1.2"
53+
\ No newline at end of file
54+
diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py
55+
index 1de312b..4c6ea18 100644
56+
--- a/torchdrug/core/core.py
57+
+++ b/torchdrug/core/core.py
58+
@@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()):
59+
MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {})
60+
else:
61+
MetaClass = _Configurable
62+
- return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module})
63+
+ return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module})
64+
\ No newline at end of file
65+
diff --git a/torchdrug/data/__init__.py b/torchdrug/data/__init__.py
66+
index bfacd4c..66131f6 100644
67+
--- a/torchdrug/data/__init__.py
68+
+++ b/torchdrug/data/__init__.py
69+
@@ -1,3 +1,4 @@
70+
+from .dictionary import PerfectHash, Dictionary
71+
from .graph import Graph, PackedGraph, cat
72+
from .molecule import Molecule, PackedMolecule
73+
from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \
74+
@@ -7,7 +8,7 @@ from . import constant
75+
from . import feature
76+
77+
__all__ = [
78+
- "Graph", "PackedGraph", "Molecule", "PackedMolecule",
79+
+ "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary",
80+
"MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised",
81+
"semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split",
82+
"DataLoader", "graph_collate", "feature", "constant",
83+
diff --git a/torchdrug/data/dataset.py b/torchdrug/data/dataset.py
84+
index 6285244..8db50df 100644
85+
--- a/torchdrug/data/dataset.py
86+
+++ b/torchdrug/data/dataset.py
87+
@@ -171,16 +171,32 @@ class MoleculeDataset(torch_data.Dataset, core.Configurable):
88+
def atom_types(self):
89+
"""All atom types."""
90+
atom_types = set()
91+
- for i in range(len(self.data)):
92+
- atom_types.update(self.get_item(i)["graph"].atom_type.tolist())
93+
+
94+
+ if getattr(self, "lazy", False):
95+
+ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
96+
+ for smiles in self.smiles_list:
97+
+ graph = data.Molecule.from_smiles(smiles, **self.kwargs)
98+
+ atom_types.update(graph.atom_type.tolist())
99+
+ else:
100+
+ for graph in self.data:
101+
+ atom_types.update(graph.atom_type.tolist())
102+
+
103+
return sorted(atom_types)
104+
105+
@utils.cached_property
106+
def bond_types(self):
107+
"""All bond types."""
108+
bond_types = set()
109+
- for i in range(len(self.data)):
110+
- bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist())
111+
+
112+
+ if getattr(self, "lazy", False):
113+
+ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
114+
+ for smiles in self.smiles_list:
115+
+ graph = data.Molecule.from_smiles(smiles, **self.kwargs)
116+
+ bond_types.update(graph.edge_list[:, 2].tolist())
117+
+ else:
118+
+ for graph in self.data:
119+
+ bond_types.update(graph.edge_list[:, 2].tolist())
120+
+
121+
return sorted(bond_types)
122+
123+
def __len__(self):
124+
diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py
125+
index db16f7d..ef78c67 100644
126+
--- a/torchdrug/models/neurallp.py
127+
+++ b/torchdrug/models/neurallp.py
128+
@@ -104,7 +104,7 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
129+
130+
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index)
131+
hr_index = h_index * graph.num_relation + r_index
132+
- hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True)
133+
+ hr_index_set, hr_inverse = hr_index.unique(return_inverse=True)
134+
h_index_set = hr_index_set // graph.num_relation
135+
r_index_set = hr_index_set % graph.num_relation
136+
137+
diff --git a/torchdrug/tasks/generation.py b/torchdrug/tasks/generation.py
138+
index bb7ddc0..942e8e3 100644
139+
--- a/torchdrug/tasks/generation.py
140+
+++ b/torchdrug/tasks/generation.py
141+
@@ -803,7 +803,7 @@ class GCPNGeneration(tasks.Task, core.Configurable):
142+
self.batch_id += 1
143+
144+
# generation takes less time when early_stop=True
145+
- graph = self.generate(len(batch["graph"]), max_resample=5, off_policy=True, max_step=40 * 2, verbose=1)
146+
+ graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1)
147+
if graph.num_nodes.max() == 1:
148+
raise ValueError("Generation results collapse to singleton molecules")
149+
150+
@@ -1338,7 +1338,7 @@ class GCPNGeneration(tasks.Task, core.Configurable):
151+
self.best_results[task] = best_results
152+
153+
@torch.no_grad()
154+
- def generate(self, num_sample, max_resample=10, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0):
155+
+ def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0):
156+
is_training = self.training
157+
self.eval()
158+
159+
diff --git a/torchdrug/utils/comm.py b/torchdrug/utils/comm.py
160+
index 0980131..817c281 100644
161+
--- a/torchdrug/utils/comm.py
162+
+++ b/torchdrug/utils/comm.py
163+
@@ -147,7 +147,7 @@ def reduce(obj, op="sum", dst=None):
164+
Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``.
165+
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
166+
167+
- Examples::
168+
+ Example::
169+
170+
>>> # assume 4 workers
171+
>>> rank = comm.get_rank()
172+
@@ -190,7 +190,7 @@ def stack(obj, dst=None):
173+
obj (Object): any container object. Can be nested list, tuple or dict.
174+
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
175+
176+
- Examples::
177+
+ Example::
178+
179+
>>> # assume 4 workers
180+
>>> rank = comm.get_rank()
181+
@@ -229,7 +229,7 @@ def cat(obj, dst=None):
182+
obj (Object): any container object. Can be nested list, tuple or dict.
183+
dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
184+
185+
- Examples::
186+
+ Example::
187+
188+
>>> # assume 4 workers
189+
>>> rank = comm.get_rank()
190+
diff --git a/torchdrug/utils/io.py b/torchdrug/utils/io.py
191+
index 29659cf..d573cde 100644
192+
--- a/torchdrug/utils/io.py
193+
+++ b/torchdrug/utils/io.py
194+
@@ -77,7 +77,7 @@ def capture_rdkit_log():
195+
"""
196+
Context manager to capture all rdkit loggings.
197+
198+
- Examples::
199+
+ Example::
200+
201+
>>> with utils.capture_rdkit_log() as log:
202+
>>> ...

doc/source/paper.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ Readout Layers
8686

8787
1. `Order Matters: Sequence to sequence for sets <Set2Set_>`_
8888

89-
Oriol Vinyals, Samy Bengio, Manjunath Kudlur
89+
Oriol Vinyals, Samy Bengio, Manjunath Kudlur
9090

91-
:class:`Set2Set <torchdrug.layers.Set2Set>`
91+
:class:`Set2Set <torchdrug.layers.Set2Set>`
9292

9393
Normalization Layers
9494
^^^^^^^^^^^^^^^^^^^^

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
long_description_content_type="text/markdown",
1414
url="https://torchdrug.ai/",
1515
author="TorchDrug Team",
16-
version="0.1.1",
16+
version="0.1.2",
1717
license="Apache-2.0",
1818
keywords=["deep-learning", "pytorch", "drug-discovery"],
1919
packages=setuptools.find_packages(),

torchdrug/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
handler.setFormatter(format)
1313
logger.addHandler(handler)
1414

15-
__version__ = "0.1.1"
15+
__version__ = "0.1.2"

torchdrug/core/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()):
355355
MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {})
356356
else:
357357
MetaClass = _Configurable
358-
return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module})
358+
return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module})

torchdrug/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .dictionary import PerfectHash, Dictionary
12
from .graph import Graph, PackedGraph, cat
23
from .molecule import Molecule, PackedMolecule
34
from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \
@@ -7,7 +8,7 @@
78
from . import feature
89

910
__all__ = [
10-
"Graph", "PackedGraph", "Molecule", "PackedMolecule",
11+
"Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary",
1112
"MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised",
1213
"semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split",
1314
"DataLoader", "graph_collate", "feature", "constant",

torchdrug/data/dataset.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,32 @@ def num_bond_type(self):
171171
def atom_types(self):
172172
"""All atom types."""
173173
atom_types = set()
174-
for i in range(len(self.data)):
175-
atom_types.update(self.get_item(i)["graph"].atom_type.tolist())
174+
175+
if getattr(self, "lazy", False):
176+
warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
177+
for smiles in self.smiles_list:
178+
graph = data.Molecule.from_smiles(smiles, **self.kwargs)
179+
atom_types.update(graph.atom_type.tolist())
180+
else:
181+
for graph in self.data:
182+
atom_types.update(graph.atom_type.tolist())
183+
176184
return sorted(atom_types)
177185

178186
@utils.cached_property
179187
def bond_types(self):
180188
"""All bond types."""
181189
bond_types = set()
182-
for i in range(len(self.data)):
183-
bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist())
190+
191+
if getattr(self, "lazy", False):
192+
warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
193+
for smiles in self.smiles_list:
194+
graph = data.Molecule.from_smiles(smiles, **self.kwargs)
195+
bond_types.update(graph.edge_list[:, 2].tolist())
196+
else:
197+
for graph in self.data:
198+
bond_types.update(graph.edge_list[:, 2].tolist())
199+
184200
return sorted(bond_types)
185201

186202
def __len__(self):

torchdrug/datasets/uspto50k.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _get_difference(self, reactant, product):
103103

104104
# check edges in the product
105105
product = product.directed()
106+
# O(n^2) brute-force match is faster than O(nlogn) data.Graph.match for small molecules
106107
mapped_edge = product.edge_list.clone()
107108
mapped_edge[:, :2] = prod2react[mapped_edge[:, :2]]
108109
is_same_index = mapped_edge.unsqueeze(0) == reactant.edge_list.unsqueeze(1)
@@ -123,8 +124,10 @@ def _get_reaction_center(self, reactant, product):
123124

124125
if len(edge_added) > 0:
125126
if len(edge_added) == 1: # add a single edge
126-
index = product.index(edge_added[0])
127-
assert len(index) == 1
127+
any = -torch.ones(1, 1, dtype=torch.long)
128+
pattern = torch.cat([edge_added, any], dim=-1)
129+
index, num_match = product.match(pattern)
130+
assert num_match.item() == 1
128131
edge_label[index] = 1
129132
h, t = edge_added[0]
130133
reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]])
@@ -172,7 +175,10 @@ def _get_synthon(self, reactant, product):
172175
if len(edge_added) == 1: # add a single edge
173176
edge = edge_added[0]
174177
reverse_edge = edge.flip(0)
175-
index = torch.cat([product.index(edge), product.index(reverse_edge)])
178+
any = -torch.ones(2, 1, dtype=torch.long)
179+
pattern = torch.cat([edge, reverse_edge])
180+
pattern = torch.cat([pattern, any], dim=-1)
181+
index, num_match = product.match(pattern)
176182
edge_mask = torch.ones(product.num_edge, dtype=torch.bool)
177183
edge_mask[index] = 0
178184
product = product.edge_mask(edge_mask)

torchdrug/models/neurallp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, graph, h_index, t_index, r_index, all_loss=None, metric=None):
104104

105105
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index)
106106
hr_index = h_index * graph.num_relation + r_index
107-
hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True)
107+
hr_index_set, hr_inverse = hr_index.unique(return_inverse=True)
108108
h_index_set = hr_index_set // graph.num_relation
109109
r_index_set = hr_index_set % graph.num_relation
110110

0 commit comments

Comments
 (0)