Skip to content

Commit 0fceb2a

Browse files
authoredOct 11, 2023
Fix PythonNode when used as return. (#446)
1 parent 7b6f184 commit 0fceb2a

File tree

8 files changed

+106
-63
lines changed

8 files changed

+106
-63
lines changed
 

‎docs/source/changes.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1010
- {pull}`443` ensures that `PythonNode.name` is always unique by only handling it
1111
internally.
1212
- {pull}`444` moves all content of `setup.cfg` to `pyproject.toml`.
13+
- {pull}`446` refactors `create_name_of_python_node` and fixes `PythonNode`s as returns.
1314
- {pull}`447` fixes handling multiple product annotations of a task.
1415

1516
## 0.4.0 - 2023-10-07

‎src/_pytask/collect.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Iterable
1313
from typing import TYPE_CHECKING
1414

15+
from _pytask.collect_utils import create_name_of_python_node
1516
from _pytask.collect_utils import parse_dependencies_from_task_function
1617
from _pytask.collect_utils import parse_products_from_task_function
1718
from _pytask.config import hookimpl
@@ -305,7 +306,7 @@ def pytask_collect_task(
305306

306307
@hookimpl(trylast=True)
307308
def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PNode:
308-
"""Collect a node of a task as a :class:`pytask.nodes.PathNode`.
309+
"""Collect a node of a task as a :class:`pytask.PNode`.
309310
310311
Strings are assumed to be paths. This might be a strict assumption, but since this
311312
hook is executed at last and possible errors will be shown, it seems reasonable and
@@ -325,8 +326,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
325326
node = node_info.value
326327

327328
if isinstance(node, PythonNode):
328-
node_name = _create_name_of_python_node(node_info)
329-
node.name = node_name
329+
node.name = create_name_of_python_node(node_info)
330330
return node
331331

332332
if isinstance(node, PPathNode) and not node.path.is_absolute():
@@ -354,7 +354,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
354354
)
355355
return PathNode.from_path(node)
356356

357-
node_name = _create_name_of_python_node(node_info)
357+
node_name = create_name_of_python_node(node_info)
358358
return PythonNode(value=node, name=node_name)
359359

360360

@@ -494,17 +494,3 @@ def pytask_collect_log(
494494
)
495495

496496
raise CollectionError
497-
498-
499-
def _create_name_of_python_node(node_info: NodeInfo) -> str:
500-
"""Create name of PythonNode."""
501-
prefix = (
502-
node_info.task_path.as_posix() + "::" + node_info.task_name
503-
if node_info.task_path
504-
else node_info.task_name
505-
)
506-
node_name = prefix + "::" + node_info.arg_name
507-
if node_info.path:
508-
suffix = "-".join(map(str, node_info.path))
509-
node_name += "::" + suffix
510-
return node_name

‎src/_pytask/collect_utils.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Iterable
1212
from typing import TYPE_CHECKING
1313

14+
import attrs
1415
from _pytask._inspect import get_annotations
1516
from _pytask.exceptions import NodeNotCollectedError
1617
from _pytask.mark_utils import has_mark
@@ -24,6 +25,7 @@
2425
from _pytask.tree_util import tree_leaves
2526
from _pytask.tree_util import tree_map
2627
from _pytask.tree_util import tree_map_with_path
28+
from _pytask.typing import no_default
2729
from _pytask.typing import ProductType
2830
from attrs import define
2931
from attrs import field
@@ -327,9 +329,15 @@ def parse_dependencies_from_task_function(
327329
isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes)
328330
)
329331
if not isinstance(nodes, PNode) and are_all_nodes_python_nodes_without_hash:
330-
prefix = task_path.as_posix() + "::" + task_name if task_path else task_name
331-
node_name = prefix + "::" + parameter_name
332-
332+
node_name = create_name_of_python_node(
333+
NodeInfo(
334+
arg_name=parameter_name,
335+
path=(),
336+
value=value,
337+
task_path=task_path,
338+
task_name=task_name,
339+
)
340+
)
333341
dependencies[parameter_name] = PythonNode(value=value, name=node_name)
334342
else:
335343
dependencies[parameter_name] = nodes
@@ -606,6 +614,13 @@ def _collect_dependency(
606614
"""
607615
node = node_info.value
608616

617+
if isinstance(node, PythonNode) and node.value is no_default:
618+
# If a node is a dependency and its value is not set, the node is a product in
619+
# another task and the value will be set there. Thus, we wrap the original node
620+
# in another node to retrieve the value after it is set.
621+
new_node = attrs.evolve(node, value=node)
622+
node_info = node_info._replace(value=new_node)
623+
609624
collected_node = session.hook.pytask_collect_node(
610625
session=session, path=path, node_info=node_info
611626
)
@@ -653,10 +668,25 @@ def _collect_product(
653668
collected_node = session.hook.pytask_collect_node(
654669
session=session, path=path, node_info=node_info
655670
)
671+
656672
if collected_node is None:
657673
msg = (
658674
f"{node!r} can't be parsed as a product for task {task_name!r} in {path!r}."
659675
)
660676
raise NodeNotCollectedError(msg)
661677

662678
return collected_node
679+
680+
681+
def create_name_of_python_node(node_info: NodeInfo) -> str:
682+
"""Create name of PythonNode."""
683+
prefix = (
684+
node_info.task_path.as_posix() + "::" + node_info.task_name
685+
if node_info.task_path
686+
else node_info.task_name
687+
)
688+
node_name = prefix + "::" + node_info.arg_name
689+
if node_info.path:
690+
suffix = "-".join(map(str, node_info.path))
691+
node_name += "::" + suffix
692+
return node_name

‎src/_pytask/dag.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from _pytask.node_protocols import PPathNode
3030
from _pytask.node_protocols import PTask
3131
from _pytask.node_protocols import PTaskWithPath
32-
from _pytask.path import find_common_ancestor_of_nodes
32+
from _pytask.nodes import PythonNode
3333
from _pytask.report import DagReport
3434
from _pytask.shared import reduce_names_of_multiple_nodes
3535
from _pytask.shared import reduce_node_name
@@ -87,6 +87,16 @@ def pytask_dag_create_dag(tasks: list[PTask]) -> nx.DiGraph:
8787
tree_map(lambda x: dag.add_node(x.name, node=x), task.produces)
8888
tree_map(lambda x: dag.add_edge(task.name, x.name), task.produces)
8989

90+
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
91+
# another task that is a dependency in the current task. Thus, draw an edge
92+
# connecting the two nodes.
93+
tree_map(
94+
lambda x: dag.add_edge(x.value.name, x.name)
95+
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
96+
else None,
97+
task.depends_on,
98+
)
99+
90100
_check_if_dag_has_cycles(dag)
91101

92102
return dag
@@ -114,7 +124,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
114124
def pytask_dag_validate_dag(session: Session, dag: nx.DiGraph) -> None:
115125
"""Validate the DAG."""
116126
_check_if_root_nodes_are_available(dag, session.config["paths"])
117-
_check_if_tasks_have_the_same_products(dag)
127+
_check_if_tasks_have_the_same_products(dag, session.config["paths"])
118128

119129

120130
def _have_task_or_neighbors_changed(
@@ -292,7 +302,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
292302
return render_to_string(tree, console=console, strip_styles=True)
293303

294304

295-
def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
305+
def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None:
296306
nodes_created_by_multiple_tasks = []
297307

298308
for node in dag.nodes:
@@ -303,19 +313,11 @@ def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
303313
nodes_created_by_multiple_tasks.append(node)
304314

305315
if nodes_created_by_multiple_tasks:
306-
all_names = nodes_created_by_multiple_tasks + [
307-
predecessor
308-
for node in nodes_created_by_multiple_tasks
309-
for predecessor in dag.predecessors(node)
310-
]
311-
common_ancestor = find_common_ancestor_of_nodes(*all_names)
312316
dictionary = {}
313317
for node in nodes_created_by_multiple_tasks:
314-
short_node_name = reduce_node_name(
315-
dag.nodes[node]["node"], [common_ancestor]
316-
)
318+
short_node_name = reduce_node_name(dag.nodes[node]["node"], paths)
317319
short_predecessors = reduce_names_of_multiple_nodes(
318-
dag.predecessors(node), dag, [common_ancestor]
320+
dag.predecessors(node), dag, paths
319321
)
320322
dictionary[short_node_name] = short_predecessors
321323
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")

‎src/_pytask/nodes.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from _pytask.node_protocols import PPathNode
1414
from _pytask.node_protocols import PTask
1515
from _pytask.node_protocols import PTaskWithPath
16+
from _pytask.typing import no_default
17+
from _pytask.typing import NoDefault
1618
from attrs import define
1719
from attrs import field
1820

@@ -47,9 +49,7 @@ class TaskWithoutPath(PTask):
4749
A list of markers attached to the task function.
4850
report_sections
4951
Reports with entries for when, what, and content.
50-
51-
Attributes
52-
----------
52+
attributes: dict[Any, Any]
5353
A dictionary to store additional information of the task.
5454
"""
5555

@@ -79,6 +79,8 @@ def execute(self, **kwargs: Any) -> None:
7979
class Task(PTaskWithPath):
8080
"""The class for tasks which are Python functions.
8181
82+
Attributes
83+
----------
8284
base_name
8385
The base name of the task.
8486
path
@@ -97,9 +99,7 @@ class Task(PTaskWithPath):
9799
A list of markers attached to the task function.
98100
report_sections
99101
Reports with entries for when, what, and content.
100-
101-
Attributes
102-
----------
102+
attributes: dict[Any, Any]
103103
A dictionary to store additional information of the task.
104104
105105
"""
@@ -204,11 +204,13 @@ class PythonNode(PNode):
204204
"""
205205

206206
name: str = ""
207-
value: Any = None
207+
value: Any | NoDefault = no_default
208208
hash: bool | Callable[[Any], bool] = False # noqa: A003
209209

210210
def load(self) -> Any:
211211
"""Load the value."""
212+
if isinstance(self.value, PythonNode):
213+
return self.value.load()
212214
return self.value
213215

214216
def save(self, value: Any) -> None:
@@ -234,11 +236,12 @@ def state(self) -> str | None:
234236
235237
"""
236238
if self.hash:
239+
value = self.load()
237240
if callable(self.hash):
238-
return str(self.hash(self.value))
239-
if isinstance(self.value, str):
240-
return str(hashlib.sha256(self.value.encode()).hexdigest())
241-
if isinstance(self.value, bytes):
242-
return str(hashlib.sha256(self.value).hexdigest())
243-
return str(hash(self.value))
241+
return str(self.hash(value))
242+
if isinstance(value, str):
243+
return str(hashlib.sha256(value.encode()).hexdigest())
244+
if isinstance(value, bytes):
245+
return str(hashlib.sha256(value).hexdigest())
246+
return str(hash(value))
244247
return "0"

‎src/_pytask/path.py

-6
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ def find_closest_ancestor(path: Path, potential_ancestors: Sequence[Path]) -> Pa
6868
return sorted(potential_closest_ancestors, key=lambda x: len(x.parts))[-1]
6969

7070

71-
def find_common_ancestor_of_nodes(*names: str) -> Path:
72-
"""Find the common ancestor from task names and nodes."""
73-
cleaned_names = [Path(name.split("::")[0]) for name in names]
74-
return find_common_ancestor(*cleaned_names)
75-
76-
7771
def find_common_ancestor(*paths: Path) -> Path:
7872
"""Find a common ancestor of many paths."""
7973
return Path(os.path.commonpath(paths))

‎src/_pytask/typing.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from __future__ import annotations
22

33
import functools
4+
from enum import Enum
45
from typing import Any
6+
from typing import Final
7+
from typing import Literal
8+
from typing import TYPE_CHECKING
59

6-
from attr import define
10+
from attrs import define
11+
12+
if TYPE_CHECKING:
13+
from typing_extensions import TypeAlias
714

815

916
__all__ = ["Product", "ProductType"]
@@ -18,7 +25,29 @@ class ProductType:
1825
"""ProductType: A singleton to mark products in annotations."""
1926

2027

21-
def is_task_function(func: Any) -> bool:
22-
return (callable(func) and hasattr(func, "__name__")) or (
23-
isinstance(func, functools.partial) and hasattr(func.func, "__name__")
28+
def is_task_function(obj: Any) -> bool:
29+
"""Check if an object is a task function."""
30+
return (callable(obj) and hasattr(obj, "__name__")) or (
31+
isinstance(obj, functools.partial) and hasattr(obj.func, "__name__")
2432
)
33+
34+
35+
class _NoDefault(Enum):
36+
"""A singleton for no defaults.
37+
38+
We make this an Enum
39+
1) because it round-trips through pickle correctly (see GH#40397)
40+
2) because mypy does not understand singletons
41+
42+
"""
43+
44+
no_default = "NO_DEFAULT"
45+
46+
def __repr__(self) -> str:
47+
return "<no_default>"
48+
49+
50+
no_default: Final = _NoDefault.no_default
51+
"""The value for missing defaults."""
52+
NoDefault: TypeAlias = Literal[_NoDefault.no_default]
53+
"""The type annotation."""

‎tests/test_execute.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,8 @@ def test_execute_tasks_and_pass_values_only_by_python_nodes(runner, tmp_path):
714714
from typing_extensions import Annotated
715715
from pathlib import Path
716716
717-
718717
node_text = PythonNode(name="text")
719718
720-
721719
def task_create_text() -> Annotated[int, node_text]:
722720
return "This is the text."
723721
@@ -743,21 +741,21 @@ def test_execute_tasks_via_functional_api(tmp_path):
743741
from pathlib import Path
744742
745743
746-
node_text = PythonNode(name="text", hash=True)
744+
node_text = PythonNode()
747745
748746
def create_text() -> Annotated[int, node_text]:
749747
return "This is the text."
750748
751749
node_file = PathNode.from_path(Path(__file__).parent.joinpath("file.txt"))
752750
753-
def create_file(text: Annotated[int, node_text]) -> Annotated[str, node_file]:
754-
return text
751+
def create_file(content: Annotated[str, node_text]) -> Annotated[str, node_file]:
752+
return content
755753
756754
if __name__ == "__main__":
757755
session = pytask.build(tasks=[create_file, create_text])
758756
759757
assert len(session.tasks) == 2
760-
assert len(session.dag.nodes) == 4
758+
assert len(session.dag.nodes) == 5
761759
"""
762760
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
763761
result = subprocess.run(

0 commit comments

Comments
 (0)
Please sign in to comment.