Skip to content

Commit 6de2cae

Browse files
authored
Support partialed functions. (#80)
1 parent b9918b1 commit 6de2cae

File tree

6 files changed

+75
-40
lines changed

6 files changed

+75
-40
lines changed

.github/workflows/main.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242

4343
- name: Run unit tests and doctests.
4444
shell: bash -l {0}
45-
run: tox -e test -- tests -m "unit or (not integration and not end_to_end)" --cov=./ --cov-report=xml
45+
run: tox -e test -- tests -m "unit or (not integration and not end_to_end)" --cov=src --cov=tests --cov-report=xml
4646

4747
- name: Upload coverage report for unit tests and doctests.
4848
if: runner.os == 'Linux' && matrix.python-version == '3.10'
@@ -51,7 +51,7 @@ jobs:
5151

5252
- name: Run end-to-end tests.
5353
shell: bash -l {0}
54-
run: tox -e test -- tests -m end_to_end --cov=./ --cov-report=xml
54+
run: tox -e test -- tests -m end_to_end --cov=src --cov=tests --cov-report=xml
5555

5656
- name: Upload coverage reports of end-to-end tests.
5757
if: runner.os == 'Linux' && matrix.python-version == '3.10'

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ __pycache__
1616
build
1717
dist
1818
src/pytask_parallel/_version.py
19-
tests/test_jupyter/file.txt
19+
tests/test_jupyter/*.txt

CHANGES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).
77

8-
## 0.4.1 - 2023-12-xx
8+
## 0.4.1 - 2024-01-12
99

1010
- {pull}`72` moves the project to `pyproject.toml`.
1111
- {pull}`75` updates the release strategy.
1212
- {pull}`79` add tests for Jupyter and fix parallelization with `PythonNode`s.
13+
- {pull}`80` adds support for partialed functions.
1314
- {pull}`82` fixes testing with pytask v0.4.5.
1415

1516
## 0.4.0 - 2023-10-07

src/pytask_parallel/execute.py

+35-29
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
import warnings
88
from concurrent.futures import Future
9+
from functools import partial
910
from pathlib import Path
1011
from types import ModuleType
1112
from types import TracebackType
@@ -296,23 +297,7 @@ def _execute_task( # noqa: PLR0913
296297
exc_info, show_locals, console_options
297298
)
298299
else:
299-
if "return" in task.produces:
300-
structure_out = tree_structure(out)
301-
structure_return = tree_structure(task.produces["return"])
302-
# strict must be false when none is leaf.
303-
if not structure_return.is_prefix(structure_out, strict=False):
304-
msg = (
305-
"The structure of the return annotation is not a subtree of "
306-
"the structure of the function return.\n\nFunction return: "
307-
f"{structure_out}\n\nReturn annotation: {structure_return}"
308-
)
309-
raise ValueError(msg)
310-
311-
nodes = tree_leaves(task.produces["return"])
312-
values = structure_return.flatten_up_to(out)
313-
for node, value in zip(nodes, values):
314-
node.save(value)
315-
300+
_handle_task_function_return(task, out)
316301
processed_exc_info = None
317302

318303
task_display_name = getattr(task, "display_name", task.name)
@@ -347,6 +332,27 @@ def _process_exception(
347332
return (*exc_info[:2], text)
348333

349334

335+
def _handle_task_function_return(task: PTask, out: Any) -> None:
336+
if "return" not in task.produces:
337+
return
338+
339+
structure_out = tree_structure(out)
340+
structure_return = tree_structure(task.produces["return"])
341+
# strict must be false when none is leaf.
342+
if not structure_return.is_prefix(structure_out, strict=False):
343+
msg = (
344+
"The structure of the return annotation is not a subtree of "
345+
"the structure of the function return.\n\nFunction return: "
346+
f"{structure_out}\n\nReturn annotation: {structure_return}"
347+
)
348+
raise ValueError(msg)
349+
350+
nodes = tree_leaves(task.produces["return"])
351+
values = structure_return.flatten_up_to(out)
352+
for node, value in zip(nodes, values):
353+
node.save(value)
354+
355+
350356
class DefaultBackendNameSpace:
351357
"""The name space for hooks related to threads."""
352358

@@ -362,13 +368,13 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
362368
if session.config["n_workers"] > 1:
363369
kwargs = _create_kwargs_for_task(task)
364370
return session.config["_parallel_executor"].submit(
365-
_mock_processes_for_threads, func=task.execute, **kwargs
371+
_mock_processes_for_threads, task=task, **kwargs
366372
)
367373
return None
368374

369375

370376
def _mock_processes_for_threads(
371-
func: Callable[..., Any], **kwargs: Any
377+
task: PTask, **kwargs: Any
372378
) -> tuple[
373379
None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None
374380
]:
@@ -381,10 +387,11 @@ def _mock_processes_for_threads(
381387
"""
382388
__tracebackhide__ = True
383389
try:
384-
func(**kwargs)
390+
out = task.function(**kwargs)
385391
except Exception: # noqa: BLE001
386392
exc_info = sys.exc_info()
387393
else:
394+
_handle_task_function_return(task, out)
388395
exc_info = None
389396
return None, [], exc_info
390397

@@ -430,18 +437,17 @@ def sleep(self) -> None:
430437
def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
431438
"""Get the module of a python function.
432439
433-
For Python <3.10, functools.partial does not set a `__module__` attribute which is
434-
why ``inspect.getmodule`` returns ``None`` and ``cloudpickle.pickle_by_value``
435-
fails. In later versions, ``functools`` is returned and everything seems to work
436-
fine.
440+
``functools.partial`` obfuscates the module of the function and
441+
``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original
442+
function.
437443
438-
Therefore, we use the path from the task module to aid the search which works for
439-
Python <3.10.
440-
441-
We do not unwrap the partialed function with ``func.func``, since pytask in general
442-
does not really support ``functools.partial``. Instead, use ``@task(kwargs=...)``.
444+
We use the path from the task module to aid the search although it is not clear
445+
whether it helps.
443446
444447
"""
448+
if isinstance(func, partial):
449+
func = func.func
450+
445451
if path:
446452
return inspect.getmodule(func, path.as_posix())
447453
return inspect.getmodule(func)

tests/test_execute.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,12 @@ def task_example() -> Annotated[str, Path("file.txt")]:
233233
"""
234234
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
235235
result = runner.invoke(
236-
cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend]
236+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
237237
)
238238
assert result.exit_code == ExitCode.OK
239-
assert tmp_path.joinpath("file.txt").exists()
239+
assert (
240+
tmp_path.joinpath("file.txt").read_text() == "Hello, Darkness, my old friend."
241+
)
240242

241243

242244
@pytest.mark.end_to_end()
@@ -252,10 +254,12 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
252254
"""
253255
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
254256
result = runner.invoke(
255-
cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend]
257+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
256258
)
257259
assert result.exit_code == ExitCode.OK
258-
assert tmp_path.joinpath("file.txt").exists()
260+
assert (
261+
tmp_path.joinpath("file.txt").read_text() == "Hello, Darkness, my old friend."
262+
)
259263

260264

261265
@pytest.mark.end_to_end()
@@ -264,7 +268,8 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
264268
def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend):
265269
tmp_path.joinpath("task_example.py").write_text("def task_example(): pass")
266270
result = runner.invoke(
267-
cli, [tmp_path.as_posix(), "-n 2", "--parallel-backend", parallel_backend, flag]
271+
cli,
272+
[tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend, flag],
268273
)
269274
assert result.exit_code == ExitCode.OK
270275
assert "Started 2 workers" not in result.output
@@ -278,7 +283,30 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back
278283
def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):
279284
tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}")
280285
result = runner.invoke(
281-
cli, [tmp_path.as_posix(), "-n 2", "--parallel-backend", parallel_backend]
286+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
282287
)
283288
assert result.exit_code == ExitCode.FAILED
284289
assert "You cannot use 'breakpoint()'" in result.output
290+
291+
292+
@pytest.mark.end_to_end()
293+
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
294+
def test_task_partialed(runner, tmp_path, parallel_backend):
295+
source = """
296+
from pathlib import Path
297+
from pytask import task
298+
from functools import partial
299+
300+
def create_text(text):
301+
return text
302+
303+
task_example = task(
304+
produces=Path("file.txt")
305+
)(partial(create_text, text="Hello, Darkness, my old friend."))
306+
"""
307+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
308+
result = runner.invoke(
309+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
310+
)
311+
assert result.exit_code == ExitCode.OK
312+
assert tmp_path.joinpath("file.txt").exists()

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = tox>=4
33
envlist = test
44

55
[testenv]
6-
package = wheel
6+
package = editable
77

88
[testenv:test]
99
extras = test

0 commit comments

Comments
 (0)