Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introducing FunctionDefinition as actual Node/Collector #12487

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 5 additions & 1 deletion src/_pytest/_code/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@

@property
def lineno(self) -> int:
return self._rawentry.tb_lineno - 1
if self._rawentry.tb_lineno is None:
# how did i trigger this 😱
return -1 # type: ignore[unreachable]

Check warning on line 215 in src/_pytest/_code/code.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/_code/code.py#L215

Added line #L215 was not covered by tests
else:
return self._rawentry.tb_lineno - 1

@property
def frame(self) -> Frame:
Expand Down
20 changes: 10 additions & 10 deletions src/_pytest/compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
"""Python version compatibility code."""

from __future__ import annotations
Expand All @@ -12,8 +11,11 @@
import os
from pathlib import Path
import sys
from types import FunctionType
from types import MethodType
from typing import Any
from typing import Callable
from typing import cast
from typing import Final
from typing import NoReturn

Expand Down Expand Up @@ -66,7 +68,8 @@ def is_async_function(func: object) -> bool:
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)


def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
def getlocation(function: Any, curdir: str | os.PathLike[str] | None = None) -> str:
# todo: declare a type alias for function, fixturefunction and callables/generators
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
Expand All @@ -80,7 +83,7 @@ def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
return "%s:%d" % (fn, lineno + 1)


def num_mock_patch_args(function) -> int:
def num_mock_patch_args(function: Callable[..., object]) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
Expand Down Expand Up @@ -222,7 +225,7 @@ class _PytestWrapper:
obj: Any


def get_real_func(obj):
def get_real_func(obj: Any) -> Any:
"""Get the real function object of the (possibly) wrapped object by
functools.wraps or functools.partial."""
start_obj = obj
Expand All @@ -249,7 +252,7 @@ def get_real_func(obj):
return obj


def get_real_method(obj, holder):
def get_real_method(obj: Any, holder: object) -> Any:
"""Attempt to obtain the real function object that might be wrapping
``obj``, while at the same time returning a bound method to ``holder`` if
the original object was a bound method."""
Expand All @@ -263,11 +266,8 @@ def get_real_method(obj, holder):
return obj


def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def getimfunc(func: FunctionType | MethodType | Callable[..., Any]) -> FunctionType:
return cast(FunctionType, getattr(func, "__func__", func))


def safe_getattr(object: Any, name: str, default: Any) -> Any:
Expand Down
7 changes: 4 additions & 3 deletions src/_pytest/doctest.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,10 @@ def _get_continue_on_failure(config: Config) -> bool:


class DoctestTextfile(Module):
obj = None
# todo: this shouldnt be a module
obj: None = None # type: ignore[assignment]

def collect(self) -> Iterable[DoctestItem]:
def collect(self) -> Iterable[DoctestItem]: # type: ignore[override]
import doctest

# Inspired by doctest.testfile; ideally we would use it directly,
Expand Down Expand Up @@ -497,7 +498,7 @@ def _mock_aware_unwrap(


class DoctestModule(Module):
def collect(self) -> Iterable[DoctestItem]:
def collect(self) -> Iterable[DoctestItem]: # type: ignore[override]
import doctest

class MockAwareDocTestFinder(doctest.DocTestFinder):
Expand Down
7 changes: 4 additions & 3 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,8 @@ def resolve_fixture_function(
) -> _FixtureFunc[FixtureValue]:
"""Get the actual callable that can be called to obtain the fixture
value."""
fixturefunc = fixturedef.func
# absuing any for the differences between FunctionTpye and Callable
fixturefunc: Any = fixturedef.func
# The fixture function needs to be bound to the actual
# request.instance so that code working with "fixturedef" behaves
# as expected.
Expand All @@ -1112,11 +1113,11 @@ def resolve_fixture_function(
instance,
fixturefunc.__self__.__class__,
):
return fixturefunc
return cast(_FixtureFunc[FixtureValue], fixturefunc)
fixturefunc = getimfunc(fixturedef.func)
if fixturefunc != fixturedef.func:
fixturefunc = fixturefunc.__get__(instance)
return fixturefunc
return cast(_FixtureFunc[FixtureValue], fixturefunc)


def pytest_fixture_setup(
Expand Down
1 change: 1 addition & 0 deletions src/_pytest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ class Session(nodes.Collector):
``Session`` collects the initial paths given as arguments to pytest.
"""

parent: None
Interrupted = Interrupted
Failed = Failed
# Set on the session by runner.pytest_sessionstart.
Expand Down
96 changes: 54 additions & 42 deletions src/_pytest/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from __future__ import annotations

import abc
Expand All @@ -20,6 +19,7 @@
import warnings

import pluggy
from typing_extensions import Self

import _pytest._code
from _pytest._code import getfslineno
Expand Down Expand Up @@ -54,9 +54,6 @@
tracebackcutdir = Path(_pytest.__file__).parent


_T = TypeVar("_T")


def _imply_path(
node_type: type[Node],
path: Path | None,
Expand Down Expand Up @@ -96,7 +93,7 @@ class NodeMeta(abc.ABCMeta):
progress on detangling the :class:`Node` classes.
"""

def __call__(cls, *k, **kw) -> NoReturn:
def __call__(cls, *k: object, **kw: object) -> NoReturn:
msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
"See "
Expand All @@ -105,25 +102,6 @@ def __call__(cls, *k, **kw) -> NoReturn:
).format(name=f"{cls.__module__}.{cls.__name__}")
fail(msg, pytrace=False)

def _create(cls: type[_T], *k, **kw) -> _T:
try:
return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError:
sig = signature(getattr(cls, "__init__"))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning

warnings.warn(
PytestDeprecationWarning(
f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
"See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details."
)
)

return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc]


