1
- """Contains code related to resolving dependencies ."""
1
+ """Contains code related to the DAG ."""
2
2
3
3
from __future__ import annotations
4
4
34
34
from _pytask .session import Session
35
35
36
36
37
- __all__ = ["create_dag" ]
37
+ __all__ = ["create_dag" , "create_dag_from_session" ]
38
38
39
39
40
40
def create_dag (session : Session ) -> nx .DiGraph :
41
41
"""Create a directed acyclic graph (DAG) for the workflow."""
42
42
try :
43
- dag = _create_dag (tasks = session .tasks )
44
- _check_if_dag_has_cycles (dag )
45
- _check_if_tasks_have_the_same_products (dag , session .config ["paths" ])
46
- _modify_dag (session = session , dag = dag )
47
- select_tasks_by_marks_and_expressions (session = session , dag = dag )
48
-
43
+ dag = create_dag_from_session (session )
49
44
except Exception : # noqa: BLE001
50
45
report = DagReport .from_exception (sys .exc_info ())
51
46
_log_dag (report = report )
@@ -55,7 +50,17 @@ def create_dag(session: Session) -> nx.DiGraph:
55
50
return dag
56
51
57
52
58
- def _create_dag (tasks : list [PTask ]) -> nx .DiGraph :
53
+ def create_dag_from_session (session : Session ) -> nx .DiGraph :
54
+ """Create a DAG from a session."""
55
+ dag = _create_dag_from_tasks (tasks = session .tasks )
56
+ _check_if_dag_has_cycles (dag )
57
+ _check_if_tasks_have_the_same_products (dag , session .config ["paths" ])
58
+ dag = _modify_dag (session = session , dag = dag )
59
+ select_tasks_by_marks_and_expressions (session = session , dag = dag )
60
+ return dag
61
+
62
+
63
+ def _create_dag_from_tasks (tasks : list [PTask ]) -> nx .DiGraph :
59
64
"""Create the DAG from tasks, dependencies and products."""
60
65
61
66
def _add_dependency (
@@ -98,7 +103,7 @@ def _add_product(
98
103
return dag
99
104
100
105
101
- def _modify_dag (session : Session , dag : nx .DiGraph ) -> None :
106
+ def _modify_dag (session : Session , dag : nx .DiGraph ) -> nx . DiGraph :
102
107
"""Create dependencies between tasks when using ``@task(after=...)``."""
103
108
temporary_id_to_task = {
104
109
task .attributes ["collection_id" ]: task
@@ -119,6 +124,7 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> None:
119
124
for signature in signatures :
120
125
for successor in dag .successors (signature ):
121
126
dag .add_edge (successor , task .signature )
127
+ return dag
122
128
123
129
124
130
def _check_if_dag_has_cycles (dag : nx .DiGraph ) -> None :
0 commit comments