Skip to content

Commit 2a68268

Browse files
authored
Simplify code since loky is a dependency. (#85)
1 parent 359dd18 commit 2a68268

File tree

5 files changed

+53
-63
lines changed

5 files changed

+53
-63
lines changed

CHANGES.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).
77

8+
## 0.4.2 - 2024-xx-xx
9+
10+
- {pull}`85` simplifies code since loky is a dependency.
11+
812
## 0.4.1 - 2024-01-12
913

1014
- {pull}`72` moves the project to `pyproject.toml`.

src/pytask_parallel/backends.py

+12-32
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from __future__ import annotations
44

5-
import enum
65
from concurrent.futures import Future
76
from concurrent.futures import ProcessPoolExecutor
87
from concurrent.futures import ThreadPoolExecutor
8+
from enum import Enum
99
from typing import Any
1010
from typing import Callable
1111

1212
import cloudpickle
13+
from loky import get_reusable_executor
1314

1415

1516
def deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
@@ -37,37 +38,16 @@ def submit( # type: ignore[override]
3738
)
3839

3940

40-
try:
41-
from loky import get_reusable_executor
41+
class ParallelBackend(Enum):
42+
"""Choices for parallel backends."""
4243

43-
except ImportError:
44+
PROCESSES = "processes"
45+
THREADS = "threads"
46+
LOKY = "loky"
4447

45-
class ParallelBackend(enum.Enum):
46-
"""Choices for parallel backends."""
4748

48-
PROCESSES = "processes"
49-
THREADS = "threads"
50-
51-
PARALLEL_BACKENDS_DEFAULT = ParallelBackend.PROCESSES
52-
53-
PARALLEL_BACKENDS = {
54-
ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor,
55-
ParallelBackend.THREADS: ThreadPoolExecutor,
56-
}
57-
58-
else:
59-
60-
class ParallelBackend(enum.Enum): # type: ignore[no-redef]
61-
"""Choices for parallel backends."""
62-
63-
PROCESSES = "processes"
64-
THREADS = "threads"
65-
LOKY = "loky"
66-
67-
PARALLEL_BACKENDS_DEFAULT = ParallelBackend.LOKY # type: ignore[attr-defined]
68-
69-
PARALLEL_BACKENDS = {
70-
ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor,
71-
ParallelBackend.THREADS: ThreadPoolExecutor,
72-
ParallelBackend.LOKY: get_reusable_executor, # type: ignore[attr-defined]
73-
}
49+
PARALLEL_BACKEND_BUILDER = {
50+
ParallelBackend.PROCESSES: lambda: CloudpickleProcessPoolExecutor,
51+
ParallelBackend.THREADS: lambda: ThreadPoolExecutor,
52+
ParallelBackend.LOKY: lambda: get_reusable_executor,
53+
}

src/pytask_parallel/build.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pytask import EnumChoice
77
from pytask import hookimpl
88

9-
from pytask_parallel.backends import PARALLEL_BACKENDS_DEFAULT
109
from pytask_parallel.backends import ParallelBackend
1110

1211

@@ -27,7 +26,7 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None:
2726
["--parallel-backend"],
2827
type=EnumChoice(ParallelBackend),
2928
help="Backend for the parallelization.",
30-
default=PARALLEL_BACKENDS_DEFAULT,
29+
default=ParallelBackend.LOKY,
3130
),
3231
]
3332
cli.commands["build"].params.extend(additional_parameters)

src/pytask_parallel/execute.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytask.tree_util import tree_structure
3535
from rich.traceback import Traceback
3636

37-
from pytask_parallel.backends import PARALLEL_BACKENDS
37+
from pytask_parallel.backends import PARALLEL_BACKEND_BUILDER
3838
from pytask_parallel.backends import ParallelBackend
3939

4040
if TYPE_CHECKING:
@@ -54,6 +54,12 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
5454
else:
5555
config["pm"].register(ProcessesNameSpace)
5656

57+
if PARALLEL_BACKEND_BUILDER[config["parallel_backend"]] is None:
58+
raise
59+
config["_parallel_executor"] = PARALLEL_BACKEND_BUILDER[
60+
config["parallel_backend"]
61+
]()
62+
5763

5864
@hookimpl(tryfirst=True)
5965
def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR0915
@@ -73,7 +79,9 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
7379
reports = session.execution_reports
7480
running_tasks: dict[str, Future[Any]] = {}
7581

76-
parallel_backend = PARALLEL_BACKENDS[session.config["parallel_backend"]]
82+
parallel_backend = PARALLEL_BACKEND_BUILDER[
83+
session.config["parallel_backend"]
84+
]()
7785

7886
with parallel_backend(max_workers=session.config["n_workers"]) as executor:
7987
session.config["_parallel_executor"] = executor

tests/test_execute.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytask import ExitCode
88
from pytask import build
99
from pytask import cli
10-
from pytask_parallel.backends import PARALLEL_BACKENDS
1110
from pytask_parallel.backends import ParallelBackend
1211
from pytask_parallel.execute import _Sleeper
1312

