Skip to content

Commit fd63a48

Browse files
committed
Fix errors and simplify collection.
1 parent 3e5de89 commit fd63a48

File tree

4 files changed

+84
-73
lines changed

4 files changed

+84
-73
lines changed

src/_pytask/collect_utils.py

Lines changed: 77 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""This module provides utility functions for :mod:`_pytask.collect`."""
22
from __future__ import annotations
33

4-
import inspect
54
import itertools
65
import uuid
76
from pathlib import Path
@@ -79,9 +78,15 @@ def parse_nodes(
7978
session: Session, path: Path, name: str, obj: Any, parser: Callable[..., Any]
8079
) -> Any:
8180
"""Parse nodes from object."""
81+
arg_name = parser.__name__
8282
objects = _extract_nodes_from_function_markers(obj, parser)
83-
nodes = _convert_objects_to_node_dictionary(objects, parser.__name__)
84-
nodes = tree_map(lambda x: _collect_old_dependencies(session, path, name, x), nodes)
83+
nodes = _convert_objects_to_node_dictionary(objects, arg_name)
84+
nodes = tree_map(
85+
lambda x: _collect_decorator_nodes(
86+
session, path, name, NodeInfo(arg_name, (), x)
87+
),
88+
nodes,
89+
)
8590
return nodes
8691

8792

@@ -211,10 +216,23 @@ def _merge_dictionaries(list_of_dicts: list[dict[Any, Any]]) -> dict[Any, Any]:
211216
return out
212217

213218

214-
def parse_dependencies_from_task_function(
219+
_ERROR_MULTIPLE_DEPENDENCY_DEFINITIONS = """"Dependencies are defined via \
220+
'@pytask.mark.depends_on' and as default arguments for the argument 'depends_on'. Use \
221+
only one way and not both.
222+
223+
Hint: You do not need to use 'depends_on' since pytask v0.4. Every function argument \
224+
that is not a product is treated as a dependency. Read more about dependencies in the \
225+
documentation: https://tinyurl.com/yrezszr4.
226+
"""
227+
228+
229+
def parse_dependencies_from_task_function( # noqa: C901
215230
session: Session, path: Path, name: str, obj: Any
216231
) -> dict[str, Any]:
217232
"""Parse dependencies from task function."""
233+
has_depends_on_decorator = False
234+
has_depends_on_argument = False
235+
218236
if has_mark(obj, "depends_on"):
219237
nodes = parse_nodes(session, path, name, obj, depends_on)
220238
return {"depends_on": nodes}
@@ -224,14 +242,30 @@ def parse_dependencies_from_task_function(
224242
kwargs = {**signature_defaults, **task_kwargs}
225243
kwargs.pop("produces", None)
226244

245+
dependencies = {}
246+
# Parse products from task decorated with @task and that uses produces.
247+
if "depends_on" in kwargs:
248+
has_depends_on_argument = True
249+
dependencies["depends_on"] = tree_map(
250+
lambda x: _collect_decorator_nodes(
251+
session, path, name, NodeInfo(arg_name="depends_on", path=(), value=x)
252+
),
253+
kwargs["depends_on"],
254+
)
255+
256+
if has_depends_on_decorator and has_depends_on_argument:
257+
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)
258+
227259
parameters_with_product_annot = _find_args_with_product_annotation(obj)
228260
parameters_with_node_annot = _find_args_with_node_annotation(obj)
229261

230-
dependencies = {}
231262
for parameter_name, value in kwargs.items():
232263
if parameter_name in parameters_with_product_annot:
233264
continue
234265

266+
if parameter_name == "depends_on":
267+
continue
268+
235269
if parameter_name in parameters_with_node_annot:
236270

237271
def _evolve(x: Any) -> Any:
@@ -316,8 +350,7 @@ def parse_products_from_task_function(
316350
317351
"""
318352
has_produces_decorator = False
319-
has_task_decorator = False
320-
has_signature_default = False
353+
has_produces_argument = False
321354
has_annotation = False
322355
out = {}
323356

