Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect all tasks in COLLECTED_TASKS. #529

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`524` improves some linting and formatting rules.
- {pull}`525` enables pytask to work with remote files using universal_pathlib.
- {pull}`528` improves the codecov setup and coverage.
- {pull}`529` adds a collection for all tasks cached in
{obj}`~_pytask.task_utils.COLLECTED_TASKS` regardless of location.

## 0.4.4 - 2023-12-04

Expand Down
59 changes: 27 additions & 32 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from _pytask.console import create_summary_panel
from _pytask.console import get_file
from _pytask.exceptions import CollectionError
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import get_all_marks
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
Expand All @@ -39,6 +38,8 @@
from _pytask.reports import CollectionReport
from _pytask.shared import find_duplicates
from _pytask.task_utils import COLLECTED_TASKS
from _pytask.task_utils import parse_collected_tasks_with_task_marker
from _pytask.task_utils import raise_error_when_task_functions_are_duplicated
from _pytask.task_utils import task as task_decorator
from _pytask.typing import is_task_function
from rich.text import Text
Expand Down Expand Up @@ -132,43 +133,37 @@ def _collect_from_tasks(session: Session) -> None:
session.collection_reports.append(report)


_FAILED_COLLECTING_TASK = """\
Failed to collect task '{name}'{path_desc}.
def _collect_not_collected_tasks(session: Session) -> None:
"""Collect tasks that are not collected yet.

This can happen when the task function is defined in another module, imported to a \
task module and wrapped with the '@task' decorator.
If task functions are imported from another module and then wrapped with ``@task``,
they would usually not be collected since their module is the imported module and
not the task module. This function collects these tasks and all other cached in
``COLLECTED_TASKS``.

To collect this task correctly, wrap the imported function in a lambda expression like
"""
for path in list(COLLECTED_TASKS):
tasks = COLLECTED_TASKS.pop(path)

task(...)(lambda **x: imported_function(**x)).
"""
# Remove tasks from the global to avoid re-collection if programmatic interface
# is used.
raise_error_when_task_functions_are_duplicated(tasks)

name_to_function = parse_collected_tasks_with_task_marker(tasks)

