1212from typing import List
1313
1414import attr
15- from pybaum . tree_util import tree_map
15+ import cloudpickle
1616from pytask import console
1717from pytask import ExecutionReport
1818from pytask import get_marks
1919from pytask import hookimpl
2020from pytask import Mark
2121from pytask import parse_warning_filter
22+ from pytask import PTask
2223from pytask import remove_internal_traceback_frames_from_exc_info
2324from pytask import Session
2425from pytask import Task
2526from pytask import warning_record_to_str
2627from pytask import WarningReport
28+ from pytask .tree_util import PyTree
29+ from pytask .tree_util import tree_leaves
30+ from pytask .tree_util import tree_map
31+ from pytask .tree_util import tree_structure
2732from pytask_parallel .backends import PARALLEL_BACKENDS
28- from pytask_parallel .backends import ParallelBackendChoices
33+ from pytask_parallel .backends import ParallelBackend
2934from rich .console import ConsoleOptions
3035from rich .traceback import Traceback
3136
3237
3338@hookimpl
3439def pytask_post_parse (config : dict [str , Any ]) -> None :
3540 """Register the parallel backend."""
36- if config ["parallel_backend" ] == ParallelBackendChoices .THREADS :
41+ if config ["parallel_backend" ] == ParallelBackend .THREADS :
3742 config ["pm" ].register (DefaultBackendNameSpace )
3843 else :
3944 config ["pm" ].register (ProcessesNameSpace )
@@ -99,12 +104,19 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
99104 for task_name in list (running_tasks ):
100105 future = running_tasks [task_name ]
101106 if future .done ():
102- warning_reports , task_exception = future .result ()
103- session .warnings .extend (warning_reports )
104- exc_info = (
105- _parse_future_exception (future .exception ())
106- or task_exception
107- )
107+ # An exception was thrown before the task was executed.
108+ if future .exception () is not None :
109+ exc_info = _parse_future_exception (future .exception ())
110+ warning_reports = []
111+ # A task raised an exception.
112+ else :
113+ warning_reports , task_exception = future .result ()
114+ session .warnings .extend (warning_reports )
115+ exc_info = (
116+ _parse_future_exception (future .exception ())
117+ or task_exception
118+ )
119+
108120 if exc_info is not None :
109121 task = session .dag .nodes [task_name ]["task" ]
110122 newly_collected_reports .append (
@@ -165,7 +177,7 @@ class ProcessesNameSpace:
165177
166178 @staticmethod
167179 @hookimpl (tryfirst = True )
168- def pytask_execute_task (session : Session , task : Task ) -> Future [Any ] | None :
180+ def pytask_execute_task (session : Session , task : PTask ) -> Future [Any ] | None :
169181 """Execute a task.
170182
171183 Take a task, pickle it and send the bytes over to another process.
@@ -174,27 +186,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
174186 if session .config ["n_workers" ] > 1 :
175187 kwargs = _create_kwargs_for_task (task )
176188
189+ # Task modules are dynamically loaded and added to `sys.modules`. Thus,
190+ # cloudpickle believes the module of the task function is also importable in
191+ # the child process. We have to register the module as dynamic again, so
192+ # that cloudpickle will pickle it with the function. See cloudpickle#417,
193+ # pytask#373 and pytask#374.
194+ task_module = inspect .getmodule (task .function )
195+ cloudpickle .register_pickle_by_value (task_module )
196+
177197 return session .config ["_parallel_executor" ].submit (
178- _unserialize_and_execute_task ,
198+ _execute_task ,
179199 task = task ,
180200 kwargs = kwargs ,
181201 show_locals = session .config ["show_locals" ],
182202 console_options = console .options ,
183203 session_filterwarnings = session .config ["filterwarnings" ],
184204 task_filterwarnings = get_marks (task , "filterwarnings" ),
185- task_short_name = task .short_name ,
186205 )
187206 return None
188207
189208
190- def _unserialize_and_execute_task ( # noqa: PLR0913
191- task : Task ,
209+ def _execute_task ( # noqa: PLR0913
210+ task : PTask ,
192211 kwargs : dict [str , Any ],
193212 show_locals : bool ,
194213 console_options : ConsoleOptions ,
195214 session_filterwarnings : tuple [str , ...],
196215 task_filterwarnings : tuple [Mark , ...],
197- task_short_name : str ,
198216) -> tuple [list [WarningReport ], tuple [type [BaseException ], BaseException , str ] | None ]:
199217 """Unserialize and execute task.
200218
@@ -217,23 +235,41 @@ def _unserialize_and_execute_task( # noqa: PLR0913
217235 warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
218236
219237 try :
220- task .execute (** kwargs )
238+ out = task .execute (** kwargs )
221239 except Exception : # noqa: BLE001
222240 exc_info = sys .exc_info ()
223241 processed_exc_info = _process_exception (
224242 exc_info , show_locals , console_options
225243 )
226244 else :
245+ if "return" in task .produces :
246+ structure_out = tree_structure (out )
247+ structure_return = tree_structure (task .produces ["return" ])
248+ # strict must be false when none is leaf.
249+ if not structure_return .is_prefix (structure_out , strict = False ):
250+ msg = (
251+ "The structure of the return annotation is not a subtree of "
252+ "the structure of the function return.\n \n Function return: "
253+ f"{ structure_out } \n \n Return annotation: { structure_return } "
254+ )
255+ raise ValueError (msg )
256+
257+ nodes = tree_leaves (task .produces ["return" ])
258+ values = structure_return .flatten_up_to (out )
259+ for node , value in zip (nodes , values ):
260+ node .save (value ) # type: ignore[attr-defined]
261+
227262 processed_exc_info = None
228263
264+ task_display_name = getattr (task , "display_name" , task .name )
229265 warning_reports = []
230266 for warning_message in log :
231267 fs_location = warning_message .filename , warning_message .lineno
232268 warning_reports .append (
233269 WarningReport (
234270 message = warning_record_to_str (warning_message ),
235271 fs_location = fs_location ,
236- id_ = task_short_name ,
272+ id_ = task_display_name ,
237273 )
238274 )
239275
@@ -293,15 +329,17 @@ def _mock_processes_for_threads(
293329 return [], exc_info
294330
295331
296- def _create_kwargs_for_task (task : Task ) -> dict [Any , Any ]:
332+ def _create_kwargs_for_task (task : PTask ) -> dict [str , PyTree [ Any ] ]:
297333 """Create kwargs for task function."""
298- kwargs = {** task .kwargs }
334+ parameters = inspect .signature (task .function ).parameters
335+
336+ kwargs = {}
337+ for name , value in task .depends_on .items ():
338+ kwargs [name ] = tree_map (lambda x : x .load (), value )
299339
300- func_arg_names = set (inspect .signature (task .function ).parameters )
301- for arg_name in ("depends_on" , "produces" ):
302- if arg_name in func_arg_names :
303- attribute = getattr (task , arg_name )
304- kwargs [arg_name ] = tree_map (lambda x : x .value , attribute )
340+ for name , value in task .produces .items ():
341+ if name in parameters :
342+ kwargs [name ] = tree_map (lambda x : x .load (), value )
305343
306344 return kwargs
307345
0 commit comments