diff --git a/docs/source/changes.md b/docs/source/changes.md index 0adcc30..de6b917 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -19,12 +19,13 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and or processes automatically. - {pull}`96` handles local paths with remote executors. `PathNode`s are not supported as dependencies or products (except for return annotations). -- {pull}`99` changes that all tasks that are ready are being scheduled. It improves - interactions with adaptive scaling. {issue}`98` does handle the resulting issues: no - strong adherence to priorities, no pending status. +- {pull}`99` changes that all ready tasks are being scheduled. It improves interactions + with adaptive scaling. {issue}`98` does handle the resulting issues: no strong + adherence to priorities, no pending status. - {pull}`100` adds project management with rye. - {pull}`101` adds syncing for local paths as dependencies or products in remote environments with the same OS. +- {pull}`102` implements a pending status for scheduled but not started tasks. - {pull}`106` fixes {pull}`99` such that only when there are coiled functions, all ready tasks are submitted. - {pull}`107` removes status from `pytask_execute_task_log_start` hook call. diff --git a/docs/source/coiled.md b/docs/source/coiled.md index 5463f7b..f53912e 100644 --- a/docs/source/coiled.md +++ b/docs/source/coiled.md @@ -1,7 +1,7 @@ # coiled ```{caution} -Currently, the coiled backend can only be used if your workflow code is organized in a +Currently, the coiled backend can only be used if your workflow code is organized as a package due to how pytask imports your code and dask serializes task functions ([issue](https://github.com/dask/distributed/issues/8607)). ``` diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 1a6b61b..765b17e 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -4,8 +4,11 @@ import sys import time +from contextlib import ExitStack +from multiprocessing import Manager from typing import TYPE_CHECKING from typing import Any +from typing import Callable import cloudpickle from _pytask.node_protocols import PPathNode @@ -16,6 +19,7 @@ from pytask import PTask from pytask import PythonNode from pytask import Session +from pytask import TaskExecutionStatus from pytask import console from pytask import get_marks from pytask import hookimpl @@ -23,6 +27,7 @@ from pytask.tree_util import tree_map from pytask.tree_util import tree_structure +from pytask_parallel.backends import ParallelBackend from pytask_parallel.backends import WorkerType from pytask_parallel.backends import registry from pytask_parallel.typing import CarryOverPath @@ -33,6 +38,7 @@ if TYPE_CHECKING: from concurrent.futures import Future + from multiprocessing.managers import SyncManager from pytask_parallel.wrappers import WrapperResult @@ -52,6 +58,23 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 __tracebackhide__ = True reports = session.execution_reports running_tasks: dict[str, Future[Any]] = {} + sleeper = _Sleeper() + + # Create a shared memory object to differentiate between running and pending + # tasks for some parallel backends. + if session.config["parallel_backend"] in ( + ParallelBackend.PROCESSES, + ParallelBackend.THREADS, + ParallelBackend.LOKY, + ): + manager_cls: Callable[[], SyncManager] | type[ExitStack] = Manager + start_execution_state = TaskExecutionStatus.PENDING + else: + manager_cls = ExitStack + start_execution_state = TaskExecutionStatus.RUNNING + + # Get the live execution manager from the registry if it exists. + live_execution = session.config["pm"].get_plugin("live_execution") any_coiled_task = any(is_coiled_function(task) for task in session.tasks) # The executor can only be created after the collection to give users the @@ -59,9 +82,13 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 session.config["_parallel_executor"] = registry.get_parallel_backend( session.config["parallel_backend"], n_workers=session.config["n_workers"] ) - - with session.config["_parallel_executor"]: - sleeper = _Sleeper() + with session.config["_parallel_executor"], manager_cls() as manager: + if session.config["parallel_backend"] in ( + ParallelBackend.PROCESSES, + ParallelBackend.THREADS, + ParallelBackend.LOKY, + ): + session.config["_shared_memory"] = manager.dict() # type: ignore[union-attr] i = 0 while session.scheduler.is_active(): @@ -88,17 +115,17 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 else [] ) - for task_name in ready_tasks: - task = session.dag.nodes[task_name]["task"] + for task_signature in ready_tasks: + task = session.dag.nodes[task_signature]["task"] session.hook.pytask_execute_task_log_start( - session=session, task=task + session=session, task=task, status=start_execution_state ) try: session.hook.pytask_execute_task_setup( session=session, task=task ) - running_tasks[task_name] = session.hook.pytask_execute_task( - session=session, task=task + running_tasks[task_signature] = ( + session.hook.pytask_execute_task(session=session, task=task) ) sleeper.reset() except Exception: # noqa: BLE001 @@ -106,13 +133,13 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 task, sys.exc_info() ) newly_collected_reports.append(report) - session.scheduler.done(task_name) + session.scheduler.done(task_signature) if not ready_tasks: sleeper.increment() - for task_name in list(running_tasks): - future = running_tasks[task_name] + for task_signature in list(running_tasks): + future = running_tasks[task_signature] if future.done(): wrapper_result = parse_future_result(future) @@ -128,17 +155,17 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 ) if wrapper_result.exc_info is not None: - task = session.dag.nodes[task_name]["task"] + task = session.dag.nodes[task_signature]["task"] newly_collected_reports.append( ExecutionReport.from_task_and_exception( task, wrapper_result.exc_info, # type: ignore[arg-type] ) ) - running_tasks.pop(task_name) - session.scheduler.done(task_name) + running_tasks.pop(task_signature) + session.scheduler.done(task_signature) else: - task = session.dag.nodes[task_name]["task"] + task = session.dag.nodes[task_signature]["task"] _update_carry_over_products( task, wrapper_result.carry_over_products ) @@ -154,9 +181,15 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 else: report = ExecutionReport.from_task(task) - running_tasks.pop(task_name) + running_tasks.pop(task_signature) newly_collected_reports.append(report) - session.scheduler.done(task_name) + session.scheduler.done(task_signature) + + elif live_execution and "_shared_memory" in session.config: + if task_signature in session.config["_shared_memory"]: + live_execution.update_task( + task_signature, status=TaskExecutionStatus.RUNNING + ) for report in newly_collected_reports: session.hook.pytask_execute_task_process_report( @@ -236,6 +269,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: kwargs=kwargs, remote=remote, session_filterwarnings=session.config["filterwarnings"], + shared_memory=session.config.get("_shared_memory"), show_locals=session.config["show_locals"], task_filterwarnings=get_marks(task, "filterwarnings"), ) @@ -244,7 +278,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: from pytask_parallel.wrappers import wrap_task_in_thread return session.config["_parallel_executor"].submit( - wrap_task_in_thread, task=task, remote=False, **kwargs + wrap_task_in_thread, + task=task, + remote=False, + shared_memory=session.config.get("_shared_memory"), + **kwargs, ) msg = f"Unknown worker type {worker_type}" raise ValueError(msg) diff --git a/src/pytask_parallel/wrappers.py b/src/pytask_parallel/wrappers.py index 6157873..ad8c2ed 100644 --- a/src/pytask_parallel/wrappers.py +++ b/src/pytask_parallel/wrappers.py @@ -55,7 +55,9 @@ class WrapperResult: stderr: str -def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperResult: +def wrap_task_in_thread( + task: PTask, *, remote: bool, shared_memory: dict[str, bool] | None, **kwargs: Any +) -> WrapperResult: """Mock execution function such that it returns the same as for processes. The function for processes returns ``warning_reports`` and an ``exception``. With @@ -64,6 +66,11 @@ def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperR """ __tracebackhide__ = True + + # Add task to shared memory to indicate that it is currently being executed. + if shared_memory is not None: + shared_memory[task.signature] = True + try: out = task.function(**kwargs) except Exception: # noqa: BLE001 @@ -71,6 +78,11 @@ def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperR else: _handle_function_products(task, out, remote=remote) exc_info = None # type: ignore[assignment] + + # Remove task from shared memory to indicate that it is no longer being executed. + if shared_memory is not None: + shared_memory.pop(task.signature) + return WrapperResult( carry_over_products=None, # type: ignore[arg-type] warning_reports=[], @@ -87,6 +99,7 @@ def wrap_task_in_process( # noqa: PLR0913 kwargs: dict[str, Any], remote: bool, session_filterwarnings: tuple[str, ...], + shared_memory: dict[str, bool] | None, show_locals: bool, task_filterwarnings: tuple[Mark, ...], ) -> WrapperResult: @@ -99,6 +112,10 @@ def wrap_task_in_process( # noqa: PLR0913 # Hide this function from tracebacks. __tracebackhide__ = True + # Add task to shared memory to indicate that it is currently being executed. + if shared_memory is not None: + shared_memory[task.signature] = True + # Patch set_trace and breakpoint to show a better error message. _patch_set_trace_and_breakpoint() @@ -156,6 +173,10 @@ def wrap_task_in_process( # noqa: PLR0913 captured_stdout_buffer.close() captured_stderr_buffer.close() + # Remove task from shared memory to indicate that it is no longer being executed. + if shared_memory is not None: + shared_memory.pop(task.signature) + return WrapperResult( carry_over_products=products, # type: ignore[arg-type] warning_reports=warning_reports,