2
2
3
3
from __future__ import annotations
4
4
5
+ import warnings
5
6
from concurrent .futures import Executor
6
7
from concurrent .futures import Future
7
8
from concurrent .futures import ProcessPoolExecutor
12
13
from typing import ClassVar
13
14
14
15
import cloudpickle
16
+ from attrs import define
15
17
from loky import get_reusable_executor
16
18
17
19
__all__ = ["ParallelBackend" , "ParallelBackendRegistry" , "registry" ]
@@ -27,7 +29,7 @@ def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
27
29
class _CloudpickleProcessPoolExecutor (ProcessPoolExecutor ):
28
30
"""Patches the standard executor to serialize functions with cloudpickle."""
29
31
30
- # The type signature is wrong for version above Py3.7 . Fix when 3.7 is deprecated .
32
+ # The type signature is wrong for Python >3.8 . Fix when support is dropped .
31
33
def submit ( # type: ignore[override]
32
34
self ,
33
35
fn : Callable [..., Any ],
@@ -42,15 +44,54 @@ def submit( # type: ignore[override]
42
44
)
43
45
44
46
47
+ def _get_dask_executor (n_workers : int ) -> Executor :
48
+ """Get an executor from a dask client."""
49
+ _rich_traceback_omit = True
50
+ from pytask import import_optional_dependency
51
+
52
+ distributed = import_optional_dependency ("distributed" )
53
+ try :
54
+ client = distributed .Client .current ()
55
+ except ValueError :
56
+ client = distributed .Client (distributed .LocalCluster (n_workers = n_workers ))
57
+ else :
58
+ if client .cluster and len (client .cluster .workers ) != n_workers :
59
+ warnings .warn (
60
+ "The number of workers in the dask cluster "
61
+ f"({ len (client .cluster .workers )} ) does not match the number of workers "
62
+ f"requested ({ n_workers } ). The requested number of workers will be "
63
+ "ignored." ,
64
+ stacklevel = 1 ,
65
+ )
66
+ return client .get_executor ()
67
+
68
+
69
+ def _get_loky_executor (n_workers : int ) -> Executor :
70
+ """Get a loky executor."""
71
+ return get_reusable_executor (max_workers = n_workers )
72
+
73
+
74
+ def _get_process_pool_executor (n_workers : int ) -> Executor :
75
+ """Get a process pool executor."""
76
+ return _CloudpickleProcessPoolExecutor (max_workers = n_workers )
77
+
78
+
79
+ def _get_thread_pool_executor (n_workers : int ) -> Executor :
80
+ """Get a thread pool executor."""
81
+ return ThreadPoolExecutor (max_workers = n_workers )
82
+
83
+
45
84
class ParallelBackend (Enum ):
46
85
"""Choices for parallel backends."""
47
86
48
87
CUSTOM = "custom"
88
+ DASK = "dask"
49
89
LOKY = "loky"
50
90
PROCESSES = "processes"
51
91
THREADS = "threads"
52
92
53
93
94
+ @define
54
95
class ParallelBackendRegistry :
55
96
"""Registry for parallel backends."""
56
97
@@ -68,23 +109,19 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo
68
109
try :
69
110
return self .registry [kind ](n_workers = n_workers )
70
111
except KeyError :
71
- msg = f"No registered parallel backend found for kind { kind } ."
112
+ msg = f"No registered parallel backend found for kind { kind . value !r } ."
72
113
raise ValueError (msg ) from None
73
114
except Exception as e : # noqa: BLE001
74
- msg = f"Could not instantiate parallel backend { kind .value } ."
115
+ msg = f"Could not instantiate parallel backend { kind .value !r } ."
75
116
raise ValueError (msg ) from e
76
117
77
118
78
119
registry = ParallelBackendRegistry ()
79
120
80
121
122
+ registry .register_parallel_backend (ParallelBackend .DASK , _get_dask_executor )
123
+ registry .register_parallel_backend (ParallelBackend .LOKY , _get_loky_executor )
81
124
registry .register_parallel_backend (
82
- ParallelBackend .PROCESSES ,
83
- lambda n_workers : _CloudpickleProcessPoolExecutor (max_workers = n_workers ),
84
- )
85
- registry .register_parallel_backend (
86
- ParallelBackend .THREADS , lambda n_workers : ThreadPoolExecutor (max_workers = n_workers )
87
- )
88
- registry .register_parallel_backend (
89
- ParallelBackend .LOKY , lambda n_workers : get_reusable_executor (max_workers = n_workers )
125
+ ParallelBackend .PROCESSES , _get_process_pool_executor
90
126
)
127
+ registry .register_parallel_backend (ParallelBackend .THREADS , _get_thread_pool_executor )
0 commit comments