@@ -19,18 +18,18 @@ class Session:
1918

2019

2120
@pytest.mark.end_to_end()
22-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
21+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
2322
def test_parallel_execution(tmp_path, parallel_backend):
2423
source = """
25-
import pytask
24+
from pytask import Product
25+
from pathlib import Path
26+
from typing_extensions import Annotated
2627
27-
@pytask.mark.produces("out_1.txt")
28-
def task_1(produces):
29-
produces.write_text("1")
28+
def task_1(path: Annotated[Path, Product] = Path("out_1.txt")):
29+
path.write_text("1")
3030
31-
@pytask.mark.produces("out_2.txt")
32-
def task_2(produces):
33-
produces.write_text("2")
31+
def task_2(path: Annotated[Path, Product] = Path("out_2.txt")):
32+
path.write_text("2")
3433
"""
3534
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
3635
session = build(paths=tmp_path, n_workers=2, parallel_backend=parallel_backend)
@@ -41,18 +40,18 @@ def task_2(produces):
4140

4241

4342
@pytest.mark.end_to_end()
44-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
43+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
4544
def test_parallel_execution_w_cli(runner, tmp_path, parallel_backend):
4645
source = """
47-
import pytask
46+
from pytask import Product
47+
from pathlib import Path
48+
from typing_extensions import Annotated
4849
49-
@pytask.mark.produces("out_1.txt")
50-
def task_1(produces):
51-
produces.write_text("1")
50+
def task_1(path: Annotated[Path, Product] = Path("out_1.txt")):
51+
path.write_text("1")
5252
53-
@pytask.mark.produces("out_2.txt")
54-
def task_2(produces):
55-
produces.write_text("2")
53+
def task_2(path: Annotated[Path, Product] = Path("out_2.txt")):
54+
path.write_text("2")
5655
"""
5756
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
5857
result = runner.invoke(
@@ -71,7 +70,7 @@ def task_2(produces):
7170

7271

7372
@pytest.mark.end_to_end()
74-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
73+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
7574
def test_stop_execution_when_max_failures_is_reached(tmp_path, parallel_backend):
7675
source = """
7776
import time
@@ -99,7 +98,7 @@ def task_3(): time.sleep(3)
9998

10099

101100
@pytest.mark.end_to_end()
102-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
101+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
103102
def test_task_priorities(tmp_path, parallel_backend):
104103
source = """
105104
import pytask
@@ -140,7 +139,7 @@ def task_5():
140139

141140

142141
@pytest.mark.end_to_end()
143-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
142+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
144143
@pytest.mark.parametrize("show_locals", [True, False])
145144
def test_rendering_of_tracebacks_with_rich(
146145
runner, tmp_path, parallel_backend, show_locals
@@ -173,12 +172,12 @@ def task_raising_error():
173172
)
174173
def test_collect_warnings_from_parallelized_tasks(runner, tmp_path, parallel_backend):
175174
source = """
176-
import pytask
175+
from pytask import task
177176
import warnings
178177
179178
for i in range(2):
180179
181-
@pytask.mark.task(id=str(i), kwargs={"produces": f"{i}.txt"})
180+
@task(id=str(i), kwargs={"produces": f"{i}.txt"})
182181
def task_example(produces):
183182
warnings.warn("This is a warning.")
184183
produces.touch()
@@ -222,7 +221,7 @@ def test_sleeper():
222221

223222

224223
@pytest.mark.end_to_end()
225-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
224+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
226225
def test_task_that_return(runner, tmp_path, parallel_backend):
227226
source = """
228227
from pathlib import Path
@@ -242,7 +241,7 @@ def task_example() -> Annotated[str, Path("file.txt")]:
242241

243242

244243
@pytest.mark.end_to_end()
245-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
244+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
246245
def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
247246
source = """
248247
from pathlib import Path
@@ -264,7 +263,7 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
264263

265264
@pytest.mark.end_to_end()
266265
@pytest.mark.parametrize("flag", ["--pdb", "--trace", "--dry-run"])
267-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
266+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
268267
def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend):
269268
tmp_path.joinpath("task_example.py").write_text("def task_example(): pass")
270269
result = runner.invoke(
@@ -278,7 +277,7 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back
278277
@pytest.mark.end_to_end()
279278
@pytest.mark.parametrize("code", ["breakpoint()", "import pdb; pdb.set_trace()"])
280279
@pytest.mark.parametrize(
281-
"parallel_backend", [i for i in PARALLEL_BACKENDS if i != ParallelBackend.THREADS]
280+
"parallel_backend", [i for i in ParallelBackend if i != ParallelBackend.THREADS]
282281
)
283282
def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):
284283
tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}")
@@ -290,7 +289,7 @@ def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):
290289

291290

292291
@pytest.mark.end_to_end()
293-
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
292+
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
294293
def test_task_partialed(runner, tmp_path, parallel_backend):
295294
source = """
296295
from pathlib import Path

0 commit comments

Comments
 (0)