Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a pending status for tasks. #104

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
or processes automatically.
- {pull}`96` handles local paths with remote executors. `PathNode`s are not supported as
dependencies or products (except for return annotations).
- {pull}`99` changes that all tasks that are ready are being scheduled. It improves
interactions with adaptive scaling. {issue}`98` does handle the resulting issues: no
strong adherence to priorities, no pending status.
- {pull}`99` changes that all ready tasks are being scheduled. It improves interactions
with adaptive scaling. {issue}`98` does handle the resulting issues: no strong
adherence to priorities, no pending status.
- {pull}`100` adds project management with rye.
- {pull}`101` adds syncing for local paths as dependencies or products in remote
environments with the same OS.
- {pull}`102` implements a pending status for scheduled but not started tasks.
- {pull}`106` fixes {pull}`99` such that only when there are coiled functions, all ready
tasks are submitted.
- {pull}`107` removes status from `pytask_execute_task_log_start` hook call.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/coiled.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coiled

```{caution}
Currently, the coiled backend can only be used if your workflow code is organized in a
Currently, the coiled backend can only be used if your workflow code is organized as a
package due to how pytask imports your code and dask serializes task functions
([issue](https://github.com/dask/distributed/issues/8607)).
```
Expand Down
74 changes: 56 additions & 18 deletions src/pytask_parallel/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import sys
import time
from contextlib import ExitStack
from multiprocessing import Manager
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import cloudpickle
from _pytask.node_protocols import PPathNode
Expand All @@ -16,13 +19,15 @@
from pytask import PTask
from pytask import PythonNode
from pytask import Session
from pytask import TaskExecutionStatus
from pytask import console
from pytask import get_marks
from pytask import hookimpl
from pytask.tree_util import PyTree
from pytask.tree_util import tree_map
from pytask.tree_util import tree_structure

from pytask_parallel.backends import ParallelBackend
from pytask_parallel.backends import WorkerType
from pytask_parallel.backends import registry
from pytask_parallel.typing import CarryOverPath
Expand All @@ -33,6 +38,7 @@

if TYPE_CHECKING:
from concurrent.futures import Future
from multiprocessing.managers import SyncManager

from pytask_parallel.wrappers import WrapperResult

Expand All @@ -52,16 +58,37 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
__tracebackhide__ = True
reports = session.execution_reports
running_tasks: dict[str, Future[Any]] = {}
sleeper = _Sleeper()

# Create a shared memory object to differentiate between running and pending
# tasks for some parallel backends.
if session.config["parallel_backend"] in (
ParallelBackend.PROCESSES,
ParallelBackend.THREADS,
ParallelBackend.LOKY,
):
manager_cls: Callable[[], SyncManager] | type[ExitStack] = Manager
start_execution_state = TaskExecutionStatus.PENDING
else:
manager_cls = ExitStack
start_execution_state = TaskExecutionStatus.RUNNING

# Get the live execution manager from the registry if it exists.
live_execution = session.config["pm"].get_plugin("live_execution")
any_coiled_task = any(is_coiled_function(task) for task in session.tasks)

# The executor can only be created after the collection to give users the
# possibility to inject their own executors.
session.config["_parallel_executor"] = registry.get_parallel_backend(
session.config["parallel_backend"], n_workers=session.config["n_workers"]
)

with session.config["_parallel_executor"]:
sleeper = _Sleeper()
with session.config["_parallel_executor"], manager_cls() as manager:
if session.config["parallel_backend"] in (
ParallelBackend.PROCESSES,
ParallelBackend.THREADS,
ParallelBackend.LOKY,
):
session.config["_shared_memory"] = manager.dict() # type: ignore[union-attr]

i = 0
while session.scheduler.is_active():
Expand All @@ -88,31 +115,31 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
else []
)

for task_name in ready_tasks:
task = session.dag.nodes[task_name]["task"]
for task_signature in ready_tasks:
task = session.dag.nodes[task_signature]["task"]
session.hook.pytask_execute_task_log_start(
session=session, task=task
session=session, task=task, status=start_execution_state
)
try:
session.hook.pytask_execute_task_setup(
session=session, task=task
)
running_tasks[task_name] = session.hook.pytask_execute_task(
session=session, task=task
running_tasks[task_signature] = (
session.hook.pytask_execute_task(session=session, task=task)
)
sleeper.reset()
except Exception: # noqa: BLE001
report = ExecutionReport.from_task_and_exception(
task, sys.exc_info()
)
newly_collected_reports.append(report)
session.scheduler.done(task_name)
session.scheduler.done(task_signature)

if not ready_tasks:
sleeper.increment()

for task_name in list(running_tasks):
future = running_tasks[task_name]
for task_signature in list(running_tasks):
future = running_tasks[task_signature]

if future.done():
wrapper_result = parse_future_result(future)
Expand All @@ -128,17 +155,17 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
)

