Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error message for not collected tasks with @task decorator. #521

Merged
merged 8 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
(fixes #514). It also warns if the path is configured as a string and not a list of
strings.
- {pull}`519` raises an error when builtin functions are wrapped with
{func}`~pytask.task`. Closes {issue}`512`.pull
- {pull}`521` raises an error message when imported functions are wrapped with
{func}`@task <pytask.task>` in a task module. Fixes {issue}`513`.
{func}`~pytask.task`. Closes {issue}`512`.
- {pull}`522` improves the issue templates.
- {pull}`523` refactors `_pytask.console._get_file`.
Expand Down
45 changes: 45 additions & 0 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from _pytask.console import create_summary_panel
from _pytask.console import get_file
from _pytask.exceptions import CollectionError
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import get_all_marks
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
Expand All @@ -37,6 +38,7 @@
from _pytask.path import shorten_path
from _pytask.reports import CollectionReport
from _pytask.shared import find_duplicates
from _pytask.task_utils import COLLECTED_TASKS
from _pytask.task_utils import task as task_decorator
from _pytask.typing import is_task_function
from rich.text import Text
Expand All @@ -61,6 +63,7 @@ def pytask_collect(session: Session) -> bool:

_collect_from_paths(session)
_collect_from_tasks(session)
_collect_not_collected_tasks(session)

session.tasks.extend(
i.node
Expand Down Expand Up @@ -108,6 +111,9 @@ def _collect_from_tasks(session: Session) -> None:
path = get_file(raw_task)
name = raw_task.pytask_meta.name

if has_mark(raw_task, "task"):
COLLECTED_TASKS[path].remove(raw_task)

# When a task is not a callable, it can be anything or a PTask. Set arbitrary
# values and it will pass without errors and not collected.
else:
Expand All @@ -126,6 +132,45 @@ def _collect_from_tasks(session: Session) -> None:
session.collection_reports.append(report)


_FAILED_COLLECTING_TASK = """\
Failed to collect task '{name}'{path_desc}.

This can happen when the task function is defined in another module, imported to a \
task module and wrapped with the '@task' decorator.

To collect this task correctly, wrap the imported function in a lambda expression like

task(...)(lambda **x: imported_function(**x)).
"""


def _collect_not_collected_tasks(session: Session) -> None:
"""Collect tasks that are not collected yet and create failed reports."""
for path in list(COLLECTED_TASKS):
tasks = COLLECTED_TASKS.pop(path)
for task in tasks:
name = task.pytask_meta.name # type: ignore[attr-defined]
node: PTask
if path:
node = Task(base_name=name, path=path, function=task)
path_desc = f" in '{path}'"
else:
node = TaskWithoutPath(name=name, function=task)
path_desc = ""
report = CollectionReport(
outcome=CollectionOutcome.FAIL,
node=node,
exc_info=(
NodeNotCollectedError,
NodeNotCollectedError(
_FAILED_COLLECTING_TASK.format(name=name, path_desc=path_desc)
),
None,
),
)
session.collection_reports.append(report)


@hookimpl
def pytask_ignore_collect(path: Path, config: dict[str, Any]) -> bool:
"""Ignore a path during the collection."""
Expand Down
5 changes: 0 additions & 5 deletions src/_pytask/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,13 @@ def get_file( # noqa: PLR0911
return get_file(function.__wrapped__)
source_file = inspect.getsourcefile(function)
if source_file:
# Handle functions defined in the REPL.
if "<stdin>" in source_file:
return None
# Handle lambda functions.
if "<string>" in source_file:
try:
return Path(function.__globals__["__file__"]).absolute().resolve()
except KeyError:
return None
# Handle functions defined in Jupyter notebooks.
if "ipykernel" in source_file or "ipython-input" in source_file:
return None
return Path(source_file).absolute().resolve()
return None

Expand Down
3 changes: 2 additions & 1 deletion src/_pytask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _raise_error_when_task_functions_are_duplicated(
msg = (
"There are some duplicates among the repeated tasks. It happens when you define"
"the task function outside the loop body and merely wrap in the loop body with "
f"the '@task(...)' decorator.\n\n{flat_tree}"
"the 'task(...)(func)' decorator. As a workaround, wrap the task function in "
f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}"
)
raise ValueError(msg)
27 changes: 27 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,30 @@ def test_raise_error_with_builtin_function_as_task(runner, tmp_path):
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "Builtin functions cannot be wrapped" in result.output


def test_task_function_in_another_module(runner, tmp_path):
source = """
def func():
return "Hello, World!"
"""
tmp_path.joinpath("module.py").write_text(textwrap.dedent(source))

source = """
from pytask import task
from pathlib import Path
from _pytask.path import import_path
import inspect

_ROOT_PATH = Path(__file__).parent

module = import_path(_ROOT_PATH / "module.py", _ROOT_PATH)
name_to_obj = dict(inspect.getmembers(module))

task(produces=Path("out.txt"))(name_to_obj["func"])
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "1 Failed" in result.output