Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions src/ducktools/classbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,19 @@
import os
import sys

from .annotations import get_ns_annotations, is_classvar, make_annotate_func
try:
# Use the internal C module if it is available
from _types import ( # type: ignore
MemberDescriptorType as _MemberDescriptorType,
MappingProxyType as _MappingProxyType
)
except ImportError:
from types import (
MemberDescriptorType as _MemberDescriptorType,
MappingProxyType as _MappingProxyType,
)

from .annotations import get_ns_annotations, is_classvar, make_annotate_func, evaluate_forwardref
from ._version import __version__, __version_tuple__ # noqa: F401

# Change this name if you make heavy modifications
Expand All @@ -45,13 +57,6 @@
# overwritten. When running this is a performance penalty so it is not required.
_UNDER_TESTING = os.environ.get("PYTEST_VERSION") is not None

# Obtain types the same way types.py does in pypy
# See: https://github.com/pypy/pypy/blob/19d9fa6be11165116dd0839b9144d969ab426ae7/lib-python/3/types.py#L61-L73
class _C: __slots__ = 's' # noqa
_MemberDescriptorType = type(_C.s) # type: ignore
_MappingProxyType = type(type.__dict__)
del _C


def get_fields(cls, *, local=False):
"""
Expand Down Expand Up @@ -132,22 +137,18 @@ class GeneratedCode:
This class provides a return value for the generated output from source code
generators.
"""
__slots__ = ("source_code", "globs", "annotations", "extra_annotation_func")
__slots__ = ("source_code", "globs", "annotations")

def __init__(self, source_code, globs, annotations=None, extra_annotation_func=None):
def __init__(self, source_code, globs, annotations=None):
self.source_code = source_code
self.globs = globs
self.annotations = annotations

# extra annotation function to evaluate if needed, required for post_init
self.extra_annotation_func = extra_annotation_func

def __repr__(self):
first_source_line = self.source_code.split("\n")[0]
return (
f"GeneratorOutput(source_code='{first_source_line} ...', "
f"globs={self.globs!r}, annotations={self.annotations!r}, "
f"extra_annotation_func={self.extra_annotation_func!r})"
f"globs={self.globs!r}, annotations={self.annotations!r})"
)

def __eq__(self, other):
Expand All @@ -156,12 +157,10 @@ def __eq__(self, other):
self.source_code,
self.globs,
self.annotations,
self.extra_annotation_func
) == (
other.source_code,
other.globs,
other.annotations,
other.extra_annotation_func
)
return NotImplemented

Expand Down Expand Up @@ -227,11 +226,7 @@ def __get__(self, inst, cls):
if "__annotations__" in gen_cls.__dict__:
method.__annotations__ = gen.annotations
else:
anno_func = make_annotate_func(
gen_cls,
gen.annotations,
gen.extra_annotation_func,
)
anno_func = make_annotate_func(gen.annotations)
anno_func.__qualname__ = f"{gen_cls.__qualname__}.{self.funcname}.__annotate__"
method.__annotate__ = anno_func
else:
Expand Down Expand Up @@ -532,6 +527,10 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True):
"""
The main builder for class generation

If the GATHERED_DATA attribute exists on the class it will be used instead of
the provided gatherer and 3.14 annotations will be updated with links to
the class.

:param cls: Class to be analysed and have methods generated
:param gatherer: Function to gather field information
:type gatherer: Callable[[type], tuple[dict[str, Field], dict[str, Any]]]
Expand All @@ -552,12 +551,26 @@ def builder(cls=None, /, *, gatherer, methods, flags=None, fix_signature=True):
gatherer=gatherer,
methods=methods,
flags=flags,
fix_signature=fix_signature,
)

internals = {}
setattr(cls, INTERNALS_DICT, internals)

cls_fields, modifications = gatherer(cls)
cls_gathered = cls.__dict__.get(GATHERED_DATA)

if cls_gathered:
cls_fields, modifications = cls_gathered
# Reconnect the forwardrefs in types to the class so they can evaluate.
# If there are forwardrefs then annotationlib should be in modules
# No need to do this if __future__ annotations are used
if sys.version_info >= (3, 14) and sys.modules.get("annotationlib"):
annos = annotations.get_ns_annotations(cls.__dict__, cls=cls)
for k, v in cls_fields.items():
if annotations.is_forwardref(v.type):
cls_fields[k] = type(v).from_field(v, type=annos[k])
else:
cls_fields, modifications = gatherer(cls)

for name, value in modifications.items():
if value is NOTHING:
Expand Down Expand Up @@ -802,6 +815,12 @@ def from_field(cls, fld, /, **kwargs):

return cls(**argument_dict)

@property
def type_eval(self):
if sys.version_info >= (3, 14):
return annotations.evaluate_forwardref(self.type)
return self.type


def _build_field():
# Complete the construction of the Field class
Expand All @@ -828,7 +847,7 @@ def _build_field():
"init": Field(default=True, doc=field_docs["init"]),
"repr": Field(default=True, doc=field_docs["repr"]),
"compare": Field(default=True, doc=field_docs["compare"]),
"kw_only": Field(default=False, doc=field_docs["kw_only"])
"kw_only": Field(default=False, doc=field_docs["kw_only"]),
}
modifications = {"__slots__": field_docs}

Expand All @@ -848,21 +867,6 @@ def _build_field():
del _build_field


def pre_gathered_gatherer(cls_or_ns):
"""
Retrieve fields previously gathered by SlotMakerMeta

