Skip to content

Commit 45597c8

Browse files
authored
Merge branch 'dev' into fix/model-specific-package
2 parents 50df357 + 0c9e492 commit 45597c8

File tree

8 files changed

+421
-24
lines changed

8 files changed

+421
-24
lines changed

chebifier/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Note: The top-level package __init__.py runs only once,
2+
# even if multiple subpackages are imported later.
3+
4+
from ._custom_cache import PerSmilesPerModelLRUCache
5+
6+
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)

chebifier/_custom_cache.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import os
2+
import pickle
3+
import threading
4+
from collections import OrderedDict
5+
from collections.abc import Iterable
6+
from functools import wraps
7+
from typing import Any, Callable
8+
9+
10+
class PerSmilesPerModelLRUCache:
11+
"""
12+
A thread-safe, optionally persistent LRU cache for storing
13+
(SMILES, model_name) → result mappings.
14+
"""
15+
16+
def __init__(self, max_size: int = 100, persist_path: str | None = None):
17+
"""
18+
Initialize the cache.
19+
20+
Args:
21+
max_size (int): Maximum number of items to keep in the cache.
22+
persist_path (str | None): Optional path to persist cache using pickle.
23+
"""
24+
self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict()
25+
self._max_size = max_size
26+
self._lock = threading.Lock()
27+
self._persist_path = persist_path
28+
29+
self.hits = 0
30+
self.misses = 0
31+
32+
if self._persist_path:
33+
self._load_cache()
34+
35+
def get(self, smiles: str, model_name: str) -> Any | None:
36+
"""
37+
Retrieve value from cache if present, otherwise return None.
38+
39+
Args:
40+
smiles (str): SMILES string key.
41+
model_name (str): Model identifier.
42+
43+
Returns:
44+
Any | None: Cached value or None.
45+
"""
46+
key = (smiles, model_name)
47+
with self._lock:
48+
if key in self._cache:
49+
self._cache.move_to_end(key)
50+
self.hits += 1
51+
return self._cache[key]
52+
else:
53+
self.misses += 1
54+
return None
55+
56+
def set(self, smiles: str, model_name: str, value: Any) -> None:
57+
"""
58+
Store value in cache under (smiles, model_name) key.
59+
60+
Args:
61+
smiles (str): SMILES string key.
62+
model_name (str): Model identifier.
63+
value (Any): Value to cache.
64+
"""
65+
assert value is not None, "Value must not be None"
66+
key = (smiles, model_name)
67+
with self._lock:
68+
if key in self._cache:
69+
self._cache.move_to_end(key)
70+
self._cache[key] = value
71+
if len(self._cache) > self._max_size:
72+
self._cache.popitem(last=False)
73+
74+
def clear(self) -> None:
75+
"""
76+
Clear the cache and remove the persistence file if present.
77+
"""
78+
self._save_cache()
79+
with self._lock:
80+
self._cache.clear()
81+
self.hits = 0
82+
self.misses = 0
83+
if self._persist_path and os.path.exists(self._persist_path):
84+
os.remove(self._persist_path)
85+
86+
def stats(self) -> dict[str, int]:
87+
"""
88+
Return cache hit/miss statistics.
89+
90+
Returns:
91+
dict[str, int]: Dictionary with 'hits' and 'misses' keys.
92+
"""
93+
return {"hits": self.hits, "misses": self.misses}
94+
95+
def batch_decorator(self, func: Callable) -> Callable:
96+
"""
97+
Decorator for class methods that accept a batch of SMILES as a list,
98+
and cache predictions per (smiles, model_name) key.
99+
100+
The instance is expected to have a `model_name` attribute.
101+
102+
Args:
103+
func (Callable): The method to decorate.
104+
105+
Returns:
106+
Callable: The wrapped method.
107+
"""
108+
109+
@wraps(func)
110+
def wrapper(instance, smiles_list: list[str]) -> list[Any]:
111+
assert isinstance(smiles_list, list), "smiles_list must be a list."
112+
model_name = getattr(instance, "model_name", None)
113+
assert model_name is not None, "Instance must have a model_name attribute."
114+
115+
missing_smiles: list[str] = []
116+
missing_indices: list[int] = []
117+
ordered_results: list[Any] = [None] * len(smiles_list)
118+
119+
# First: try to fetch all from cache
120+
for idx, smiles in enumerate(smiles_list):
121+
prediction = self.get(smiles=smiles, model_name=model_name)
122+
if prediction is not None:
123+
# For debugging purposes, you can uncomment the print statement below
124+
# print(
125+
# f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
126+
# )
127+
ordered_results[idx] = prediction
128+
else:
129+
missing_smiles.append(smiles)
130+
missing_indices.append(idx)
131+
132+
# If some are missing, call original function
133+
if missing_smiles:
134+
new_results = func(instance, tuple(missing_smiles))
135+
assert isinstance(
136+
new_results, Iterable
137+
), "Function must return an Iterable."
138+
139+
# Save to cache and append
140+
for smiles, prediction, missing_idx in zip(
141+
missing_smiles, new_results, missing_indices
142+
):
143+
if prediction is not None:
144+
self.set(smiles, model_name, prediction)
145+
ordered_results[missing_idx] = prediction
146+
147+
return ordered_results
148+
149+
return wrapper
150+
151+
def __len__(self) -> int:
152+
"""
153+
Return number of items in the cache.
154+
155+
Returns:
156+
int: Number of entries in the cache.
157+
"""
158+
with self._lock:
159+
return len(self._cache)
160+
161+
def __repr__(self) -> str:
162+
"""
163+
String representation of the underlying cache.
164+
165+
Returns:
166+
str: String version of the OrderedDict.
167+
"""
168+
return self._cache.__repr__()
169+
170+
def save(self) -> None:
171+
"""
172+
Save the cache to disk, if persistence is enabled.
173+
"""
174+
self._save_cache()
175+
176+
def load(self) -> None:
177+
"""
178+
Load the cache from disk, if persistence is enabled.
179+
"""
180+
self._load_cache()
181+
182+
def _save_cache(self) -> None:
183+
"""
184+
Serialize the cache to disk using pickle.
185+
"""
186+
if self._persist_path:
187+
try:
188+
with open(self._persist_path, "wb") as f:
189+
pickle.dump(self._cache, f)
190+
except Exception as e:
191+
print(f"[Cache Save Error] {e}")
192+
193+
def _load_cache(self) -> None:
194+
"""
195+
Load the cache from disk, if the file exists and is non-empty.
196+
"""
197+
if (
198+
self._persist_path
199+
and os.path.exists(self._persist_path)
200+
and os.path.getsize(self._persist_path) > 0
201+
):
202+
try:
203+
with open(self._persist_path, "rb") as f:
204+
loaded = pickle.load(f)
205+
if isinstance(loaded, OrderedDict):
206+
self._cache = loaded
207+
except Exception as e:
208+
print(f"[Cache Load Error] {e}")

