22
33from __future__ import annotations
44
5+ import warnings
56from concurrent .futures import Executor
67from concurrent .futures import Future
78from concurrent .futures import ProcessPoolExecutor
1213from typing import ClassVar
1314
1415import cloudpickle
16+ from attrs import define
1517from loky import get_reusable_executor
1618
1719__all__ = ["ParallelBackend" , "ParallelBackendRegistry" , "registry" ]
@@ -27,7 +29,7 @@ def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
2729class _CloudpickleProcessPoolExecutor (ProcessPoolExecutor ):
2830 """Patches the standard executor to serialize functions with cloudpickle."""
2931
30- # The type signature is wrong for version above Py3.7 . Fix when 3.7 is deprecated .
32+ # The type signature is wrong for Python >3.8 . Fix when support is dropped .
3133 def submit ( # type: ignore[override]
3234 self ,
3335 fn : Callable [..., Any ],
@@ -42,15 +44,54 @@ def submit( # type: ignore[override]
4244 )
4345
4446
47+ def _get_dask_executor (n_workers : int ) -> Executor :
48+ """Get an executor from a dask client."""
49+ _rich_traceback_omit = True
50+ from pytask import import_optional_dependency
51+
52+ distributed = import_optional_dependency ("distributed" )
53+ try :
54+ client = distributed .Client .current ()
55+ except ValueError :
56+ client = distributed .Client (distributed .LocalCluster (n_workers = n_workers ))
57+ else :
58+ if client .cluster and len (client .cluster .workers ) != n_workers :
59+ warnings .warn (
60+ "The number of workers in the dask cluster "
61+ f"({ len (client .cluster .workers )} ) does not match the number of workers "
62+ f"requested ({ n_workers } ). The requested number of workers will be "
63+ "ignored." ,
64+ stacklevel = 1 ,
65+ )
66+ return client .get_executor ()
67+
68+
69+ def _get_loky_executor (n_workers : int ) -> Executor :
70+ """Get a loky executor."""
71+ return get_reusable_executor (max_workers = n_workers )
72+
73+
74+ def _get_process_pool_executor (n_workers : int ) -> Executor :
75+ """Get a process pool executor."""
76+ return _CloudpickleProcessPoolExecutor (max_workers = n_workers )
77+
78+
79+ def _get_thread_pool_executor (n_workers : int ) -> Executor :
80+ """Get a thread pool executor."""
81+ return ThreadPoolExecutor (max_workers = n_workers )
82+
83+
4584class ParallelBackend (Enum ):
4685 """Choices for parallel backends."""
4786
4887 CUSTOM = "custom"
88+ DASK = "dask"
4989 LOKY = "loky"
5090 PROCESSES = "processes"
5191 THREADS = "threads"
5292
5393
94+ @define
5495class ParallelBackendRegistry :
5596 """Registry for parallel backends."""
5697
@@ -68,23 +109,19 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo
68109 try :
69110 return self .registry [kind ](n_workers = n_workers )
70111 except KeyError :
71- msg = f"No registered parallel backend found for kind { kind } ."
112+ msg = f"No registered parallel backend found for kind { kind . value !r } ."
72113 raise ValueError (msg ) from None
73114 except Exception as e : # noqa: BLE001
74- msg = f"Could not instantiate parallel backend { kind .value } ."
115+ msg = f"Could not instantiate parallel backend { kind .value !r } ."
75116 raise ValueError (msg ) from e
76117
77118
78119registry = ParallelBackendRegistry ()
79120
80121
122+ registry .register_parallel_backend (ParallelBackend .DASK , _get_dask_executor )
123+ registry .register_parallel_backend (ParallelBackend .LOKY , _get_loky_executor )
81124registry .register_parallel_backend (
82- ParallelBackend .PROCESSES ,
83- lambda n_workers : _CloudpickleProcessPoolExecutor (max_workers = n_workers ),
84- )
85- registry .register_parallel_backend (
86- ParallelBackend .THREADS , lambda n_workers : ThreadPoolExecutor (max_workers = n_workers )
87- )
88- registry .register_parallel_backend (
89- ParallelBackend .LOKY , lambda n_workers : get_reusable_executor (max_workers = n_workers )
125+ ParallelBackend .PROCESSES , _get_process_pool_executor
90126)
127+ registry .register_parallel_backend (ParallelBackend .THREADS , _get_thread_pool_executor )
0 commit comments