:param cls_or_ns: Class to gather field information from (or class namespace)
:return: dict of field_name: Field(...) and modifications to be performed by the builder
"""
if isinstance(cls_or_ns, (_MappingProxyType, dict)):
cls_dict = cls_or_ns
else:
cls_dict = cls_or_ns.__dict__

return cls_dict[GATHERED_DATA]


def make_slot_gatherer(field_type=Field):
"""
Create a new annotation gatherer that will work with `Field` instances
Expand Down Expand Up @@ -945,16 +949,18 @@ def make_annotation_gatherer(
"""
def field_annotation_gatherer(cls_or_ns):
if isinstance(cls_or_ns, (_MappingProxyType, dict)):
cls = None
cls_dict = cls_or_ns
else:
cls = cls_or_ns
cls_dict = cls_or_ns.__dict__

# This should really be dict[str, field_type] but static analysis
# doesn't understand this.
cls_fields: dict[str, Field] = {}
modifications = {}

cls_annotations = get_ns_annotations(cls_dict)
cls_annotations = get_ns_annotations(cls_dict, cls=cls)

kw_flag = False

Expand All @@ -963,7 +969,9 @@ def field_annotation_gatherer(cls_or_ns):
if is_classvar(v):
continue

if v is KW_ONLY or (isinstance(v, str) and v == "KW_ONLY"):
v_eval = evaluate_forwardref(v)

if v_eval is KW_ONLY or (isinstance(v, str) and v == "KW_ONLY"):
if kw_flag:
raise SyntaxError("KW_ONLY sentinel may only appear once.")
kw_flag = True
Expand Down Expand Up @@ -1006,8 +1014,10 @@ def make_field_gatherer(
def field_attribute_gatherer(cls_or_ns):
if isinstance(cls_or_ns, (_MappingProxyType, dict)):
cls_dict = cls_or_ns
cls = None
else:
cls_dict = cls_or_ns.__dict__
cls = cls_or_ns

cls_attributes = {
k: v
Expand All @@ -1016,7 +1026,7 @@ def field_attribute_gatherer(cls_or_ns):
}

if assign_types:
cls_annotations = get_ns_annotations(cls_dict)
cls_annotations = get_ns_annotations(cls_dict, cls=cls)
else:
cls_annotations = {}

Expand Down Expand Up @@ -1058,12 +1068,10 @@ def make_unified_gatherer(
def field_unified_gatherer(cls_or_ns):
if isinstance(cls_or_ns, (_MappingProxyType, dict)):
cls_dict = cls_or_ns
cls = None
else:
cls_dict = cls_or_ns.__dict__

cls_gathered = cls_dict.get(GATHERED_DATA)
if cls_gathered:
return pre_gathered_gatherer(cls_dict)
cls = cls_or_ns

cls_slots = cls_dict.get("__slots__")

Expand All @@ -1076,7 +1084,7 @@ def field_unified_gatherer(cls_or_ns):
# To choose between annotation and attribute gatherers
# compare sets of names.
# Don't bother evaluating string annotations, as we only need names
cls_annotations = get_ns_annotations(cls_dict)
cls_annotations = get_ns_annotations(cls_dict, cls=cls)
cls_attributes = {
k: v for k, v in cls_dict.items() if isinstance(v, field_type)
}
Expand All @@ -1086,7 +1094,9 @@ def field_unified_gatherer(cls_or_ns):

if set(cls_annotation_names).issuperset(set(cls_attribute_names)):
# All `Field` values have annotations, so use annotation gatherer
return anno_g(cls_dict)
# Pass the original cls_or_ns object

return anno_g(cls_or_ns)

return attrib_g(cls_dict)

Expand Down
22 changes: 10 additions & 12 deletions src/ducktools/classbuilder/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import sys
import types
import typing
import typing_extensions

import inspect

from collections.abc import Callable
from types import MappingProxyType

_py_type = type | str # Alias for type hint values
if sys.version_info >= (3, 14):
import annotationlib

_py_type = annotationlib.ForwardRef | type | str
else:
_py_type = type | str

_CopiableMappings = dict[str, typing.Any] | MappingProxyType[str, typing.Any]

__version__: str
Expand Down Expand Up @@ -45,14 +51,12 @@ class GeneratedCode:
source_code: str
globs: dict[str, typing.Any]
annotations: dict[str, typing.Any]
extra_annotation_func: None | types.FunctionType

def __init__(
self,
source_code: str,
globs: dict[str, typing.Any],
annotations: dict[str, typing.Any] | None = ...,
extra_annotation_func: None | types.FunctionType = ...,
) -> None: ...
def __repr__(self) -> str: ...

Expand Down Expand Up @@ -169,17 +173,14 @@ class Field(metaclass=SlotMakerMeta):
def validate_field(self) -> None: ...
@classmethod
def from_field(cls, fld: Field, /, **kwargs: typing.Any) -> Field: ...

@property
def type_eval(self) -> typing.Any: ...

# type[Field] doesn't work due to metaclass
# This is not really precise enough because isinstance is used
_ReturnsField = Callable[..., Field]
_FieldType = typing.TypeVar("_FieldType", bound=Field)

def pre_gathered_gatherer(
cls_or_ns: type | _CopiableMappings
) -> tuple[dict[str, Field | _FieldType], dict[str, typing.Any]]: ...

@typing.overload
def make_slot_gatherer(
field_type: type[_FieldType]
Expand Down Expand Up @@ -265,9 +266,6 @@ class GatheredFields:
fields: dict[str, Field]
modifications: dict[str, typing.Any]

__classbuilder_internals__: dict
__signature__: inspect.Signature

def __init__(
self,
fields: dict[str, Field],
Expand Down
Loading