diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 59af53e3..586f4eca 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -3,7 +3,6 @@ import inspect from collections import defaultdict -from pathlib import Path from typing import Any from typing import Callable @@ -12,6 +11,8 @@ from _pytask.models import CollectionMetadata from _pytask.shared import find_duplicates from _pytask.typing import is_task_function +from pathlib import Path + __all__ = [ @@ -92,18 +93,11 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: ) raise ValueError(msg) - unwrapped = inspect.unwrap(func) - - raw_path = inspect.getfile(unwrapped) - if "<string>" in raw_path: - path = Path(unwrapped.__globals__["__file__"]).absolute().resolve() - else: - path = Path(raw_path).absolute().resolve() - parsed_kwargs = {} if kwargs is None else kwargs parsed_name = name if isinstance(name, str) else func.__name__ parsed_after = _parse_after(after) + unwrapped = inspect.unwrap(func) if hasattr(unwrapped, "pytask_meta"): unwrapped.pytask_meta.name = parsed_name unwrapped.pytask_meta.kwargs = parsed_kwargs @@ -123,7 +117,11 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: # Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage # collection when the function definition is overwritten in a loop. - COLLECTED_TASKS[path].append(unwrapped) + # Based on https://stackoverflow.com/questions/1095543/get-name-of-calling-functions-module-in-python # noqa: E501 + frm = inspect.stack()[1] + task_module = inspect.getmodule(frm.frame) + task_path = Path(task_module.__file__) + COLLECTED_TASKS[task_path].append(unwrapped) return unwrapped