Skip to content

Commit 9cd23e4

Browse files
authored
Fix passing repeated tasks to the functional interface. (#719)
1 parent bdc9120 commit 9cd23e4

File tree

5 files changed

+424
-25
lines changed

5 files changed

+424
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1515
- {pull}`709` add uv pre-commit check.
1616
- {pull}`713` removes uv as a test dependency. Closes {issue}`712`. Thanks to {user}`erooke`!
1717
- {pull}`718` fixes {issue}`717` by properly parsing the `pdbcls` configuration option from config files. Thanks to {user}`MImmesberger` for the report!
18+
- {pull}`719` fixes repeated tasks with the same function name in the programmatic interface to ensure all tasks execute correctly.
1819

1920
## 0.5.5 - 2025-07-25
2021

docs/source/how_to_guides/functional_interface.ipynb

Lines changed: 252 additions & 17 deletions
Large diffs are not rendered by default.

src/_pytask/collect.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from _pytask.shared import to_list
5151
from _pytask.shared import unwrap_task_function
5252
from _pytask.task_utils import COLLECTED_TASKS
53+
from _pytask.task_utils import parse_collected_tasks_with_task_marker
5354
from _pytask.task_utils import task as task_decorator
5455
from _pytask.typing import is_task_function
5556

@@ -108,6 +109,10 @@ def _collect_from_paths(session: Session) -> None:
108109

109110
def _collect_from_tasks(session: Session) -> None:
110111
"""Collect tasks from user provided tasks via the functional interface."""
112+
# First pass: collect and group tasks by path
113+
tasks_by_path: dict[Path | None, list[Any]] = {}
114+
non_task_objects = []
115+
111116
for raw_task in to_list(session.config.get("tasks", ())):
112117
if is_task_function(raw_task):
113118
if not hasattr(raw_task, "pytask_meta"):
@@ -117,18 +122,31 @@ def _collect_from_tasks(session: Session) -> None:
117122
name = raw_task.pytask_meta.name
118123

119124
if has_mark(raw_task, "task"):
120-
# When tasks with @task are passed to the programmatic interface multiple
121-
# times, they are deleted from ``COLLECTED_TASKS`` in the first iteration
122-
# and are missing in the later. See #625.
125+
# When tasks with @task are passed to the programmatic interface
126+
# multiple times, they are deleted from ``COLLECTED_TASKS`` in the first
127+
# iteration and are missing in the later. See #625.
123128
with suppress(ValueError):
124129
COLLECTED_TASKS[path].remove(raw_task)
125130

126-
# When a task is not a callable, it can be anything or a PTask. Set arbitrary
127-
# values and it will pass without errors and not collected.
131+
# Group tasks by path for parametrization
132+
if path not in tasks_by_path:
133+
tasks_by_path[path] = []
134+
tasks_by_path[path].append(raw_task)
128135
else:
129-
name = ""
130-
path = None
131-
136+
# When a task is not a callable, it can be anything or a PTask. Set
137+
# arbitrary values and it will pass without errors and not collected.
138+
non_task_objects.append((raw_task, None, ""))
139+
140+
# Second pass: apply parametrization to grouped tasks
141+
parametrized_tasks = []
142+
for path, tasks in tasks_by_path.items():
143+
# Apply the same parametrization logic as file-based collection
144+
name_to_function = parse_collected_tasks_with_task_marker(tasks)
145+
for name, function in name_to_function.items():
146+
parametrized_tasks.append((function, path, name))
147+
148+
# Third pass: collect all tasks
149+
for raw_task, path, name in parametrized_tasks + non_task_objects:
132150
report = session.hook.pytask_collect_task_protocol(
133151
session=session,
134152
reports=session.collection_reports,

tests/test_execute.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,64 @@ def test_pass_non_task_to_functional_api_that_are_ignored():
660660
assert len(session.tasks) == 0
661661

662662

663+
@pytest.mark.skipif(
664+
sys.platform == "win32" and os.environ.get("CI") == "true",
665+
reason="Windows does not pick up the right Python interpreter.",
666+
)
667+
def test_repeated_tasks_via_functional_interface(tmp_path):
668+
"""Test that repeated tasks with the same function name work correctly.
669+
670+
This test ensures that when multiple tasks with the same function name are passed
671+
to pytask.build(), they all get unique IDs and execute correctly, similar to how
672+
file-based collection handles repeated tasks.
673+
"""
674+
source = """
675+
from pathlib import Path
676+
from typing import Annotated
677+
from pytask import Product, task, build, ExitCode
678+
import sys
679+
680+
# Create repeated tasks with the same function name
681+
tasks = []
682+
for i in range(3):
683+
def create_data(
684+
value: int = i * 10,
685+
produces: Annotated[Path, Product] = Path(f"output_{i}.txt")
686+
) -> None:
687+
'''Generate data based on a value.'''
688+
produces.write_text(str(value))
689+
690+
tasks.append(create_data)
691+
692+
if __name__ == "__main__":
693+
session = build(tasks=tasks)
694+
695+
# Verify all tasks were collected and executed
696+
assert session.exit_code == ExitCode.OK, f"Exit code: {session.exit_code}"
697+
assert len(session.tasks) == 3, f"Expected 3 tasks, got {len(session.tasks)}"
698+
assert len(session.execution_reports) == 3
699+
700+
# Verify each task executed and produced the correct output
701+
assert Path("output_0.txt").read_text() == "0"
702+
assert Path("output_1.txt").read_text() == "10"
703+
assert Path("output_2.txt").read_text() == "20"
704+
705+
# Verify tasks have unique names with repeated task IDs
706+
task_names = [task.name for task in session.tasks]
707+
assert len(task_names) == len(set(task_names)), "Task names should be unique"
708+
assert all("create_data[" in name for name in task_names), \\
709+
f"Task names should contain repeated task IDs: {task_names}"
710+
711+
sys.exit(session.exit_code)
712+
"""
713+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
714+
result = run_in_subprocess(
715+
(sys.executable, tmp_path.joinpath("task_module.py").as_posix()),
716+
cwd=tmp_path,
717+
)
718+
assert result.exit_code == ExitCode.OK
719+
720+
663721
def test_multiple_product_annotations(runner, tmp_path):
664722
source = """
665723
from pytask import Product
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from pathlib import Path\n",
11+
"from typing import Annotated\n",
12+
"\n",
13+
"import pytask\n",
14+
"from pytask import ExitCode\n",
15+
"from pytask import Product"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"id": "1",
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"# Create repeated tasks with the same function name\n",
26+
"tasks = []\n",
27+
"for i in range(3):\n",
28+
"\n",
29+
" def create_data(\n",
30+
" value: int = i * 10,\n",
31+
" produces: Annotated[Path, Product] = Path(f\"data_{i}.txt\"),\n",
32+
" ):\n",
33+
" \"\"\"Generate data based on a value.\"\"\"\n",
34+
" produces.write_text(str(value))\n",
35+
"\n",
36+
" tasks.append(create_data)"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"id": "2",
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"# Test that all tasks execute correctly\n",
47+
"session = pytask.build(tasks=tasks)\n",
48+
"assert session.exit_code == ExitCode.OK\n",
49+
"assert len(session.tasks) == 3, f\"Expected 3 tasks, got {len(session.tasks)}\"\n",
50+
"assert len(session.execution_reports) == 3, (\n",
51+
" f\"Expected 3 execution reports, got {len(session.execution_reports)}\"\n",
52+
")\n",
53+
"\n",
54+
"# Verify each file was created with the correct content\n",
55+
"assert Path(\"data_0.txt\").read_text() == \"0\"\n",
56+
"assert Path(\"data_1.txt\").read_text() == \"10\"\n",
57+
"assert Path(\"data_2.txt\").read_text() == \"20\"\n",
58+
"\n",
59+
"# Clean up\n",
60+
"Path(\"data_0.txt\").unlink()\n",
61+
"Path(\"data_1.txt\").unlink()\n",
62+
"Path(\"data_2.txt\").unlink()"
63+
]
64+
}
65+
],
66+
"metadata": {
67+
"kernelspec": {
68+
"display_name": ".venv",
69+
"language": "python",
70+
"name": "python3"
71+
},
72+
"language_info": {
73+
"codemirror_mode": {
74+
"name": "ipython",
75+
"version": 3
76+
},
77+
"file_extension": ".py",
78+
"mimetype": "text/x-python",
79+
"name": "python",
80+
"nbconvert_exporter": "python",
81+
"pygments_lexer": "ipython3",
82+
"version": "3.12.12"
83+
}
84+
},
85+
"nbformat": 4,
86+
"nbformat_minor": 5
87+
}

0 commit comments

Comments
 (0)