Skip to content

Commit ac2fda1

Browse files
authored
Add support for dask. (#86)
1 parent d3a7bdc commit ac2fda1

14 files changed

+325
-45
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
88
## 0.5.0 - 2024-xx-xx
99

1010
- {pull}`85` simplifies code since loky is a dependency.
11+
- {pull}`86` adds support for dask.
1112
- {pull}`88` updates handling `Traceback`.
1213
- {pull}`89` restructures the package.
1314
- {pull}`92` redirects stdout and stderr from processes and loky and shows them in error

environment.yml

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ dependencies:
1616
- loky
1717
- optree
1818

19+
# Additional dependencies
20+
- universal_pathlib <0.2
21+
- s3fs>=2023.4.0
22+
- coiled
23+
- distributed
24+
1925
# Misc
2026
- tox
2127
- ipywidgets

pyproject.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ name = "Tobias Raabe"
2929
3030

3131
[project.optional-dependencies]
32+
dask = ["dask[complete]", "distributed"]
3233
test = [
33-
"nbmake",
34-
"pytest",
35-
"pytest-cov",
34+
"pytask-parallel[all]",
35+
"nbmake",
36+
"pytest",
37+
"pytest-cov",
3638
]
3739

3840
[project.readme]
@@ -112,6 +114,7 @@ force-single-line = true
112114
convention = "numpy"
113115

114116
[tool.pytest.ini_options]
117+
addopts = ["--nbmake"]
115118
# Do not add src since it messes with the loading of pytask-parallel as a plugin.
116119
testpaths = ["tests"]
117120
markers = [

src/pytask_parallel/backends.py

+48-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from concurrent.futures import Executor
67
from concurrent.futures import Future
78
from concurrent.futures import ProcessPoolExecutor
@@ -12,6 +13,7 @@
1213
from typing import ClassVar
1314

1415
import cloudpickle
16+
from attrs import define
1517
from loky import get_reusable_executor
1618

1719
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"]
@@ -27,7 +29,7 @@ def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
2729
class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
2830
"""Patches the standard executor to serialize functions with cloudpickle."""
2931

30-
# The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated.
32+
# The type signature is wrong for Python >3.8. Fix when support is dropped.
3133
def submit( # type: ignore[override]
3234
self,
3335
fn: Callable[..., Any],
@@ -42,15 +44,54 @@ def submit( # type: ignore[override]
4244
)
4345

4446

47+
def _get_dask_executor(n_workers: int) -> Executor:
48+
"""Get an executor from a dask client."""
49+
_rich_traceback_omit = True
50+
from pytask import import_optional_dependency
51+
52+
distributed = import_optional_dependency("distributed")
53+
try:
54+
client = distributed.Client.current()
55+
except ValueError:
56+
client = distributed.Client(distributed.LocalCluster(n_workers=n_workers))
57+
else:
58+
if client.cluster and len(client.cluster.workers) != n_workers:
59+
warnings.warn(
60+
"The number of workers in the dask cluster "
61+
f"({len(client.cluster.workers)}) does not match the number of workers "
62+
f"requested ({n_workers}). The requested number of workers will be "
63+
"ignored.",
64+
stacklevel=1,
65+
)
66+
return client.get_executor()
67+
68+
69+
def _get_loky_executor(n_workers: int) -> Executor:
70+
"""Get a loky executor."""
71+
return get_reusable_executor(max_workers=n_workers)
72+
73+
74+
def _get_process_pool_executor(n_workers: int) -> Executor:
75+
"""Get a process pool executor."""
76+
return _CloudpickleProcessPoolExecutor(max_workers=n_workers)
77+
78+
79+
def _get_thread_pool_executor(n_workers: int) -> Executor:
80+
"""Get a thread pool executor."""
81+
return ThreadPoolExecutor(max_workers=n_workers)
82+
83+
4584
class ParallelBackend(Enum):
4685
"""Choices for parallel backends."""
4786

4887
CUSTOM = "custom"
88+
DASK = "dask"
4989
LOKY = "loky"
5090
PROCESSES = "processes"
5191
THREADS = "threads"
5292

5393

94+
@define
5495
class ParallelBackendRegistry:
5596
"""Registry for parallel backends."""
5697

@@ -68,23 +109,19 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo
68109
try:
69110
return self.registry[kind](n_workers=n_workers)
70111
except KeyError:
71-
msg = f"No registered parallel backend found for kind {kind}."
112+
msg = f"No registered parallel backend found for kind {kind.value!r}."
72113
raise ValueError(msg) from None
73114
except Exception as e: # noqa: BLE001
74-
msg = f"Could not instantiate parallel backend {kind.value}."
115+
msg = f"Could not instantiate parallel backend {kind.value!r}."
75116
raise ValueError(msg) from e
76117

77118

78119
registry = ParallelBackendRegistry()
79120

80121

122+
registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor)
123+
registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor)
81124
registry.register_parallel_backend(
82-
ParallelBackend.PROCESSES,
83-
lambda n_workers: _CloudpickleProcessPoolExecutor(max_workers=n_workers),
84-
)
85-
registry.register_parallel_backend(
86-
ParallelBackend.THREADS, lambda n_workers: ThreadPoolExecutor(max_workers=n_workers)
87-
)
88-
registry.register_parallel_backend(
89-
ParallelBackend.LOKY, lambda n_workers: get_reusable_executor(max_workers=n_workers)
125+
ParallelBackend.PROCESSES, _get_process_pool_executor
90126
)
127+
registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor)

src/pytask_parallel/config.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from pytask import hookimpl
99

1010
from pytask_parallel import custom
11+
from pytask_parallel import dask
1112
from pytask_parallel import execute
13+
from pytask_parallel import logging
1214
from pytask_parallel import processes
1315
from pytask_parallel import threads
1416
from pytask_parallel.backends import ParallelBackend
@@ -37,19 +39,20 @@ def pytask_parse_config(config: dict[str, Any]) -> None:
3739
@hookimpl(trylast=True)
3840
def pytask_post_parse(config: dict[str, Any]) -> None:
3941
"""Register the parallel backend if debugging is not enabled."""
42+
# Deactivate parallel execution if debugger, trace or dry-run is used.
4043
if config["pdb"] or config["trace"] or config["dry_run"]:
41-
config["n_workers"] = 1
44+
return
4245

43-
# Register parallel execute hook.
44-
if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM:
45-
config["pm"].register(execute)
46+
# Register parallel execute and logging hook.
47+
config["pm"].register(logging)
48+
config["pm"].register(execute)
4649

4750
# Register parallel backends.
48-
if config["n_workers"] > 1:
49-
if config["parallel_backend"] == ParallelBackend.THREADS:
50-
config["pm"].register(threads)
51-
else:
52-
config["pm"].register(processes)
53-
54-
if config["parallel_backend"] == ParallelBackend.CUSTOM:
51+
if config["parallel_backend"] == ParallelBackend.THREADS:
52+
config["pm"].register(threads)
53+
elif config["parallel_backend"] == ParallelBackend.DASK:
54+
config["pm"].register(dask)
55+
elif config["parallel_backend"] == ParallelBackend.CUSTOM:
5556
config["pm"].register(custom)
57+
else:
58+
config["pm"].register(processes)

0 commit comments

Comments
 (0)