|
19 | 19 | from _pytask.console import render_to_string
|
20 | 20 | from _pytask.exceptions import ResolvingDependenciesError
|
21 | 21 | from _pytask.mark import select_by_after_keyword
|
| 22 | +from _pytask.mark import select_tasks_by_marks_and_expressions |
22 | 23 | from _pytask.node_protocols import PNode
|
23 | 24 | from _pytask.node_protocols import PTask
|
24 | 25 | from _pytask.nodes import PythonNode
|
25 |
| -from _pytask.pluginmanager import hookimpl |
26 | 26 | from _pytask.reports import DagReport
|
27 | 27 | from _pytask.shared import reduce_names_of_multiple_nodes
|
28 | 28 | from _pytask.tree_util import tree_map
|
|
33 | 33 | from _pytask.session import Session
|
34 | 34 |
|
35 | 35 |
|
36 |
| -@hookimpl |
37 |
| -def pytask_dag(session: Session) -> bool | None: |
| 36 | +__all__ = ["create_dag"] |
| 37 | + |
| 38 | + |
| 39 | +def create_dag(session: Session) -> nx.DiGraph: |
38 | 40 | """Create a directed acyclic graph (DAG) for the workflow."""
|
39 | 41 | try:
|
40 |
| - session.dag = session.hook.pytask_dag_create_dag( |
41 |
| - session=session, tasks=session.tasks |
42 |
| - ) |
43 |
| - session.hook.pytask_dag_modify_dag(session=session, dag=session.dag) |
| 42 | + dag = _create_dag(tasks=session.tasks) |
| 43 | + _check_if_dag_has_cycles(dag) |
| 44 | + _check_if_tasks_have_the_same_products(dag, session.config["paths"]) |
| 45 | + _modify_dag(session=session, dag=dag) |
| 46 | + select_tasks_by_marks_and_expressions(session=session, dag=dag) |
44 | 47 |
|
45 | 48 | except Exception: # noqa: BLE001
|
46 | 49 | report = DagReport.from_exception(sys.exc_info())
|
47 |
| - session.hook.pytask_dag_log(session=session, report=report) |
| 50 | + _log_dag(report=report) |
48 | 51 | session.dag_report = report
|
49 | 52 |
|
50 | 53 | raise ResolvingDependenciesError from None
|
51 |
| - |
52 |
| - else: |
53 |
| - return True |
| 54 | + return dag |
54 | 55 |
|
55 | 56 |
|
56 |
| -@hookimpl |
57 |
| -def pytask_dag_create_dag(session: Session, tasks: list[PTask]) -> nx.DiGraph: |
| 57 | +def _create_dag(tasks: list[PTask]) -> nx.DiGraph: |
58 | 58 | """Create the DAG from tasks, dependencies and products."""
|
59 | 59 |
|
60 | 60 | def _add_dependency(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
|
@@ -90,15 +90,10 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
|
90 | 90 | else None,
|
91 | 91 | task.depends_on,
|
92 | 92 | )
|
93 |
| - |
94 |
| - _check_if_dag_has_cycles(dag) |
95 |
| - _check_if_tasks_have_the_same_products(dag, session.config["paths"]) |
96 |
| - |
97 | 93 | return dag
|
98 | 94 |
|
99 | 95 |
|
100 |
| -@hookimpl |
101 |
| -def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None: |
| 96 | +def _modify_dag(session: Session, dag: nx.DiGraph) -> None: |
102 | 97 | """Create dependencies between tasks when using ``@task(after=...)``."""
|
103 | 98 | temporary_id_to_task = {
|
104 | 99 | task.attributes["collection_id"]: task
|
@@ -194,8 +189,7 @@ def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -
|
194 | 189 | raise ResolvingDependenciesError(msg)
|
195 | 190 |
|
196 | 191 |
|
197 |
| -@hookimpl |
198 |
| -def pytask_dag_log(report: DagReport) -> None: |
| 192 | +def _log_dag(report: DagReport) -> None: |
199 | 193 | """Log errors which happened while resolving dependencies."""
|
200 | 194 | console.print()
|
201 | 195 | console.rule(
|
|
0 commit comments