if wrapper_result.exc_info is not None:
task = session.dag.nodes[task_name]["task"]
task = session.dag.nodes[task_signature]["task"]
newly_collected_reports.append(
ExecutionReport.from_task_and_exception(
task,
wrapper_result.exc_info, # type: ignore[arg-type]
)
)
running_tasks.pop(task_name)
session.scheduler.done(task_name)
running_tasks.pop(task_signature)
session.scheduler.done(task_signature)
else:
task = session.dag.nodes[task_name]["task"]
task = session.dag.nodes[task_signature]["task"]
_update_carry_over_products(
task, wrapper_result.carry_over_products
)
Expand All @@ -154,9 +181,15 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
else:
report = ExecutionReport.from_task(task)

running_tasks.pop(task_name)
running_tasks.pop(task_signature)
newly_collected_reports.append(report)
session.scheduler.done(task_name)
session.scheduler.done(task_signature)

elif live_execution and "_shared_memory" in session.config:
if task_signature in session.config["_shared_memory"]:
live_execution.update_task(
task_signature, status=TaskExecutionStatus.RUNNING
)

for report in newly_collected_reports:
session.hook.pytask_execute_task_process_report(
Expand Down Expand Up @@ -236,6 +269,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
kwargs=kwargs,
remote=remote,
session_filterwarnings=session.config["filterwarnings"],
shared_memory=session.config.get("_shared_memory"),
show_locals=session.config["show_locals"],
task_filterwarnings=get_marks(task, "filterwarnings"),
)
Expand All @@ -244,7 +278,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
from pytask_parallel.wrappers import wrap_task_in_thread

return session.config["_parallel_executor"].submit(
wrap_task_in_thread, task=task, remote=False, **kwargs
wrap_task_in_thread,
task=task,
remote=False,
shared_memory=session.config.get("_shared_memory"),
**kwargs,
)
msg = f"Unknown worker type {worker_type}"
raise ValueError(msg)
Expand Down
23 changes: 22 additions & 1 deletion src/pytask_parallel/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class WrapperResult:
stderr: str


def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperResult:
def wrap_task_in_thread(
task: PTask, *, remote: bool, shared_memory: dict[str, bool] | None, **kwargs: Any
) -> WrapperResult:
"""Mock execution function such that it returns the same as for processes.

The function for processes returns ``warning_reports`` and an ``exception``. With
Expand All @@ -64,13 +66,23 @@ def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperR

"""
__tracebackhide__ = True

# Add task to shared memory to indicate that it is currently being executed.
if shared_memory is not None:
shared_memory[task.signature] = True

try:
out = task.function(**kwargs)
except Exception: # noqa: BLE001
exc_info = sys.exc_info()
else:
_handle_function_products(task, out, remote=remote)
exc_info = None # type: ignore[assignment]

# Remove task from shared memory to indicate that it is no longer being executed.
if shared_memory is not None:
shared_memory.pop(task.signature)

return WrapperResult(
carry_over_products=None, # type: ignore[arg-type]
warning_reports=[],
Expand All @@ -87,6 +99,7 @@ def wrap_task_in_process( # noqa: PLR0913
kwargs: dict[str, Any],
remote: bool,
session_filterwarnings: tuple[str, ...],
shared_memory: dict[str, bool] | None,
show_locals: bool,
task_filterwarnings: tuple[Mark, ...],
) -> WrapperResult:
Expand All @@ -99,6 +112,10 @@ def wrap_task_in_process( # noqa: PLR0913
# Hide this function from tracebacks.
__tracebackhide__ = True

# Add task to shared memory to indicate that it is currently being executed.
if shared_memory is not None:
shared_memory[task.signature] = True

# Patch set_trace and breakpoint to show a better error message.
_patch_set_trace_and_breakpoint()

Expand Down Expand Up @@ -156,6 +173,10 @@ def wrap_task_in_process( # noqa: PLR0913
captured_stdout_buffer.close()
captured_stderr_buffer.close()

# Remove task from shared memory to indicate that it is no longer being executed.
if shared_memory is not None:
shared_memory.pop(task.signature)

return WrapperResult(
carry_over_products=products, # type: ignore[arg-type]
warning_reports=warning_reports,
Expand Down
Loading