|
| 1 | +diff --git a/conda/torchdrug/meta.yaml b/conda/torchdrug/meta.yaml |
| 2 | +index b366902..55604f0 100644 |
| 3 | +--- a/conda/torchdrug/meta.yaml |
| 4 | ++++ b/conda/torchdrug/meta.yaml |
| 5 | +@@ -1,6 +1,6 @@ |
| 6 | + package: |
| 7 | + name: torchdrug |
| 8 | +- version: 0.1.1 |
| 9 | ++ version: 0.1.2 |
| 10 | + |
| 11 | + source: |
| 12 | + path: ../.. |
| 13 | +diff --git a/doc/source/paper.rst b/doc/source/paper.rst |
| 14 | +index ab8489c..c22a7ae 100644 |
| 15 | +--- a/doc/source/paper.rst |
| 16 | ++++ b/doc/source/paper.rst |
| 17 | +@@ -86,9 +86,9 @@ Readout Layers |
| 18 | + |
| 19 | + 1. `Order Matters: Sequence to sequence for sets <Set2Set_>`_ |
| 20 | + |
| 21 | +- Oriol Vinyals, Samy Bengio, Manjunath Kudlur |
| 22 | ++ Oriol Vinyals, Samy Bengio, Manjunath Kudlur |
| 23 | + |
| 24 | +- :class:`Set2Set <torchdrug.layers.Set2Set>` |
| 25 | ++ :class:`Set2Set <torchdrug.layers.Set2Set>` |
| 26 | + |
| 27 | + Normalization Layers |
| 28 | + ^^^^^^^^^^^^^^^^^^^^ |
| 29 | +diff --git a/setup.py b/setup.py |
| 30 | +index ddf3cdb..9da3d27 100644 |
| 31 | +--- a/setup.py |
| 32 | ++++ b/setup.py |
| 33 | +@@ -13,7 +13,7 @@ if __name__ == "__main__": |
| 34 | + long_description_content_type="text/markdown", |
| 35 | + url="https://torchdrug.ai/", |
| 36 | + author="TorchDrug Team", |
| 37 | +- version="0.1.1", |
| 38 | ++ version="0.1.2", |
| 39 | + license="Apache-2.0", |
| 40 | + keywords=["deep-learning", "pytorch", "drug-discovery"], |
| 41 | + packages=setuptools.find_packages(), |
| 42 | +diff --git a/torchdrug/__init__.py b/torchdrug/__init__.py |
| 43 | +index 7058780..7dca7a0 100644 |
| 44 | +--- a/torchdrug/__init__.py |
| 45 | ++++ b/torchdrug/__init__.py |
| 46 | +@@ -12,4 +12,4 @@ handler = logging.StreamHandler(sys.stdout) |
| 47 | + handler.setFormatter(format) |
| 48 | + logger.addHandler(handler) |
| 49 | + |
| 50 | +-__version__ = "0.1.1" |
| 51 | +\ No newline at end of file |
| 52 | ++__version__ = "0.1.2" |
| 53 | +\ No newline at end of file |
| 54 | +diff --git a/torchdrug/core/core.py b/torchdrug/core/core.py |
| 55 | +index 1de312b..4c6ea18 100644 |
| 56 | +--- a/torchdrug/core/core.py |
| 57 | ++++ b/torchdrug/core/core.py |
| 58 | +@@ -355,4 +355,4 @@ def make_configurable(cls, module=None, ignore_args=()): |
| 59 | + MetaClass = type(_Configurable.__name__, (Metaclass, _Configurable), {}) |
| 60 | + else: |
| 61 | + MetaClass = _Configurable |
| 62 | +- return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) |
| 63 | ++ return MetaClass(cls.__name__, (cls,), {"_ignore_args": ignore_args, "__module__": module}) |
| 64 | +\ No newline at end of file |
| 65 | +diff --git a/torchdrug/data/__init__.py b/torchdrug/data/__init__.py |
| 66 | +index bfacd4c..66131f6 100644 |
| 67 | +--- a/torchdrug/data/__init__.py |
| 68 | ++++ b/torchdrug/data/__init__.py |
| 69 | +@@ -1,3 +1,4 @@ |
| 70 | ++from .dictionary import PerfectHash, Dictionary |
| 71 | + from .graph import Graph, PackedGraph, cat |
| 72 | + from .molecule import Molecule, PackedMolecule |
| 73 | + from .dataset import MoleculeDataset, ReactionDataset, NodeClassificationDataset, KnowledgeGraphDataset, \ |
| 74 | +@@ -7,7 +8,7 @@ from . import constant |
| 75 | + from . import feature |
| 76 | + |
| 77 | + __all__ = [ |
| 78 | +- "Graph", "PackedGraph", "Molecule", "PackedMolecule", |
| 79 | ++ "Graph", "PackedGraph", "Molecule", "PackedMolecule", "PerfectHash", "Dictionary", |
| 80 | + "MoleculeDataset", "ReactionDataset", "NodeClassificationDataset", "KnowledgeGraphDataset", "SemiSupervised", |
| 81 | + "semisupervised", "key_split", "scaffold_split", "ordered_scaffold_split", |
| 82 | + "DataLoader", "graph_collate", "feature", "constant", |
| 83 | +diff --git a/torchdrug/data/dataset.py b/torchdrug/data/dataset.py |
| 84 | +index 6285244..8db50df 100644 |
| 85 | +--- a/torchdrug/data/dataset.py |
| 86 | ++++ b/torchdrug/data/dataset.py |
| 87 | +@@ -171,16 +171,32 @@ class MoleculeDataset(torch_data.Dataset, core.Configurable): |
| 88 | + def atom_types(self): |
| 89 | + """All atom types.""" |
| 90 | + atom_types = set() |
| 91 | +- for i in range(len(self.data)): |
| 92 | +- atom_types.update(self.get_item(i)["graph"].atom_type.tolist()) |
| 93 | ++ |
| 94 | ++ if getattr(self, "lazy", False): |
| 95 | ++ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") |
| 96 | ++ for smiles in self.smiles_list: |
| 97 | ++ graph = data.Molecule.from_smiles(smiles, **self.kwargs) |
| 98 | ++ atom_types.update(graph.atom_type.tolist()) |
| 99 | ++ else: |
| 100 | ++ for graph in self.data: |
| 101 | ++ atom_types.update(graph.atom_type.tolist()) |
| 102 | ++ |
| 103 | + return sorted(atom_types) |
| 104 | + |
| 105 | + @utils.cached_property |
| 106 | + def bond_types(self): |
| 107 | + """All bond types.""" |
| 108 | + bond_types = set() |
| 109 | +- for i in range(len(self.data)): |
| 110 | +- bond_types.update(self.get_item(i)["graph"].edge_list[:, 2].tolist()) |
| 111 | ++ |
| 112 | ++ if getattr(self, "lazy", False): |
| 113 | ++ warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.") |
| 114 | ++ for smiles in self.smiles_list: |
| 115 | ++ graph = data.Molecule.from_smiles(smiles, **self.kwargs) |
| 116 | ++ bond_types.update(graph.edge_list[:, 2].tolist()) |
| 117 | ++ else: |
| 118 | ++ for graph in self.data: |
| 119 | ++ bond_types.update(graph.edge_list[:, 2].tolist()) |
| 120 | ++ |
| 121 | + return sorted(bond_types) |
| 122 | + |
| 123 | + def __len__(self): |
| 124 | +diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py |
| 125 | +index db16f7d..ef78c67 100644 |
| 126 | +--- a/torchdrug/models/neurallp.py |
| 127 | ++++ b/torchdrug/models/neurallp.py |
| 128 | +@@ -104,7 +104,7 @@ class NeuralLogicProgramming(nn.Module, core.Configurable): |
| 129 | + |
| 130 | + h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) |
| 131 | + hr_index = h_index * graph.num_relation + r_index |
| 132 | +- hr_index_set, hr_inverse = torch.unique(hr_index, return_inverse=True) |
| 133 | ++ hr_index_set, hr_inverse = hr_index.unique(return_inverse=True) |
| 134 | + h_index_set = hr_index_set // graph.num_relation |
| 135 | + r_index_set = hr_index_set % graph.num_relation |
| 136 | + |
| 137 | +diff --git a/torchdrug/tasks/generation.py b/torchdrug/tasks/generation.py |
| 138 | +index bb7ddc0..942e8e3 100644 |
| 139 | +--- a/torchdrug/tasks/generation.py |
| 140 | ++++ b/torchdrug/tasks/generation.py |
| 141 | +@@ -803,7 +803,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): |
| 142 | + self.batch_id += 1 |
| 143 | + |
| 144 | + # generation takes less time when early_stop=True |
| 145 | +- graph = self.generate(len(batch["graph"]), max_resample=5, off_policy=True, max_step=40 * 2, verbose=1) |
| 146 | ++ graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1) |
| 147 | + if graph.num_nodes.max() == 1: |
| 148 | + raise ValueError("Generation results collapse to singleton molecules") |
| 149 | + |
| 150 | +@@ -1338,7 +1338,7 @@ class GCPNGeneration(tasks.Task, core.Configurable): |
| 151 | + self.best_results[task] = best_results |
| 152 | + |
| 153 | + @torch.no_grad() |
| 154 | +- def generate(self, num_sample, max_resample=10, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): |
| 155 | ++ def generate(self, num_sample, max_resample=20, off_policy=False, max_step=30 * 2, initial_smiles="C", verbose=0): |
| 156 | + is_training = self.training |
| 157 | + self.eval() |
| 158 | + |
| 159 | +diff --git a/torchdrug/utils/comm.py b/torchdrug/utils/comm.py |
| 160 | +index 0980131..817c281 100644 |
| 161 | +--- a/torchdrug/utils/comm.py |
| 162 | ++++ b/torchdrug/utils/comm.py |
| 163 | +@@ -147,7 +147,7 @@ def reduce(obj, op="sum", dst=None): |
| 164 | + Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. |
| 165 | + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. |
| 166 | + |
| 167 | +- Examples:: |
| 168 | ++ Example:: |
| 169 | + |
| 170 | + >>> # assume 4 workers |
| 171 | + >>> rank = comm.get_rank() |
| 172 | +@@ -190,7 +190,7 @@ def stack(obj, dst=None): |
| 173 | + obj (Object): any container object. Can be nested list, tuple or dict. |
| 174 | + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. |
| 175 | + |
| 176 | +- Examples:: |
| 177 | ++ Example:: |
| 178 | + |
| 179 | + >>> # assume 4 workers |
| 180 | + >>> rank = comm.get_rank() |
| 181 | +@@ -229,7 +229,7 @@ def cat(obj, dst=None): |
| 182 | + obj (Object): any container object. Can be nested list, tuple or dict. |
| 183 | + dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. |
| 184 | + |
| 185 | +- Examples:: |
| 186 | ++ Example:: |
| 187 | + |
| 188 | + >>> # assume 4 workers |
| 189 | + >>> rank = comm.get_rank() |
| 190 | +diff --git a/torchdrug/utils/io.py b/torchdrug/utils/io.py |
| 191 | +index 29659cf..d573cde 100644 |
| 192 | +--- a/torchdrug/utils/io.py |
| 193 | ++++ b/torchdrug/utils/io.py |
| 194 | +@@ -77,7 +77,7 @@ def capture_rdkit_log(): |
| 195 | + """ |
| 196 | + Context manager to capture all rdkit loggings. |
| 197 | + |
| 198 | +- Examples:: |
| 199 | ++ Example:: |
| 200 | + |
| 201 | + >>> with utils.capture_rdkit_log() as log: |
| 202 | + >>> ... |
0 commit comments