Skip to content

Commit 4f9424f

Browse files
committed
Stop unwrapping coiled functions.
1 parent 3e2d415 commit 4f9424f

File tree

5 files changed

+64
-5
lines changed

5 files changed

+64
-5
lines changed

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
3939
- {pull}`593` recreate `PythonNode`s every run since they carry the `_NoDefault` enum as
4040
the value whose state is `None`.
4141
- {pull}`594` publishes `NodeLoadError`.
42+
- {pull}`595` stops unwrapping task functions until a `coiled.function.Function`.
4243

4344
## 0.4.7 - 2024-03-19
4445

src/_pytask/collect.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from _pytask.reports import CollectionReport
4747
from _pytask.shared import find_duplicates
4848
from _pytask.shared import to_list
49+
from _pytask.shared import unwrap_task_function
4950
from _pytask.task_utils import COLLECTED_TASKS
5051
from _pytask.task_utils import task as task_decorator
5152
from _pytask.typing import is_task_function
@@ -317,9 +318,7 @@ def pytask_collect_task(
317318
obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False
318319
)
319320

320-
# Get the underlying function to avoid having different states of the function,
321-
# e.g. due to pytask_meta, in different layers of the wrapping.
322-
unwrapped = inspect.unwrap(obj)
321+
unwrapped = unwrap_task_function(obj)
323322

324323
if path is None:
325324
return TaskWithoutPath(

src/_pytask/shared.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from __future__ import annotations
44

55
import glob
6+
import inspect
67
from pathlib import Path
78
from typing import TYPE_CHECKING
89
from typing import Any
10+
from typing import Callable
911
from typing import Iterable
1012
from typing import Sequence
1113

1214
import click
15+
from attrs import define
1316

1417
from _pytask.console import format_node_name
1518
from _pytask.console import format_task_name
@@ -23,13 +26,26 @@
2326
import networkx as nx
2427

2528

29+
try:
30+
from coiled.function import Function as CoiledFunction
31+
except ImportError:
32+
33+
@define
34+
class CoiledFunction: # type: ignore[no-redef]
35+
cluster_kwargs: dict[str, Any]
36+
environ: dict[str, Any]
37+
function: Callable[..., Any] | None
38+
keepalive: str | None
39+
40+
2641
__all__ = [
2742
"convert_to_enum",
2843
"find_duplicates",
2944
"parse_markers",
3045
"parse_paths",
3146
"reduce_names_of_multiple_nodes",
3247
"to_list",
48+
"unwrap_task_function",
3349
]
3450

3551

@@ -146,3 +162,13 @@ def convert_to_enum(value: Any, enum: type[Enum]) -> Enum:
146162
values = [e.value for e in enum]
147163
msg = f"Value {value!r} is not a valid {enum!r}. Valid values are {values}."
148164
raise ValueError(msg) from None
165+
166+
167+
def unwrap_task_function(obj: Any) -> Callable[..., Any]:
168+
"""Unwrap a task function.
169+
170+
Get the underlying function to avoid having different states of the function, e.g.
171+
due to pytask_meta, in different layers of the wrapping.
172+
173+
"""
174+
return inspect.unwrap(obj, stop=lambda x: isinstance(x, CoiledFunction))

src/_pytask/task_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from _pytask.mark import Mark
1717
from _pytask.models import CollectionMetadata
1818
from _pytask.shared import find_duplicates
19+
from _pytask.shared import unwrap_task_function
1920
from _pytask.typing import is_task_function
2021

2122
if TYPE_CHECKING:
@@ -117,7 +118,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
117118
)
118119
raise ValueError(msg)
119120

120-
unwrapped = inspect.unwrap(func)
121+
unwrapped = unwrap_task_function(func)
121122

122123
# We do not allow builtins as functions because we would need to use
123124
# ``inspect.stack`` to infer their caller location and they are unable to carry
@@ -145,7 +146,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
145146
unwrapped.pytask_meta.produces = produces
146147
unwrapped.pytask_meta.after = parsed_after
147148
else:
148-
unwrapped.pytask_meta = CollectionMetadata(
149+
unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined]
149150
after=parsed_after,
150151
is_generator=is_generator,
151152
id_=id,

tests/test_shared.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import functools
34
import textwrap
45
from contextlib import ExitStack as does_not_raise # noqa: N813
56

67
import pytest
78
from _pytask.shared import convert_to_enum
89
from _pytask.shared import find_duplicates
10+
from _pytask.shared import unwrap_task_function
911
from pytask import ExitCode
1012
from pytask import ShowCapture
1113
from pytask import build
@@ -49,3 +51,33 @@ def test_convert_to_enum(value, enum, expectation, expected):
4951
with expectation:
5052
result = convert_to_enum(value, enum)
5153
assert result == expected
54+
55+
56+
@pytest.mark.unit()
57+
def test_unwrap_task_function():
58+
def task():
59+
pass
60+
61+
# partialed functions are only unwrapped after wraps.
62+
partialed = functools.wraps(task)(functools.partial(task))
63+
assert unwrap_task_function(partialed) is task
64+
65+
partialed = functools.partial(task)
66+
assert unwrap_task_function(partialed) is partialed
67+
68+
def decorator(func):
69+
@functools.wraps(func)
70+
def wrapper():
71+
return func()
72+
73+
return wrapper
74+
75+
decorated = decorator(task)
76+
assert unwrap_task_function(decorated) is task
77+
78+
from _pytask.shared import CoiledFunction
79+
80+
coiled_function = functools.wraps(task)(
81+
CoiledFunction(function=task, cluster_kwargs={}, environ={}, keepalive=None)
82+
)
83+
assert unwrap_task_function(coiled_function) is coiled_function

0 commit comments

Comments
 (0)