chebifier/prediction_models/base_predictor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from abc import ABC
33

4-
from functools import lru_cache
4+
from chebifier import modelwise_smiles_lru_cache
55

66

77
class BasePredictor(ABC):
@@ -23,17 +23,13 @@ def __init__(
2323

2424
self._description = kwargs.get("description", None)
2525

26+
@modelwise_smiles_lru_cache.batch_decorator
2627
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
27-
# list is not hashable, so we convert it to a tuple (useful for caching)
28-
return self.predict_smiles_tuple(tuple(smiles_list))
29-
30-
@lru_cache(maxsize=100)
31-
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
3228
raise NotImplementedError()
3329

3430
def predict_smiles(self, smiles: str) -> dict:
3531
# by default, use list-based prediction
36-
return self.predict_smiles_tuple((smiles,))[0]
32+
return self.predict_smiles_list([smiles])[0]
3733

3834
@property
3935
def info_text(self):

chebifier/prediction_models/c3p_predictor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from functools import lru_cache
21
from pathlib import Path
32
from typing import List, Optional
43

4+
from c3p import classifier as c3p_classifier
5+
6+
from chebifier import modelwise_smiles_lru_cache
57
from chebifier.prediction_models import BasePredictor
68

79

@@ -21,11 +23,10 @@ def __init__(
2123
self.program_directory = program_directory
2224
self.chemical_classes = chemical_classes
2325
self.chebi_graph = kwargs.get("chebi_graph", None)
24-
25-
@lru_cache(maxsize=100)
26-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
27-
from c3p import classifier as c3p_classifier
28-
26+
27+
@modelwise_smiles_lru_cache.batch_decorator
28+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
29+
from c3p import classifier as c3p_classifier
2930
result_list = c3p_classifier.classify(
3031
list(smiles_list),
3132
self.program_directory,

chebifier/prediction_models/chebi_lookup.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import json
22
import os
3-
from functools import lru_cache
43
from typing import Optional
54

65
from rdkit import Chem
76

7+
from chebifier import modelwise_smiles_lru_cache
8+
from chebifier.prediction_models import BasePredictor
9+
from chebifier.utils import load_chebi_graph
10+
811
from chebifier.prediction_models import BasePredictor
912

1013

@@ -69,7 +72,6 @@ def build_smiles_lookup(self):
6972
)
7073
return smiles_lookup
7174

72-
@lru_cache(maxsize=100)
7375
def predict_smiles(self, smiles: str) -> Optional[dict]:
7476
if not smiles:
7577
return None
@@ -96,7 +98,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
9698
else:
9799
return None
98100

99-
def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
101+
@modelwise_smiles_lru_cache.batch_decorator
102+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
100103
predictions = []
101104
for smiles in smiles_list:
102105
predictions.append(self.predict_smiles(smiles))

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333

3434

3535
class ChemlogExtraPredictor(BasePredictor):
36+
3637
def __init__(self, model_name: str, **kwargs):
3738
super().__init__(model_name, **kwargs)
3839
self.chebi_graph = kwargs.get("chebi_graph", None)
3940
self.classifier = None
4041

41-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
42+
@modelwise_smiles_lru_cache.batch_decorator
43+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
4244
from chemlog.cli import _smiles_to_mol
43-
4445
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
4546
res = self.classifier.classify(mol_list)
4647
if self.chebi_graph is not None:
@@ -94,7 +95,6 @@ def __init__(self, model_name: str, **kwargs):
9495
# fmt: on
9596
print(f"Initialised ChemLog model {self.model_name}")
9697

97-
@lru_cache(maxsize=100)
9898
def predict_smiles(self, smiles: str) -> Optional[dict]:
9999
from chemlog.cli import _smiles_to_mol, strategy_call
100100

@@ -121,7 +121,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
121121
for label in self.peptide_labels + pos_labels
122122
}
123123

124-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
124+
@modelwise_smiles_lru_cache.batch_decorator
125+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
125126
results = []
126127
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
127128
results.append(self.predict_smiles(smiles))

chebifier/prediction_models/nn_predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from functools import lru_cache
2-
31
import numpy as np
42
import tqdm
53
from rdkit import Chem
64

5+
from chebifier import modelwise_smiles_lru_cache
6+
77
from .base_predictor import BasePredictor
88

99

@@ -53,8 +53,8 @@ def read_smiles(self, smiles):
5353
d = reader.to_data(dict(features=smiles, labels=None))
5454
return d
5555

56-
@lru_cache(maxsize=100)
57-
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
56+
@modelwise_smiles_lru_cache.batch_decorator
57+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
5858
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
5959
Of classes and predicted values."""
6060
import torch

0 commit comments

Comments
 (0)