Skip to content

Commit f0d5994

Browse files
authored
Add error message for not collected tasks with @task decorator. (#521)
1 parent 88a5fb1 commit f0d5994

File tree

5 files changed

+77
-6
lines changed

5 files changed

+77
-6
lines changed

docs/source/changes.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1212
(fixes #514). It also warns if the path is configured as a string and not a list of
1313
strings.
1414
- {pull}`519` raises an error when builtin functions are wrapped with
15+
{func}`~pytask.task`. Closes {issue}`512`.pull
16+
- {pull}`521` raises an error message when imported functions are wrapped with
17+
{func}`@task <pytask.task>` in a task module. Fixes {issue}`513`.
1518
{func}`~pytask.task`. Closes {issue}`512`.
1619
- {pull}`522` improves the issue templates.
1720
- {pull}`523` refactors `_pytask.console._get_file`.

src/_pytask/collect.py

+45
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from _pytask.console import create_summary_panel
2222
from _pytask.console import get_file
2323
from _pytask.exceptions import CollectionError
24+
from _pytask.exceptions import NodeNotCollectedError
2425
from _pytask.mark_utils import get_all_marks
2526
from _pytask.mark_utils import has_mark
2627
from _pytask.node_protocols import PNode
@@ -37,6 +38,7 @@
3738
from _pytask.path import shorten_path
3839
from _pytask.reports import CollectionReport
3940
from _pytask.shared import find_duplicates
41+
from _pytask.task_utils import COLLECTED_TASKS
4042
from _pytask.task_utils import task as task_decorator
4143
from _pytask.typing import is_task_function
4244
from rich.text import Text
@@ -61,6 +63,7 @@ def pytask_collect(session: Session) -> bool:
6163

6264
_collect_from_paths(session)
6365
_collect_from_tasks(session)
66+
_collect_not_collected_tasks(session)
6467

6568
session.tasks.extend(
6669
i.node
@@ -108,6 +111,9 @@ def _collect_from_tasks(session: Session) -> None:
108111
path = get_file(raw_task)
109112
name = raw_task.pytask_meta.name
110113

114+
if has_mark(raw_task, "task"):
115+
COLLECTED_TASKS[path].remove(raw_task)
116+
111117
# When a task is not a callable, it can be anything or a PTask. Set arbitrary
112118
# values and it will pass without errors and not collected.
113119
else:
@@ -126,6 +132,45 @@ def _collect_from_tasks(session: Session) -> None:
126132
session.collection_reports.append(report)
127133

128134

135+
_FAILED_COLLECTING_TASK = """\
136+
Failed to collect task '{name}'{path_desc}.
137+
138+
This can happen when the task function is defined in another module, imported to a \
139+
task module and wrapped with the '@task' decorator.
140+
141+
To collect this task correctly, wrap the imported function in a lambda expression like
142+
143+
task(...)(lambda **x: imported_function(**x)).
144+
"""
145+
146+
147+
def _collect_not_collected_tasks(session: Session) -> None:
148+
"""Collect tasks that are not collected yet and create failed reports."""
149+
for path in list(COLLECTED_TASKS):
150+
tasks = COLLECTED_TASKS.pop(path)
151+
for task in tasks:
152+
name = task.pytask_meta.name # type: ignore[attr-defined]
153+
node: PTask
154+
if path:
155+
node = Task(base_name=name, path=path, function=task)
156+
path_desc = f" in '{path}'"
157+
else:
158+
node = TaskWithoutPath(name=name, function=task)
159+
path_desc = ""
160+
report = CollectionReport(
161+
outcome=CollectionOutcome.FAIL,
162+
node=node,
163+
exc_info=(
164+
NodeNotCollectedError,
165+
NodeNotCollectedError(
166+
_FAILED_COLLECTING_TASK.format(name=name, path_desc=path_desc)
167+
),
168+
None,
169+
),
170+
)
171+
session.collection_reports.append(report)
172+
173+
129174
@hookimpl
130175
def pytask_ignore_collect(path: Path, config: dict[str, Any]) -> bool:
131176
"""Ignore a path during the collection."""

src/_pytask/console.py

-5
Original file line numberDiff line numberDiff line change
@@ -225,18 +225,13 @@ def get_file( # noqa: PLR0911
225225
return get_file(function.__wrapped__)
226226
source_file = inspect.getsourcefile(function)
227227
if source_file:
228-
# Handle functions defined in the REPL.
229228
if "<stdin>" in source_file:
230229
return None
231-
# Handle lambda functions.
232230
if "<string>" in source_file:
233231
try:
234232
return Path(function.__globals__["__file__"]).absolute().resolve()
235233
except KeyError:
236234
return None
237-
# Handle functions defined in Jupyter notebooks.
238-
if "ipykernel" in source_file or "ipython-input" in source_file:
239-
return None
240235
return Path(source_file).absolute().resolve()
241236
return None
242237

src/_pytask/task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _raise_error_when_task_functions_are_duplicated(
7676
msg = (
7777
"There are some duplicates among the repeated tasks. It happens when you define"
7878
"the task function outside the loop body and merely wrap in the loop body with "
79-
f"the '@task(...)' decorator.\n\n{flat_tree}"
79+
"the 'task(...)(func)' decorator. As a workaround, wrap the task function in "
80+
f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}"
8081
)
8182
raise ValueError(msg)

tests/test_task.py

+27
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,30 @@ def test_raise_error_with_builtin_function_as_task(runner, tmp_path):
679679
result = runner.invoke(cli, [tmp_path.as_posix()])
680680
assert result.exit_code == ExitCode.COLLECTION_FAILED
681681
assert "Builtin functions cannot be wrapped" in result.output
682+
683+
684+
def test_task_function_in_another_module(runner, tmp_path):
685+
source = """
686+
def func():
687+
return "Hello, World!"
688+
"""
689+
tmp_path.joinpath("module.py").write_text(textwrap.dedent(source))
690+
691+
source = """
692+
from pytask import task
693+
from pathlib import Path
694+
from _pytask.path import import_path
695+
import inspect
696+
697+
_ROOT_PATH = Path(__file__).parent
698+
699+
module = import_path(_ROOT_PATH / "module.py", _ROOT_PATH)
700+
name_to_obj = dict(inspect.getmembers(module))
701+
702+
task(produces=Path("out.txt"))(name_to_obj["func"])
703+
"""
704+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
705+
706+
result = runner.invoke(cli, [tmp_path.as_posix()])
707+
assert result.exit_code == ExitCode.COLLECTION_FAILED
708+
assert "1 Failed" in result.output

0 commit comments

Comments
 (0)