Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PythonNode when used as return. #446

Merged
merged 7 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`443` ensures that `PythonNode.name` is always unique by only handling it
internally.
- {pull}`444` moves all content of `setup.cfg` to `pyproject.toml`.
- {pull}`446` refactors `create_name_of_python_node` and fixes `PythonNode`s as returns.
- {pull}`447` fixes handling multiple product annotations of a task.

## 0.4.0 - 2023-10-07
Expand Down
22 changes: 4 additions & 18 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Iterable
from typing import TYPE_CHECKING

from _pytask.collect_utils import create_name_of_python_node
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_products_from_task_function
from _pytask.config import hookimpl
Expand Down Expand Up @@ -305,7 +306,7 @@ def pytask_collect_task(

@hookimpl(trylast=True)
def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PNode:
"""Collect a node of a task as a :class:`pytask.nodes.PathNode`.
"""Collect a node of a task as a :class:`pytask.PNode`.

Strings are assumed to be paths. This might be a strict assumption, but since this
hook is executed at last and possible errors will be shown, it seems reasonable and
Expand All @@ -325,8 +326,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
node = node_info.value

if isinstance(node, PythonNode):
node_name = _create_name_of_python_node(node_info)
node.name = node_name
node.name = create_name_of_python_node(node_info)
return node

if isinstance(node, PPathNode) and not node.path.is_absolute():
Expand Down Expand Up @@ -354,7 +354,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN
)
return PathNode.from_path(node)

node_name = _create_name_of_python_node(node_info)
node_name = create_name_of_python_node(node_info)
return PythonNode(value=node, name=node_name)


Expand Down Expand Up @@ -494,17 +494,3 @@ def pytask_collect_log(
)

raise CollectionError


def _create_name_of_python_node(node_info: NodeInfo) -> str:
"""Create name of PythonNode."""
prefix = (
node_info.task_path.as_posix() + "::" + node_info.task_name
if node_info.task_path
else node_info.task_name
)
node_name = prefix + "::" + node_info.arg_name
if node_info.path:
suffix = "-".join(map(str, node_info.path))
node_name += "::" + suffix
return node_name
36 changes: 33 additions & 3 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Iterable
from typing import TYPE_CHECKING

import attrs
from _pytask._inspect import get_annotations
from _pytask.exceptions import NodeNotCollectedError
from _pytask.mark_utils import has_mark
Expand All @@ -24,6 +25,7 @@
from _pytask.tree_util import tree_leaves
from _pytask.tree_util import tree_map
from _pytask.tree_util import tree_map_with_path
from _pytask.typing import no_default
from _pytask.typing import ProductType
from attrs import define
from attrs import field
Expand Down Expand Up @@ -327,9 +329,15 @@ def parse_dependencies_from_task_function(
isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes)
)
if not isinstance(nodes, PNode) and are_all_nodes_python_nodes_without_hash:
prefix = task_path.as_posix() + "::" + task_name if task_path else task_name
node_name = prefix + "::" + parameter_name

