6
6
import time
7
7
import warnings
8
8
from concurrent .futures import Future
9
+ from functools import partial
9
10
from pathlib import Path
10
11
from types import ModuleType
11
12
from types import TracebackType
@@ -296,23 +297,7 @@ def _execute_task( # noqa: PLR0913
296
297
exc_info , show_locals , console_options
297
298
)
298
299
else :
299
- if "return" in task .produces :
300
- structure_out = tree_structure (out )
301
- structure_return = tree_structure (task .produces ["return" ])
302
- # strict must be false when none is leaf.
303
- if not structure_return .is_prefix (structure_out , strict = False ):
304
- msg = (
305
- "The structure of the return annotation is not a subtree of "
306
- "the structure of the function return.\n \n Function return: "
307
- f"{ structure_out } \n \n Return annotation: { structure_return } "
308
- )
309
- raise ValueError (msg )
310
-
311
- nodes = tree_leaves (task .produces ["return" ])
312
- values = structure_return .flatten_up_to (out )
313
- for node , value in zip (nodes , values ):
314
- node .save (value )
315
-
300
+ _handle_task_function_return (task , out )
316
301
processed_exc_info = None
317
302
318
303
task_display_name = getattr (task , "display_name" , task .name )
@@ -347,6 +332,27 @@ def _process_exception(
347
332
return (* exc_info [:2 ], text )
348
333
349
334
335
+ def _handle_task_function_return (task : PTask , out : Any ) -> None :
336
+ if "return" not in task .produces :
337
+ return
338
+
339
+ structure_out = tree_structure (out )
340
+ structure_return = tree_structure (task .produces ["return" ])
341
+ # strict must be false when none is leaf.
342
+ if not structure_return .is_prefix (structure_out , strict = False ):
343
+ msg = (
344
+ "The structure of the return annotation is not a subtree of "
345
+ "the structure of the function return.\n \n Function return: "
346
+ f"{ structure_out } \n \n Return annotation: { structure_return } "
347
+ )
348
+ raise ValueError (msg )
349
+
350
+ nodes = tree_leaves (task .produces ["return" ])
351
+ values = structure_return .flatten_up_to (out )
352
+ for node , value in zip (nodes , values ):
353
+ node .save (value )
354
+
355
+
350
356
class DefaultBackendNameSpace :
351
357
"""The name space for hooks related to threads."""
352
358
@@ -362,13 +368,13 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
362
368
if session .config ["n_workers" ] > 1 :
363
369
kwargs = _create_kwargs_for_task (task )
364
370
return session .config ["_parallel_executor" ].submit (
365
- _mock_processes_for_threads , func = task . execute , ** kwargs
371
+ _mock_processes_for_threads , task = task , ** kwargs
366
372
)
367
373
return None
368
374
369
375
370
376
def _mock_processes_for_threads (
371
- func : Callable [..., Any ] , ** kwargs : Any
377
+ task : PTask , ** kwargs : Any
372
378
) -> tuple [
373
379
None , list [Any ], tuple [type [BaseException ], BaseException , TracebackType ] | None
374
380
]:
@@ -381,10 +387,11 @@ def _mock_processes_for_threads(
381
387
"""
382
388
__tracebackhide__ = True
383
389
try :
384
- func (** kwargs )
390
+ out = task . function (** kwargs )
385
391
except Exception : # noqa: BLE001
386
392
exc_info = sys .exc_info ()
387
393
else :
394
+ _handle_task_function_return (task , out )
388
395
exc_info = None
389
396
return None , [], exc_info
390
397
@@ -430,18 +437,17 @@ def sleep(self) -> None:
430
437
def _get_module (func : Callable [..., Any ], path : Path | None ) -> ModuleType :
431
438
"""Get the module of a python function.
432
439
433
- For Python <3.10, functools.partial does not set a `__module__` attribute which is
434
- why ``inspect.getmodule`` returns ``None`` and ``cloudpickle.pickle_by_value``
435
- fails. In later versions, ``functools`` is returned and everything seems to work
436
- fine.
440
+ ``functools.partial`` obfuscates the module of the function and
441
+ ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original
442
+ function.
437
443
438
- Therefore, we use the path from the task module to aid the search which works for
439
- Python <3.10.
440
-
441
- We do not unwrap the partialed function with ``func.func``, since pytask in general
442
- does not really support ``functools.partial``. Instead, use ``@task(kwargs=...)``.
444
+ We use the path from the task module to aid the search although it is not clear
445
+ whether it helps.
443
446
444
447
"""
448
+ if isinstance (func , partial ):
449
+ func = func .func
450
+
445
451
if path :
446
452
return inspect .getmodule (func , path .as_posix ())
447
453
return inspect .getmodule (func )
0 commit comments