Skip to content

Commit ce6a825

Browse files
authored
Simplify the tree_map code for generating the DAG. (#447)
1 parent 0fceb2a commit ce6a825

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

docs/source/changes.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask).
77

8-
## 0.4.1 - 2023-10-xx
8+
## 0.4.1 - 2023-10-11
99

1010
- {pull}`443` ensures that `PythonNode.name` is always unique by only handling it
1111
internally.
1212
- {pull}`444` moves all content of `setup.cfg` to `pyproject.toml`.
1313
- {pull}`446` refactors `create_name_of_python_node` and fixes `PythonNode`s as returns.
14-
- {pull}`447` fixes handling multiple product annotations of a task.
14+
- {pull}`447` simplifies the `tree_map` code while generating the DAG.
15+
- {pull}`448` fixes handling multiple product annotations of a task.
1516

1617
## 0.4.0 - 2023-10-07
1718

src/_pytask/dag.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,30 @@ def pytask_dag(session: Session) -> bool | None:
7676
@hookimpl
7777
def pytask_dag_create_dag(tasks: list[PTask]) -> nx.DiGraph:
7878
"""Create the DAG from tasks, dependencies and products."""
79+
80+
def _add_dependency(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
81+
"""Add a dependency to the DAG."""
82+
dag.add_node(node.name, node=node)
83+
dag.add_edge(node.name, task.name)
84+
85+
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
86+
# another task that is a dependency in the current task. Thus, draw an edge
87+
# connecting the two nodes.
88+
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
89+
dag.add_edge(node.value.name, node.name)
90+
91+
def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
92+
"""Add a product to the DAG."""
93+
dag.add_node(node.name, node=node)
94+
dag.add_edge(task.name, node.name)
95+
7996
dag = nx.DiGraph()
8097

8198
for task in tasks:
8299
dag.add_node(task.name, task=task)
83100

84-
tree_map(lambda x: dag.add_node(x.name, node=x), task.depends_on)
85-
tree_map(lambda x: dag.add_edge(x.name, task.name), task.depends_on)
86-
87-
tree_map(lambda x: dag.add_node(x.name, node=x), task.produces)
88-
tree_map(lambda x: dag.add_edge(task.name, x.name), task.produces)
101+
tree_map(lambda x: _add_dependency(dag, task, x), task.depends_on)
102+
tree_map(lambda x: _add_product(dag, task, x), task.produces)
89103

90104
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
91105
# another task that is a dependency in the current task. Thus, draw an edge

0 commit comments

Comments
 (0)