12
12
from typing import List
13
13
14
14
import attr
15
- from pybaum . tree_util import tree_map
15
+ import cloudpickle
16
16
from pytask import console
17
17
from pytask import ExecutionReport
18
18
from pytask import get_marks
19
19
from pytask import hookimpl
20
20
from pytask import Mark
21
21
from pytask import parse_warning_filter
22
+ from pytask import PTask
22
23
from pytask import remove_internal_traceback_frames_from_exc_info
23
24
from pytask import Session
24
25
from pytask import Task
25
26
from pytask import warning_record_to_str
26
27
from pytask import WarningReport
28
+ from pytask .tree_util import PyTree
29
+ from pytask .tree_util import tree_leaves
30
+ from pytask .tree_util import tree_map
31
+ from pytask .tree_util import tree_structure
27
32
from pytask_parallel .backends import PARALLEL_BACKENDS
28
- from pytask_parallel .backends import ParallelBackendChoices
33
+ from pytask_parallel .backends import ParallelBackend
29
34
from rich .console import ConsoleOptions
30
35
from rich .traceback import Traceback
31
36
32
37
33
38
@hookimpl
34
39
def pytask_post_parse (config : dict [str , Any ]) -> None :
35
40
"""Register the parallel backend."""
36
- if config ["parallel_backend" ] == ParallelBackendChoices .THREADS :
41
+ if config ["parallel_backend" ] == ParallelBackend .THREADS :
37
42
config ["pm" ].register (DefaultBackendNameSpace )
38
43
else :
39
44
config ["pm" ].register (ProcessesNameSpace )
@@ -99,12 +104,19 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
99
104
for task_name in list (running_tasks ):
100
105
future = running_tasks [task_name ]
101
106
if future .done ():
102
- warning_reports , task_exception = future .result ()
103
- session .warnings .extend (warning_reports )
104
- exc_info = (
105
- _parse_future_exception (future .exception ())
106
- or task_exception
107
- )
107
+ # An exception was thrown before the task was executed.
108
+ if future .exception () is not None :
109
+ exc_info = _parse_future_exception (future .exception ())
110
+ warning_reports = []
111
+ # A task raised an exception.
112
+ else :
113
+ warning_reports , task_exception = future .result ()
114
+ session .warnings .extend (warning_reports )
115
+ exc_info = (
116
+ _parse_future_exception (future .exception ())
117
+ or task_exception
118
+ )
119
+
108
120
if exc_info is not None :
109
121
task = session .dag .nodes [task_name ]["task" ]
110
122
newly_collected_reports .append (
@@ -165,7 +177,7 @@ class ProcessesNameSpace:
165
177
166
178
@staticmethod
167
179
@hookimpl (tryfirst = True )
168
- def pytask_execute_task (session : Session , task : Task ) -> Future [Any ] | None :
180
+ def pytask_execute_task (session : Session , task : PTask ) -> Future [Any ] | None :
169
181
"""Execute a task.
170
182
171
183
Take a task, pickle it and send the bytes over to another process.
@@ -174,27 +186,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
174
186
if session .config ["n_workers" ] > 1 :
175
187
kwargs = _create_kwargs_for_task (task )
176
188
189
+ # Task modules are dynamically loaded and added to `sys.modules`. Thus,
190
+ # cloudpickle believes the module of the task function is also importable in
191
+ # the child process. We have to register the module as dynamic again, so
192
+ # that cloudpickle will pickle it with the function. See cloudpickle#417,
193
+ # pytask#373 and pytask#374.
194
+ task_module = inspect .getmodule (task .function )
195
+ cloudpickle .register_pickle_by_value (task_module )
196
+
177
197
return session .config ["_parallel_executor" ].submit (
178
- _unserialize_and_execute_task ,
198
+ _execute_task ,
179
199
task = task ,
180
200
kwargs = kwargs ,
181
201
show_locals = session .config ["show_locals" ],
182
202
console_options = console .options ,
183
203
session_filterwarnings = session .config ["filterwarnings" ],
184
204
task_filterwarnings = get_marks (task , "filterwarnings" ),
185
- task_short_name = task .short_name ,
186
205
)
187
206
return None
188
207
189
208
190
- def _unserialize_and_execute_task ( # noqa: PLR0913
191
- task : Task ,
209
+ def _execute_task ( # noqa: PLR0913
210
+ task : PTask ,
192
211
kwargs : dict [str , Any ],
193
212
show_locals : bool ,
194
213
console_options : ConsoleOptions ,
195
214
session_filterwarnings : tuple [str , ...],
196
215
task_filterwarnings : tuple [Mark , ...],
197
- task_short_name : str ,
198
216
) -> tuple [list [WarningReport ], tuple [type [BaseException ], BaseException , str ] | None ]:
199
217
"""Unserialize and execute task.
200
218
@@ -217,23 +235,41 @@ def _unserialize_and_execute_task( # noqa: PLR0913
217
235
warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
218
236
219
237
try :
220
- task .execute (** kwargs )
238
+ out = task .execute (** kwargs )
221
239
except Exception : # noqa: BLE001
222
240
exc_info = sys .exc_info ()
223
241
processed_exc_info = _process_exception (
224
242
exc_info , show_locals , console_options
225
243
)
226
244
else :
245
+ if "return" in task .produces :
246
+ structure_out = tree_structure (out )
247
+ structure_return = tree_structure (task .produces ["return" ])
248
+ # strict must be false when none is leaf.
249
+ if not structure_return .is_prefix (structure_out , strict = False ):
250
+ msg = (
251
+ "The structure of the return annotation is not a subtree of "
252
+ "the structure of the function return.\n \n Function return: "
253
+ f"{ structure_out } \n \n Return annotation: { structure_return } "
254
+ )
255
+ raise ValueError (msg )
256
+
257
+ nodes = tree_leaves (task .produces ["return" ])
258
+ values = structure_return .flatten_up_to (out )
259
+ for node , value in zip (nodes , values ):
260
+ node .save (value ) # type: ignore[attr-defined]
261
+
227
262
processed_exc_info = None
228
263
264
+ task_display_name = getattr (task , "display_name" , task .name )
229
265
warning_reports = []
230
266
for warning_message in log :
231
267
fs_location = warning_message .filename , warning_message .lineno
232
268
warning_reports .append (
233
269
WarningReport (
234
270
message = warning_record_to_str (warning_message ),
235
271
fs_location = fs_location ,
236
- id_ = task_short_name ,
272
+ id_ = task_display_name ,
237
273
)
238
274
)
239
275
@@ -293,15 +329,17 @@ def _mock_processes_for_threads(
293
329
return [], exc_info
294
330
295
331
296
- def _create_kwargs_for_task (task : Task ) -> dict [Any , Any ]:
332
+ def _create_kwargs_for_task (task : PTask ) -> dict [str , PyTree [ Any ] ]:
297
333
"""Create kwargs for task function."""
298
- kwargs = {** task .kwargs }
334
+ parameters = inspect .signature (task .function ).parameters
335
+
336
+ kwargs = {}
337
+ for name , value in task .depends_on .items ():
338
+ kwargs [name ] = tree_map (lambda x : x .load (), value )
299
339
300
- func_arg_names = set (inspect .signature (task .function ).parameters )
301
- for arg_name in ("depends_on" , "produces" ):
302
- if arg_name in func_arg_names :
303
- attribute = getattr (task , arg_name )
304
- kwargs [arg_name ] = tree_map (lambda x : x .value , attribute )
340
+ for name , value in task .produces .items ():
341
+ if name in parameters :
342
+ kwargs [name ] = tree_map (lambda x : x .load (), value )
305
343
306
344
return kwargs
307
345
0 commit comments