Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tasks to be pending. #609

Merged
merged 5 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`603` fixes an example in the documentation about capturing warnings.
- {pull}`604` fixes some examples with `PythonNode`s in the documentation.
- {pull}`605` improves checks and CI.
- {pull}`609` allows a pending status for tasks. Useful for async backends implemented
in pytask-parallel.

## 0.4.7 - 2024-03-19

Expand Down
4 changes: 1 addition & 3 deletions src/_pytask/dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
priorities = _extract_priorities_from_tasks(tasks)

task_signatures = {task.signature for task in tasks}
task_dict = {
name: nx.ancestors(dag, name) & task_signatures for name in task_signatures
}
task_dict = {s: nx.ancestors(dag, s) & task_signatures for s in task_signatures}
task_dag = nx.DiGraph(task_dict).reverse()

return cls(dag=task_dag, priorities=priorities)
Expand Down
5 changes: 4 additions & 1 deletion src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from _pytask.exceptions import ExecutionError
from _pytask.exceptions import NodeLoadError
from _pytask.exceptions import NodeNotFoundError
from _pytask.logging_utils import TaskExecutionStatus
from _pytask.mark import Mark
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
Expand Down Expand Up @@ -98,7 +99,9 @@ def pytask_execute_build(session: Session) -> bool | None:
@hookimpl
def pytask_execute_task_protocol(session: Session, task: PTask) -> ExecutionReport:
"""Follow the protocol to execute each task."""
session.hook.pytask_execute_task_log_start(session=session, task=task)
session.hook.pytask_execute_task_log_start(
session=session, task=task, status=TaskExecutionStatus.RUNNING
)
try:
session.hook.pytask_execute_task_setup(session=session, task=task)
session.hook.pytask_execute_task(session=session, task=task)
Expand Down
5 changes: 4 additions & 1 deletion src/_pytask/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import click
from pluggy import PluginManager

from _pytask.logging_utils import TaskExecutionStatus
from _pytask.models import NodeInfo
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PProvisionalNode
Expand Down Expand Up @@ -255,7 +256,9 @@ def pytask_execute_task_protocol(session: Session, task: PTask) -> ExecutionRepo


