Skip to content

Commit 9cd4160

Browse files
committed
Fix pickle support to allow for parallelization and add parallelization tests
Fixed pickle support by removing all ctypes pointers from the state in CadetDLLRunner.__getstate__ and recreating the dll interface in CadetDLLRunner.__setstate__ . Fixed "no attribute __frozen" error by casting Cadet state into addict.Dict on Cadet.__setstate__ .
1 parent b99ef3c commit 9cd4160

File tree

4 files changed

+70
-1
lines changed

4 files changed

+70
-1
lines changed

cadet/cadet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Optional
77
import warnings
88

9+
from addict import Dict
10+
911
from cadet.h5 import H5
1012
from cadet.runner import CadetRunnerBase, CadetCLIRunner, ReturnInformation
1113
from cadet.cadet_dll import CadetDLLRunner
@@ -525,3 +527,12 @@ def __del__(self):
525527
self.clear()
526528
del self._cadet_dll_runner
527529
del self._cadet_cli_runner
530+
531+
def __getstate__(self):
532+
state = self.__dict__.copy()
533+
return state
534+
535+
def __setstate__(self, state):
536+
# Restore the state and cast to addict.Dict() to add __frozen attributes
537+
state = Dict(state)
538+
self.__dict__.update(state)

cadet/cadet_dll.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,6 +1626,9 @@ def __init__(self, dll_path: os.PathLike | str) -> None:
16261626
Path to the CADET DLL.
16271627
"""
16281628
self._cadet_path = Path(dll_path)
1629+
self._initialize_dll()
1630+
1631+
def _initialize_dll(self):
16291632
self._lib = ctypes.cdll.LoadLibrary(self._cadet_path.as_posix())
16301633

16311634
# Query meta information
@@ -1693,6 +1696,18 @@ def __init__(self, dll_path: os.PathLike | str) -> None:
16931696
self._driver = self._api.createDriver()
16941697
self.res: Optional[SimulationResult] = None
16951698

1699+
def __getstate__(self):
1700+
# Exclude all non-pickleable attributes and only keep _cadet_path
1701+
state = self.__dict__.copy()
1702+
pickleable_keys = ["_cadet_path"]
1703+
state = {key: state[key] for key in pickleable_keys}
1704+
return state
1705+
1706+
def __setstate__(self, state):
1707+
# Restore the state and reinitialize the DLL
1708+
self.__dict__.update(state)
1709+
self._initialize_dll()
1710+
16961711
def clear(self) -> None:
16971712
"""
16981713
Clear the current simulation state.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ dependencies = [
3333
]
3434

3535
[project.optional-dependencies]
36-
testing = ["pytest"]
36+
testing = [
37+
"pytest",
38+
"joblib"
39+
]
3740

3841
[project.urls]
3942
"homepage" = "https://github.com/cadet/CADET-Python"

tests/test_parallelization.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from cadet import Cadet
2+
from joblib import Parallel, delayed
3+
from .test_dll import setup_model
4+
5+
n_jobs = 2
6+
7+
8+
def run_simulation(model):
9+
model.save()
10+
data = model.run_load()
11+
return data
12+
13+
14+
def test_parallelization_io():
15+
model1 = Cadet()
16+
model1.root.input = {'model': 1}
17+
model1.filename = "sim_1.h5"
18+
model2 = Cadet()
19+
model2.root.input = {'model': 2}
20+
model2.filename = "sim_2.h5"
21+
22+
models = [model1, model2]
23+
24+
results_sequential = [run_simulation(model) for model in models]
25+
26+
results_parallel = Parallel(n_jobs=n_jobs, verbose=0)(
27+
delayed(run_simulation)(model, ) for model in models
28+
)
29+
assert results_sequential == results_parallel
30+
31+
32+
def test_parallelization_simulation():
33+
models = [setup_model(Cadet.autodetect_cadet(), file_name=f"LWE_{i}.h5") for i in range(2)]
34+
35+
results_sequential = [run_simulation(model) for model in models]
36+
37+
results_parallel = Parallel(n_jobs=n_jobs, verbose=0)(
38+
delayed(run_simulation)(model, ) for model in models
39+
)
40+
assert results_sequential == results_parallel

0 commit comments

Comments
 (0)