node_name = create_name_of_python_node(
NodeInfo(
arg_name=parameter_name,
path=(),
value=value,
task_path=task_path,
task_name=task_name,
)
)
dependencies[parameter_name] = PythonNode(value=value, name=node_name)
else:
dependencies[parameter_name] = nodes
Expand Down Expand Up @@ -606,6 +614,13 @@ def _collect_dependency(
"""
node = node_info.value

if isinstance(node, PythonNode) and node.value is no_default:
# If a node is a dependency and its value is not set, the node is a product in
# another task and the value will be set there. Thus, we wrap the original node
# in another node to retrieve the value after it is set.
new_node = attrs.evolve(node, value=node)
node_info = node_info._replace(value=new_node)

collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info
)
Expand Down Expand Up @@ -653,10 +668,25 @@ def _collect_product(
collected_node = session.hook.pytask_collect_node(
session=session, path=path, node_info=node_info
)

if collected_node is None:
msg = (
f"{node!r} can't be parsed as a product for task {task_name!r} in {path!r}."
)
raise NodeNotCollectedError(msg)

return collected_node


def create_name_of_python_node(node_info: NodeInfo) -> str:
"""Create name of PythonNode."""
prefix = (
node_info.task_path.as_posix() + "::" + node_info.task_name
if node_info.task_path
else node_info.task_name
)
node_name = prefix + "::" + node_info.arg_name
if node_info.path:
suffix = "-".join(map(str, node_info.path))
node_name += "::" + suffix
return node_name
28 changes: 15 additions & 13 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.path import find_common_ancestor_of_nodes
from _pytask.nodes import PythonNode
from _pytask.report import DagReport
from _pytask.shared import reduce_names_of_multiple_nodes
from _pytask.shared import reduce_node_name
Expand Down Expand Up @@ -87,6 +87,16 @@ def pytask_dag_create_dag(tasks: list[PTask]) -> nx.DiGraph:
tree_map(lambda x: dag.add_node(x.name, node=x), task.produces)
tree_map(lambda x: dag.add_edge(task.name, x.name), task.produces)

# If a node is a PythonNode wrapped in another PythonNode, it is a product from
# another task that is a dependency in the current task. Thus, draw an edge
# connecting the two nodes.
tree_map(
lambda x: dag.add_edge(x.value.name, x.name)
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
else None,
task.depends_on,
)

_check_if_dag_has_cycles(dag)

return dag
Expand Down Expand Up @@ -114,7 +124,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
def pytask_dag_validate_dag(session: Session, dag: nx.DiGraph) -> None:
"""Validate the DAG."""
_check_if_root_nodes_are_available(dag, session.config["paths"])
_check_if_tasks_have_the_same_products(dag)
_check_if_tasks_have_the_same_products(dag, session.config["paths"])


def _have_task_or_neighbors_changed(
Expand Down Expand Up @@ -292,7 +302,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
return render_to_string(tree, console=console, strip_styles=True)


def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None:
nodes_created_by_multiple_tasks = []

for node in dag.nodes:
Expand All @@ -303,19 +313,11 @@ def _check_if_tasks_have_the_same_products(dag: nx.DiGraph) -> None:
nodes_created_by_multiple_tasks.append(node)

if nodes_created_by_multiple_tasks:
all_names = nodes_created_by_multiple_tasks + [
predecessor
for node in nodes_created_by_multiple_tasks
for predecessor in dag.predecessors(node)
]
common_ancestor = find_common_ancestor_of_nodes(*all_names)
dictionary = {}
for node in nodes_created_by_multiple_tasks:
short_node_name = reduce_node_name(
dag.nodes[node]["node"], [common_ancestor]
)
short_node_name = reduce_node_name(dag.nodes[node]["node"], paths)
short_predecessors = reduce_names_of_multiple_nodes(
dag.predecessors(node), dag, [common_ancestor]
dag.predecessors(node), dag, paths
)
dictionary[short_node_name] = short_predecessors
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")
Expand Down
29 changes: 16 additions & 13 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from _pytask.node_protocols import PPathNode
from _pytask.node_protocols import PTask
from _pytask.node_protocols import PTaskWithPath
from _pytask.typing import no_default
from _pytask.typing import NoDefault
from attrs import define
from attrs import field

Expand Down Expand Up @@ -47,9 +49,7 @@ class TaskWithoutPath(PTask):
A list of markers attached to the task function.
report_sections
Reports with entries for when, what, and content.

Attributes
----------
attributes: dict[Any, Any]
A dictionary to store additional information of the task.
"""

Expand Down Expand Up @@ -79,6 +79,8 @@ def execute(self, **kwargs: Any) -> None:
class Task(PTaskWithPath):
"""The class for tasks which are Python functions.

Attributes
----------
base_name
The base name of the task.
path
Expand All @@ -97,9 +99,7 @@ class Task(PTaskWithPath):
A list of markers attached to the task function.
report_sections
Reports with entries for when, what, and content.

Attributes
----------
attributes: dict[Any, Any]
A dictionary to store additional information of the task.

"""
Expand Down Expand Up @@ -204,11 +204,13 @@ class PythonNode(PNode):
"""

name: str = ""
value: Any = None
value: Any | NoDefault = no_default
hash: bool | Callable[[Any], bool] = False # noqa: A003

def load(self) -> Any:
"""Load the value."""
if isinstance(self.value, PythonNode):
return self.value.load()
return self.value

def save(self, value: Any) -> None:
Expand All @@ -234,11 +236,12 @@ def state(self) -> str | None:

"""
if self.hash:
value = self.load()
if callable(self.hash):
return str(self.hash(self.value))
if isinstance(self.value, str):
return str(hashlib.sha256(self.value.encode()).hexdigest())
if isinstance(self.value, bytes):
return str(hashlib.sha256(self.value).hexdigest())
return str(hash(self.value))
return str(self.hash(value))
if isinstance(value, str):
return str(hashlib.sha256(value.encode()).hexdigest())
if isinstance(value, bytes):
return str(hashlib.sha256(value).hexdigest())
return str(hash(value))
return "0"
6 changes: 0 additions & 6 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def find_closest_ancestor(path: Path, potential_ancestors: Sequence[Path]) -> Pa
return sorted(potential_closest_ancestors, key=lambda x: len(x.parts))[-1]


def find_common_ancestor_of_nodes(*names: str) -> Path:
"""Find the common ancestor from task names and nodes."""
cleaned_names = [Path(name.split("::")[0]) for name in names]
return find_common_ancestor(*cleaned_names)


def find_common_ancestor(*paths: Path) -> Path:
"""Find a common ancestor of many paths."""
return Path(os.path.commonpath(paths))
Expand Down
37 changes: 33 additions & 4 deletions src/_pytask/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations

import functools
from enum import Enum
from typing import Any
from typing import Final
from typing import Literal
from typing import TYPE_CHECKING

from attr import define
from attrs import define

if TYPE_CHECKING:
from typing_extensions import TypeAlias


__all__ = ["Product", "ProductType"]
Expand All @@ -18,7 +25,29 @@ class ProductType:
"""ProductType: A singleton to mark products in annotations."""


def is_task_function(func: Any) -> bool:
return (callable(func) and hasattr(func, "__name__")) or (
isinstance(func, functools.partial) and hasattr(func.func, "__name__")
def is_task_function(obj: Any) -> bool:
"""Check if an object is a task function."""
return (callable(obj) and hasattr(obj, "__name__")) or (
isinstance(obj, functools.partial) and hasattr(obj.func, "__name__")
)


class _NoDefault(Enum):
"""A singleton for no defaults.

We make this an Enum
1) because it round-trips through pickle correctly (see GH#40397)
2) because mypy does not understand singletons

"""

no_default = "NO_DEFAULT"

def __repr__(self) -> str:
return "<no_default>"


no_default: Final = _NoDefault.no_default
"""The value for missing defaults."""
NoDefault: TypeAlias = Literal[_NoDefault.no_default]
"""The type annotation."""
10 changes: 4 additions & 6 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,8 @@ def test_execute_tasks_and_pass_values_only_by_python_nodes(runner, tmp_path):
from typing_extensions import Annotated
from pathlib import Path


node_text = PythonNode(name="text")


def task_create_text() -> Annotated[int, node_text]:
return "This is the text."

Expand All @@ -743,21 +741,21 @@ def test_execute_tasks_via_functional_api(tmp_path):
from pathlib import Path


node_text = PythonNode(name="text", hash=True)
node_text = PythonNode()

def create_text() -> Annotated[int, node_text]:
return "This is the text."

node_file = PathNode.from_path(Path(__file__).parent.joinpath("file.txt"))

def create_file(text: Annotated[int, node_text]) -> Annotated[str, node_file]:
return text
def create_file(content: Annotated[str, node_text]) -> Annotated[str, node_file]:
return content

if __name__ == "__main__":
session = pytask.build(tasks=[create_file, create_text])

assert len(session.tasks) == 2
assert len(session.dag.nodes) == 4
assert len(session.dag.nodes) == 5
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = subprocess.run(
Expand Down