Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 38a2a87

Browse files
committedMar 19, 2025·
dwelltime: allow parallel bootstraps
1 parent 5db41ad commit 38a2a87

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed
 

‎lumicks/pylake/population/dwelltime.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from typing import Dict, Tuple, Union
33
from dataclasses import field, dataclass
4+
from multiprocessing import Pool
45

56
import numpy as np
67
import scipy
@@ -292,7 +293,7 @@ def __post_init__(self):
292293
)
293294

294295
@classmethod
295-
def _from_dwelltime_model(cls, optimized, iterations):
296+
def _from_dwelltime_model(cls, optimized, iterations, num_processes=None):
296297
"""Construct bootstrap distributions for parameters from an optimized
297298
:class:`~lumicks.pylake.DwelltimeModel`.
298299
@@ -306,9 +307,27 @@ def _from_dwelltime_model(cls, optimized, iterations):
306307
optimized model results
307308
iterations : int
308309
number of iterations (random samples) to use for the bootstrap
310+
num_processes : int | None
311+
number of processes to use for parallelization. If `None`, no parallelization is used.
309312
"""
310-
samples = DwelltimeBootstrap._sample(optimized, iterations)
311-
return cls(optimized, samples[: optimized.n_components], samples[optimized.n_components :])
313+
314+
def from_samples(samples):
315+
return cls(
316+
optimized,
317+
samples[: optimized.n_components],
318+
samples[optimized.n_components :],
319+
)
320+
321+
if num_processes is None:
322+
return from_samples(DwelltimeBootstrap._sample(optimized, iterations))
323+
324+
with Pool(num_processes) as p:
325+
result = p.starmap(
326+
DwelltimeBootstrap._sample,
327+
[(optimized, int(np.ceil(iterations / num_processes)))] * num_processes,
328+
)
329+
330+
return from_samples(np.hstack(result)[:, :iterations])
312331

313332
def extend(self, iterations):
314333
"""Extend the distribution by additional sampling iterations.
@@ -721,15 +740,19 @@ def bic(self) -> float:
721740
n = self.dwelltimes.size # number of observations
722741
return k * np.log(n) - 2 * self.log_likelihood
723742

724-
def calculate_bootstrap(self, iterations=500):
743+
def calculate_bootstrap(self, iterations=500, *, num_processes=None):
725744
"""Calculate a bootstrap distribution for the model.
726745
727746
Parameters
728747
----------
729748
iterations : int
730749
Number of iterations to sample for the distribution.
750+
num_processes : int | None
751+
Number of processes to use for parallelization. If `None`, no parallelization is used.
731752
"""
732-
bootstrap = DwelltimeBootstrap._from_dwelltime_model(self, iterations)
753+
bootstrap = DwelltimeBootstrap._from_dwelltime_model(
754+
self, iterations, num_processes=num_processes
755+
)
733756
# TODO: remove with deprecation
734757
self._bootstrap = bootstrap
735758
return bootstrap

‎lumicks/pylake/population/tests/test_dwelltimes.py

+16
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,22 @@ def test_bootstrap_multi(min_obs, max_obs, ref_ci, time_step):
158158
np.testing.assert_allclose(ci, ref_ci, rtol=1e-5)
159159

160160

161+
@pytest.mark.parametrize("num_components", [1, 2])
162+
def test_bootstrap_parallel(monkeypatch, exponential_data, num_components):
163+
import multiprocessing.dummy
164+
165+
dataset = exponential_data["dataset_2exp"]
166+
fit = DwelltimeModel(
167+
dataset["data"], num_components, **dataset["parameters"].observation_limits
168+
)
169+
170+
with monkeypatch.context() as m:
171+
m.setattr("multiprocessing.Pool", multiprocessing.dummy.Pool)
172+
173+
bootstrap = fit.calculate_bootstrap(iterations=3, num_processes=2)
174+
assert bootstrap.n_samples == 3
175+
176+
161177
@pytest.mark.slow
162178
def test_bootstrap(exponential_data):
163179
# double exponential data

0 commit comments

Comments
 (0)
Please sign in to comment.