diff --git a/docs/source/changes.md b/docs/source/changes.md index 25415bb1..554ff23f 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -14,6 +14,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`519` raises an error when builtin functions are wrapped with {func}`~pytask.task`. Closes {issue}`512`. - {pull}`522` improves the issue templates. +- {pull}`523` refactors `_pytask.console._get_file`. ## 0.4.4 - 2023-12-04 diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 786fbf63..4cfe156d 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -20,7 +20,6 @@ from _pytask.console import console from _pytask.console import create_summary_panel from _pytask.console import get_file -from _pytask.console import is_jupyter from _pytask.exceptions import CollectionError from _pytask.mark_utils import get_all_marks from _pytask.mark_utils import has_mark @@ -98,23 +97,7 @@ def _collect_from_tasks(session: Session) -> None: if not hasattr(raw_task, "pytask_meta"): raw_task = task_decorator()(raw_task) # noqa: PLW2901 - try: - path = get_file(raw_task) - except (TypeError, OSError): - path = None - else: - if path and path.name == "": - path = None # pragma: no cover - - # Detect whether a path is defined in a Jupyter notebook. - if ( - is_jupyter() - and path - and "ipykernel" in path.as_posix() - and path.suffix == ".py" - ): - path = None # pragma: no cover - + path = get_file(raw_task) name = raw_task.pytask_meta.name # When a task is not a callable, it can be anything or a PTask. Set arbitrary diff --git a/src/_pytask/console.py b/src/_pytask/console.py index b637b3ec..a0bf5b01 100644 --- a/src/_pytask/console.py +++ b/src/_pytask/console.py @@ -45,7 +45,6 @@ "format_strings_as_flat_tree", "format_task_name", "get_file", - "is_jupyter", "render_to_string", "unify_styles", ] @@ -200,7 +199,7 @@ def create_url_style_for_path(path: Path, edtior_url_scheme: str) -> Style: ) -def get_file( +def get_file( # noqa: PLR0911 function: Callable[..., Any], skipped_paths: list[Path] | None = None ) -> Path | None: """Get path to module where the function is defined. @@ -209,6 +208,11 @@ def get_file( a decorator which we need to skip to get to the underlying task function. Thus, the special case. + Raises + ------ + TypeError + If the object is a builtin module, class, or function. + """ if skipped_paths is None: skipped_paths = _SKIPPED_PATHS @@ -221,7 +225,19 @@ def get_file( return get_file(function.__wrapped__) source_file = inspect.getsourcefile(function) if source_file: - return Path(source_file) + # Handle functions defined in the REPL. + if "" in source_file: + return None + # Handle lambda functions. + if "" in source_file: + try: + return Path(function.__globals__["__file__"]).absolute().resolve() + except KeyError: + return None + # Handle functions defined in Jupyter notebooks. + if "ipykernel" in source_file or "ipython-input" in source_file: + return None + return Path(source_file).absolute().resolve() return None @@ -287,26 +303,3 @@ def create_summary_panel( if counts[outcome_enum.FAIL] else outcome_enum.SUCCESS.style, ) - - -def is_jupyter() -> bool: # pragma: no cover - """Check if we're running in a Jupyter notebook. - - Copied from rich. - - """ - try: - get_ipython # type: ignore[name-defined] # noqa: B018 - except NameError: - return False - ipython = get_ipython() # type: ignore[name-defined] # noqa: F821 - shell = ipython.__class__.__name__ - if ( - "google.colab" in str(ipython.__class__) - or os.getenv("DATABRICKS_RUNTIME_VERSION") - or shell == "ZMQInteractiveShell" - ): - return True # Jupyter notebook or qtconsole - if shell == "TerminalInteractiveShell": - return False # Terminal running IPython - return False # Other type (?) diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 45ec3848..8d3638fd 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -3,17 +3,21 @@ import inspect from collections import defaultdict -from pathlib import Path from types import BuiltinFunctionType from typing import Any from typing import Callable +from typing import TYPE_CHECKING import attrs +from _pytask.console import get_file from _pytask.mark import Mark from _pytask.models import CollectionMetadata from _pytask.shared import find_duplicates from _pytask.typing import is_task_function +if TYPE_CHECKING: + from pathlib import Path + __all__ = [ "COLLECTED_TASKS", @@ -23,7 +27,7 @@ ] -COLLECTED_TASKS: dict[Path, list[Callable[..., Any]]] = defaultdict(list) +COLLECTED_TASKS: dict[Path | None, list[Callable[..., Any]]] = defaultdict(list) """A container for collecting tasks. Tasks marked by the ``@pytask.mark.task`` decorator can be generated in a loop where one @@ -108,11 +112,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: ) raise NotImplementedError(msg) - raw_path = inspect.getfile(unwrapped) - if "" in raw_path: - path = Path(unwrapped.__globals__["__file__"]).absolute().resolve() - else: - path = Path(raw_path).absolute().resolve() + path = get_file(unwrapped) parsed_kwargs = {} if kwargs is None else kwargs parsed_name = name if isinstance(name, str) else func.__name__ diff --git a/tests/test_console.py b/tests/test_console.py index a0a09476..b2d4dd8d 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -186,14 +186,17 @@ def test_reduce_node_name(node, paths, expectation, expected): assert result == expected +exec("__unknown_lambda = lambda x: x") # noqa: S102 + + @pytest.mark.unit() @pytest.mark.parametrize( ("task_func", "skipped_paths", "expected"), [ - (task_func, [], _THIS_FILE), + (task_func, None, _THIS_FILE), ( empty_decorator(task_func), - [], + None, _THIS_FILE.parent.joinpath("_test_console_helpers.py"), ), ( @@ -201,6 +204,8 @@ def test_reduce_node_name(node, paths, expectation, expected): [_THIS_FILE.parent.joinpath("_test_console_helpers.py")], _THIS_FILE, ), + (lambda x: x, None, Path(__file__)), + (__unknown_lambda, None, Path(__file__)), # noqa: F821 ], ) def test_get_file(task_func, skipped_paths, expected):