@hookspec(firstresult=True)
def pytask_execute_task_log_start(session: Session, task: PTask) -> None:
def pytask_execute_task_log_start(
session: Session, task: PTask, status: TaskExecutionStatus
) -> None:
"""Start logging of task execution.

This hook can be used to provide more verbose output during the execution.
Expand Down
44 changes: 31 additions & 13 deletions src/_pytask/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any
from typing import Generator
Expand All @@ -24,6 +25,7 @@
from _pytask.pluginmanager import hookimpl

if TYPE_CHECKING:
from _pytask.logging_utils import TaskExecutionStatus
from _pytask.node_protocols import PTask
from _pytask.reports import CollectionReport
from _pytask.reports import ExecutionReport
Expand Down Expand Up @@ -129,6 +131,12 @@ def is_started(self) -> bool:
return self._live.is_started


@dataclass
class _TaskEntry:
task: PTask
status: TaskExecutionStatus


class _ReportEntry(NamedTuple):
name: str
outcome: TaskOutcome
Expand All @@ -146,7 +154,7 @@ class LiveExecution:
sort_final_table: bool = False
n_tasks: int | str = "x"
_reports: list[_ReportEntry] = field(factory=list)
_running_tasks: dict[str, PTask] = field(factory=dict)
_running_tasks: dict[str, _TaskEntry] = field(factory=dict)

@hookimpl(wrapper=True)
def pytask_execute_build(self) -> Generator[None, None, None]:
Expand All @@ -162,15 +170,17 @@ def pytask_execute_build(self) -> Generator[None, None, None]:
return result

@hookimpl(tryfirst=True)
def pytask_execute_task_log_start(self, task: PTask) -> bool:
def pytask_execute_task_log_start(
self, task: PTask, status: TaskExecutionStatus
) -> bool:
"""Mark a new task as running."""
self.update_running_tasks(task)
self.add_task(new_running_task=task, status=status)
return True

@hookimpl
def pytask_execute_task_log_end(self, report: ExecutionReport) -> bool:
"""Mark a task as being finished and update outcome."""
self.update_reports(report)
self.update_report(report)
return True

def _generate_table(
Expand Down Expand Up @@ -232,16 +242,17 @@ def _generate_table(
format_task_name(report.task, editor_url_scheme=self.editor_url_scheme),
Text(report.outcome.symbol, style=report.outcome.style),
)
for task in self._running_tasks.values():
for task_entry in self._running_tasks.values():
table.add_row(
format_task_name(task, editor_url_scheme=self.editor_url_scheme),
"running",
format_task_name(
task_entry.task, editor_url_scheme=self.editor_url_scheme
),
task_entry.status.value,
)

# If the table is empty, do not display anything.
if table.rows == []:
table = None

return None
return table

def _update_table(
Expand All @@ -256,14 +267,21 @@ def _update_table(
)
self.live_manager.update(table)

def update_running_tasks(self, new_running_task: PTask) -> None:
def add_task(self, new_running_task: PTask, status: TaskExecutionStatus) -> None:
"""Add a new running task."""
self._running_tasks[new_running_task.name] = new_running_task
self._running_tasks[new_running_task.signature] = _TaskEntry(
task=new_running_task, status=status
)
self._update_table()

def update_task(self, signature: str, status: TaskExecutionStatus) -> None:
"""Update the status of a running task."""
self._running_tasks[signature].status = status
self._update_table()

def update_reports(self, new_report: ExecutionReport) -> None:
def update_report(self, new_report: ExecutionReport) -> None:
"""Update the status of a running task by adding its report."""
self._running_tasks.pop(new_report.task.name)
self._running_tasks.pop(new_report.task.signature)
self._reports.append(
_ReportEntry(
name=new_report.task.name,
Expand Down
6 changes: 6 additions & 0 deletions src/_pytask/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class TaskExecutionStatus(Enum):
PENDING = "pending"
RUNNING = "running"
2 changes: 2 additions & 0 deletions src/pytask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from _pytask.build import build
from _pytask.capture_utils import CaptureMethod
from _pytask.capture_utils import ShowCapture
from _pytask.logging_utils import TaskExecutionStatus


from _pytask.click import ColoredCommand
Expand Down Expand Up @@ -129,6 +130,7 @@
"SkippedUnchanged",
"State",
"Task",
"TaskExecutionStatus",
"TaskOutcome",
"TaskWithoutPath",
"Traceback",
Expand Down
25 changes: 13 additions & 12 deletions tests/test_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from _pytask.live import LiveExecution
from _pytask.live import LiveManager
from _pytask.logging_utils import TaskExecutionStatus
from pytask import ExecutionReport
from pytask import ExitCode
from pytask import Task
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_live_execution_sequentially(capsys, tmp_path):
)

live_manager.start()
live.update_running_tasks(task)
live.add_task(task, status=TaskExecutionStatus.RUNNING)
live_manager.pause()

# Test pause removes the table.
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_live_execution_sequentially(capsys, tmp_path):
report = ExecutionReport(task=task, outcome=TaskOutcome.SUCCESS, exc_info=None)

live_manager.resume()
live.update_reports(report)
live.update_report(report)
live_manager.stop()

# Test final table with reported outcome.
Expand Down Expand Up @@ -99,13 +100,13 @@ def test_live_execution_displays_skips_and_persists(capsys, tmp_path, verbose, o
)

live_manager.start()
live.update_running_tasks(task)
live.add_task(task, status=TaskExecutionStatus.RUNNING)
live_manager.pause()

report = ExecutionReport(task=task, outcome=outcome, exc_info=None)

live_manager.resume()
live.update_reports(report)
live.update_report(report)
live_manager.stop()

# Test final table with reported outcome.
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_live_execution_displays_subset_of_table(capsys, tmp_path, n_entries_in_
)

live_manager.start()
live.update_running_tasks(running_task)
live.add_task(running_task, status=TaskExecutionStatus.RUNNING)
live_manager.stop(transient=False)

captured = capsys.readouterr()
Expand All @@ -161,13 +162,13 @@ def test_live_execution_displays_subset_of_table(capsys, tmp_path, n_entries_in_

completed_task = Task(base_name="task_completed", path=path, function=lambda x: x)
completed_task.name = "task_module.py::task_completed"
live.update_running_tasks(completed_task)
live.add_task(completed_task, status=TaskExecutionStatus.RUNNING)
report = ExecutionReport(
task=completed_task, outcome=TaskOutcome.SUCCESS, exc_info=None
)

live_manager.resume()
live.update_reports(report)
live.update_report(report)
live_manager.stop()

# Test that report is or is not included.
Expand Down Expand Up @@ -202,7 +203,7 @@ def test_live_execution_skips_do_not_crowd_out_displayed_tasks(capsys, tmp_path)
)

live_manager.start()
live.update_running_tasks(task)
live.add_task(task, status=TaskExecutionStatus.RUNNING)
live_manager.stop()

# Test table with running task.
Expand All @@ -224,9 +225,9 @@ def test_live_execution_skips_do_not_crowd_out_displayed_tasks(capsys, tmp_path)
tasks.append(skipped_task)

live_manager.start()
live.update_running_tasks(successful_task)
live.add_task(successful_task, status=TaskExecutionStatus.RUNNING)
for task in tasks:
live.update_running_tasks(task)
live.add_task(task, status=TaskExecutionStatus.RUNNING)
live_manager.stop()

captured = capsys.readouterr()
Expand All @@ -239,10 +240,10 @@ def test_live_execution_skips_do_not_crowd_out_displayed_tasks(capsys, tmp_path)
report = ExecutionReport(
task=successful_task, outcome=TaskOutcome.SUCCESS, exc_info=None
)
live.update_reports(report)
live.update_report(report)
for task in tasks:
report = ExecutionReport(task=task, outcome=TaskOutcome.SKIP, exc_info=None)
live.update_reports(report)
live.update_report(report)
live_manager.stop()

# Test final table with reported outcome.
Expand Down
Loading