class Node(abc.ABC, metaclass=NodeMeta):
r"""Base class of :class:`Collector` and :class:`Item`, the components of
Expand All @@ -138,8 +116,13 @@ class Node(abc.ABC, metaclass=NodeMeta):
#: for methods not migrated to ``pathlib.Path`` yet, such as
#: :meth:`Item.reportinfo <pytest.Item.reportinfo>`. Will be deprecated in
#: a future release, prefer using :attr:`path` instead.
name: str
parent: Node | None
config: Config
session: Session
fspath: LEGACY_PATH

_nodeid: str
# Use __slots__ to make attribute access faster.
# Note that __dict__ is still available.
__slots__ = (
Expand All @@ -156,7 +139,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
def __init__(
self,
name: str,
parent: Node | None = None,
parent: Node | None,
config: Config | None = None,
session: Session | None = None,
fspath: LEGACY_PATH | None = None,
Expand Down Expand Up @@ -200,13 +183,9 @@ def __init__(
#: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: set[str] = set()

if nodeid is not None:
assert "::()" not in nodeid
self._nodeid = nodeid
else:
if not self.parent:
raise TypeError("nodeid or parent must be provided")
self._nodeid = self.parent.nodeid + "::" + self.name
self._nodeid = self._make_nodeid(
name=self.name, parent=self.parent, given=nodeid
)

#: A place where plugins can store information on the node for their
#: own use.
Expand All @@ -215,7 +194,38 @@ def __init__(
self._store = self.stash

@classmethod
def from_parent(cls, parent: Node, **kw) -> Self:
def _make_nodeid(cls, name: str, parent: Node | None, given: str | None) -> str:
if given is not None:
assert "::()" not in given
return given
else:
assert parent is not None
return f"{parent.nodeid}::{name}"

@classmethod
def _create(cls, *k: object, **kw: object) -> Self:
callit = super(type(cls), NodeMeta).__call__ # type: ignore[misc]
try:
return cast(Self, callit(cls, *k, **kw))
except TypeError as e:
sig = signature(getattr(cls, "__init__"))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning

warnings.warn(
PytestDeprecationWarning(
f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
f"Exception: {e}\n"
"See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details."
)
)

return cast(Self, callit(cls, *k, **known_kw))

@classmethod
def from_parent(cls, parent: Node, **kw: Any) -> Self:
"""Public constructor for Nodes.

This indirection got introduced in order to enable removing
Expand All @@ -238,7 +248,7 @@ def ihook(self) -> pluggy.HookRelay:
return self.session.gethookproxy(self.path)

def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
return f'<{self.__class__.__name__} { getattr(self, "name", None)}>'

def warn(self, warning: Warning) -> None:
"""Issue a warning for this Node.
Expand Down Expand Up @@ -598,7 +608,6 @@ def __init__(

if nodeid and os.sep != SEP:
nodeid = nodeid.replace(os.sep, SEP)

super().__init__(
name=name,
parent=parent,
Expand All @@ -611,11 +620,11 @@ def __init__(
@classmethod
def from_parent(
cls,
parent,
parent: Node,
*,
fspath: LEGACY_PATH | None = None,
path: Path | None = None,
**kw,
**kw: Any,
) -> Self:
"""The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
Expand Down Expand Up @@ -646,22 +655,25 @@ class Directory(FSCollector, abc.ABC):
"""


class Definition(Collector, abc.ABC):
@abc.abstractmethod
def collect(self) -> Iterable[Item]: ...


class Item(Node, abc.ABC):
"""Base class of all test invocation items.

Note that for a single function there might be multiple test invocation items.
"""

nextitem = None

def __init__(
self,
name,
parent=None,
name: str,
parent: Node | None = None,
config: Config | None = None,
session: Session | None = None,
nodeid: str | None = None,
**kw,
**kw: Any,
) -> None:
# The first two arguments are intentionally passed positionally,
# to keep plugins who define a node type which inherits from
Expand Down
Loading
Loading