Skip to content

Commit a06d948

Browse files
authored
Catch objects pretending to be PTask. (#508)
1 parent d846c5c commit a06d948

File tree

6 files changed

+44
-17
lines changed

6 files changed

+44
-17
lines changed

docs/source/changes.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask).
77

8-
## 0.4.3 - 2023-11-xx
8+
## 0.4.3 - 2023-12-01
99

1010
- {pull}`483` simplifies the teardown of a task.
1111
- {pull}`484` raises more informative error when directories instead of files are used
@@ -26,6 +26,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
2626
- {pull}`498` fixes an error when using {class}`~pytask.Task` and
2727
{class}`~pytask.TaskWithoutPath` in task modules.
2828
- {pull}`500` refactors the dependencies for tests.
29+
- {pull}`501` removes `MetaNode`.
30+
- {pull}`508` catches objects that pretend to be a {class}`~pytask.PTask`. Fixes
31+
{issue}`507`.
2932

3033
## 0.4.2 - 2023-11-08
3134

src/_pytask/collect.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from _pytask.console import get_file
2323
from _pytask.console import is_jupyter
2424
from _pytask.exceptions import CollectionError
25-
from _pytask.mark import MarkGenerator
2625
from _pytask.mark_utils import get_all_marks
2726
from _pytask.mark_utils import has_mark
2827
from _pytask.node_protocols import PNode
@@ -176,10 +175,7 @@ def pytask_collect_file(
176175

177176
collected_reports = []
178177
for name, obj in inspect.getmembers(mod):
179-
# Skip mark generator since it overrides __getattr__ and seems like any
180-
# object. Happens when people do ``from pytask import mark`` and
181-
# ``@mark.x``.
182-
if isinstance(obj, MarkGenerator):
178+
if _is_filtered_object(obj):
183179
continue
184180

185181
# Ensures that tasks with this decorator are only collected once.
@@ -196,6 +192,26 @@ def pytask_collect_file(
196192
return None
197193

198194

195+
def _is_filtered_object(obj: Any) -> bool:
196+
"""Filter some objects that are only causing harm later on.
197+
198+
See :issue:`507`.
199+
200+
"""
201+
# Filter :class:`pytask.Task` and :class:`pytask.TaskWithoutPath` objects.
202+
if isinstance(obj, PTask) and inspect.isclass(obj):
203+
return True
204+
205+
# Filter objects overwriting the ``__getattr__`` method like :class:`pytask.mark` or
206+
# ``from ibis import _``.
207+
attr_name = "attr_that_definitely_does_not_exist"
208+
if hasattr(obj, attr_name) and not bool(
209+
inspect.getattr_static(obj, attr_name, False)
210+
):
211+
return True
212+
return False
213+
214+
199215
@hookimpl
200216
def pytask_collect_task_protocol(
201217
session: Session, path: Path | None, name: str, obj: Any
@@ -279,7 +295,7 @@ def pytask_collect_task(
279295
markers=markers,
280296
attributes={"collection_id": collection_id, "after": after},
281297
)
282-
if isinstance(obj, PTask) and not inspect.isclass(obj):
298+
if isinstance(obj, PTask):
283299
return obj
284300
return None
285301

src/_pytask/mark_utils.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66
from __future__ import annotations
77

8-
import inspect
98
from typing import Any
109
from typing import TYPE_CHECKING
1110

@@ -19,12 +18,10 @@
1918

2019
def get_all_marks(obj_or_task: Any | PTask) -> list[Mark]:
2120
"""Get all marks from a callable or task."""
22-
if isinstance(obj_or_task, PTask) and not inspect.isclass(obj_or_task):
23-
marks = obj_or_task.markers
24-
else:
25-
obj = obj_or_task
26-
marks = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
27-
return marks
21+
if isinstance(obj_or_task, PTask):
22+
return obj_or_task.markers
23+
obj = obj_or_task
24+
return obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
2825

2926

3027
def set_marks(obj_or_task: Any | PTask, marks: list[Mark]) -> Any | PTask:

src/_pytask/nodes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from _pytask.mark import Mark
2828

2929

30-
__all__ = ["PathNode", "PythonNode", "Task", "TaskWithoutPath"]
30+
__all__ = ["PathNode", "PickleNode", "PythonNode", "Task", "TaskWithoutPath"]
3131

3232

3333
@define(kw_only=True)

src/_pytask/typing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from typing_extensions import TypeAlias
1414

1515

16-
__all__ = ["Product", "ProductType"]
16+
__all__ = [
17+
"NoDefault",
18+
"Product",
19+
"ProductType",
20+
"is_task_function",
21+
"no_default",
22+
]
1723

1824

1925
@define(frozen=True)

tests/test_collect.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,12 @@ def task_mixed(): pass
684684
@pytest.mark.end_to_end()
685685
def test_module_can_be_collected(runner, tmp_path):
686686
source = """
687-
from pytask import Task, TaskWithoutPath
687+
from pytask import Task, TaskWithoutPath, mark
688+
689+
class C:
690+
def __getattr__(self, name):
691+
return C()
692+
c = C()
688693
"""
689694
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
690695

0 commit comments

Comments
 (0)