def _collect_not_collected_tasks(session: Session) -> None:
"""Collect tasks that are not collected yet and create failed reports."""
for path in list(COLLECTED_TASKS):
tasks = COLLECTED_TASKS.pop(path)
for task in tasks:
name = task.pytask_meta.name # type: ignore[attr-defined]
node: PTask
if path:
node = Task(base_name=name, path=path, function=task)
path_desc = f" in '{path}'"
else:
node = TaskWithoutPath(name=name, function=task)
path_desc = ""
report = CollectionReport(
outcome=CollectionOutcome.FAIL,
node=node,
exc_info=(
NodeNotCollectedError,
NodeNotCollectedError(
_FAILED_COLLECTING_TASK.format(name=name, path_desc=path_desc)
),
None,
),
collected_reports = []
for name, function in name_to_function.items():
report = session.hook.pytask_collect_task_protocol(
session=session,
reports=session.collection_reports,
path=path,
name=name,
obj=function,
)
session.collection_reports.append(report)
if report is not None:
collected_reports.append(report)

session.collection_reports.extend(collected_reports)


@hookimpl
Expand Down
3 changes: 2 additions & 1 deletion src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class Task(PTaskWithPath):
base_name
The base name of the task.
path
Path to the file where the task was defined.
Path to the file where the task was defined. It is used to collect the path and
for displaying information.
function
The task function.
name
Expand Down
32 changes: 2 additions & 30 deletions src/_pytask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from __future__ import annotations

from typing import Any
from typing import Callable
from typing import TYPE_CHECKING

from _pytask.config import hookimpl
from _pytask.console import format_strings_as_flat_tree
from _pytask.shared import find_duplicates
from _pytask.task_utils import COLLECTED_TASKS
from _pytask.task_utils import parse_collected_tasks_with_task_marker
from _pytask.task_utils import raise_error_when_task_functions_are_duplicated

if TYPE_CHECKING:
from _pytask.reports import CollectionReport
Expand Down Expand Up @@ -40,7 +38,7 @@ def pytask_collect_file(
# is used.
tasks = COLLECTED_TASKS.pop(path)

_raise_error_when_task_functions_are_duplicated(tasks)
raise_error_when_task_functions_are_duplicated(tasks)

name_to_function = parse_collected_tasks_with_task_marker(tasks)

Expand All @@ -54,29 +52,3 @@ def pytask_collect_file(

return collected_reports
return None


def _raise_error_when_task_functions_are_duplicated(
tasks: list[Callable[..., Any]],
) -> None:
"""Raise error when task functions are duplicated.

When task functions are created outside the loop body, every wrapped version of the

"""
duplicates = find_duplicates(tasks)
if not duplicates:
return

strings = [
f"function_name={func.pytask_meta.name}, id={func.pytask_meta.id_}"
for func in duplicates
]
flat_tree = format_strings_as_flat_tree(strings, "Duplicated tasks")
msg = (
"There are some duplicates among the repeated tasks. It happens when you define"
"the task function outside the loop body and merely wrap in the loop body with "
"the 'task(...)(func)' decorator. As a workaround, wrap the task function in "
f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}"
)
raise ValueError(msg)
80 changes: 69 additions & 11 deletions src/_pytask/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,46 @@

import inspect
from collections import defaultdict
from pathlib import Path
from types import BuiltinFunctionType
from typing import Any
from typing import Callable
from typing import TYPE_CHECKING

import attrs
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import get_file
from _pytask.mark import Mark
from _pytask.models import CollectionMetadata
from _pytask.shared import find_duplicates
from _pytask.typing import is_task_function

if TYPE_CHECKING:
from pathlib import Path


__all__ = [
"COLLECTED_TASKS",
"parse_collected_tasks_with_task_marker",
"parse_keyword_arguments_from_signature_defaults",
"raise_error_when_task_functions_are_duplicated",
"task",
]


COLLECTED_TASKS: dict[Path | None, list[Callable[..., Any]]] = defaultdict(list)
"""A container for collecting tasks.

Tasks marked by the ``@pytask.mark.task`` decorator can be generated in a loop where one
iteration overwrites the previous task. To retrieve the tasks later, use this dictionary
mapping from paths of modules to a list of tasks per module.
Tasks marked by the ``@task`` decorator can be generated in a loop where one iteration
overwrites the previous task. To retrieve the tasks later, use this dictionary mapping
from paths of modules to a list of tasks per module.

"""


def task(
def task( # noqa: PLR0913
name: str | None = None,
*,
after: str | Callable[..., Any] | list[Callable[..., Any]] | None = None,
id: str | None = None, # noqa: A002
kwargs: dict[Any, Any] | None = None,
module: Path | str | None = None,
produces: Any | None = None,
) -> Callable[..., Callable[..., Any]]:
"""Decorate a task function.
Expand All @@ -69,6 +69,27 @@ def task(
Use a dictionary to pass any keyword arguments to the task function which can be
dependencies or products of the task. Read :ref:`task-kwargs` for more
information.
module
An experimental and cosmetic feature.

By default, the module is the location where the task function is defined. When
a task function is imported in a task module and wrapped with
:func:`@task <pytask.task>`, this argument allows to set the path to the task
module instead of the imported module.

Relative paths are resolved relative to the current working directory.

.. code-block:: python

from pytask import task
from module import function

# Location will be 'module.py'.
@task()(function)

# Location will be this module, e.g., 'task_module.py'.
@task(module=__file__)(function)

produces
Use this argument if you want to parse the return of the task function as a
product, but you cannot annotate the return of the function. See :doc:`this
Expand All @@ -81,9 +102,11 @@ def task(

.. code-block:: python

from typing import Annotated from pytask import task
from typing import Annotated
from pytask import task

@task def create_text_file() -> Annotated[str, Path("file.txt")]:
@task
def create_text_file() -> Annotated[str, Path("file.txt")]:
return "Hello, World!"

"""
Expand Down Expand Up @@ -112,7 +135,16 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
)
raise NotImplementedError(msg)

path = get_file(unwrapped)
if not module:
path = get_file(unwrapped)
else:
path = Path(module).resolve()
if not path.exists():
msg = (
f"Module '{path}' does not exist, but is set as "
"'@task(module=...)'."
)
raise ValueError(msg)

parsed_kwargs = {} if kwargs is None else kwargs
parsed_name = name if isinstance(name, str) else func.__name__
Expand Down Expand Up @@ -332,3 +364,29 @@ def _arg_value_to_id_component(
else:
id_component = arg_name + str(i)
return id_component


def raise_error_when_task_functions_are_duplicated(
tasks: list[Callable[..., Any]],
) -> None:
"""Raise error when task functions are duplicated.

When task functions are created outside the loop body, every wrapped version of the

"""
duplicates = find_duplicates(tasks)
if not duplicates:
return

strings = [
f"function_name={func.pytask_meta.name}, id={func.pytask_meta.id_}"
for func in duplicates
]
flat_tree = format_strings_as_flat_tree(strings, "Duplicated tasks")
msg = (
"There are some duplicates among the repeated tasks. It happens when you define"
"the task function outside the loop body and merely wrap in the loop body with "
"the 'task(...)(func)' decorator. As a workaround, wrap the task function in "
f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}"
)
raise ValueError(msg)
24 changes: 19 additions & 5 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,15 @@ def test_raise_error_with_builtin_function_as_task(runner, tmp_path):
assert "Builtin functions cannot be wrapped" in result.output


def test_task_function_in_another_module(runner, tmp_path):
@pytest.mark.parametrize("module", [None, "__file__", "'a'"])
def test_task_function_in_another_module(runner, tmp_path, module):
source = """
def func():
return "Hello, World!"
"""
tmp_path.joinpath("module.py").write_text(textwrap.dedent(source))

source = """
source = f"""
from pytask import task
from pathlib import Path
from _pytask.path import import_path
Expand All @@ -715,10 +716,23 @@ def func():
module = import_path(_ROOT_PATH / "module.py", _ROOT_PATH)
name_to_obj = dict(inspect.getmembers(module))

task(produces=Path("out.txt"))(name_to_obj["func"])
task(produces=Path("out.txt"), module={module})(name_to_obj["func"])
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "1 Failed" in result.output

if module == "'a'":
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "ValueError: Module" in result.output

else:
assert result.exit_code == ExitCode.OK
assert "1 Succeeded" in result.output
assert tmp_path.joinpath("out.txt").read_text() == "Hello, World!"

# Check whether the module is overwritten or not.
if module:
assert "task_example.py" in result.output
else:
assert "module.py" in result.output