From ac1e1e94aa750b7c2a45df5748c01bd614b9c4fc Mon Sep 17 00:00:00 2001 From: hyunjunji Date: Thu, 20 Jun 2024 13:46:12 +0900 Subject: [PATCH 1/5] refactor: use rdFingerprinGenerator --- molfeat/calc/fingerprints.py | 190 ++++++++++++++++++----------------- 1 file changed, 98 insertions(+), 92 deletions(-) diff --git a/molfeat/calc/fingerprints.py b/molfeat/calc/fingerprints.py index 7144436..167abec 100644 --- a/molfeat/calc/fingerprints.py +++ b/molfeat/calc/fingerprints.py @@ -6,6 +6,7 @@ import copy import datamol as dm from rdkit.Avalon import pyAvalonTools +from rdkit.Chem import rdFingerprintGenerator from rdkit.Chem import rdMolDescriptors from rdkit.Chem import rdReducedGraphs from rdkit.Chem import rdmolops @@ -19,14 +20,22 @@ from molfeat.utils.commons import fold_count_fp +FP_GENERATORS = { + "ecfp": rdFingerprintGenerator.GetMorganGenerator, + "fcfp": rdFingerprintGenerator.GetMorganGenerator, + "topological": rdFingerprintGenerator.GetTopologicalTorsionGenerator, + "atompair": rdFingerprintGenerator.GetAtomPairGenerator, + "rdkit": rdFingerprintGenerator.GetRDKitFPGenerator, + "ecfp-count": rdFingerprintGenerator.GetMorganGenerator, + "fcfp-count": rdFingerprintGenerator.GetMorganGenerator, + "topological-count": rdFingerprintGenerator.GetTopologicalTorsionGenerator, + "atompair-count": rdFingerprintGenerator.GetAtomPairGenerator, + "rdkit-count": rdFingerprintGenerator.GetRDKitFPGenerator, +} + FP_FUNCS = { "maccs": rdMolDescriptors.GetMACCSKeysFingerprint, "avalon": pyAvalonTools.GetAvalonFP, - "ecfp": rdMolDescriptors.GetMorganFingerprintAsBitVect, - "fcfp": partial(rdMolDescriptors.GetMorganFingerprintAsBitVect, useFeatures=True), - "topological": rdMolDescriptors.GetHashedTopologicalTorsionFingerprintAsBitVect, - "atompair": rdMolDescriptors.GetHashedAtomPairFingerprintAsBitVect, - "rdkit": rdmolops.RDKFingerprint, "pattern": rdmolops.PatternFingerprint, "layered": rdmolops.LayeredFingerprint, "map4": MAP4, @@ -34,11 +43,7 @@ "erg": rdReducedGraphs.GetErGFingerprint, "estate": lambda x, **params: EStateFingerprinter.FingerprintMol(x)[0], "avalon-count": pyAvalonTools.GetAvalonCountFP, - "rdkit-count": rdmolops.UnfoldedRDKFingerprintCountBased, - "ecfp-count": rdMolDescriptors.GetHashedMorganFingerprint, - "fcfp-count": rdMolDescriptors.GetHashedMorganFingerprint, - "topological-count": rdMolDescriptors.GetHashedTopologicalTorsionFingerprint, - "atompair-count": rdMolDescriptors.GetHashedAtomPairFingerprint, + **FP_GENERATORS, } @@ -52,59 +57,60 @@ }, "ecfp": { "radius": 2, # ECFP4 - "nBits": 2048, - "invariants": [], - "fromAtoms": [], - "useChirality": False, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": False, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": None, + "bondInvariantsGenerator": None, }, "fcfp": { - "radius": 2, # FCFP4 - "nBits": 2048, - "invariants": [], # you may want to provide features invariance - "fromAtoms": [], - "useChirality": False, + "radius": 2, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": True, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": rdFingerprintGenerator.GetMorganFeatureAtomInvGen(), + "bondInvariantsGenerator": None, }, "topological": { - "nBits": 2048, - "targetSize": 4, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "nBitsPerEntry": 4, "includeChirality": False, + "torsionAtomCount": 4, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "atompair": { - "nBits": 2048, - "minLength": 1, - "maxLength": 30, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "nBitsPerEntry": 4, + "minDistance": 1, + "maxDistance": 30, "includeChirality": False, "use2D": True, - "confId": -1, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "rdkit": { "minPath": 1, "maxPath": 7, - "fpSize": 2048, - "nBitsPerHash": 2, "useHs": True, - "tgtDensity": 0.0, - "minSize": 128, "branchedPaths": True, "useBondOrder": True, - "atomInvariants": 0, - "fromAtoms": 0, - "atomBits": None, - "bitInfo": None, + "countSimulation": False, + "countBounds": None, + "fpSize": 2048, + "numBitsPerFeature": 2, + "atomInvariantsGenerator": None, + }, + "pattern": { + "fpSize": 2048, + "atomCounts": [], + "setOnlyBits": None, + "tautomerFingerprints": False, }, - "pattern": {"fpSize": 2048, "atomCounts": [], "setOnlyBits": None}, "layered": { "fpSize": 2048, "minPath": 1, @@ -139,36 +145,41 @@ # COUNTING FP "ecfp-count": { "radius": 2, # ECFP4 - "nBits": 2048, - "invariants": [], - "fromAtoms": [], - "useChirality": False, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": False, - "includeRedundantEnvironments": False, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": None, + "bondInvariantsGenerator": None, }, "fcfp-count": { - "radius": 2, # FCFP4 - "nBits": 2048, - "invariants": [], # you may want to provide features invariance - "fromAtoms": [], - "useChirality": False, + "radius": 2, + "fpSize": 2048, + "includeChirality": False, "useBondTypes": True, - "useFeatures": True, - "includeRedundantEnvironments": False, + "countSimulation": False, + "countBounds": None, + "atomInvariantsGenerator": rdFingerprintGenerator.GetMorganFeatureAtomInvGen(), + "bondInvariantsGenerator": None, }, "topological-count": { - "nBits": 2048, - "targetSize": 4, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, "includeChirality": False, + "torsionAtomCount": 4, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, - "avalon-count": { - "nBits": 512, - "isQuery": False, - "bitFlags": pyAvalonTools.avalonSimilarityBits, + "atompair-count": { + "minDistance": 1, + "maxDistance": 30, + "includeChirality": False, + "use2D": True, + "countSimulation": True, + "countBounds": None, + "fpSize": 2048, + "atomInvariantsGenerator": None, }, "rdkit-count": { "minPath": 1, @@ -176,21 +187,11 @@ "useHs": True, "branchedPaths": True, "useBondOrder": True, - "atomInvariants": 0, - "fromAtoms": 0, - "atomBits": None, - "bitInfo": None, - }, - "atompair-count": { - "nBits": 2048, - "minLength": 1, - "maxLength": 30, - "fromAtoms": 0, - "ignoreAtoms": 0, - "atomInvariants": 0, - "includeChirality": False, - "use2D": True, - "confId": -1, + "countSimulation": False, + "countBounds": None, + "fpSize": 2048, + "numBitsPerFeature": 1, + "atomInvariantsGenerator": None, }, } @@ -231,9 +232,7 @@ def __init__( if unknown_params: logger.error(f"Params: {unknown_params} are not valid for {method}") self.params = default_params - self.params.update( - {k: method_params[k] for k in method_params if k in default_params.keys()} - ) + self.params.update({k: method_params[k] for k in method_params if k in default_params.keys()}) self._length = self._set_length(length) @staticmethod @@ -303,7 +302,16 @@ def __call__(self, mol: Union[dm.Mol, str], raw: bool = False): props (np.ndarray): list of computed rdkit molecular descriptors """ mol = dm.to_mol(mol) - fp_val = FP_FUNCS[self.method](mol, **self.params) + + fp_func = FP_FUNCS[self.method] + if self.method in FP_GENERATORS: + fp_func = fp_func(**self.params) + if self.method.endswith("-count"): + fp_val = fp_func.GetCountFingerprint(mol) + else: + fp_val = fp_func.GetFingerprint(mol) + else: + fp_val = fp_func(mol, **self.params) if self.counting: fp_val = fold_count_fp(fp_val, self._length) if not raw: @@ -334,12 +342,10 @@ def to_state_dict(self): state_dict = super().to_state_dict() cur_params = self.params default_params = copy.deepcopy(FP_DEF_PARAMS[state_dict["args"]["method"]]) - state_dict["args"].update( - { - k: cur_params[k] - for k in cur_params - if (cur_params[k] != default_params[k] and cur_params[k] is not None) - } - ) + state_dict["args"].update({ + k: cur_params[k] + for k in cur_params + if (cur_params[k] != default_params[k] and cur_params[k] is not None) + }) # we want to keep all the additional parameters in the state dict return state_dict From 639462de3341ecf7bc0ea558f4d87389adeea9a1 Mon Sep 17 00:00:00 2001 From: hyunjunji Date: Mon, 24 Jun 2024 10:37:55 +0900 Subject: [PATCH 2/5] fix: register serializable classes and change tests etc --- molfeat/calc/_serializable_classes.py | 40 +++++++++++++++++++++++++++ molfeat/calc/fingerprints.py | 40 +++++++++++++++++++-------- tests/test_state.py | 5 ++-- 3 files changed, 72 insertions(+), 13 deletions(-) create mode 100644 molfeat/calc/_serializable_classes.py diff --git a/molfeat/calc/_serializable_classes.py b/molfeat/calc/_serializable_classes.py new file mode 100644 index 0000000..40c8b0e --- /dev/null +++ b/molfeat/calc/_serializable_classes.py @@ -0,0 +1,40 @@ +from typing import Optional +from typing import Dict +from typing import Any + +from rdkit.Chem import rdFingerprintGenerator + +SERIALIZABLE_CLASSES = {} + + +def register_custom_class(cls: type): + SERIALIZABLE_CLASSES[cls.__name__] = cls + return cls + + +@register_custom_class +class SerializableMorganFeatureAtomInvGen: + def __init__(self): + self._generator = rdFingerprintGenerator.GetMorganFeatureAtomInvGen() + + def __getstate__(self): + return None + + def __setstate__(self, state: Optional[None]): + self._generator = rdFingerprintGenerator.GetMorganFeatureAtomInvGen() + + def __deepcopy__(self, memo: Dict[int, Any]): + new_instance = SerializableMorganFeatureAtomInvGen() + memo[id(self)] = new_instance + return new_instance + + def __getattr__(self, name: str): + try: + generator = object.__getattribute__(self, "_generator") + except AttributeError: + raise AttributeError("'_generator' is not initialized") + + try: + return getattr(generator, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/molfeat/calc/fingerprints.py b/molfeat/calc/fingerprints.py index 167abec..8490abb 100644 --- a/molfeat/calc/fingerprints.py +++ b/molfeat/calc/fingerprints.py @@ -1,8 +1,6 @@ from typing import Union from typing import Optional -from functools import partial - import copy import datamol as dm from rdkit.Avalon import pyAvalonTools @@ -15,6 +13,10 @@ from loguru import logger from molfeat.calc._mhfp import SECFP from molfeat.calc._map4 import MAP4 +from molfeat.calc._serializable_classes import ( + SerializableMorganFeatureAtomInvGen, + SERIALIZABLE_CLASSES, +) from molfeat.calc.base import SerializableCalculator from molfeat.utils.datatype import to_numpy, to_fp from molfeat.utils.commons import fold_count_fp @@ -72,7 +74,7 @@ "useBondTypes": True, "countSimulation": False, "countBounds": None, - "atomInvariantsGenerator": rdFingerprintGenerator.GetMorganFeatureAtomInvGen(), + "atomInvariantsGenerator": SerializableMorganFeatureAtomInvGen(), "bondInvariantsGenerator": None, }, "topological": { @@ -160,7 +162,7 @@ "useBondTypes": True, "countSimulation": False, "countBounds": None, - "atomInvariantsGenerator": rdFingerprintGenerator.GetMorganFeatureAtomInvGen(), + "atomInvariantsGenerator": SerializableMorganFeatureAtomInvGen(), "bondInvariantsGenerator": None, }, "topological-count": { @@ -232,7 +234,9 @@ def __init__( if unknown_params: logger.error(f"Params: {unknown_params} are not valid for {method}") self.params = default_params - self.params.update({k: method_params[k] for k in method_params if k in default_params.keys()}) + self.params.update( + {k: method_params[k] for k in method_params if k in default_params.keys()} + ) self._length = self._set_length(length) @staticmethod @@ -329,12 +333,19 @@ def __getstate__(self): state["input_length"] = self.input_length state["method"] = self.method state["counting"] = self.counting - state["params"] = self.params + state["params"] = { + k: (v if v.__class__.__name__ not in SERIALIZABLE_CLASSES else v.__class__.__name__) + for k, v in self.params.items() + } return state def __setstate__(self, state: dict): """Set the state of the featurizer""" self.__dict__.update(state) + self.params = { + k: (v if v not in SERIALIZABLE_CLASSES else SERIALIZABLE_CLASSES[v]()) + for k, v in self.params.items() + } self._length = self._set_length(self.input_length) def to_state_dict(self): @@ -342,10 +353,17 @@ def to_state_dict(self): state_dict = super().to_state_dict() cur_params = self.params default_params = copy.deepcopy(FP_DEF_PARAMS[state_dict["args"]["method"]]) - state_dict["args"].update({ - k: cur_params[k] - for k in cur_params - if (cur_params[k] != default_params[k] and cur_params[k] is not None) - }) + + state_dict["args"].update( + { + k: ( + cur_params[k] + if cur_params[k].__class__.__name__ not in SERIALIZABLE_CLASSES + else cur_params[k].__class__.__name__ + ) + for k in cur_params + if (cur_params[k] != default_params[k] and cur_params[k] is not None) + } + ) # we want to keep all the additional parameters in the state dict return state_dict diff --git a/tests/test_state.py b/tests/test_state.py index 95a2458..f6199b4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -259,7 +259,7 @@ def test_fp_state(): { "name": "FPCalculator", "module": "molfeat.calc.fingerprints", - "args": {"length": 512, "method": "ecfp", "counting": False, "nBits": 512}, + "args": {"length": 512, "method": "ecfp", "counting": False, "fpSize": 512}, "_molfeat_version": MOLFEAT_VERSION, }, { @@ -269,7 +269,8 @@ def test_fp_state(): "length": 241, "method": "fcfp-count", "counting": True, - "nBits": 241, + "fpSize": 241, + "atomInvariantsGenerator": "SerializableMorganFeatureAtomInvGen", }, "_molfeat_version": MOLFEAT_VERSION, }, From 620e782217d22ff2ab27f5b0d893d8581658c43b Mon Sep 17 00:00:00 2001 From: hyunjunji Date: Thu, 18 Jul 2024 13:39:12 +0900 Subject: [PATCH 3/5] fix: resolve comments --- molfeat/calc/_serializable_classes.py | 36 +++++++++++++++++++++++++-- molfeat/calc/fingerprints.py | 7 +++--- tests/test_fp.py | 2 ++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/molfeat/calc/_serializable_classes.py b/molfeat/calc/_serializable_classes.py index 40c8b0e..2a43b22 100644 --- a/molfeat/calc/_serializable_classes.py +++ b/molfeat/calc/_serializable_classes.py @@ -7,13 +7,15 @@ SERIALIZABLE_CLASSES = {} -def register_custom_class(cls: type): +def register_custom_serializable_class(cls: type): SERIALIZABLE_CLASSES[cls.__name__] = cls return cls -@register_custom_class +@register_custom_serializable_class class SerializableMorganFeatureAtomInvGen: + """A serializable wrapper class for `rdFingerprintGenerator.GetMorganFeatureAtomInvGen()`""" + def __init__(self): self._generator = rdFingerprintGenerator.GetMorganFeatureAtomInvGen() @@ -38,3 +40,33 @@ def __getattr__(self, name: str): return getattr(generator, name) except AttributeError: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + +@register_custom_serializable_class +class SerializableMorganFeatureBondInvGen: + """A serializable wrapper class for `rdFingerprintGenerator.GetMorganFeatureBondInvGen()`""" + + def __init__(self): + self._generator = rdFingerprintGenerator.GetMorganFeatureBondInvGen() + + def __getstate__(self): + return None + + def __setstate__(self, state: Optional[None]): + self._generator = rdFingerprintGenerator.GetMorganFeatureBondInvGen() + + def __deepcopy__(self, memo: Dict[int, Any]): + new_instance = SerializableMorganFeatureBondInvGen() + memo[id(self)] = new_instance + return new_instance + + def __getattr__(self, name: str): + try: + generator = object.__getattribute__(self, "_generator") + except AttributeError: + raise AttributeError("'_generator' is not initialized") + + try: + return getattr(generator, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/molfeat/calc/fingerprints.py b/molfeat/calc/fingerprints.py index 8490abb..3130b50 100644 --- a/molfeat/calc/fingerprints.py +++ b/molfeat/calc/fingerprints.py @@ -150,7 +150,7 @@ "fpSize": 2048, "includeChirality": False, "useBondTypes": True, - "countSimulation": False, + "includeRedundantEnvironments": False, "countBounds": None, "atomInvariantsGenerator": None, "bondInvariantsGenerator": None, @@ -160,8 +160,7 @@ "fpSize": 2048, "includeChirality": False, "useBondTypes": True, - "countSimulation": False, - "countBounds": None, + "includeRedundantEnvironments": False, "atomInvariantsGenerator": SerializableMorganFeatureAtomInvGen(), "bondInvariantsGenerator": None, }, @@ -310,7 +309,7 @@ def __call__(self, mol: Union[dm.Mol, str], raw: bool = False): fp_func = FP_FUNCS[self.method] if self.method in FP_GENERATORS: fp_func = fp_func(**self.params) - if self.method.endswith("-count"): + if self.counting: fp_val = fp_func.GetCountFingerprint(mol) else: fp_val = fp_func.GetFingerprint(mol) diff --git a/tests/test_fp.py b/tests/test_fp.py index d432fd1..13ad27a 100644 --- a/tests/test_fp.py +++ b/tests/test_fp.py @@ -40,6 +40,7 @@ def batch_compute(self, mols, **kwargs): class TestMolTransformer(ut.TestCase): r"""Test cases for FingerprintsTransformer""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", @@ -51,6 +52,7 @@ class TestMolTransformer(ut.TestCase): "avalon", "rdkit", "ecfp", + "ecfp-count", "pharm2D", "desc2D", ] From d9c459ddead80e316ec746e116e2f2764508c013 Mon Sep 17 00:00:00 2001 From: hyunjunji Date: Thu, 18 Jul 2024 13:46:14 +0900 Subject: [PATCH 4/5] misc: reformat --- molfeat/plugins/factories.py | 24 ++++++++-------------- molfeat/trans/base.py | 1 - molfeat/trans/pretrained/dgl_pretrained.py | 1 + molfeat/utils/cache.py | 10 +++------ molfeat/utils/commons.py | 1 + tests/test_atom_bond_calculator.py | 1 + tests/test_descriptors.py | 1 + tests/test_graphs.py | 2 ++ tests/test_pretrained.py | 3 +++ 9 files changed, 20 insertions(+), 24 deletions(-) diff --git a/molfeat/plugins/factories.py b/molfeat/plugins/factories.py index f264b7f..4f91d52 100644 --- a/molfeat/plugins/factories.py +++ b/molfeat/plugins/factories.py @@ -84,15 +84,13 @@ def CalculatorFactory( entry_point_name: str, load: Literal[True] = True, entry_point_group: Optional[str] = None, -) -> Union[Type["SerializableCalculator"], Callable]: - ... +) -> Union[Type["SerializableCalculator"], Callable]: ... @overload def CalculatorFactory( entry_point_name: str, load: Literal[False], entry_point_group: Optional[str] = None -) -> EntryPoint: - ... +) -> EntryPoint: ... def CalculatorFactory( @@ -134,15 +132,13 @@ def TransformerFactory( entry_point_name: str, load: Literal[True] = True, entry_point_group: Optional[str] = None, -) -> Union[Type["MoleculeTransformer"], Callable]: - ... +) -> Union[Type["MoleculeTransformer"], Callable]: ... @overload def TransformerFactory( entry_point_name: str, load: Literal[False], entry_point_group: Optional[str] = None -) -> EntryPoint: - ... +) -> EntryPoint: ... def TransformerFactory( @@ -188,13 +184,11 @@ def PretrainedTransformerFactory( entry_point_name: str, load: Literal[True] = True, entry_point_group: Optional[str] = None, -) -> Union[Type["PretrainedMolTransformer"], Callable]: - ... +) -> Union[Type["PretrainedMolTransformer"], Callable]: ... @overload -def PretrainedTransformerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: - ... +def PretrainedTransformerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ... def PretrainedTransformerFactory( @@ -239,15 +233,13 @@ def DefaultFactory( entry_point_name: str, load: Literal[True] = True, entry_point_group: str = None, -) -> Union[Type["PretrainedMolTransformer"], Callable]: - ... +) -> Union[Type["PretrainedMolTransformer"], Callable]: ... @overload def DefaultFactory( entry_point_name: str, load: Literal[False], entry_point_group: str = None -) -> EntryPoint: - ... +) -> EntryPoint: ... def DefaultFactory( diff --git a/molfeat/trans/base.py b/molfeat/trans/base.py index fc053e5..037b622 100644 --- a/molfeat/trans/base.py +++ b/molfeat/trans/base.py @@ -146,7 +146,6 @@ def get_collate_fn(self, *args, **kwargs): class MoleculeTransformer(TransformerMixin, BaseFeaturizer, metaclass=_TransformerMeta): - """ Base class for molecular data transformer such as Fingerprinter etc. If you create a subclass of this featurizer, you will need to make sure that the diff --git a/molfeat/trans/pretrained/dgl_pretrained.py b/molfeat/trans/pretrained/dgl_pretrained.py index 02ca594..41eb630 100644 --- a/molfeat/trans/pretrained/dgl_pretrained.py +++ b/molfeat/trans/pretrained/dgl_pretrained.py @@ -37,6 +37,7 @@ class DGLModel(PretrainedStoreModel): r""" Load one of the pretrained DGL models for molecular embedding: """ + AVAILABLE_MODELS = [ "gin_supervised_contextpred", "gin_supervised_infomax", diff --git a/molfeat/utils/cache.py b/molfeat/utils/cache.py index fb6e8fd..1ab1ba2 100644 --- a/molfeat/utils/cache.py +++ b/molfeat/utils/cache.py @@ -97,7 +97,6 @@ def from_state_dict(state: dict) -> "MolToKey": class _Cache(abc.ABC): - """Implementation of a cache interface""" def __init__( @@ -209,12 +208,10 @@ def __call__( self._sync_cache() return self.fetch(mols) - def clear(self, *args, **kwargs): - ... + def clear(self, *args, **kwargs): ... @abc.abstractmethod - def update(self, new_cache: Mapping[Any, Any]): - ... + def update(self, new_cache: Mapping[Any, Any]): ... def get(self, key, default: Optional[Any] = None): """Get the cached value for a specific key @@ -241,8 +238,7 @@ def to_dict(self): """Convert current cache to a dictionary""" return dict(self.items()) - def _sync_cache(self): - ... + def _sync_cache(self): ... def fetch( self, diff --git a/molfeat/utils/commons.py b/molfeat/utils/commons.py index dc5ab64..9fef652 100644 --- a/molfeat/utils/commons.py +++ b/molfeat/utils/commons.py @@ -1,4 +1,5 @@ """Common utility functions""" + from typing import Type from typing import Callable from typing import Iterable diff --git a/tests/test_atom_bond_calculator.py b/tests/test_atom_bond_calculator.py index 12f6aa6..115e601 100644 --- a/tests/test_atom_bond_calculator.py +++ b/tests/test_atom_bond_calculator.py @@ -72,6 +72,7 @@ def test_to_from_state(calculator_builder): @pytest.mark.xfail(not requires.check("dgllife"), reason="3rd party module dgllife is missing") class TestGraphCalculator(ut.TestCase): r"""Test cases for basic graph featurizer vs dgl generation""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", diff --git a/tests/test_descriptors.py b/tests/test_descriptors.py index 878a63c..7823f96 100644 --- a/tests/test_descriptors.py +++ b/tests/test_descriptors.py @@ -22,6 +22,7 @@ class TestDescPharm(ut.TestCase): r"""Test cases for descriptors and pharmacophore generation""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", diff --git a/tests/test_graphs.py b/tests/test_graphs.py index f0ddcf9..ab17265 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -12,6 +12,7 @@ class TestMolTreeDecomposition(ut.TestCase): r"""Test cases for Tree decomposition""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", @@ -49,6 +50,7 @@ def test_moltree_transformer(self): @pytest.mark.xfail(not requires.check("dgllife"), reason="3rd party module dgllife is missing") class TestGraphTransformer(ut.TestCase): r"""Test cases for AdjGraphTransformer""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", diff --git a/tests/test_pretrained.py b/tests/test_pretrained.py index 6e75701..18fda65 100644 --- a/tests/test_pretrained.py +++ b/tests/test_pretrained.py @@ -16,6 +16,7 @@ ) class TestGraphormerTransformer(ut.TestCase): r"""Test cases for FingerprintsTransformer""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", @@ -86,6 +87,7 @@ def test_graphormer_cache(self): class TestDGLTransformer(ut.TestCase): r"""Test cases for FingerprintsTransformer""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", @@ -130,6 +132,7 @@ def test_dgl_pretrained_cache(self): ) class TestHGFTransformer(ut.TestCase): r"""Test cases for FingerprintsTransformer""" + smiles = [ "CCOc1c(OC)cc(CCN)cc1OC", "COc1cc(CCN)cc(OC)c1OC", From 066b6fd1dfeb8f20b43da555d316742e9b37b1e9 Mon Sep 17 00:00:00 2001 From: hyunjunji Date: Tue, 13 Aug 2024 09:01:13 +0900 Subject: [PATCH 5/5] merge molfeat/main to fix/fp --- molfeat/trans/fp.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/molfeat/trans/fp.py b/molfeat/trans/fp.py index f1e6a17..256b714 100644 --- a/molfeat/trans/fp.py +++ b/molfeat/trans/fp.py @@ -1,14 +1,11 @@ -from typing import Callable -from typing import List -from typing import Optional -from typing import Union - -import re import copy -import numpy as np +import re +from typing import Callable, List, Optional, Union + import datamol as dm +import numpy as np -from molfeat.calc import get_calculator, FP_FUNCS +from molfeat.calc import FP_FUNCS, get_calculator from molfeat.trans.base import MoleculeTransformer from molfeat.utils import datatype from molfeat.utils.commons import _parse_to_evaluable_str