Skip to content

Commit 43486d7

Browse files
committed
Fix.
1 parent eacecf4 commit 43486d7

File tree

3 files changed

+28
-24
lines changed

3 files changed

+28
-24
lines changed

src/_pytask/dag.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Contains code related to resolving dependencies."""
1+
"""Contains code related to the DAG."""
22

33
from __future__ import annotations
44

@@ -34,18 +34,13 @@
3434
from _pytask.session import Session
3535

3636

37-
__all__ = ["create_dag"]
37+
__all__ = ["create_dag", "create_dag_from_session"]
3838

3939

4040
def create_dag(session: Session) -> nx.DiGraph:
4141
"""Create a directed acyclic graph (DAG) for the workflow."""
4242
try:
43-
dag = _create_dag(tasks=session.tasks)
44-
_check_if_dag_has_cycles(dag)
45-
_check_if_tasks_have_the_same_products(dag, session.config["paths"])
46-
_modify_dag(session=session, dag=dag)
47-
select_tasks_by_marks_and_expressions(session=session, dag=dag)
48-
43+
dag = create_dag_from_session(session)
4944
except Exception: # noqa: BLE001
5045
report = DagReport.from_exception(sys.exc_info())
5146
_log_dag(report=report)
@@ -55,7 +50,17 @@ def create_dag(session: Session) -> nx.DiGraph:
5550
return dag
5651

5752

58-
def _create_dag(tasks: list[PTask]) -> nx.DiGraph:
53+
def create_dag_from_session(session: Session) -> nx.DiGraph:
54+
"""Create a DAG from a session."""
55+
dag = _create_dag_from_tasks(tasks=session.tasks)
56+
_check_if_dag_has_cycles(dag)
57+
_check_if_tasks_have_the_same_products(dag, session.config["paths"])
58+
dag = _modify_dag(session=session, dag=dag)
59+
select_tasks_by_marks_and_expressions(session=session, dag=dag)
60+
return dag
61+
62+
63+
def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph:
5964
"""Create the DAG from tasks, dependencies and products."""
6065

6166
def _add_dependency(
@@ -98,7 +103,7 @@ def _add_product(
98103
return dag
99104

100105

101-
def _modify_dag(session: Session, dag: nx.DiGraph) -> None:
106+
def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
102107
"""Create dependencies between tasks when using ``@task(after=...)``."""
103108
temporary_id_to_task = {
104109
task.attributes["collection_id"]: task
@@ -119,6 +124,7 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> None:
119124
for signature in signatures:
120125
for successor in dag.successors(signature):
121126
dag.add_edge(successor, task.signature)
127+
return dag
122128

123129

124130
def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:

src/_pytask/provisional_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Contains utilities related to provisional nodes and task generators."""
2+
13
from __future__ import annotations
24

35
import sys
@@ -6,12 +8,8 @@
68
from typing import Any
79

810
from _pytask.collect_utils import collect_dependency
9-
from _pytask.dag import _check_if_dag_has_cycles
10-
from _pytask.dag import _check_if_tasks_have_the_same_products
11-
from _pytask.dag import _create_dag
12-
from _pytask.dag import _modify_dag
11+
from _pytask.dag import create_dag_from_session
1312
from _pytask.dag_utils import TopologicalSorter
14-
from _pytask.mark import select_tasks_by_marks_and_expressions
1513
from _pytask.models import NodeInfo
1614
from _pytask.node_protocols import PNode
1715
from _pytask.node_protocols import PProvisionalNode
@@ -72,14 +70,14 @@ def collect_provisional_nodes(
7270

7371

7472
def recreate_dag(session: Session, task: PTask) -> None:
75-
"""Recreate the DAG."""
73+
"""Recreate the DAG when provisional nodes are resolved.
74+
75+
If the DAG resolution fails, the error is attached as an execution report since
76+
there is not better mechanic yet to display the error.
77+
78+
"""
7679
try:
77-
dag = _create_dag(tasks=session.tasks)
78-
_check_if_dag_has_cycles(dag)
79-
_check_if_tasks_have_the_same_products(dag, session.config["paths"])
80-
_modify_dag(session=session, dag=dag)
81-
select_tasks_by_marks_and_expressions(session=session, dag=dag)
82-
session.dag = dag
80+
session.dag = create_dag_from_session(session)
8381
session.scheduler = TopologicalSorter.from_dag_and_sorter(
8482
session.dag, session.scheduler
8583
)

tests/test_dag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
import pytest
8-
from _pytask.dag import _create_dag
8+
from _pytask.dag import _create_dag_from_tasks
99
from pytask import ExitCode
1010
from pytask import PathNode
1111
from pytask import Task
@@ -26,7 +26,7 @@ def test_create_dag():
2626
1: PathNode.from_path(root / "node_2"),
2727
},
2828
)
29-
dag = _create_dag(tasks=[task])
29+
dag = _create_dag_from_tasks(tasks=[task])
3030

3131
for signature in (
3232
"90bb899a1b60da28ff70352cfb9f34a8bed485597c7f40eed9bd4c6449147525",

0 commit comments

Comments
 (0)