@@ -333,6 +366,7 @@ def parse_products_from_task_function(
333366

334367
# Parse products from task decorated with @task and that uses produces.
335368
if "produces" in kwargs:
369+
has_produces_argument = True
336370
collected_products = tree_map_with_path(
337371
lambda p, x: _collect_product(
338372
session,
@@ -345,52 +379,26 @@ def parse_products_from_task_function(
345379
)
346380
out = {"produces": collected_products}
347381

348-
parameters = inspect.signature(obj).parameters
349-
350-
# Parse products from default arguments
351-
if not has_mark(obj, "task") and "produces" in parameters:
352-
parameter = parameters["produces"]
353-
if parameter.default is not parameter.empty:
354-
has_signature_default = True
355-
# Use _collect_new_node to not collect strings.
356-
collected_products = tree_map_with_path(
357-
lambda p, x: _collect_product(
358-
session,
359-
path,
360-
name,
361-
NodeInfo(arg_name="produces", path=p, value=x),
362-
is_string_allowed=False,
363-
),
364-
parameter.default,
365-
)
366-
out = {"produces": collected_products}
367-
368382
parameters_with_product_annot = _find_args_with_product_annotation(obj)
369383
if parameters_with_product_annot:
370384
has_annotation = True
371385
for parameter_name in parameters_with_product_annot:
372-
# Use _collect_new_node to not collect strings.
373-
collected_products = tree_map_with_path(
374-
lambda p, x: _collect_product(
375-
session,
376-
path,
377-
name,
378-
NodeInfo(parameter_name, p, x), # noqa: B023
379-
is_string_allowed=False,
380-
),
381-
kwargs[parameter_name],
382-
)
383-
out = {parameter_name: collected_products}
386+
if parameter_name in kwargs:
387+
# Use _collect_new_node to not collect strings.
388+
collected_products = tree_map_with_path(
389+
lambda p, x: _collect_product(
390+
session,
391+
path,
392+
name,
393+
NodeInfo(parameter_name, p, x), # noqa: B023
394+
is_string_allowed=False,
395+
),
396+
kwargs[parameter_name],
397+
)
398+
out = {parameter_name: collected_products}
384399

385400
if (
386-
sum(
387-
(
388-
has_produces_decorator,
389-
has_task_decorator,
390-
has_signature_default,
391-
has_annotation,
392-
)
393-
)
401+
sum((has_produces_decorator, has_produces_argument, has_annotation))
394402
>= 2 # noqa: PLR2004
395403
):
396404
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)
@@ -416,8 +424,15 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
416424
return args_with_product_annot
417425

418426

419-
def _collect_old_dependencies(
420-
session: Session, path: Path, name: str, node: str | Path
427+
_ERROR_WRONG_TYPE_DECORATOR = """'@pytask.mark.depends_on', '@pytask.mark.produces', \
428+
and their function arguments can only accept values of type 'str' and 'pathlib.Path' \
429+
or the same values nested in tuples, lists, and dictionaries. Here, {node} has type \
430+
{node_type}.
431+
"""
432+
433+
434+
def _collect_decorator_nodes(
435+
session: Session, path: Path, name: str, node_info: NodeInfo
421436
) -> dict[str, MetaNode]:
422437
"""Collect nodes for a task.
423438
@@ -427,22 +442,26 @@ def _collect_old_dependencies(
427442
If the node could not collected.
428443
429444
"""
445+
node = node_info.value
446+
430447
if not isinstance(node, (str, Path)):
431-
raise ValueError(
432-
"'@pytask.mark.depends_on' and '@pytask.mark.produces' can only accept "
433-
"values of type 'str' and 'pathlib.Path' or the same values nested in "
434-
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
448+
raise NodeNotCollectedError(
449+
_ERROR_WRONG_TYPE_DECORATOR.format(node=node, node_type=type(node))
435450
)
436451

437452
if isinstance(node, str):
438453
node = Path(node)
454+
node_info = node_info._replace(value=node)
439455

440456
collected_node = session.hook.pytask_collect_node(
441-
session=session, path=path, node_info=NodeInfo("produces", (), node)
457+
session=session, path=path, node_info=node_info
442458
)
443459
if collected_node is None:
460+
kind = {"depends_on": "dependency", "produces": "product"}.get(
461+
node_info.arg_name
462+
)
444463
raise NodeNotCollectedError(
445-
f"{node!r} cannot be parsed as a dependency for task {name!r} in {path!r}."
464+
f"{node!r} cannot be parsed as a {kind} for task {name!r} in {path!r}."
446465
)
447466

448467
return collected_node
@@ -462,7 +481,7 @@ def _collect_dependencies(
462481
node = node_info.value
463482

464483
collected_node = session.hook.pytask_collect_node(
465-
session=session, path=path, node_info=node_info, node=node
484+
session=session, path=path, node_info=node_info
466485
)
467486
if collected_node is None:
468487
raise NodeNotCollectedError(
@@ -509,9 +528,10 @@ def _collect_product(
509528

510529
if isinstance(node, str):
511530
node = Path(node)
531+
node_info = node_info._replace(value=node)
512532

513533
collected_node = session.hook.pytask_collect_node(
514-
session=session, path=path, node_info=node_info, node=node
534+
session=session, path=path, node_info=node_info
515535
)
516536
if collected_node is None:
517537
raise NodeNotCollectedError(

src/_pytask/task_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,6 @@ def parse_collected_tasks_with_task_marker(
114114
else:
115115
collected_tasks[name] = [i[1] for i in parsed_tasks if i[0] == name][0]
116116

117-
# TODO: Remove when parsing dependencies and products from all arguments is
118-
# implemented.
119-
for task in collected_tasks.values():
120-
meta = task.pytask_meta # type: ignore[attr-defined]
121-
for marker_name in ("depends_on", "produces"):
122-
if marker_name in meta.kwargs:
123-
value = meta.kwargs.pop(marker_name)
124-
meta.markers.append(Mark(marker_name, (value,), {}))
125-
126117
return collected_tasks
127118

128119

tests/test_collect.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
from _pytask.collect import _find_shortest_uniquely_identifiable_name_for_tasks
1010
from _pytask.collect import pytask_collect_node
11+
from _pytask.exceptions import NodeNotCollectedError
1112
from _pytask.models import NodeInfo
1213
from pytask import cli
1314
from pytask import CollectionOutcome
@@ -53,7 +54,7 @@ def task_with_non_path_dependency():
5354
assert session.exit_code == ExitCode.COLLECTION_FAILED
5455
assert session.collection_reports[0].outcome == CollectionOutcome.FAIL
5556
exc_info = session.collection_reports[0].exc_info
56-
assert isinstance(exc_info[1], ValueError)
57+
assert isinstance(exc_info[1], NodeNotCollectedError)
5758
assert "'@pytask.mark.depends_on'" in str(exc_info[1])
5859

5960

@@ -74,7 +75,7 @@ def task_with_non_path_dependency():
7475
assert session.exit_code == ExitCode.COLLECTION_FAILED
7576
assert session.collection_reports[0].outcome == CollectionOutcome.FAIL
7677
exc_info = session.collection_reports[0].exc_info
77-
assert isinstance(exc_info[1], ValueError)
78+
assert isinstance(exc_info[1], NodeNotCollectedError)
7879
assert "'@pytask.mark.depends_on'" in str(exc_info[1])
7980

8081

@@ -326,7 +327,7 @@ def task_write_text(produces="out.txt"):
326327

327328

328329
@pytest.mark.end_to_end()
329-
def test_collect_string_product_as_function_default_fails(tmp_path):
330+
def test_collect_string_product_as_function_default(tmp_path):
330331
source = """
331332
import pytask
332333
@@ -335,9 +336,8 @@ def task_write_text(produces="out.txt"):
335336
"""
336337
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
337338
session = main({"paths": tmp_path})
338-
report = session.collection_reports[0]
339-
assert report.outcome == CollectionOutcome.FAIL
340-
assert "If you use 'produces'" in str(report.exc_info[1])
339+
assert session.exit_code == ExitCode.OK
340+
assert tmp_path.joinpath("out.txt").exists()
341341

342342

343343
@pytest.mark.end_to_end()

tests/test_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def task_example(
478478
path_out.write_text(path_in.read_text() + " " + arg)
479479
"""
480480
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
481-
tmp_path.joinpath("input.txt").write_text("hello")
481+
tmp_path.joinpath("input.txt").write_text("Hello")
482482

483483
result = runner.invoke(cli, [tmp_path.as_posix()])
484484
assert result.exit_code == ExitCode.OK

0 commit comments

Comments
 (0)