From 259b11914782c2745b705d77a42c7b36419301df Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sat, 24 Sep 2022 11:25:38 -0700 Subject: [PATCH 01/32] Test for memory leaks as described in #198 --- tests/test_memory_leak.py | 111 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_memory_leak.py diff --git a/tests/test_memory_leak.py b/tests/test_memory_leak.py new file mode 100644 index 0000000..64aef41 --- /dev/null +++ b/tests/test_memory_leak.py @@ -0,0 +1,111 @@ +import gc +import inspect +import sys +import unittest +import weakref +from dataclasses import dataclass + +import marshmallow +import marshmallow_dataclass as md + + +class Referenceable: + pass + + +class TestMemoryLeak(unittest.TestCase): + """Test for memory leaks as decribed in `#198`_. + + .. _#198: https://github.com/lovasoa/marshmallow_dataclass/issues/198 + """ + + def setUp(self): + gc.collect() + gc.disable() + self.frame_collected = False + + def tearDown(self): + gc.enable() + + def trackFrame(self): + """Create a tracked local variable in the callers frame. + + We track these locals in the WeakSet self.livingLocals. + + When the callers frame is freed, the locals will be GCed as well. + In this way we can check that the callers frame has been collected. + """ + local = Referenceable() + weakref.finalize(local, self._set_frame_collected) + try: + frame = inspect.currentframe() + frame.f_back.f_locals["local_variable"] = local + finally: + del frame + + def _set_frame_collected(self): + self.frame_collected = True + + def assertFrameCollected(self): + """Check that all locals created by makeLocal have been GCed""" + if not hasattr(sys, "getrefcount"): + # pypy does not do reference counting + gc.collect(0) + self.assertTrue(self.frame_collected) + + def test_sanity(self): + """Test that our scheme for detecting leaked frames works.""" + frames = [] + + def f(): + frames.append(inspect.currentframe()) + self.trackFrame() + + f() + + gc.collect(0) + self.assertFalse( + self.frame_collected + ) # with frame leaked, f's locals are still alive + frames.clear() + self.assertFrameCollected() + + def test_class_schema(self): + def f(): + @dataclass + class Foo: + value: int + + md.class_schema(Foo) + + self.trackFrame() + + f() + self.assertFrameCollected() + + def test_md_dataclass_lazy_schema(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.trackFrame() + + f() + # NB: The "lazy" Foo.Schema attribute descriptor holds a reference to f's frame, + # which, in turn, holds a reference to class Foo, thereby creating ref cycle. + # So, a gc pass is required to clean that up. + gc.collect(0) + self.assertFrameCollected() + + def test_md_dataclass(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.assertIsInstance(Foo.Schema(), marshmallow.Schema) + self.trackFrame() + + f() + self.assertFrameCollected() From 387afb3c48118c743d96daa00622a771cfc5a81d Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sat, 24 Sep 2022 11:30:22 -0700 Subject: [PATCH 02/32] Possible fix for #198: memory leak --- marshmallow_dataclass/__init__.py | 141 +++++++++++++++++++----------- 1 file changed, 91 insertions(+), 50 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 106cde2..e856888 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -46,6 +46,7 @@ class User: Any, Callable, Dict, + Generic, List, Mapping, NewType as typing_NewType, @@ -79,9 +80,6 @@ class User: # Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates. MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 -# Recursion guard for class_schema() -_RECURSION_GUARD = threading.local() - @overload def dataclass( @@ -352,20 +350,61 @@ def class_schema( clazz_frame = current_frame.f_back # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack del current_frame - _RECURSION_GUARD.seen_classes = {} - try: - return _internal_class_schema(clazz, base_schema, clazz_frame) - finally: - _RECURSION_GUARD.seen_classes.clear() + + with _SchemaContext(clazz_frame): + return _internal_class_schema(clazz, base_schema) + + +class _SchemaContext: + """Global context for an invocation of class_schema.""" + + def __init__(self, frame: Optional[types.FrameType]): + self.seen_classes: Dict[type, str] = {} + self.frame = frame + + def get_type_hints(self, cls: Type) -> Dict[str, Any]: + frame = self.frame + localns = frame.f_locals if frame is not None else None + return get_type_hints(cls, localns=localns) + + def __enter__(self) -> "_SchemaContext": + _schema_ctx_stack.push(self) + return self + + def __exit__( + self, + _typ: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[types.TracebackType], + ) -> None: + _schema_ctx_stack.pop() + + +class _LocalStack(threading.local, Generic[_U]): + def __init__(self) -> None: + self.stack: List[_U] = [] + + def push(self, value: _U) -> None: + self.stack.append(value) + + def pop(self) -> None: + self.stack.pop() + + @property + def top(self) -> _U: + return self.stack[-1] + + +_schema_ctx_stack = _LocalStack[_SchemaContext]() @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, - clazz_frame: Optional[types.FrameType] = None, ) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ + schema_ctx = _schema_ctx_stack.top + schema_ctx.seen_classes[clazz] = clazz.__name__ try: # noinspection PyDataclass fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) @@ -383,7 +422,7 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) + return _internal_class_schema(created_dataclass, base_schema) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -397,18 +436,15 @@ def _internal_class_schema( } # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None - ) + type_hints = schema_ctx.get_type_hints(clazz) attributes.update( ( field.name, - field_for_schema( + _field_for_schema( type_hints[field.name], _get_field_default(field), field.metadata, base_schema, - clazz_frame, ), ) for field in fields @@ -433,7 +469,6 @@ def _field_by_supertype( newtype_supertype: Type, metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -459,12 +494,11 @@ def _field_by_supertype( if field: return field(**metadata) else: - return field_for_schema( + return _field_for_schema( newtype_supertype, metadata=metadata, default=default, base_schema=base_schema, - typ_frame=typ_frame, ) @@ -488,7 +522,6 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -501,9 +534,7 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -516,32 +547,25 @@ def _field_for_generic_type( ): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) - for arg in arguments + _field_for_schema(arg, base_schema=base_schema) for arg in arguments ) tuple_type = cast( Type[marshmallow.fields.Tuple], @@ -553,14 +577,11 @@ def _field_for_generic_type( elif origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ), - values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame - ), + keys=_field_for_schema(arguments[0], base_schema=base_schema), + values=_field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) + if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -570,11 +591,10 @@ def _field_for_generic_type( metadata.setdefault("required", False) subtypes = [t for t in arguments if t is not NoneType] # type: ignore if len(subtypes) == 1: - return field_for_schema( + return _field_for_schema( subtypes[0], metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) from . import union_field @@ -582,11 +602,10 @@ def _field_for_generic_type( [ ( subtyp, - field_for_schema( + _field_for_schema( subtyp, metadata={"required": True}, base_schema=base_schema, - typ_frame=typ_frame, ), ) for subtyp in subtypes @@ -598,7 +617,7 @@ def _field_for_generic_type( def field_for_schema( typ: type, - default=marshmallow.missing, + default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, typ_frame: Optional[types.FrameType] = None, @@ -622,6 +641,29 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ + """ + with _SchemaContext(typ_frame): + return _field_for_schema(typ, default, metadata, base_schema) + + +def _field_for_schema( + typ: type, + default: Any = marshmallow.missing, + metadata: Optional[Mapping[str, Any]] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, +) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + + This is an internal version of field_for_schema. It assumes a _SchemaContext + has been pushed onto the local stack. + + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + :param base_schema: marshmallow schema used as a base class when deriving dataclass schema + """ metadata = {} if metadata is None else dict(metadata) @@ -694,10 +736,10 @@ def field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) + return _field_for_schema(subtyp, default, metadata, base_schema) # Generic types - generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) + generic_field = _field_for_generic_type(typ, base_schema, **metadata) if generic_field: return generic_field @@ -711,7 +753,6 @@ def field_for_schema( newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, - typ_frame=typ_frame, ) # enumerations @@ -734,8 +775,8 @@ def field_for_schema( nested = ( nested_schema or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type] + or _schema_ctx_stack.top.seen_classes.get(typ) + or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME ) return marshmallow.fields.Nested(nested, **metadata) From 6416dc911ebe4c8366a54f1caca997429fe99bf7 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sat, 24 Sep 2022 21:00:00 -0700 Subject: [PATCH 03/32] Optimization: avoid holding frame reference when locals == globals --- marshmallow_dataclass/__init__.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index e856888..2015534 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -37,6 +37,7 @@ class User: import collections.abc import dataclasses import inspect +import sys import threading import types import warnings @@ -208,9 +209,24 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None): """ def decorator(clazz: Type[_U]) -> Type[_U]: + cls_frame_ = cls_frame + if cls_frame is not None: + cls_globals = getattr(sys.modules.get(clazz.__module__), "__dict__", None) + if cls_frame.f_locals is cls_globals: + # Memory optimization: + # If the caller's locals are the same as the class + # module globals, we don't need the locals. (This is + # typically the case for dataclasses defined at the + # module top-level.) (Typing.get_type_hints() knows + # how to check the class module globals on its own.) + # Not holding a reference to the frame in our our lazy + # class attribute which is a significant win, + # memory-wise. + cls_frame_ = None + # noinspection PyTypeHints clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, cls_frame), + partial(class_schema, clazz, base_schema, cls_frame_), "Schema", clazz.__name__, ) From fd04f8c8dbefb230decafaa25addee8d59e454b0 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sun, 25 Sep 2022 13:03:27 -0700 Subject: [PATCH 04/32] Get caller frame at decoration-time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here we are more careful about which caller's locals we use to resolve forward type references. We want the callers locals at decoration-time — not at decorator-construction time. Consider: ```py frozen_dataclass = marshmallow_dataclass.dataclass(frozen=True) def f(): @custom_dataclass class A: b: "B" @custom_dataclass class B: x: int ``` The locals we want in this case are the one from where the custom_dataclass decorator is called, not from where marshmallow_dataclass.dataclass is called. --- marshmallow_dataclass/__init__.py | 100 ++++++++++++++++++++---------- tests/test_forward_references.py | 16 +++++ 2 files changed, 83 insertions(+), 33 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 2015534..fa12546 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -82,6 +82,51 @@ class User: MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 +def _maybe_get_callers_frame( + cls: type, stacklevel: int = 1 +) -> Optional[types.FrameType]: + """Return the caller's frame, but only if it will help resolve forward type references. + + We sometimes need the caller's frame to get access to the caller's + local namespace in order to be able to resolve forward type + references in dataclasses. + + Notes + ----- + + If the caller's locals are the same as the dataclass' module + globals — this is the case for the common case of dataclasses + defined at the module top-level — we don't need the locals. + (Typing.get_type_hints() knows how to check the class module + globals on its own.) + + In that case, we don't need the caller's frame. Not holding a + reference to the frame in our our lazy ``.Scheme`` class attribute + is a significant win, memory-wise. + + """ + try: + frame = inspect.currentframe() + for _ in range(stacklevel + 1): + if frame is None: + return None + frame = frame.f_back + + if frame is None: + return None + + globalns = getattr(sys.modules.get(cls.__module__), "__dict__", None) + if frame.f_locals is globalns: + # Locals are the globals + return None + + return frame + + finally: + # Paranoia, per https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame + + @overload def dataclass( _cls: Type[_U], @@ -124,6 +169,7 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Union[Type[_U], Callable[[Type[_U]], Type[_U]]]: """ This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`. @@ -150,19 +196,18 @@ def dataclass( >>> Point.Schema().load({'x':0, 'y':0}) # This line can be statically type checked Point(x=0.0, y=0.0) """ - # dataclass's typing doesn't expect it to be called as a function, so ignore type check - dc = dataclasses.dataclass( # type: ignore - _cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + dc = dataclasses.dataclass( + repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen ) - if not cls_frame: - current_frame = inspect.currentframe() - if current_frame: - cls_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame + + def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + return add_schema( + dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 + ) + if _cls is None: - return lambda cls: add_schema(dc(cls), base_schema, cls_frame=cls_frame) - return add_schema(dc, base_schema, cls_frame=cls_frame) + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) @overload @@ -182,11 +227,12 @@ def add_schema( _cls: Type[_U], base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Type[_U]: ... -def add_schema(_cls=None, base_schema=None, cls_frame=None): +def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. @@ -208,31 +254,23 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None): Artist(names=('Martin', 'Ramirez')) """ - def decorator(clazz: Type[_U]) -> Type[_U]: - cls_frame_ = cls_frame + def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: if cls_frame is not None: - cls_globals = getattr(sys.modules.get(clazz.__module__), "__dict__", None) - if cls_frame.f_locals is cls_globals: - # Memory optimization: - # If the caller's locals are the same as the class - # module globals, we don't need the locals. (This is - # typically the case for dataclasses defined at the - # module top-level.) (Typing.get_type_hints() knows - # how to check the class module globals on its own.) - # Not holding a reference to the frame in our our lazy - # class attribute which is a significant win, - # memory-wise. - cls_frame_ = None + frame = cls_frame + else: + frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel) # noinspection PyTypeHints clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, cls_frame_), + partial(class_schema, clazz, base_schema, frame), "Schema", clazz.__name__, ) return clazz - return decorator(_cls) if _cls else decorator + if _cls is None: + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) def class_schema( @@ -361,11 +399,7 @@ def class_schema( if not dataclasses.is_dataclass(clazz): clazz = dataclasses.dataclass(clazz) if not clazz_frame: - current_frame = inspect.currentframe() - if current_frame: - clazz_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame + clazz_frame = _maybe_get_callers_frame(clazz) with _SchemaContext(clazz_frame): return _internal_class_schema(clazz, base_schema) diff --git a/tests/test_forward_references.py b/tests/test_forward_references.py index fc05b12..2a2fa96 100644 --- a/tests/test_forward_references.py +++ b/tests/test_forward_references.py @@ -133,3 +133,19 @@ class B: B.Schema().load(dict(a=dict(c=1))) # marshmallow.exceptions.ValidationError: # {'a': {'d': ['Missing data for required field.'], 'c': ['Unknown field.']}} + + def test_locals_from_decoration_ns(self): + # Test that locals are picked-up at decoration-time rather + # than when the decorator is constructed. + @frozen_dataclass + class A: + b: "B" + + @frozen_dataclass + class B: + x: int + + assert A.Schema().load({"b": {"x": 42}}) == A(b=B(x=42)) + + +frozen_dataclass = dataclass(frozen=True) From 9093446437c448b0ebb09c960994c162efc999f2 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sun, 25 Sep 2022 13:44:28 -0700 Subject: [PATCH 05/32] Add ability to pass explicit localns (and globalns) to class_schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When class_schema is called, it doesn't need the caller's whole stack frame. What it really wants is a `localns` to pass to `typing.get_type_hints` to be used to resolve type references. Here we add the ability to pass an explicit `localns` parameter to `class_schema`. We also add the ability to pass an explicit `globalns`, because ... might as well — it might come in useful. (Since we need these only to pass to `get_type_hints`, we might as well match `get_type_hints` API as closely as possible.) --- marshmallow_dataclass/__init__.py | 57 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index fa12546..cf94172 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: return decorator(_cls, stacklevel=stacklevel + 1) +@overload +def class_schema( + clazz: type, + base_schema: Optional[Type[marshmallow.Schema]] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... + + +@overload +def class_schema( + clazz: type, + base_schema: Optional[Type[marshmallow.Schema]] = None, + clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... + + def class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete clazz_frame from API? clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, ) -> Type[marshmallow.Schema]: """ Convert a class to a marshmallow schema @@ -398,24 +424,26 @@ def class_schema( """ if not dataclasses.is_dataclass(clazz): clazz = dataclasses.dataclass(clazz) - if not clazz_frame: - clazz_frame = _maybe_get_callers_frame(clazz) - - with _SchemaContext(clazz_frame): + if localns is None: + if clazz_frame is None: + clazz_frame = _maybe_get_callers_frame(clazz) + if clazz_frame is not None: + localns = clazz_frame.f_locals + with _SchemaContext(globalns, localns): return _internal_class_schema(clazz, base_schema) class _SchemaContext: """Global context for an invocation of class_schema.""" - def __init__(self, frame: Optional[types.FrameType]): + def __init__( + self, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, + ): self.seen_classes: Dict[type, str] = {} - self.frame = frame - - def get_type_hints(self, cls: Type) -> Dict[str, Any]: - frame = self.frame - localns = frame.f_locals if frame is not None else None - return get_type_hints(cls, localns=localns) + self.globalns = globalns + self.localns = localns def __enter__(self) -> "_SchemaContext": _schema_ctx_stack.push(self) @@ -486,7 +514,9 @@ def _internal_class_schema( } # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = schema_ctx.get_type_hints(clazz) + type_hints = get_type_hints( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) attributes.update( ( field.name, @@ -670,6 +700,7 @@ def field_for_schema( default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete typ_frame from API? typ_frame: Optional[types.FrameType] = None, ) -> marshmallow.fields.Field: """ @@ -692,7 +723,7 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ """ - with _SchemaContext(typ_frame): + with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None): return _field_for_schema(typ, default, metadata, base_schema) From e693bb0463979babcad4452b003d055ced0d750a Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 12:39:39 -0800 Subject: [PATCH 06/32] test(mypy): move mypy config from setup.cfg to pyproject.toml --- .pre-commit-config.yaml | 6 +++++- pyproject.toml | 28 ++++++++++++++++++++++++++++ setup.cfg | 3 --- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dca4a3b..2cc2293 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,11 @@ repos: rev: v0.991 hooks: - id: mypy - additional_dependencies: [marshmallow-enum,typeguard,marshmallow] + additional_dependencies: + - marshmallow + - marshmallow-enum + - typeguard + - types-setuptools args: [--show-error-codes] - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 diff --git a/pyproject.toml b/pyproject.toml index 4112bd1..54631c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,3 +6,31 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310', 'py310'] filterwarnings = [ "error:::marshmallow_dataclass|test", ] + +[tool.mypy] +packages = [ + "marshmallow_dataclass", + "tests", +] +# XXX: Can not load our plugin when running mypy from pre-commit. +# Pre-commit runs mypy from its own venv (into which we do not want +# to install marshmallow_dataclass). +# The fact that our plugin is in a file named "mypy.py" causes issues +# (I think) if we try to load it by path. In that case mypy adds +# the containing directory to sys.path then calls import_module("mypy"), +# which, in turn, finds the already imported sys.modules['mypy']. +# +# plugins = "marshmallow_dataclass.mypy" + +warn_redundant_casts = true +warn_unused_configs = true +disable_error_code = "annotation-unchecked" + +[[tool.mypy.overrides]] +# dependencies without type hints +module = [ + "marshmallow_enum", + "typing_inspect", +] +ignore_missing_imports = true + diff --git a/setup.cfg b/setup.cfg index 1216d89..d5a376b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,3 @@ ignore = E203, E266, E501, W503 max-line-length = 100 max-complexity = 18 select = B,C,E,F,W,T4,B9 - -[mypy] -ignore_missing_imports = true From 075b3a7e58965d0037c7e167491cb535e05cc1c1 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 17:38:10 -0800 Subject: [PATCH 07/32] fix: mypy plugin loading under pre-commit --- .pre-commit-config.yaml | 1 + mypy_plugin.py | 64 +++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 13 +++------ 3 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 mypy_plugin.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2cc2293..cc19292 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: - marshmallow-enum - typeguard - types-setuptools + - typing-inspect args: [--show-error-codes] - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 diff --git a/mypy_plugin.py b/mypy_plugin.py new file mode 100644 index 0000000..f76a199 --- /dev/null +++ b/mypy_plugin.py @@ -0,0 +1,64 @@ +"""Shim to load the marshmallow_dataclass.mypy plugin. + +This shim is needed when running mypy from pre-commit. + +Pre-commit runs mypy from its own venv (into which we do not want +to install marshmallow_dataclass). Because of this, loading the plugin +by module name, e.g. + + [tool.mypy] + plugins = "marshmallow_dataclass.mypy" + +does not work. Mypy also supports specifying a path to the plugin +module source, which would normally get us out of this bind, however, +the fact that our plugin is in a file named "mypy.py" causes issues. + +If we set + + [tool.mypy] + plugins = "marshmallow_dataclass/mypy.py" + +mypy `attempts to load`__ the plugin module by temporarily prepending + ``marshmallow_dataclass`` to ``sys.path`` then importing the ``mypy`` +module. Sadly, mypy's ``mypy`` module has already been imported, +so this doesn't end well. + +__ https://github.com/python/mypy/blob/914901f14e0e6223077a8433388c367138717451/mypy/build.py#L450 + + +Our solution, here, is to manually load the plugin module (with a better +``sys.path``, and import the ``plugin`` from the real plugin module into this one. + +Now we can configure mypy to load this file, by path. + + [tool.mypy] + plugins = "mypy_plugin.py" + +""" +import importlib +import sys +from os import fspath +from pathlib import Path +from typing import Type +from warnings import warn + +from mypy.plugin import Plugin + + +def null_plugin(version: str) -> Type[Plugin]: + """A fallback do-nothing plugin hook""" + return Plugin + + +module_name = "marshmallow_dataclass.mypy" + +src = fspath(Path(__file__).parent) +sys.path.insert(0, src) +try: + plugin_module = importlib.import_module(module_name) + plugin = plugin_module.plugin +except Exception as exc: + warn(f"can not load {module_name} plugin: {exc}") + plugin = null_plugin +finally: + del sys.path[0] diff --git a/pyproject.toml b/pyproject.toml index 54631c9..5bbb9ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,15 +12,10 @@ packages = [ "marshmallow_dataclass", "tests", ] -# XXX: Can not load our plugin when running mypy from pre-commit. -# Pre-commit runs mypy from its own venv (into which we do not want -# to install marshmallow_dataclass). -# The fact that our plugin is in a file named "mypy.py" causes issues -# (I think) if we try to load it by path. In that case mypy adds -# the containing directory to sys.path then calls import_module("mypy"), -# which, in turn, finds the already imported sys.modules['mypy']. -# -# plugins = "marshmallow_dataclass.mypy" +# XXX: Specifying the marshmallow_dataclass.mypy plugin directly by +# module name or by path does not work when running mypy from pre-commit. +# (See the docstring in mypy_plugin.py for more.) +plugins = "mypy_plugin.py" warn_redundant_casts = true warn_unused_configs = true From 4312042b042ea56ef2ef828306b7a2c189e5a4cd Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 10:05:29 -0800 Subject: [PATCH 08/32] refactor: convert _SchemaContext to dataclass --- marshmallow_dataclass/__init__.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index cf94172..7c99456 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -433,17 +433,13 @@ def class_schema( return _internal_class_schema(clazz, base_schema) +@dataclasses.dataclass class _SchemaContext: """Global context for an invocation of class_schema.""" - def __init__( - self, - globalns: Optional[Dict[str, Any]] = None, - localns: Optional[Dict[str, Any]] = None, - ): - self.seen_classes: Dict[type, str] = {} - self.globalns = globalns - self.localns = localns + globalns: Optional[Dict[str, Any]] = None + localns: Optional[Dict[str, Any]] = None + seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict) def __enter__(self) -> "_SchemaContext": _schema_ctx_stack.push(self) From 0f113c5d9d3dacc37b34ccd16e0223f10238dc3c Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 10:29:24 -0800 Subject: [PATCH 09/32] refactor: move base_schema into _SchemaContext --- marshmallow_dataclass/__init__.py | 68 ++++++++++++++----------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 7c99456..b872896 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -439,6 +439,7 @@ class _SchemaContext: globalns: Optional[Dict[str, Any]] = None localns: Optional[Dict[str, Any]] = None + base_schema: Optional[Type[marshmallow.Schema]] = None seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict) def __enter__(self) -> "_SchemaContext": @@ -513,27 +514,27 @@ def _internal_class_schema( type_hints = get_type_hints( clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns ) - attributes.update( - ( - field.name, - _field_for_schema( - type_hints[field.name], - _get_field_default(field), - field.metadata, - base_schema, - ), + with dataclasses.replace(schema_ctx, base_schema=base_schema): + attributes.update( + ( + field.name, + _field_for_schema( + type_hints[field.name], + _get_field_default(field), + field.metadata, + ), + ) + for field in fields + if field.init ) - for field in fields - if field.init - ) schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) -def _field_by_type( - typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]] -) -> Optional[Type[marshmallow.fields.Field]]: +def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]: + # FIXME: remove this function + base_schema = _schema_ctx_stack.top.base_schema return ( base_schema and base_schema.TYPE_MAPPING.get(typ) ) or marshmallow.Schema.TYPE_MAPPING.get(typ) @@ -544,7 +545,6 @@ def _field_by_supertype( default: Any, newtype_supertype: Type, metadata: dict, - base_schema: Optional[Type[marshmallow.Schema]], ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -574,7 +574,6 @@ def _field_by_supertype( newtype_supertype, metadata=metadata, default=default, - base_schema=base_schema, ) @@ -597,7 +596,6 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, - base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -607,10 +605,11 @@ def _field_for_generic_type( arguments = typing_inspect.get_args(typ, True) if origin: # Override base_schema.TYPE_MAPPING to change the class used for generic types below + base_schema = _schema_ctx_stack.top.base_schema type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = _field_for_schema(arguments[0]) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -623,26 +622,24 @@ def _field_for_generic_type( ): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = _field_for_schema(arguments[0]) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = _field_for_schema(arguments[0]) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = _field_for_schema(arguments[0]) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): - children = tuple( - _field_for_schema(arg, base_schema=base_schema) for arg in arguments - ) + children = tuple(_field_for_schema(arg) for arg in arguments) tuple_type = cast( Type[marshmallow.fields.Tuple], type_mapping.get( # type:ignore[call-overload] @@ -653,8 +650,8 @@ def _field_for_generic_type( elif origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=_field_for_schema(arguments[0], base_schema=base_schema), - values=_field_for_schema(arguments[1], base_schema=base_schema), + keys=_field_for_schema(arguments[0]), + values=_field_for_schema(arguments[1]), **metadata, ) @@ -670,7 +667,6 @@ def _field_for_generic_type( return _field_for_schema( subtypes[0], metadata=metadata, - base_schema=base_schema, ) from . import union_field @@ -681,7 +677,6 @@ def _field_for_generic_type( _field_for_schema( subtyp, metadata={"required": True}, - base_schema=base_schema, ), ) for subtyp in subtypes @@ -719,15 +714,15 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ """ - with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None): - return _field_for_schema(typ, default, metadata, base_schema) + localns = typ_frame.f_locals if typ_frame is not None else None + with _SchemaContext(localns=localns, base_schema=base_schema): + return _field_for_schema(typ, default, metadata) def _field_for_schema( typ: type, default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, - base_schema: Optional[Type[marshmallow.Schema]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -739,7 +734,6 @@ def _field_for_schema( :param typ: The type for which a field should be generated :param default: value to use for (de)serialization when the field is missing :param metadata: Additional parameters to pass to the marshmallow field constructor - :param base_schema: marshmallow schema used as a base class when deriving dataclass schema """ @@ -762,7 +756,7 @@ def _field_for_schema( typ = _generic_type_add_any(typ) # Base types - field = _field_by_type(typ, base_schema) + field = _field_by_type(typ) if field: return field(**metadata) @@ -813,10 +807,10 @@ def _field_for_schema( ) else: subtyp = Any - return _field_for_schema(subtyp, default, metadata, base_schema) + return _field_for_schema(subtyp, default, metadata) # Generic types - generic_field = _field_for_generic_type(typ, base_schema, **metadata) + generic_field = _field_for_generic_type(typ, **metadata) if generic_field: return generic_field @@ -829,7 +823,6 @@ def _field_for_schema( default=default, newtype_supertype=newtype_supertype, metadata=metadata, - base_schema=base_schema, ) # enumerations @@ -849,6 +842,7 @@ def _field_for_schema( # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) + base_schema = _schema_ctx_stack.top.base_schema nested = ( nested_schema or forward_reference From d95e171a849b51b1d96b86da2abc0452d4d92032 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 10:35:50 -0800 Subject: [PATCH 10/32] refactor: _field_for_generic_type: pass metadata without splat --- marshmallow_dataclass/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index b872896..e57e7ba 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -544,7 +544,7 @@ def _field_by_supertype( typ: Type, default: Any, newtype_supertype: Type, - metadata: dict, + metadata: Dict[str, Any], ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -596,7 +596,7 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, - **metadata: Any, + metadata: Dict[str, Any], ) -> Optional[marshmallow.fields.Field]: """ If the type is a generic interface, resolve the arguments and construct the appropriate Field. @@ -810,7 +810,7 @@ def _field_for_schema( return _field_for_schema(subtyp, default, metadata) # Generic types - generic_field = _field_for_generic_type(typ, **metadata) + generic_field = _field_for_generic_type(typ, metadata) if generic_field: return generic_field From 01c71daab1585e346cd7f79ed5bbd8eef5a35f56 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 09:45:55 -0800 Subject: [PATCH 11/32] refactor: add _SchemaContext.get_type_mapping --- marshmallow_dataclass/__init__.py | 48 ++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index e57e7ba..54666d4 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -46,6 +46,7 @@ class User: from typing import ( Any, Callable, + ChainMap, Dict, Generic, List, @@ -74,6 +75,8 @@ class User: NoneType = type(None) _U = TypeVar("_U") +_Field = TypeVar("_Field", bound=marshmallow.fields.Field) + # Whitelist of dataclass members that will be copied to generated schema. MEMBERS_WHITELIST: Set[str] = {"Meta"} @@ -442,6 +445,23 @@ class _SchemaContext: base_schema: Optional[Type[marshmallow.Schema]] = None seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict) + def get_type_mapping( + self, use_mro: bool = False + ) -> Mapping[Any, Type[marshmallow.fields.Field]]: + """Get base_schema.TYPE_MAPPING. + + If use_mro is true, then merges the TYPE_MAPPINGs from + all bases in base_schema's MRO. + """ + base_schema = self.base_schema + if base_schema is None: + base_schema = marshmallow.Schema + if use_mro: + return ChainMap( + *(getattr(cls, "TYPE_MAPPING", {}) for cls in base_schema.__mro__) + ) + return getattr(base_schema, "TYPE_MAPPING", {}) + def __enter__(self) -> "_SchemaContext": _schema_ctx_stack.push(self) return self @@ -534,10 +554,9 @@ def _internal_class_schema( def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]: # FIXME: remove this function - base_schema = _schema_ctx_stack.top.base_schema - return ( - base_schema and base_schema.TYPE_MAPPING.get(typ) - ) or marshmallow.Schema.TYPE_MAPPING.get(typ) + schema_ctx = _schema_ctx_stack.top + type_mapping = schema_ctx.get_type_mapping(use_mro=True) + return type_mapping.get(typ) def _field_by_supertype( @@ -605,15 +624,15 @@ def _field_for_generic_type( arguments = typing_inspect.get_args(typ, True) if origin: # Override base_schema.TYPE_MAPPING to change the class used for generic types below - base_schema = _schema_ctx_stack.top.base_schema - type_mapping = base_schema.TYPE_MAPPING if base_schema else {} + schema_ctx = _schema_ctx_stack.top + + def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: + type_mapping = schema_ctx.get_type_mapping() + return type_mapping.get(type_spec, default) # type: ignore[return-value] if origin in (list, List): child_type = _field_for_schema(arguments[0]) - list_type = cast( - Type[marshmallow.fields.List], - type_mapping.get(List, marshmallow.fields.List), - ) + list_type = get_field_type(List, default=marshmallow.fields.List) return list_type(child_type, **metadata) if origin in (collections.abc.Sequence, Sequence) or ( origin in (tuple, Tuple) @@ -640,15 +659,10 @@ def _field_for_generic_type( ) if origin in (tuple, Tuple): children = tuple(_field_for_schema(arg) for arg in arguments) - tuple_type = cast( - Type[marshmallow.fields.Tuple], - type_mapping.get( # type:ignore[call-overload] - Tuple, marshmallow.fields.Tuple - ), - ) + tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) return tuple_type(children, **metadata) elif origin in (dict, Dict, collections.abc.Mapping, Mapping): - dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) + dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) return dict_type( keys=_field_for_schema(arguments[0]), values=_field_for_schema(arguments[1]), From f5cadbde14a7359cc8d0a1a9184f024940cae42d Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 10:56:41 -0800 Subject: [PATCH 12/32] refactor: delete _field_by_type --- marshmallow_dataclass/__init__.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 54666d4..4bf43e0 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -552,13 +552,6 @@ def _internal_class_schema( return cast(Type[marshmallow.Schema], schema_class) -def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]: - # FIXME: remove this function - schema_ctx = _schema_ctx_stack.top - type_mapping = schema_ctx.get_type_mapping(use_mro=True) - return type_mapping.get(typ) - - def _field_by_supertype( typ: Type, default: Any, @@ -769,9 +762,12 @@ def _field_for_schema( # Generic types specified without type arguments typ = _generic_type_add_any(typ) + schema_ctx = _schema_ctx_stack.top + # Base types - field = _field_by_type(typ) - if field: + type_mapping = schema_ctx.get_type_mapping(use_mro=True) + field = type_mapping.get(typ) + if field is not None: return field(**metadata) if typ is Any: @@ -856,12 +852,13 @@ def _field_for_schema( # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) - base_schema = _schema_ctx_stack.top.base_schema nested = ( nested_schema or forward_reference - or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME + or schema_ctx.seen_classes.get(typ) + or _internal_class_schema( + typ, schema_ctx.base_schema # type: ignore[arg-type] # FIXME + ) ) return marshmallow.fields.Nested(nested, **metadata) From 449ef1ab0d6cfea82549d8b8b71bf510b1022ad4 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 13:36:02 -0800 Subject: [PATCH 13/32] feat: support generic dataclasses Manually merge work in PR #172 by @onursatici to support generic dataclasses Refactor _field_for_schema to reduce complexity. --- marshmallow_dataclass/__init__.py | 639 ++++++++++++++++++++---------- setup.py | 5 +- tests/test_class_schema.py | 125 +++++- 3 files changed, 544 insertions(+), 225 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 4bf43e0..acabd21 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -49,6 +49,9 @@ class User: ChainMap, Dict, Generic, + Hashable, + Iterable, + Iterator, List, Mapping, NewType as typing_NewType, @@ -58,7 +61,6 @@ class User: Type, TypeVar, Union, - cast, get_type_hints, overload, Sequence, @@ -70,9 +72,51 @@ class User: from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute - __all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"] + +if sys.version_info >= (3, 8): + from typing import get_args + from typing import get_origin +elif sys.version_info >= (3, 7): + from typing_extensions import get_args + from typing_extensions import get_origin +else: + + def get_args(tp): + return typing_inspect.get_args(tp, evaluate=True) + + def get_origin(tp): + TYPE_MAP = { + List: list, + Sequence: collections.abc.Sequence, + Set: set, + FrozenSet: frozenset, + Tuple: tuple, + Dict: dict, + Mapping: collections.abc.Mapping, + Generic: Generic, + } + + origin = typing_inspect.get_origin(tp) + if origin in TYPE_MAP: + return TYPE_MAP[origin] + elif origin is not tp: + return origin + return None + + +if sys.version_info >= (3, 7): + TypeVar_ = TypeVar +else: + TypeVar_ = type + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + + NoneType = type(None) _U = TypeVar("_U") _Field = TypeVar("_Field", bound=marshmallow.fields.Field) @@ -204,8 +248,9 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + dc(cls) return add_schema( - dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 + cls, base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 ) if _cls is None: @@ -299,7 +344,7 @@ def class_schema( def class_schema( - clazz: type, + clazz: type, # FIXME: type | _GenericAlias base_schema: Optional[Type[marshmallow.Schema]] = None, # FIXME: delete clazz_frame from API? clazz_frame: Optional[types.FrameType] = None, @@ -425,15 +470,110 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): - clazz = dataclasses.dataclass(clazz) if localns is None: if clazz_frame is None: clazz_frame = _maybe_get_callers_frame(clazz) if clazz_frame is not None: localns = clazz_frame.f_locals with _SchemaContext(globalns, localns): - return _internal_class_schema(clazz, base_schema) + schema = _internal_class_schema(clazz, base_schema) + + assert not isinstance(schema, _Future) + return schema + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +TypeSpec = typing_NewType("TypeSpec", object) +GenericAliasOfDataclass = typing_NewType("GenericAliasOfDataclass", object) + + +def _is_generic_alias_of_dataclass( + cls: object, +) -> TypeGuard[GenericAliasOfDataclass]: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return dataclasses.is_dataclass(get_origin(cls)) + + +class _GenericArgs(Mapping[TypeVar_, TypeSpec], collections.abc.Hashable): + """A mapping of TypeVars to type specs""" + + def __init__( + self, + generic_alias: GenericAliasOfDataclass, + binding: Optional["_GenericArgs"] = None, + ): + origin = typing_inspect.get_origin(generic_alias) + parameters: Iterable[TypeVar_] = typing_inspect.get_parameters(origin) + arguments: Iterable[TypeSpec] = get_args(generic_alias) + if binding is not None: + arguments = map(binding.resolve, arguments) + + self._args = dict(zip(parameters, arguments)) + self._hashvalue = hash(tuple(self._args.items())) + + _args: Mapping[TypeVar_, TypeSpec] + _hashvalue: int + + def resolve(self, spec: Union[TypeVar_, TypeSpec]) -> TypeSpec: + if isinstance(spec, TypeVar): + try: + return self._args[spec] + except KeyError as exc: + raise TypeError( + f"generic type variable {spec.__name__} is not bound" + ) from exc + return spec + + def __getitem__(self, param: TypeVar_) -> TypeSpec: + return self._args[param] + + def __iter__(self) -> Iterator[TypeVar_]: + return iter(self._args.keys()) + + def __len__(self) -> int: + return len(self._args) + + def __hash__(self) -> int: + return self._hashvalue @dataclasses.dataclass @@ -443,7 +583,10 @@ class _SchemaContext: globalns: Optional[Dict[str, Any]] = None localns: Optional[Dict[str, Any]] = None base_schema: Optional[Type[marshmallow.Schema]] = None - seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict) + generic_args: Optional[_GenericArgs] = None + seen_classes: Dict[type, _Future[Type[marshmallow.Schema]]] = dataclasses.field( + default_factory=dict + ) def get_type_mapping( self, use_mro: bool = False @@ -497,13 +640,20 @@ def top(self) -> _U: def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, -) -> Type[marshmallow.Schema]: +) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: schema_ctx = _schema_ctx_stack.top - schema_ctx.seen_classes[clazz] = clazz.__name__ - try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) - except TypeError: # Not a dataclass + if clazz in schema_ctx.seen_classes: + return schema_ctx.seen_classes[clazz] + + future: _Future[Type[marshmallow.Schema]] = _Future() + schema_ctx.seen_classes[clazz] = future + + generic_args = schema_ctx.generic_args + + if _is_generic_alias_of_dataclass(clazz): + generic_args = _GenericArgs(clazz, generic_args) + clazz = typing_inspect.get_origin(clazz) + elif not dataclasses.is_dataclass(clazz): try: warnings.warn( "****** WARNING ****** " @@ -516,13 +666,14 @@ def _internal_class_schema( "https://github.com/lovasoa/marshmallow_dataclass/issues/51 " "****** WARNING ******" ) - created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema) + dataclasses.dataclass(clazz) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." ) from exc + fields = dataclasses.fields(clazz) + # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema. attributes = { k: v @@ -534,7 +685,9 @@ def _internal_class_schema( type_hints = get_type_hints( clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns ) - with dataclasses.replace(schema_ctx, base_schema=base_schema): + with dataclasses.replace( + schema_ctx, base_schema=base_schema, generic_args=generic_args + ): attributes.update( ( field.name, @@ -548,48 +701,14 @@ def _internal_class_schema( if field.init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) - return cast(Type[marshmallow.Schema], schema_class) - - -def _field_by_supertype( - typ: Type, - default: Any, - newtype_supertype: Type, - metadata: Dict[str, Any], -) -> marshmallow.fields.Field: - """ - Return a new field for fields based on a super field. (Usually spawned from NewType) - """ - # Add the information coming our custom NewType implementation - - typ_args = getattr(typ, "_marshmallow_args", {}) - - # Handle multiple validators from both `typ` and `metadata`. - # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 - new_validators: List[Callable] = [] - for meta_dict in (typ_args, metadata): - if "validate" in meta_dict: - if marshmallow.utils.is_iterable_but_not_string(meta_dict["validate"]): - new_validators.extend(meta_dict["validate"]) - elif callable(meta_dict["validate"]): - new_validators.append(meta_dict["validate"]) - metadata["validate"] = new_validators if new_validators else None - - metadata = {**typ_args, **metadata} - metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) - field = getattr(typ, "_marshmallow_field", None) - if field: - return field(**metadata) - else: - return _field_for_schema( - newtype_supertype, - metadata=metadata, - default=default, - ) + schema_class: Type[marshmallow.Schema] = type( + clazz.__name__, (_base_schema(clazz, base_schema),), attributes + ) + future.set_result(schema_class) + return schema_class -def _generic_type_add_any(typ: type) -> type: +def _generic_type_add_any(typ: type) -> type: # FIXME: signature is wrong """if typ is generic type without arguments, replace them by Any.""" if typ is list or typ is List: typ = List[Any] @@ -606,91 +725,231 @@ def _generic_type_add_any(typ: type) -> type: return typ -def _field_for_generic_type( - typ: type, - metadata: Dict[str, Any], -) -> Optional[marshmallow.fields.Field]: +def _is_builtin_collection_type(typ: object) -> bool: + return get_origin(typ) in { + list, + collections.abc.Sequence, + set, + frozenset, + tuple, + dict, + collections.abc.Mapping, + } + + +def _field_for_builtin_collection_type( + typ: object, metadata: Dict[str, Any] +) -> marshmallow.fields.Field: """ - If the type is a generic interface, resolve the arguments and construct the appropriate Field. + Handle builtin container types like list, tuple, set, etc. """ - origin = typing_inspect.get_origin(typ) - arguments = typing_inspect.get_args(typ, True) - if origin: - # Override base_schema.TYPE_MAPPING to change the class used for generic types below - schema_ctx = _schema_ctx_stack.top + origin = get_origin(typ) + assert origin is not None + assert not typing_inspect.is_union_type(typ) + + arguments = get_args(typ) + # if len(arguments) == 0: + # if issubclass(origin, (collections.abc.Sequence, collections.abc.Set)): + # arguments = (Any,) + # elif issubclass(origin, collections.abc.Mapping): + # arguments = (Any, Any) + # else: + # print(repr(origin)) + # raise TypeError(f"{typ!r} requires generic arguments") + + if origin is tuple and len(arguments) == 2 and arguments[1] is Ellipsis: + origin = collections.abc.Sequence + arguments = (arguments[0],) + + fields = tuple(map(_field_for_schema, arguments)) - def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: - type_mapping = schema_ctx.get_type_mapping() - return type_mapping.get(type_spec, default) # type: ignore[return-value] - - if origin in (list, List): - child_type = _field_for_schema(arguments[0]) - list_type = get_field_type(List, default=marshmallow.fields.List) - return list_type(child_type, **metadata) - if origin in (collections.abc.Sequence, Sequence) or ( - origin in (tuple, Tuple) - and len(arguments) == 2 - and arguments[1] is Ellipsis - ): - from . import collection_field - - child_type = _field_for_schema(arguments[0]) - return collection_field.Sequence(cls_or_instance=child_type, **metadata) - if origin in (set, Set): - from . import collection_field - - child_type = _field_for_schema(arguments[0]) - return collection_field.Set( - cls_or_instance=child_type, frozen=False, **metadata - ) - if origin in (frozenset, FrozenSet): - from . import collection_field + schema_ctx = _schema_ctx_stack.top - child_type = _field_for_schema(arguments[0]) - return collection_field.Set( - cls_or_instance=child_type, frozen=True, **metadata - ) - if origin in (tuple, Tuple): - children = tuple(_field_for_schema(arg) for arg in arguments) - tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) - return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): - dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) - return dict_type( - keys=_field_for_schema(arguments[0]), - values=_field_for_schema(arguments[1]), - **metadata, - ) + # Override base_schema.TYPE_MAPPING to change the class used for generic types below + def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: + type_mapping = schema_ctx.get_type_mapping() + return type_mapping.get(type_spec, default) # type: ignore[return-value] - if typing_inspect.is_union_type(typ): - if typing_inspect.is_optional_type(typ): - metadata["allow_none"] = metadata.get("allow_none", True) - metadata["dump_default"] = metadata.get("dump_default", None) - if not metadata.get("required"): - metadata["load_default"] = metadata.get("load_default", None) - metadata.setdefault("required", False) - subtypes = [t for t in arguments if t is not NoneType] # type: ignore - if len(subtypes) == 1: - return _field_for_schema( - subtypes[0], - metadata=metadata, - ) - from . import union_field - - return union_field.Union( - [ - ( - subtyp, - _field_for_schema( - subtyp, - metadata={"required": True}, - ), - ) - for subtyp in subtypes - ], + if origin is list: + assert len(fields) == 1 + list_type = get_field_type(List, default=marshmallow.fields.List) + return list_type(fields[0], **metadata) + + if origin is collections.abc.Sequence: + from . import collection_field + + assert len(fields) == 1 + return collection_field.Sequence(fields[0], **metadata) + + if origin in (set, frozenset): + from . import collection_field + + assert len(fields) == 1 + frozen = origin is frozenset + return collection_field.Set(fields[0], frozen=frozen, **metadata) + + if origin is tuple: + tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) + return tuple_type(fields, **metadata) + + assert origin in (dict, collections.abc.Mapping) + dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) + return dict_type(keys=fields[0], values=fields[1], **metadata) + + +def _field_for_union_type( + typ: type, metadata: Dict[str, Any] +) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a union or optional type. + """ + assert typing_inspect.is_union_type(typ) + subtypes = [t for t in get_args(typ) if t is not NoneType] + + if typing_inspect.is_optional_type(typ): + metadata = { + "allow_none": True, + "dump_default": None, **metadata, + } + if not metadata.setdefault("required", False): + metadata.setdefault("load_default", None) + + if len(subtypes) == 1: + return _field_for_schema(subtypes[0], metadata=metadata) + + from . import union_field + + return union_field.Union( + [ + (typ, _field_for_schema(typ, metadata={"required": True})) + for typ in subtypes + ], + **metadata, + ) + + +def _field_for_literal_type( + typ: type, metadata: Dict[str, Any] +) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a Literal type. + """ + validate: marshmallow.validate.Validator + + assert typing_inspect.is_literal_type(typ) + arguments = typing_inspect.get_args(typ) + if len(arguments) == 1: + validate = marshmallow.validate.Equal(arguments[0]) + else: + validate = marshmallow.validate.OneOf(arguments) + return marshmallow.fields.Raw(validate=validate, **metadata) + + +def _get_subtype_for_final_type(typ: type, default: Any) -> Any: + """ + Construct the appropriate Field for a Final type. + """ + assert typing_inspect.is_final_type(typ) + arguments = typing_inspect.get_args(typ) + if arguments: + return arguments[0] + elif default is marshmallow.missing: + return Any + elif callable(default): + warnings.warn( + "****** WARNING ****** " + "marshmallow_dataclass was called on a dataclass with an " + 'attribute that is type-annotated with "Final" and uses ' + "dataclasses.field for specifying a default value using a " + "factory. The Marshmallow field type cannot be inferred from the " + "factory and will fall back to a raw field which is equivalent to " + 'the type annotation "Any" and will result in no validation. ' + "Provide a type to Final[...] to ensure accurate validation. " + "****** WARNING ******" ) - return None + return Any + warnings.warn( + "****** WARNING ****** " + "marshmallow_dataclass was called on a dataclass with an " + 'attribute that is type-annotated with "Final" with a default ' + "value from which the Marshmallow field type is inferred. " + "Support for type inference from a default value is limited and " + "may result in inaccurate validation. Provide a type to " + "Final[...] to ensure accurate validation. " + "****** WARNING ******" + ) + return type(default) + + +def _field_for_new_type( + typ: Type, default: Any, metadata: Dict[str, Any] +) -> marshmallow.fields.Field: + """ + Return a new field for fields based on a NewType. + """ + # Add the information coming our custom NewType implementation + typ_args = getattr(typ, "_marshmallow_args", {}) + + # Handle multiple validators from both `typ` and `metadata`. + # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 + validators: List[Callable[[Any], Any]] = [] + for args in (typ_args, metadata): + validate = args.get("validate") + if marshmallow.utils.is_iterable_but_not_string(validate): + validators.extend(validate) # type: ignore[arg-type] + elif validate is not None: + validators.append(validate) + + metadata = { + **typ_args, + **metadata, + "validate": validators if validators else None, + } + metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) + + field: Optional[Type[marshmallow.fields.Field]] = getattr( + typ, "_marshmallow_field", None + ) + if field is not None: + return field(**metadata) + return _field_for_schema( + typ.__supertype__, # type: ignore[attr-defined] + default=default, + metadata=metadata, + ) + + +def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.Field: + """ + Return a new field for an Enum field. + """ + if sys.version_info >= (3, 7): + return marshmallow.fields.Enum(typ, **metadata) + else: + # Remove this once support for python 3.6 is dropped. + import marshmallow_enum + + return marshmallow_enum.EnumField(typ, **metadata) + + +def _field_for_dataclass( + typ: Union[Type, object], metadata: Dict[str, Any] +) -> marshmallow.fields.Field: + """ + Return a new field for a nested dataclass field. + """ + if isinstance(typ, type) and hasattr(typ, "Schema"): + # marshmallow_dataclass.dataclass + nested = typ.Schema + else: + assert isinstance(typ, Hashable) + schema_ctx = _schema_ctx_stack.top + nested = _internal_class_schema(typ, schema_ctx.base_schema) + if isinstance(nested, _Future): + nested = nested.result + + return marshmallow.fields.Nested(nested, **metadata) def field_for_schema( @@ -746,6 +1005,11 @@ def _field_for_schema( metadata = {} if metadata is None else dict(metadata) + # If the field was already defined by the user + predefined_field = metadata.get("marshmallow_field") + if predefined_field: + return predefined_field + if default is not marshmallow.missing: metadata.setdefault("dump_default", default) # 'missing' must not be set for required fields. @@ -754,16 +1018,14 @@ def _field_for_schema( else: metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) - # If the field was already defined by the user - predefined_field = metadata.get("marshmallow_field") - if predefined_field: - return predefined_field + schema_ctx = _schema_ctx_stack.top + + if schema_ctx.generic_args is not None and isinstance(typ, TypeVar): + typ = schema_ctx.generic_args.resolve(typ) # Generic types specified without type arguments typ = _generic_type_add_any(typ) - schema_ctx = _schema_ctx_stack.top - # Base types type_mapping = schema_ctx.get_type_mapping(use_mro=True) field = type_mapping.get(typ) @@ -775,93 +1037,30 @@ def _field_for_schema( return marshmallow.fields.Raw(**metadata) if typing_inspect.is_literal_type(typ): - arguments = typing_inspect.get_args(typ) - return marshmallow.fields.Raw( - validate=( - marshmallow.validate.Equal(arguments[0]) - if len(arguments) == 1 - else marshmallow.validate.OneOf(arguments) - ), - **metadata, - ) + return _field_for_literal_type(typ, metadata) if typing_inspect.is_final_type(typ): - arguments = typing_inspect.get_args(typ) - if arguments: - subtyp = arguments[0] - elif default is not marshmallow.missing: - if callable(default): - subtyp = Any - warnings.warn( - "****** WARNING ****** " - "marshmallow_dataclass was called on a dataclass with an " - 'attribute that is type-annotated with "Final" and uses ' - "dataclasses.field for specifying a default value using a " - "factory. The Marshmallow field type cannot be inferred from the " - "factory and will fall back to a raw field which is equivalent to " - 'the type annotation "Any" and will result in no validation. ' - "Provide a type to Final[...] to ensure accurate validation. " - "****** WARNING ******" - ) - else: - subtyp = type(default) - warnings.warn( - "****** WARNING ****** " - "marshmallow_dataclass was called on a dataclass with an " - 'attribute that is type-annotated with "Final" with a default ' - "value from which the Marshmallow field type is inferred. " - "Support for type inference from a default value is limited and " - "may result in inaccurate validation. Provide a type to " - "Final[...] to ensure accurate validation. " - "****** WARNING ******" - ) - else: - subtyp = Any - return _field_for_schema(subtyp, default, metadata) - - # Generic types - generic_field = _field_for_generic_type(typ, metadata) - if generic_field: - return generic_field - - # typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a - # __supertype__ attribute - newtype_supertype = getattr(typ, "__supertype__", None) - if typing_inspect.is_new_type(typ) and newtype_supertype is not None: - return _field_by_supertype( - typ=typ, + return _field_for_schema( + _get_subtype_for_final_type(typ, default), default=default, - newtype_supertype=newtype_supertype, metadata=metadata, ) + if _is_builtin_collection_type(typ): + return _field_for_builtin_collection_type(typ, metadata) + + if typing_inspect.is_union_type(typ): + return _field_for_union_type(typ, metadata) + + if typing_inspect.is_new_type(typ): + return _field_for_new_type(typ, default, metadata) + # enumerations - if issubclass(typ, Enum): - try: - return marshmallow.fields.Enum(typ, **metadata) - except AttributeError: - # Remove this once support for python 3.6 is dropped. - import marshmallow_enum - - return marshmallow_enum.EnumField(typ, **metadata) - - # Nested marshmallow dataclass - # it would be just a class name instead of actual schema util the schema is not ready yet - nested_schema = getattr(typ, "Schema", None) - - # Nested dataclasses - forward_reference = getattr(typ, "__forward_arg__", None) - - nested = ( - nested_schema - or forward_reference - or schema_ctx.seen_classes.get(typ) - or _internal_class_schema( - typ, schema_ctx.base_schema # type: ignore[arg-type] # FIXME - ) - ) + if isinstance(typ, type) and issubclass(typ, Enum): + return _field_for_enum(typ, metadata) - return marshmallow.fields.Nested(nested, **metadata) + # Assume nested marshmallow dataclass (and hope for the best) + return _field_for_dataclass(typ, metadata) def _base_schema( diff --git a/setup.py b/setup.py index 325b350..6c00a6b 100644 --- a/setup.py +++ b/setup.py @@ -31,10 +31,6 @@ # re: pypy: typed-ast (a dependency of mypy) fails to install on pypy # https://github.com/python/typed_ast/issues/111 "pytest-mypy-plugins>=1.2.0; implementation_name != 'pypy'", - # `Literal` was introduced in: - # - Python 3.8 (https://www.python.org/dev/peps/pep-0586) - # - typing-extensions 3.7.2 (https://github.com/python/typing/pull/591) - "typing-extensions>=3.7.2; python_version < '3.8'", ], } EXTRAS_REQUIRE["dev"] = ( @@ -64,6 +60,7 @@ install_requires=[ "marshmallow>=3.13.0,<4.0", "typing-inspect>=0.8.0", + "typing-extensions>=3.10; python_version < '3.8'", ], extras_require=EXTRAS_REQUIRE, package_data={"marshmallow_dataclass": ["py.typed"]}, diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index aa82975..59f4f6c 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -14,7 +14,7 @@ from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass class TestClassSchema(unittest.TestCase): @@ -401,6 +401,129 @@ class J: [validator_a, validator_b, validator_c, validator_d], ) + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + y: BB[AA] + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + + def test_deep_generic(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: typing.List[typing.Tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + def test_generic_bases(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[T]): + pass + + test_schema = class_schema(TestClass[int])() + + self.assertEqual(test_schema.load({"answer": "42"}), TestClass(42)) + + @unittest.expectedFailure + def test_broken_generic_bases(self) -> None: + # When a different TypeVar is used when declaring the base GenericAlias + # than when declaring that generic base class, things currently don't work. + # TestClass.__orig_bases__ (see PEP 560) might be of some help, but isn't + # the full answer. + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[U]): + pass + + test_schema = class_schema(TestClass[int])() + + self.assertEqual(test_schema.load({"answer": "42"}), TestClass(42)) + def test_recursive_reference(self): @dataclasses.dataclass class Tree: From 84e7736ea2d1ad8f55565789443c27b1974374c7 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Fri, 20 Jan 2023 20:48:54 -0800 Subject: [PATCH 14/32] test(mypy): fix mypy errors in python <= 3.9 --- marshmallow_dataclass/__init__.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index acabd21..4aeef28 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -708,7 +708,7 @@ def _internal_class_schema( return schema_class -def _generic_type_add_any(typ: type) -> type: # FIXME: signature is wrong +def _generic_type_add_any(typ: object) -> object: """if typ is generic type without arguments, replace them by Any.""" if typ is list or typ is List: typ = List[Any] @@ -798,7 +798,7 @@ def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: def _field_for_union_type( - typ: type, metadata: Dict[str, Any] + typ: object, metadata: Dict[str, Any] ) -> marshmallow.fields.Field: """ Construct the appropriate Field for a union or optional type. @@ -830,7 +830,7 @@ def _field_for_union_type( def _field_for_literal_type( - typ: type, metadata: Dict[str, Any] + typ: object, metadata: Dict[str, Any] ) -> marshmallow.fields.Field: """ Construct the appropriate Field for a Literal type. @@ -846,7 +846,7 @@ def _field_for_literal_type( return marshmallow.fields.Raw(validate=validate, **metadata) -def _get_subtype_for_final_type(typ: type, default: Any) -> Any: +def _get_subtype_for_final_type(typ: object, default: Any) -> object: """ Construct the appropriate Field for a Final type. """ @@ -883,7 +883,7 @@ def _get_subtype_for_final_type(typ: type, default: Any) -> Any: def _field_for_new_type( - typ: Type, default: Any, metadata: Dict[str, Any] + typ: object, default: Any, metadata: Dict[str, Any] ) -> marshmallow.fields.Field: """ Return a new field for fields based on a NewType. @@ -906,7 +906,8 @@ def _field_for_new_type( **metadata, "validate": validators if validators else None, } - metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) + if hasattr(typ, "__name__"): + metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) field: Optional[Type[marshmallow.fields.Field]] = getattr( typ, "_marshmallow_field", None @@ -986,7 +987,7 @@ def field_for_schema( def _field_for_schema( - typ: type, + typ: object, default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, ) -> marshmallow.fields.Field: From d25261ac8b4103c94cb7d388ddd5db3145aa2b19 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 08:14:06 -0800 Subject: [PATCH 15/32] fix: raise exception if generic class passed to decorator --- marshmallow_dataclass/__init__.py | 15 ++++++++++++++ tests/test_class_schema.py | 33 ++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 4aeef28..b8218e1 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -174,6 +174,18 @@ def _maybe_get_callers_frame( del frame +def _check_decorated_type(cls: object) -> None: + if typing_inspect.is_generic_type(cls): + # A .Schema attribute doesn't make sense on a generic type — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic types " + "(hint: use class_schema directly instead)" + ) + if not isinstance(cls, type): + raise TypeError(f"expected a class not {cls!r}") + + @overload def dataclass( _cls: Type[_U], @@ -248,6 +260,7 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + _check_decorated_type(cls) dc(cls) return add_schema( cls, base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 @@ -303,6 +316,8 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: + _check_decorated_type(clazz) + if cls_frame is not None: frame = cls_frame else: diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 59f4f6c..9c346a0 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -14,7 +14,12 @@ from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass +from marshmallow_dataclass import ( + add_schema, + class_schema, + NewType, + _is_generic_alias_of_dataclass, +) class TestClassSchema(unittest.TestCase): @@ -474,6 +479,32 @@ class Nested: schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), ) + def test_marshmallow_dataclass_decorator_raises_on_generics(self): + import marshmallow_dataclass + + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass) + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass[int]) + + def test_add_schema_raises_on_generics(self): + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass) + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass[int]) + def test_deep_generic(self): T = typing.TypeVar("T") U = typing.TypeVar("U") From 8491d0c46feb34af69ff976c159c6d7132313ff6 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 09:28:59 -0800 Subject: [PATCH 16/32] test: check for frame leakage when decorators throw exceptions --- tests/test_memory_leak.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_memory_leak.py b/tests/test_memory_leak.py index 64aef41..306430e 100644 --- a/tests/test_memory_leak.py +++ b/tests/test_memory_leak.py @@ -4,6 +4,7 @@ import unittest import weakref from dataclasses import dataclass +from unittest import mock import marshmallow import marshmallow_dataclass as md @@ -109,3 +110,31 @@ class Foo: f() self.assertFrameCollected() + + def assertDecoratorDoesNotLeakFrame(self, decorator): + def f() -> None: + class Foo: + value: int + + self.trackFrame() + with self.assertRaisesRegex(Exception, "forced exception"): + decorator(Foo) + + with mock.patch( + "marshmallow_dataclass.lazy_class_attribute", + side_effect=Exception("forced exception"), + ) as m: + f() + + assert m.mock_calls == [mock.call(mock.ANY, "Schema", mock.ANY)] + # NB: The Mock holds a reference to its arguments, one of which is the + # lazy_class_attribute which holds a reference to the caller's frame + m.reset_mock() + + self.assertFrameCollected() + + def test_exception_in_dataclass(self): + self.assertDecoratorDoesNotLeakFrame(md.dataclass) + + def test_exception_in_add_schema(self): + self.assertDecoratorDoesNotLeakFrame(md.add_schema) From e0aad11807e0b9ed2cc22e02d758f8b60cd4734e Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 17:40:14 -0800 Subject: [PATCH 17/32] refactor: delete _generic_type_add_any Handle default generic params values in _field_for_builtin_collection_type. --- marshmallow_dataclass/__init__.py | 84 ++++++++++++------------------- 1 file changed, 31 insertions(+), 53 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index b8218e1..af08cff 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -723,25 +723,12 @@ def _internal_class_schema( return schema_class -def _generic_type_add_any(typ: object) -> object: - """if typ is generic type without arguments, replace them by Any.""" - if typ is list or typ is List: - typ = List[Any] - elif typ is dict or typ is Dict: - typ = Dict[Any, Any] - elif typ is Mapping: - typ = Mapping[Any, Any] - elif typ is Sequence: - typ = Sequence[Any] - elif typ is set or typ is Set: - typ = Set[Any] - elif typ is frozenset or typ is FrozenSet: - typ = FrozenSet[Any] - return typ - - def _is_builtin_collection_type(typ: object) -> bool: - return get_origin(typ) in { + origin = get_origin(typ) + if origin is None: + origin = typ + + return origin in { list, collections.abc.Sequence, set, @@ -759,24 +746,11 @@ def _field_for_builtin_collection_type( Handle builtin container types like list, tuple, set, etc. """ origin = get_origin(typ) - assert origin is not None - assert not typing_inspect.is_union_type(typ) - - arguments = get_args(typ) - # if len(arguments) == 0: - # if issubclass(origin, (collections.abc.Sequence, collections.abc.Set)): - # arguments = (Any,) - # elif issubclass(origin, collections.abc.Mapping): - # arguments = (Any, Any) - # else: - # print(repr(origin)) - # raise TypeError(f"{typ!r} requires generic arguments") - - if origin is tuple and len(arguments) == 2 and arguments[1] is Ellipsis: - origin = collections.abc.Sequence - arguments = (arguments[0],) + if origin is None: + origin = typ + assert len(get_args(typ)) == 0 - fields = tuple(map(_field_for_schema, arguments)) + args = get_args(typ) schema_ctx = _schema_ctx_stack.top @@ -785,31 +759,38 @@ def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: type_mapping = schema_ctx.get_type_mapping() return type_mapping.get(type_spec, default) # type: ignore[return-value] + if origin is tuple and (len(args) == 0 or (len(args) == 2 and args[1] is Ellipsis)): + # Special case: homogeneous tuple — treat as Sequence + origin = collections.abc.Sequence + args = args[:1] + + if origin is tuple: + tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) + return tuple_type(tuple(map(_field_for_schema, args)), **metadata) + + def get_field(i: int) -> marshmallow.fields.Field: + return _field_for_schema(args[i] if args else Any) + + if origin in (dict, collections.abc.Mapping): + dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) + return dict_type(keys=get_field(0), values=get_field(1), **metadata) + if origin is list: - assert len(fields) == 1 list_type = get_field_type(List, default=marshmallow.fields.List) - return list_type(fields[0], **metadata) + return list_type(get_field(0), **metadata) if origin is collections.abc.Sequence: from . import collection_field - assert len(fields) == 1 - return collection_field.Sequence(fields[0], **metadata) + return collection_field.Sequence(get_field(0), **metadata) if origin in (set, frozenset): from . import collection_field - assert len(fields) == 1 frozen = origin is frozenset - return collection_field.Set(fields[0], frozen=frozen, **metadata) - - if origin is tuple: - tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) - return tuple_type(fields, **metadata) + return collection_field.Set(get_field(0), frozen=frozen, **metadata) - assert origin in (dict, collections.abc.Mapping) - dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) - return dict_type(keys=fields[0], values=fields[1], **metadata) + raise ValueError(f"{typ} is not a builtin collection type") def _field_for_union_type( @@ -1039,8 +1020,8 @@ def _field_for_schema( if schema_ctx.generic_args is not None and isinstance(typ, TypeVar): typ = schema_ctx.generic_args.resolve(typ) - # Generic types specified without type arguments - typ = _generic_type_add_any(typ) + if _is_builtin_collection_type(typ): + return _field_for_builtin_collection_type(typ, metadata) # Base types type_mapping = schema_ctx.get_type_mapping(use_mro=True) @@ -1062,9 +1043,6 @@ def _field_for_schema( metadata=metadata, ) - if _is_builtin_collection_type(typ): - return _field_for_builtin_collection_type(typ, metadata) - if typing_inspect.is_union_type(typ): return _field_for_union_type(typ, metadata) From 39cfa30b1146ce083b4d03d0ad872b790aaeddfc Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 21:46:05 -0800 Subject: [PATCH 18/32] refactor: disuse lru_cache in favor of an _LRUDict --- marshmallow_dataclass/__init__.py | 82 +++++++++++++++++++++++++------ tests/test_lrudict.py | 35 +++++++++++++ 2 files changed, 102 insertions(+), 15 deletions(-) create mode 100644 tests/test_lrudict.py diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index af08cff..79dc1af 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -42,7 +42,7 @@ class User: import types import warnings from enum import Enum -from functools import lru_cache, partial +from functools import partial from typing import ( Any, Callable, @@ -107,8 +107,12 @@ def get_origin(tp): if sys.version_info >= (3, 7): + from typing import OrderedDict + TypeVar_ = TypeVar else: + from typing_extensions import OrderedDict + TypeVar_ = type if sys.version_info >= (3, 10): @@ -119,9 +123,9 @@ def get_origin(tp): NoneType = type(None) _U = TypeVar("_U") +_V = TypeVar("_V") _Field = TypeVar("_Field", bound=marshmallow.fields.Field) - # Whitelist of dataclass members that will be copied to generated schema. MEMBERS_WHITELIST: Set[str] = {"Meta"} @@ -490,13 +494,59 @@ def class_schema( clazz_frame = _maybe_get_callers_frame(clazz) if clazz_frame is not None: localns = clazz_frame.f_locals - with _SchemaContext(globalns, localns): - schema = _internal_class_schema(clazz, base_schema) + + if base_schema is None: + base_schema = marshmallow.Schema + + with _SchemaContext(globalns, localns, base_schema): + schema = _internal_class_schema(clazz) assert not isinstance(schema, _Future) return schema +class _LRUDict(OrderedDict[_U, _V]): + """Limited-length dict which discards LRU entries.""" + + def __init__(self, maxsize: int = 128): + self.maxsize = maxsize + super().__init__() + + def __setitem__(self, key: _U, value: _V) -> None: + super().__setitem__(key, value) + super().move_to_end(key) + + while len(self) > self.maxsize: + oldkey = next(iter(self)) + super().__delitem__(oldkey) + + def __getitem__(self, key: _U) -> _V: + val = super().__getitem__(key) + super().move_to_end(key) + return val + + _T = TypeVar("_T") + + @overload + def get(self, key: _U) -> Optional[_V]: + ... + + @overload + def get(self, key: _U, default: _T) -> Union[_V, _T]: + ... + + def get(self, key: _U, default: Any = None) -> Any: + try: + return self.__getitem__(key) + except KeyError: + return default + + +_schema_cache = _LRUDict[Hashable, Type[marshmallow.Schema]]( + MAX_CLASS_SCHEMA_CACHE_SIZE +) + + class InvalidStateError(Exception): """Raised when an operation is performed on a future that is not allowed in the current state. @@ -597,7 +647,7 @@ class _SchemaContext: globalns: Optional[Dict[str, Any]] = None localns: Optional[Dict[str, Any]] = None - base_schema: Optional[Type[marshmallow.Schema]] = None + base_schema: Type[marshmallow.Schema] = marshmallow.Schema generic_args: Optional[_GenericArgs] = None seen_classes: Dict[type, _Future[Type[marshmallow.Schema]]] = dataclasses.field( default_factory=dict @@ -612,8 +662,6 @@ def get_type_mapping( all bases in base_schema's MRO. """ base_schema = self.base_schema - if base_schema is None: - base_schema = marshmallow.Schema if use_mro: return ChainMap( *(getattr(cls, "TYPE_MAPPING", {}) for cls in base_schema.__mro__) @@ -651,15 +699,19 @@ def top(self) -> _U: _schema_ctx_stack = _LocalStack[_SchemaContext]() -@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, - base_schema: Optional[Type[marshmallow.Schema]] = None, ) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: schema_ctx = _schema_ctx_stack.top if clazz in schema_ctx.seen_classes: return schema_ctx.seen_classes[clazz] + cache_key = clazz, schema_ctx.base_schema + try: + return _schema_cache[cache_key] + except KeyError: + pass + future: _Future[Type[marshmallow.Schema]] = _Future() schema_ctx.seen_classes[clazz] = future @@ -700,9 +752,7 @@ def _internal_class_schema( type_hints = get_type_hints( clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns ) - with dataclasses.replace( - schema_ctx, base_schema=base_schema, generic_args=generic_args - ): + with dataclasses.replace(schema_ctx, generic_args=generic_args): attributes.update( ( field.name, @@ -717,9 +767,10 @@ def _internal_class_schema( ) schema_class: Type[marshmallow.Schema] = type( - clazz.__name__, (_base_schema(clazz, base_schema),), attributes + clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes ) future.set_result(schema_class) + _schema_cache[cache_key] = schema_class return schema_class @@ -941,8 +992,7 @@ def _field_for_dataclass( nested = typ.Schema else: assert isinstance(typ, Hashable) - schema_ctx = _schema_ctx_stack.top - nested = _internal_class_schema(typ, schema_ctx.base_schema) + nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME if isinstance(nested, _Future): nested = nested.result @@ -977,6 +1027,8 @@ def field_for_schema( >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ """ + if base_schema is None: + base_schema = marshmallow.Schema localns = typ_frame.f_locals if typ_frame is not None else None with _SchemaContext(localns=localns, base_schema=base_schema): return _field_for_schema(typ, default, metadata) diff --git a/tests/test_lrudict.py b/tests/test_lrudict.py new file mode 100644 index 0000000..e323fe5 --- /dev/null +++ b/tests/test_lrudict.py @@ -0,0 +1,35 @@ +from marshmallow_dataclass import _LRUDict + + +def test_LRUDict_getitem_moves_to_end() -> None: + d = _LRUDict[str, str]() + d["a"] = "aval" + d["b"] = "bval" + assert list(d.items()) == [("a", "aval"), ("b", "bval")] + assert d["a"] == "aval" + assert list(d.items()) == [("b", "bval"), ("a", "aval")] + + +def test_LRUDict_get_moves_to_end() -> None: + d = _LRUDict[str, str]() + d["a"] = "aval" + d["b"] = "bval" + assert list(d.items()) == [("a", "aval"), ("b", "bval")] + assert d.get("a") == "aval" + assert list(d.items()) == [("b", "bval"), ("a", "aval")] + + +def test_LRUDict_setitem_moves_to_end() -> None: + d = _LRUDict[str, str]() + d["a"] = "aval" + d["b"] = "bval" + assert list(d.items()) == [("a", "aval"), ("b", "bval")] + d["a"] = "newval" + assert list(d.items()) == [("b", "bval"), ("a", "newval")] + + +def test_LRUDict_discards_oldest() -> None: + d = _LRUDict[str, str](maxsize=1) + d["a"] = "aval" + d["b"] = "bval" + assert list(d.items()) == [("b", "bval")] From 769ebc6d1f0a5412685703f427e440b3402bfc99 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Mon, 16 Jan 2023 17:59:21 -0800 Subject: [PATCH 19/32] feat: Do not auto-dataclassify non-dataclasses This implement the suggestions made in #51. See https://github.com/lovasoa/marshmallow_dataclass/issues/51#issuecomment-1383208927 --- marshmallow_dataclass/__init__.py | 247 +++++++++++++++++------------- tests/test_class_schema.py | 13 ++ tests/test_field_for_schema.py | 24 +-- 3 files changed, 171 insertions(+), 113 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 79dc1af..418d811 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -653,6 +653,9 @@ class _SchemaContext: default_factory=dict ) + def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext": + return dataclasses.replace(self, generic_args=generic_args) + def get_type_mapping( self, use_mro: bool = False ) -> Mapping[Any, Type[marshmallow.fields.Field]]: @@ -717,63 +720,109 @@ def _internal_class_schema( generic_args = schema_ctx.generic_args - if _is_generic_alias_of_dataclass(clazz): - generic_args = _GenericArgs(clazz, generic_args) - clazz = typing_inspect.get_origin(clazz) - elif not dataclasses.is_dataclass(clazz): - try: - warnings.warn( - "****** WARNING ****** " - f"marshmallow_dataclass was called on the class {clazz}, which is not a dataclass. " - "It is going to try and convert the class into a dataclass, which may have " - "undesirable side effects. To avoid this message, make sure all your classes and " - "all the classes of their fields are either explicitly supported by " - "marshmallow_dataclass, or define the schema explicitly using " - "field(metadata=dict(marshmallow_field=...)). For more information, see " - "https://github.com/lovasoa/marshmallow_dataclass/issues/51 " - "****** WARNING ******" - ) - dataclasses.dataclass(clazz) - except Exception as exc: - raise TypeError( - f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." - ) from exc - - fields = dataclasses.fields(clazz) - - # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema. - attributes = { - k: v - for k, v in inspect.getmembers(clazz) - if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST - } + constructor: Callable[..., object] + + if _is_simple_annotated_class(clazz): + class_name = clazz.__name__ + constructor = _simple_class_constructor(clazz) + attributes = _schema_attrs_for_simple_class(clazz) + elif _is_generic_alias_of_dataclass(clazz): + origin = get_origin(clazz) + assert isinstance(origin, type) + class_name = origin.__name__ + constructor = origin + with schema_ctx.replace(generic_args=_GenericArgs(clazz, generic_args)): + attributes = _schema_attrs_for_dataclass(origin) + elif dataclasses.is_dataclass(clazz): + class_name = clazz.__name__ + constructor = clazz + attributes = _schema_attrs_for_dataclass(clazz) + else: + raise TypeError(f"{clazz} is not a dataclass or a simple annotated class") + + base_schema = marshmallow.Schema + if schema_ctx.base_schema is not None: + base_schema = schema_ctx.base_schema + + load_to_dict = base_schema.load + + def load( + self: marshmallow.Schema, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Optional[bool] = None, + unknown: Optional[str] = None, + **kwargs: Any, + ) -> Any: + many = self.many if many is None else bool(many) + loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs) + if many: + return [constructor(**item) for item in loaded] + else: + return constructor(**loaded) - # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) - with dataclasses.replace(schema_ctx, generic_args=generic_args): - attributes.update( - ( - field.name, - _field_for_schema( - type_hints[field.name], - _get_field_default(field), - field.metadata, - ), - ) - for field in fields - if field.init - ) + attributes["load"] = load schema_class: Type[marshmallow.Schema] = type( - clazz.__name__, (_base_schema(clazz, schema_ctx.base_schema),), attributes + f"{class_name}Schema", (base_schema,), attributes ) + future.set_result(schema_class) _schema_cache[cache_key] = schema_class return schema_class +def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: + for name, attr in inspect.getmembers(clazz): + if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST: + yield name, attr + + +def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]: + schema_ctx = _schema_ctx_stack.top + type_hints = get_type_hints( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) + + attrs = dict(_marshmallow_hooks(clazz)) + for field in dataclasses.fields(clazz): + if field.init: + typ = type_hints[field.name] + default = ( + field.default_factory + if field.default_factory is not dataclasses.MISSING + else field.default + if field.default is not dataclasses.MISSING + else marshmallow.missing + ) + attrs[field.name] = _field_for_schema(typ, default, field.metadata) + return attrs + + +def _schema_attrs_for_simple_class(clazz: type) -> Dict[str, Any]: + schema_ctx = _schema_ctx_stack.top + type_hints = get_type_hints( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) + + attrs = dict(_marshmallow_hooks(clazz)) + for field_name, typ in type_hints.items(): + if not typing_inspect.is_classvar(typ): + default = getattr(clazz, field_name, marshmallow.missing) + attrs[field_name] = _field_for_schema(typ, default) + return attrs + + +def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]: + def constructor(**kwargs: Any) -> _U: + obj = clazz.__new__(clazz) + for k, v in kwargs.items(): + setattr(obj, k, v) + return obj + + return constructor + + def _is_builtin_collection_type(typ: object) -> bool: origin = get_origin(typ) if origin is None: @@ -953,8 +1002,8 @@ def _field_for_new_type( **metadata, "validate": validators if validators else None, } - if hasattr(typ, "__name__"): - metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) + type_name = getattr(typ, "__name__", repr(typ)) + metadata.setdefault("metadata", {}).setdefault("description", type_name) field: Optional[Type[marshmallow.fields.Field]] = getattr( typ, "_marshmallow_field", None @@ -981,22 +1030,41 @@ def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.F return marshmallow_enum.EnumField(typ, **metadata) -def _field_for_dataclass( - typ: Union[Type, object], metadata: Dict[str, Any] -) -> marshmallow.fields.Field: +def _schema_for_nested( + typ: object, +) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]: """ - Return a new field for a nested dataclass field. + Return a marshmallow.Schema for a nested dataclass (or simple annotated class) """ if isinstance(typ, type) and hasattr(typ, "Schema"): # marshmallow_dataclass.dataclass - nested = typ.Schema - else: - assert isinstance(typ, Hashable) - nested = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME - if isinstance(nested, _Future): - nested = nested.result + # Defer evaluation of .Schema attribute, to avoid forward reference issues + return partial(getattr, typ, "Schema") - return marshmallow.fields.Nested(nested, **metadata) + class_schema = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME + if isinstance(class_schema, _Future): + return class_schema.result + return class_schema + + +def _is_simple_annotated_class(obj: object) -> bool: + """Determine whether obj is a "simple annotated class". + + The ```class_schema``` function can generate schemas for + simple annotated classes (as well as for dataclasses). + """ + if not isinstance(obj, type): + return False + if getattr(obj, "__init__", None) is not object.__init__: + return False + if getattr(obj, "__new__", None) is not object.__new__: + return False + + schema_ctx = _schema_ctx_stack.top + type_hints = get_type_hints( + obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) + return any(not typing_inspect.is_classvar(th) for th in type_hints.values()) def field_for_schema( @@ -1105,54 +1173,25 @@ def _field_for_schema( if isinstance(typ, type) and issubclass(typ, Enum): return _field_for_enum(typ, metadata) - # Assume nested marshmallow dataclass (and hope for the best) - return _field_for_dataclass(typ, metadata) - - -def _base_schema( - clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None -) -> Type[marshmallow.Schema]: - """ - Base schema factory that creates a schema for `clazz` derived either from `base_schema` - or `BaseSchema` - """ - - # Remove `type: ignore` when mypy handles dynamic base classes - # https://github.com/python/mypy/issues/2813 - class BaseSchema(base_schema or marshmallow.Schema): # type: ignore - def load(self, data: Mapping, *, many: Optional[bool] = None, **kwargs): - all_loaded = super().load(data, many=many, **kwargs) - many = self.many if many is None else bool(many) - if many: - return [clazz(**loaded) for loaded in all_loaded] - else: - return clazz(**all_loaded) - - return BaseSchema - - -def _get_field_default(field: dataclasses.Field): - """ - Return a marshmallow default value given a dataclass default value + # nested dataclasses + if ( + dataclasses.is_dataclass(typ) + or _is_generic_alias_of_dataclass(typ) + or _is_simple_annotated_class(typ) + ): + nested = _schema_for_nested(typ) + # type spec for Nested.__init__ is not correct + return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type] - >>> _get_field_default(dataclasses.field()) - - """ - # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed - default_factory = field.default_factory # type: ignore - if default_factory is not dataclasses.MISSING: - return default_factory - elif field.default is dataclasses.MISSING: - return marshmallow.missing - return field.default + raise TypeError(f"can not deduce field type for {typ}") def NewType( name: str, typ: Type[_U], field: Optional[Type[marshmallow.fields.Field]] = None, - **kwargs, -) -> Callable[[_U], _U]: + **kwargs: Any, +) -> type: """NewType creates simple unique types to which you can attach custom marshmallow attributes. All the keyword arguments passed to this function will be transmitted @@ -1185,9 +1224,9 @@ def NewType( # noinspection PyTypeHints new_type = typing_NewType(name, typ) # type: ignore # noinspection PyTypeHints - new_type._marshmallow_field = field # type: ignore + new_type._marshmallow_field = field # noinspection PyTypeHints - new_type._marshmallow_args = kwargs # type: ignore + new_type._marshmallow_args = kwargs return new_type diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 9c346a0..9ff4980 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -406,6 +406,19 @@ class J: [validator_a, validator_b, validator_c, validator_d], ) + def test_simple_annotated_class(self): + class Child: + x: int + + @dataclasses.dataclass + class Container: + child: Child + + schema = class_schema(Container)() + + loaded = schema.load({"child": {"x": "42"}}) + self.assertEqual(loaded.child.x, 42) + def test_generic_dataclass(self): T = typing.TypeVar("T") diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 0e60f0b..e4bea21 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -3,12 +3,12 @@ import typing import unittest from enum import Enum -from typing import Dict, Optional, Union, Any, List, Tuple +from typing import Dict, Optional, Union, Any, List, Tuple, Iterable -try: - from typing import Final, Literal # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Final, Literal # type: ignore[assignment] +if sys.version_info >= (3, 8): + from typing import Final, Literal +else: + from typing_extensions import Final, Literal from marshmallow import fields, Schema, validate @@ -21,14 +21,18 @@ class TestFieldForSchema(unittest.TestCase): - def assertFieldsEqual(self, a: fields.Field, b: fields.Field): + def assertFieldsEqual( + self, a: fields.Field, b: fields.Field, *, ignore_attrs: Iterable[str] = () + ) -> None: + ignored = set(ignore_attrs) + self.assertEqual(a.__class__, b.__class__, "field class") def attrs(x): return { k: f"{v!r} ({v.__mro__!r})" if inspect.isclass(v) else repr(v) for k, v in x.__dict__.items() - if not k.startswith("_") + if not (k in ignored or k.startswith("_")) } self.assertEqual(attrs(a), attrs(b)) @@ -213,10 +217,12 @@ class NewSchema(Schema): class NewDataclass: pass + field = field_for_schema(NewDataclass, metadata=dict(required=False)) + self.assertFieldsEqual( - field_for_schema(NewDataclass, metadata=dict(required=False)), - fields.Nested(NewDataclass.Schema), + field, fields.Nested(NewDataclass.Schema), ignore_attrs=["nested"] ) + self.assertIs(type(field.schema), NewDataclass.Schema) def test_override_container_type_with_type_mapping(self): type_mapping = [ From d1d52491a68936ea10d436a3e95df94f53bcc90e Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 21:00:55 -0800 Subject: [PATCH 20/32] fix: check for generic base classes Class_schema does not currently support generating schemas for classes which have generic base classes. Typing.get_type_hints doesn't properly dereference the generic parameters. (I think it can be done, but I don't think it's simple.) --- marshmallow_dataclass/__init__.py | 9 +++++++++ tests/test_class_schema.py | 20 ++++++-------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 418d811..d5e9d33 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -598,6 +598,11 @@ def _is_generic_alias_of_dataclass( return dataclasses.is_dataclass(get_origin(cls)) +def _has_generic_base(cls: type) -> bool: + """Return True if cls has any generic base classes.""" + return any(typing_inspect.get_parameters(base) for base in cls.__mro__[1:]) + + class _GenericArgs(Mapping[TypeVar_, TypeSpec], collections.abc.Hashable): """A mapping of TypeVars to type specs""" @@ -779,6 +784,10 @@ def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]: + if _has_generic_base(clazz): + raise TypeError( + "class_schema does not support dataclasses with generic base classes" + ) schema_ctx = _schema_ctx_stack.top type_hints = get_type_hints( clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 9ff4980..5f14331 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -543,30 +543,22 @@ class Base1(typing.Generic[T]): class TestClass(Base1[T]): pass - test_schema = class_schema(TestClass[int])() + with self.assertRaisesRegex(TypeError, "generic base class"): + class_schema(TestClass[int]) - self.assertEqual(test_schema.load({"answer": "42"}), TestClass(42)) - - @unittest.expectedFailure - def test_broken_generic_bases(self) -> None: - # When a different TypeVar is used when declaring the base GenericAlias - # than when declaring that generic base class, things currently don't work. - # TestClass.__orig_bases__ (see PEP 560) might be of some help, but isn't - # the full answer. + def test_bound_generic_base(self) -> None: T = typing.TypeVar("T") - U = typing.TypeVar("U") @dataclasses.dataclass class Base1(typing.Generic[T]): answer: T @dataclasses.dataclass - class TestClass(Base1[U]): + class TestClass(Base1[int]): pass - test_schema = class_schema(TestClass[int])() - - self.assertEqual(test_schema.load({"answer": "42"}), TestClass(42)) + with self.assertRaisesRegex(TypeError, "generic base class"): + class_schema(TestClass) def test_recursive_reference(self): @dataclasses.dataclass From b4330b250eb9560afd33ae58617530d5f5d66d4b Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Tue, 17 Jan 2023 16:49:12 -0800 Subject: [PATCH 21/32] chore: remove workaround for bug in mypy < 0.990 --- tests/test_class_schema.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 5f14331..02a1ba9 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -1,13 +1,14 @@ import inspect +import sys import typing import unittest -from typing import Any, cast, TYPE_CHECKING +from typing import Any, cast from uuid import UUID -try: - from typing import Final, Literal # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Final, Literal # type: ignore[assignment] +if sys.version_info >= (3, 8): + from typing import Final, Literal +else: + from typing_extensions import Final, Literal import dataclasses from marshmallow import Schema, ValidationError @@ -231,21 +232,14 @@ class A: with self.assertRaises(ValidationError): schema.load({"data": data}) - def test_final_infers_type_from_default(self): - # @dataclasses.dataclass(frozen=True) + def test_final_infers_type_from_default(self) -> None: + @dataclasses.dataclass(frozen=True) class A: data: Final = "a" - # @dataclasses.dataclass + @dataclasses.dataclass class B: - data: Final = A() - - # NOTE: This workaround is needed to avoid a Mypy crash. - # See: https://github.com/python/mypy/issues/10090#issuecomment-865971891 - if not TYPE_CHECKING: - frozen_dataclass = dataclasses.dataclass(frozen=True) - A = frozen_dataclass(A) - B = dataclasses.dataclass(B) + data: Final = A() # type: ignore[misc] with self.assertWarns(Warning): schema_a = class_schema(A)() @@ -274,14 +268,9 @@ class B: schema_b.load({"data": data}) def test_final_infers_type_any_from_field_default_factory(self): - # @dataclasses.dataclass + @dataclasses.dataclass class A: - data: Final = dataclasses.field(default_factory=lambda: []) - - # NOTE: This workaround is needed to avoid a Mypy crash. - # See: https://github.com/python/mypy/issues/10090#issuecomment-866686096 - if not TYPE_CHECKING: - A = dataclasses.dataclass(A) + data: Final = dataclasses.field(default_factory=lambda: []) # type: ignore[misc] with self.assertWarns(Warning): schema = class_schema(A)() From 3e2667c62dba17a07174aa7544fbe1f736f4f4f8 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 13:53:25 -0800 Subject: [PATCH 22/32] chore(typing): fix type annotations for add_schema --- marshmallow_dataclass/__init__.py | 43 +++++++++++++++++++------------ 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index d5e9d33..4e28bfd 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -120,6 +120,11 @@ def get_origin(tp): else: from typing_extensions import TypeGuard +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + NoneType = type(None) _U = TypeVar("_U") @@ -275,15 +280,18 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: return decorator(_cls, stacklevel=stacklevel + 1) -@overload -def add_schema(_cls: Type[_U]) -> Type[_U]: - ... +class ClassDecorator(Protocol): + def __call__(self, cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + ... @overload def add_schema( + *, base_schema: Optional[Type[marshmallow.Schema]] = None, -) -> Callable[[Type[_U]], Type[_U]]: + cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, +) -> ClassDecorator: ... @@ -297,7 +305,12 @@ def add_schema( ... -def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): +def add_schema( + _cls: Optional[Type[_U]] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, + cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, +) -> Union[Type[_U], ClassDecorator]: """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. @@ -319,21 +332,19 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): Artist(names=('Martin', 'Ramirez')) """ - def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: - _check_decorated_type(clazz) - - if cls_frame is not None: - frame = cls_frame - else: - frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel) + def decorator(cls: Type[_V], stacklevel: int = stacklevel) -> Type[_V]: + _check_decorated_type(cls) + frame = cls_frame + if frame is None: + frame = _maybe_get_callers_frame(cls, stacklevel=stacklevel) # noinspection PyTypeHints - clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, frame), + cls.Schema = lazy_class_attribute( # type: ignore[attr-defined] + partial(class_schema, cls, base_schema, frame), "Schema", - clazz.__name__, + cls.__name__, ) - return clazz + return cls if _cls is None: return decorator From 00a9718322d0df9c9393ec831f819de0e74ed0e5 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 11:46:31 -0800 Subject: [PATCH 23/32] fix: lazy_class_attribute This fixes thread-safety issues in lazy_class_attribute. Here we also specialize to our use case, thus cleaning up the type annotations --- .pre-commit-config.yaml | 1 + marshmallow_dataclass/__init__.py | 16 ++--- marshmallow_dataclass/lazy_class_attribute.py | 56 ++++++++---------- tests/test_lazy_class_attribute.py | 59 +++++++++++++++++++ tests/test_memory_leak.py | 3 +- 5 files changed, 92 insertions(+), 43 deletions(-) create mode 100644 tests/test_lazy_class_attribute.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc19292..06433cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,7 @@ repos: additional_dependencies: - marshmallow - marshmallow-enum + - pytest - typeguard - types-setuptools - typing-inspect diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 4e28bfd..e98b31b 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -310,6 +310,7 @@ def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, stacklevel: int = 1, + attr_name: str = "Schema", ) -> Union[Type[_U], ClassDecorator]: """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. @@ -333,17 +334,12 @@ def add_schema( """ def decorator(cls: Type[_V], stacklevel: int = stacklevel) -> Type[_V]: + nonlocal cls_frame _check_decorated_type(cls) - frame = cls_frame - if frame is None: - frame = _maybe_get_callers_frame(cls, stacklevel=stacklevel) - - # noinspection PyTypeHints - cls.Schema = lazy_class_attribute( # type: ignore[attr-defined] - partial(class_schema, cls, base_schema, frame), - "Schema", - cls.__name__, - ) + if cls_frame is None: + cls_frame = _maybe_get_callers_frame(cls, stacklevel=stacklevel) + fget = partial(class_schema, cls, base_schema, cls_frame) + setattr(cls, attr_name, lazy_class_attribute(fget, attr_name)) return cls if _cls is None: diff --git a/marshmallow_dataclass/lazy_class_attribute.py b/marshmallow_dataclass/lazy_class_attribute.py index 2dbe4a4..0555d7d 100644 --- a/marshmallow_dataclass/lazy_class_attribute.py +++ b/marshmallow_dataclass/lazy_class_attribute.py @@ -1,45 +1,39 @@ -from typing import Any, Callable, Optional +import threading +from typing import Callable, Generic, Optional, TypeVar __all__ = ("lazy_class_attribute",) -class LazyClassAttribute: - """Descriptor decorator implementing a class-level, read-only - property, which caches its results on the class(es) on which it - operates. - """ +_T_co = TypeVar("_T_co", covariant=True) - __slots__ = ("func", "name", "called", "forward_value") - def __init__( - self, - func: Callable[..., Any], - name: Optional[str] = None, - forward_value: Any = None, - ): - self.func = func - self.name = name - self.called = False - self.forward_value = forward_value +class LazyClassAttribute(Generic[_T_co]): + """Descriptor implementing a cached class property.""" - def __get__(self, instance, cls=None): - if not cls: - cls = type(instance) - - # avoid recursion - if self.called: - return self.forward_value + __slots__ = ("fget", "attr_name", "rlock", "called_from") - self.called = True + def __init__(self, fget: Callable[[], _T_co], attr_name: str): + self.fget = fget + self.attr_name = attr_name + self.rlock = threading.RLock() + self.called_from: Optional[threading.Thread] = None - setattr(cls, self.name, self.func()) - - # "getattr" is used to handle bounded methods - return getattr(cls, self.name) + def __get__(self, instance: object, cls: Optional[type] = None) -> _T_co: + if not cls: + cls = type(instance) - def __set_name__(self, owner, name): - self.name = self.name or name + with self.rlock: + if self.called_from is not None: + if self.called_from is not threading.current_thread(): + return getattr(cls, self.attr_name) # type: ignore[no-any-return] + raise AttributeError( + f"recursive evaluation of {cls.__name__}.{self.attr_name}" + ) + self.called_from = threading.current_thread() + value = self.fget() + setattr(cls, self.attr_name, value) + return value lazy_class_attribute = LazyClassAttribute diff --git a/tests/test_lazy_class_attribute.py b/tests/test_lazy_class_attribute.py new file mode 100644 index 0000000..c6dc1fe --- /dev/null +++ b/tests/test_lazy_class_attribute.py @@ -0,0 +1,59 @@ +import threading +import time +from itertools import count + +import pytest + +from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute + + +def test_caching() -> None: + counter = count() + + def fget() -> str: + return f"value-{next(counter)}" + + class A: + x = lazy_class_attribute(fget, "x") + + assert A.x == "value-0" + assert A.x == "value-0" + + +def test_recursive_evaluation() -> None: + def fget() -> str: + return A.x + + class A: + x: str = lazy_class_attribute(fget, "x") # type: ignore[assignment] + + with pytest.raises(AttributeError, match="recursive evaluation of A.x"): + A.x + + +def test_threading() -> None: + counter = count() + lock = threading.Lock() + + def fget() -> str: + time.sleep(0.05) + with lock: + return f"value-{next(counter)}" + + class A: + x = lazy_class_attribute(fget, "x") + + n_threads = 4 + barrier = threading.Barrier(n_threads) + values = set() + + def run(): + barrier.wait() + values.add(A.x) + + threads = [threading.Thread(target=run) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + assert values == {"value-0"} diff --git a/tests/test_memory_leak.py b/tests/test_memory_leak.py index 306430e..31f56e0 100644 --- a/tests/test_memory_leak.py +++ b/tests/test_memory_leak.py @@ -121,8 +121,7 @@ class Foo: decorator(Foo) with mock.patch( - "marshmallow_dataclass.lazy_class_attribute", - side_effect=Exception("forced exception"), + "marshmallow_dataclass.setattr", side_effect=Exception("forced exception") ) as m: f() From ae08afda5b9d76c34065b6acf7f311b3ef190119 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 15:15:28 -0800 Subject: [PATCH 24/32] feat: check type of metadata["marshmallow_field"] FIXME: may should warn instead of raising exception --- marshmallow_dataclass/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index e98b31b..dc6993d 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -1141,6 +1141,11 @@ def _field_for_schema( # If the field was already defined by the user predefined_field = metadata.get("marshmallow_field") if predefined_field: + if not isinstance(predefined_field, marshmallow.fields.Field): + raise TypeError( + "metadata['marshmallow_field'] must be set to a Field instance, " + f"not {predefined_field}" + ) return predefined_field if default is not marshmallow.missing: From 484c2b3f13e1e22f2bca44c9856aca216325a592 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 16:40:39 -0800 Subject: [PATCH 25/32] feat(mypy plugin): add type annotation for Schema attribute to dataclasses --- marshmallow_dataclass/mypy.py | 29 +++++++++++++++++++++++++---- tests/test_mypy.yml | 18 +++++++++++++++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/marshmallow_dataclass/mypy.py b/marshmallow_dataclass/mypy.py index d33a5ad..a45cb19 100644 --- a/marshmallow_dataclass/mypy.py +++ b/marshmallow_dataclass/mypy.py @@ -2,8 +2,10 @@ from typing import Callable, Optional, Type from mypy import nodes -from mypy.plugin import DynamicClassDefContext, Plugin +from mypy.plugin import ClassDefContext, DynamicClassDefContext, Plugin from mypy.plugins import dataclasses +from mypy.plugins.common import add_attribute_to_class +from mypy.types import AnyType, TypeOfAny, TypeType import marshmallow_dataclass @@ -22,11 +24,30 @@ def get_dynamic_class_hook( return new_type_hook return None - def get_class_decorator_hook(self, fullname: str): + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: if fullname == "marshmallow_dataclass.dataclass": - return dataclasses.dataclass_class_maker_callback + return dataclasses.dataclass_tag_callback return None + def get_class_decorator_hook_2( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + if fullname == "marshmallow_dataclass.dataclass": + return class_decorator_hook + return None + + +def class_decorator_hook(ctx: ClassDefContext) -> bool: + if not dataclasses.dataclass_class_maker_callback(ctx): + return False + any_type = AnyType(TypeOfAny.explicit) + schema_type = ctx.api.named_type_or_none("marshmallow.Schema") or any_type + schema_type_type = TypeType.make_normalized(schema_type) + add_attribute_to_class(ctx.api, ctx.cls, "Schema", schema_type_type) + return True + def new_type_hook(ctx: DynamicClassDefContext) -> None: """ @@ -66,6 +87,6 @@ def _get_arg_by_name( except TypeError: return None try: - return bound_args.arguments[name] + return bound_args.arguments[name] # type: ignore[no-any-return] except KeyError: return None diff --git a/tests/test_mypy.yml b/tests/test_mypy.yml index 55e5cb2..3b99545 100644 --- a/tests/test_mypy.yml +++ b/tests/test_mypy.yml @@ -42,6 +42,22 @@ name: str user = User(id=4, name='Johny') + +- case: dataclass_Schema_attribute + mypy_config: | + follow_imports = silent + plugins = marshmallow_dataclass.mypy + env: + - PYTHONPATH=. + main: | + from marshmallow_dataclass import dataclass + + @dataclass + class Test: + child: "Test" + + reveal_type(Test.Schema) # N: Revealed type is "Type[marshmallow.schema.Schema]" + - case: public_custom_types mypy_config: | follow_imports = silent @@ -63,5 +79,5 @@ website = Website(url="http://www.example.org", email="admin@example.org") reveal_type(website.url) # N: Revealed type is "builtins.str" reveal_type(website.email) # N: Revealed type is "builtins.str" - + Website(url=42, email="user@email.com") # E: Argument "url" to "Website" has incompatible type "int"; expected "str" [arg-type] From 6a0b05a273ef5d72c91fc4131e709f929f4c1388 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Wed, 18 Jan 2023 18:08:22 -0800 Subject: [PATCH 26/32] chore: clean up type annotations --- marshmallow_dataclass/collection_field.py | 10 ++++++---- marshmallow_dataclass/union_field.py | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/marshmallow_dataclass/collection_field.py b/marshmallow_dataclass/collection_field.py index 6823b72..0f9d4c2 100644 --- a/marshmallow_dataclass/collection_field.py +++ b/marshmallow_dataclass/collection_field.py @@ -29,16 +29,18 @@ class Set(marshmallow.fields.List): will be random. So if the order matters, use a List or Sequence ! """ + set_type: typing.Type[ + typing.Union[typing.FrozenSet[typing.Any], typing.Set[typing.Any]] + ] + def __init__( self, cls_or_instance: typing.Union[marshmallow.fields.Field, type], frozen: bool = False, - **kwargs, + **kwargs: typing.Any, ): super().__init__(cls_or_instance, **kwargs) - self.set_type: typing.Type[typing.Union[frozenset, set]] = ( - frozenset if frozen else set - ) + self.set_type = frozenset if frozen else set def _deserialize( # type: ignore[override] self, diff --git a/marshmallow_dataclass/union_field.py b/marshmallow_dataclass/union_field.py index 6e87e29..c7875ec 100644 --- a/marshmallow_dataclass/union_field.py +++ b/marshmallow_dataclass/union_field.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Any, Optional +from typing import Any, List, Mapping, Optional, Tuple import typeguard from marshmallow import fields, Schema, ValidationError @@ -26,21 +26,23 @@ class Union(fields.Field): :param kwargs: The same keyword arguments that :class:`Field` receives. """ - def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs): + def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs: Any): super().__init__(**kwargs) self.union_fields = union_fields def _bind_to_schema(self, field_name: str, schema: Schema) -> None: - super()._bind_to_schema(field_name, schema) + super()._bind_to_schema(field_name, schema) # type: ignore[no-untyped-call] new_union_fields = [] for typ, field in self.union_fields: field = copy.deepcopy(field) - field._bind_to_schema(field_name, self) + field._bind_to_schema(field_name, self) # type: ignore[no-untyped-call] new_union_fields.append((typ, field)) self.union_fields = new_union_fields - def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any: + def _serialize( + self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any + ) -> Any: errors = [] if value is None: return value @@ -56,7 +58,13 @@ def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any: f"Unable to serialize value with any of the fields in the union: {errors}" ) - def _deserialize(self, value: Any, attr: Optional[str], data, **kwargs) -> Any: + def _deserialize( + self, + value: Any, + attr: Optional[str], + data: Optional[Mapping[str, Any]], + **kwargs: Any, + ) -> Any: errors = [] for typ, field in self.union_fields: try: From 76ccff5add31046a3f8baf1c343a3f982e247f54 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 12:36:11 -0800 Subject: [PATCH 27/32] test(mypy): turn on strict checking of marshmallow_dataclass Retain relaxed checking of tests for now --- pyproject.toml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5bbb9ad..8503045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,8 @@ packages = [ # (See the docstring in mypy_plugin.py for more.) plugins = "mypy_plugin.py" -warn_redundant_casts = true -warn_unused_configs = true -disable_error_code = "annotation-unchecked" +strict = true +warn_unreachable = true [[tool.mypy.overrides]] # dependencies without type hints @@ -29,3 +28,13 @@ module = [ ] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = [ + "tests.*", +] +disable_error_code = "annotation-unchecked" +check_untyped_defs = false +disallow_untyped_calls = false +disallow_untyped_defs = false +disallow_incomplete_defs = false + From 60d8f95e6f927b06dab0392dcf04fc83cb0c4367 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 15:33:09 -0800 Subject: [PATCH 28/32] refactor: make context-dependent functions methods of the context This allows us to cleanly eliminate the thread-local context stack. --- marshmallow_dataclass/__init__.py | 793 +++++++++++++++--------------- 1 file changed, 387 insertions(+), 406 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index dc6993d..f7c1f9b 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -38,7 +38,6 @@ class User: import dataclasses import inspect import sys -import threading import types import warnings from enum import Enum @@ -138,6 +137,44 @@ def get_origin(tp): MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 +class Error(TypeError): + """Class passed ``class_schema`` can not be converted to a Marshmallow schema. + + FIXME: Currently this inherits from TypeError for backward compatibility with + older versions of marshmallow_dataclass which always raised + TypeError(f"{name} is not a dataclass and cannot be turned into one.") + + """ + + +class InvalidClassError(ValueError, Error): + """Argument to ``class_schema`` can not be converted to a Marshmallow schema. + + This exception is raised when, while generating a Marshmallow schema for a + dataclass, a class is encountered for which a Marshmallow Schema can not + be generated. + + """ + + +class UnrecognizedFieldTypeError(Error): + """An unrecognized field type spec was encountered. + + This exception is raised when, while generating a Marshmallow schema for a + dataclass, a field is encountered for which a Marshmallow Field can not + be generated. + + """ + + +class UnboundTypeVarError(Error): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + + """ + + def _maybe_get_callers_frame( cls: type, stacklevel: int = 1 ) -> Optional[types.FrameType]: @@ -505,9 +542,8 @@ def class_schema( if base_schema is None: base_schema = marshmallow.Schema - with _SchemaContext(globalns, localns, base_schema): - schema = _internal_class_schema(clazz) - + schema_ctx = _SchemaContext(globalns, localns, base_schema) + schema = schema_ctx.class_schema(clazz) assert not isinstance(schema, _Future) return schema @@ -635,7 +671,7 @@ def resolve(self, spec: Union[TypeVar_, TypeSpec]) -> TypeSpec: try: return self._args[spec] except KeyError as exc: - raise TypeError( + raise UnboundTypeVarError( f"generic type variable {spec.__name__} is not bound" ) from exc return spec @@ -683,279 +719,402 @@ def get_type_mapping( ) return getattr(base_schema, "TYPE_MAPPING", {}) - def __enter__(self) -> "_SchemaContext": - _schema_ctx_stack.push(self) - return self - - def __exit__( - self, - _typ: Optional[Type[BaseException]], - _value: Optional[BaseException], - _tb: Optional[types.TracebackType], - ) -> None: - _schema_ctx_stack.pop() + def class_schema( + self, clazz: type + ) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: + if clazz in self.seen_classes: + return self.seen_classes[clazz] + cache_key = clazz, self.base_schema + try: + return _schema_cache[cache_key] + except KeyError: + pass + + future: _Future[Type[marshmallow.Schema]] = _Future() + self.seen_classes[clazz] = future + + constructor: Callable[..., object] + + if self.is_simple_annotated_class(clazz): + class_name = clazz.__name__ + constructor = _simple_class_constructor(clazz) + attributes = self.schema_attrs_for_simple_class(clazz) + elif _is_generic_alias_of_dataclass(clazz): + origin = get_origin(clazz) + assert isinstance(origin, type) + class_name = origin.__name__ + constructor = origin + ctx = self.replace(generic_args=_GenericArgs(clazz, self.generic_args)) + attributes = ctx.schema_attrs_for_dataclass(origin) + elif dataclasses.is_dataclass(clazz): + class_name = clazz.__name__ + constructor = clazz + attributes = self.schema_attrs_for_dataclass(clazz) + else: + raise InvalidClassError( + f"{clazz} is not a dataclass or a simple annotated class" + ) -class _LocalStack(threading.local, Generic[_U]): - def __init__(self) -> None: - self.stack: List[_U] = [] + load_to_dict = self.base_schema.load + + def load( + self: marshmallow.Schema, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Optional[bool] = None, + unknown: Optional[str] = None, + **kwargs: Any, + ) -> Any: + many = self.many if many is None else bool(many) + loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs) + if many: + return [constructor(**item) for item in loaded] + else: + return constructor(**loaded) + + attributes["load"] = load + + schema_class: Type[marshmallow.Schema] = type( + f"{class_name}Schema", (self.base_schema,), attributes + ) - def push(self, value: _U) -> None: - self.stack.append(value) + future.set_result(schema_class) + _schema_cache[cache_key] = schema_class + return schema_class - def pop(self) -> None: - self.stack.pop() + def schema_attrs_for_dataclass(self, clazz: type) -> Dict[str, Any]: + if _has_generic_base(clazz): + raise InvalidClassError( + "class_schema does not support dataclasses with generic base classes" + ) - @property - def top(self) -> _U: - return self.stack[-1] + type_hints = get_type_hints(clazz, globalns=self.globalns, localns=self.localns) + attrs = dict(_marshmallow_hooks(clazz)) + for field in dataclasses.fields(clazz): + if field.init: + typ = type_hints[field.name] + default = ( + field.default_factory + if field.default_factory is not dataclasses.MISSING + else field.default + if field.default is not dataclasses.MISSING + else marshmallow.missing + ) + attrs[field.name] = self.field_for_schema(typ, default, field.metadata) + return attrs + + def is_simple_annotated_class(self, obj: object) -> bool: + """Determine whether obj is a "simple annotated class". + + The ```class_schema``` function can generate schemas for + simple annotated classes (as well as for dataclasses). + """ + if not isinstance(obj, type): + return False + if getattr(obj, "__init__", None) is not object.__init__: + return False + if getattr(obj, "__new__", None) is not object.__new__: + return False + + type_hints = get_type_hints(obj, globalns=self.globalns, localns=self.localns) + return any(not typing_inspect.is_classvar(th) for th in type_hints.values()) + + def schema_attrs_for_simple_class(self, clazz: type) -> Dict[str, Any]: + type_hints = get_type_hints(clazz, globalns=self.globalns, localns=self.localns) + + attrs = dict(_marshmallow_hooks(clazz)) + for field_name, typ in type_hints.items(): + if not typing_inspect.is_classvar(typ): + default = getattr(clazz, field_name, marshmallow.missing) + attrs[field_name] = self.field_for_schema(typ, default) + return attrs + + def field_for_schema( + self, + typ: Union[type, object], + default: Any = marshmallow.missing, + metadata: Optional[Mapping[str, Any]] = None, + ) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + This is an internal version of field_for_schema. -_schema_ctx_stack = _LocalStack[_SchemaContext]() + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + """ -def _internal_class_schema( - clazz: type, -) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: - schema_ctx = _schema_ctx_stack.top - if clazz in schema_ctx.seen_classes: - return schema_ctx.seen_classes[clazz] + metadata = {} if metadata is None else dict(metadata) - cache_key = clazz, schema_ctx.base_schema - try: - return _schema_cache[cache_key] - except KeyError: - pass - - future: _Future[Type[marshmallow.Schema]] = _Future() - schema_ctx.seen_classes[clazz] = future - - generic_args = schema_ctx.generic_args - - constructor: Callable[..., object] - - if _is_simple_annotated_class(clazz): - class_name = clazz.__name__ - constructor = _simple_class_constructor(clazz) - attributes = _schema_attrs_for_simple_class(clazz) - elif _is_generic_alias_of_dataclass(clazz): - origin = get_origin(clazz) - assert isinstance(origin, type) - class_name = origin.__name__ - constructor = origin - with schema_ctx.replace(generic_args=_GenericArgs(clazz, generic_args)): - attributes = _schema_attrs_for_dataclass(origin) - elif dataclasses.is_dataclass(clazz): - class_name = clazz.__name__ - constructor = clazz - attributes = _schema_attrs_for_dataclass(clazz) - else: - raise TypeError(f"{clazz} is not a dataclass or a simple annotated class") - - base_schema = marshmallow.Schema - if schema_ctx.base_schema is not None: - base_schema = schema_ctx.base_schema - - load_to_dict = base_schema.load - - def load( - self: marshmallow.Schema, - data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], - *, - many: Optional[bool] = None, - unknown: Optional[str] = None, - **kwargs: Any, - ) -> Any: - many = self.many if many is None else bool(many) - loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs) - if many: - return [constructor(**item) for item in loaded] + # If the field was already defined by the user + predefined_field = metadata.get("marshmallow_field") + if predefined_field: + if not isinstance(predefined_field, marshmallow.fields.Field): + raise TypeError( + "metadata['marshmallow_field'] must be set to a Field instance, " + f"not {predefined_field}" + ) + return predefined_field + + if default is not marshmallow.missing: + metadata.setdefault("dump_default", default) + # 'missing' must not be set for required fields. + if not metadata.get("required"): + metadata.setdefault("load_default", default) else: - return constructor(**loaded) + metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) - attributes["load"] = load + if self.generic_args is not None and isinstance(typ, TypeVar): + typ = self.generic_args.resolve(typ) - schema_class: Type[marshmallow.Schema] = type( - f"{class_name}Schema", (base_schema,), attributes - ) + if _is_builtin_collection_type(typ): + return self.field_for_builtin_collection_type(typ, metadata) - future.set_result(schema_class) - _schema_cache[cache_key] = schema_class - return schema_class + # Base types + type_mapping = self.get_type_mapping(use_mro=True) + field = type_mapping.get(typ) + if field is not None: + return field(**metadata) + if typ is Any: + metadata.setdefault("allow_none", True) + return marshmallow.fields.Raw(**metadata) -def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: - for name, attr in inspect.getmembers(clazz): - if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST: - yield name, attr + if typing_inspect.is_literal_type(typ): + return self.field_for_literal_type(typ, metadata) + if typing_inspect.is_final_type(typ): + return self.field_for_schema( + _get_subtype_for_final_type(typ, default), + default=default, + metadata=metadata, + ) -def _schema_attrs_for_dataclass(clazz: type) -> Dict[str, Any]: - if _has_generic_base(clazz): - raise TypeError( - "class_schema does not support dataclasses with generic base classes" - ) - schema_ctx = _schema_ctx_stack.top - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) + if typing_inspect.is_union_type(typ): + return self.field_for_union_type(typ, metadata) - attrs = dict(_marshmallow_hooks(clazz)) - for field in dataclasses.fields(clazz): - if field.init: - typ = type_hints[field.name] - default = ( - field.default_factory - if field.default_factory is not dataclasses.MISSING - else field.default - if field.default is not dataclasses.MISSING - else marshmallow.missing - ) - attrs[field.name] = _field_for_schema(typ, default, field.metadata) - return attrs + if typing_inspect.is_new_type(typ): + return self.field_for_new_type(typ, default, metadata) + # enumerations + if isinstance(typ, type) and issubclass(typ, Enum): + return self.field_for_enum(typ, metadata) -def _schema_attrs_for_simple_class(clazz: type) -> Dict[str, Any]: - schema_ctx = _schema_ctx_stack.top - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) + # nested dataclasses + if ( + dataclasses.is_dataclass(typ) + or _is_generic_alias_of_dataclass(typ) + or self.is_simple_annotated_class(typ) + ): + nested = self.schema_for_nested(typ) + # type spec for Nested.__init__ is not correct + return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type] - attrs = dict(_marshmallow_hooks(clazz)) - for field_name, typ in type_hints.items(): - if not typing_inspect.is_classvar(typ): - default = getattr(clazz, field_name, marshmallow.missing) - attrs[field_name] = _field_for_schema(typ, default) - return attrs + raise UnrecognizedFieldTypeError(f"can not deduce field type for {typ}") + def field_for_builtin_collection_type( + self, typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Handle builtin container types like list, tuple, set, etc. + """ + origin = get_origin(typ) + if origin is None: + origin = typ + assert len(get_args(typ)) == 0 -def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]: - def constructor(**kwargs: Any) -> _U: - obj = clazz.__new__(clazz) - for k, v in kwargs.items(): - setattr(obj, k, v) - return obj + args = get_args(typ) - return constructor + if origin is tuple and ( + len(args) == 0 or (len(args) == 2 and args[1] is Ellipsis) + ): + # Special case: homogeneous tuple — treat as Sequence + origin = collections.abc.Sequence + args = args[:1] + # Override base_schema.TYPE_MAPPING to change the class used for generic types below + def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: + type_mapping = self.get_type_mapping(use_mro=False) + return type_mapping.get(type_spec, default) # type: ignore[return-value] -def _is_builtin_collection_type(typ: object) -> bool: - origin = get_origin(typ) - if origin is None: - origin = typ + def get_field(i: int) -> marshmallow.fields.Field: + return self.field_for_schema(args[i] if args else Any) - return origin in { - list, - collections.abc.Sequence, - set, - frozenset, - tuple, - dict, - collections.abc.Mapping, - } + if origin is tuple: + tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) + return tuple_type(tuple(map(self.field_for_schema, args)), **metadata) + if origin in (dict, collections.abc.Mapping): + dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) + return dict_type(keys=get_field(0), values=get_field(1), **metadata) -def _field_for_builtin_collection_type( - typ: object, metadata: Dict[str, Any] -) -> marshmallow.fields.Field: - """ - Handle builtin container types like list, tuple, set, etc. - """ - origin = get_origin(typ) - if origin is None: - origin = typ - assert len(get_args(typ)) == 0 + if origin is list: + list_type = get_field_type(List, default=marshmallow.fields.List) + return list_type(get_field(0), **metadata) - args = get_args(typ) + if origin is collections.abc.Sequence: + from . import collection_field - schema_ctx = _schema_ctx_stack.top + return collection_field.Sequence(get_field(0), **metadata) - # Override base_schema.TYPE_MAPPING to change the class used for generic types below - def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: - type_mapping = schema_ctx.get_type_mapping() - return type_mapping.get(type_spec, default) # type: ignore[return-value] + if origin in (set, frozenset): + from . import collection_field - if origin is tuple and (len(args) == 0 or (len(args) == 2 and args[1] is Ellipsis)): - # Special case: homogeneous tuple — treat as Sequence - origin = collections.abc.Sequence - args = args[:1] + frozen = origin is frozenset + return collection_field.Set(get_field(0), frozen=frozen, **metadata) - if origin is tuple: - tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) - return tuple_type(tuple(map(_field_for_schema, args)), **metadata) + raise ValueError(f"{typ} is not a builtin collection type") - def get_field(i: int) -> marshmallow.fields.Field: - return _field_for_schema(args[i] if args else Any) + def field_for_union_type( + self, typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a union or optional type. + """ + assert typing_inspect.is_union_type(typ) + subtypes = [t for t in get_args(typ) if t is not NoneType] + + if typing_inspect.is_optional_type(typ): + metadata = { + "allow_none": True, + "dump_default": None, + **metadata, + } + if not metadata.setdefault("required", False): + metadata.setdefault("load_default", None) + + if len(subtypes) == 1: + return self.field_for_schema(subtypes[0], metadata=metadata) + + from . import union_field + + return union_field.Union( + [ + (typ, self.field_for_schema(typ, metadata={"required": True})) + for typ in subtypes + ], + **metadata, + ) - if origin in (dict, collections.abc.Mapping): - dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) - return dict_type(keys=get_field(0), values=get_field(1), **metadata) + @staticmethod + def field_for_literal_type( + typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a Literal type. + """ + validate: marshmallow.validate.Validator - if origin is list: - list_type = get_field_type(List, default=marshmallow.fields.List) - return list_type(get_field(0), **metadata) + assert typing_inspect.is_literal_type(typ) + arguments = typing_inspect.get_args(typ) + if len(arguments) == 1: + validate = marshmallow.validate.Equal(arguments[0]) + else: + validate = marshmallow.validate.OneOf(arguments) + return marshmallow.fields.Raw(validate=validate, **metadata) - if origin is collections.abc.Sequence: - from . import collection_field + def field_for_new_type( + self, typ: object, default: Any, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Return a new field for fields based on a NewType. + """ + # Add the information coming our custom NewType implementation + typ_args = getattr(typ, "_marshmallow_args", {}) + + # Handle multiple validators from both `typ` and `metadata`. + # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 + validators: List[Callable[[Any], Any]] = [] + for args in (typ_args, metadata): + validate = args.get("validate") + if marshmallow.utils.is_iterable_but_not_string(validate): + validators.extend(validate) # type: ignore[arg-type] + elif validate is not None: + validators.append(validate) - return collection_field.Sequence(get_field(0), **metadata) + metadata = { + **typ_args, + **metadata, + "validate": validators if validators else None, + } + type_name = getattr(typ, "__name__", repr(typ)) + metadata.setdefault("metadata", {}).setdefault("description", type_name) - if origin in (set, frozenset): - from . import collection_field + field: Optional[Type[marshmallow.fields.Field]] = getattr( + typ, "_marshmallow_field", None + ) + if field is not None: + return field(**metadata) + return self.field_for_schema( + typ.__supertype__, # type: ignore[attr-defined] + default=default, + metadata=metadata, + ) - frozen = origin is frozenset - return collection_field.Set(get_field(0), frozen=frozen, **metadata) + @staticmethod + def field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.Field: + """ + Return a new field for an Enum field. + """ + if sys.version_info >= (3, 7): + return marshmallow.fields.Enum(typ, **metadata) + else: + # Remove this once support for python 3.6 is dropped. + import marshmallow_enum - raise ValueError(f"{typ} is not a builtin collection type") + return marshmallow_enum.EnumField(typ, **metadata) + def schema_for_nested( + self, typ: object + ) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]: + """ + Return a marshmallow.Schema for a nested dataclass (or simple annotated class) + """ + if isinstance(typ, type) and hasattr(typ, "Schema"): + # marshmallow_dataclass.dataclass + # Defer evaluation of .Schema attribute, to avoid forward reference issues + return partial(getattr, typ, "Schema") -def _field_for_union_type( - typ: object, metadata: Dict[str, Any] -) -> marshmallow.fields.Field: - """ - Construct the appropriate Field for a union or optional type. - """ - assert typing_inspect.is_union_type(typ) - subtypes = [t for t in get_args(typ) if t is not NoneType] + class_schema = self.class_schema(typ) # type: ignore[arg-type] # FIXME + if isinstance(class_schema, _Future): + return class_schema.result + return class_schema - if typing_inspect.is_optional_type(typ): - metadata = { - "allow_none": True, - "dump_default": None, - **metadata, - } - if not metadata.setdefault("required", False): - metadata.setdefault("load_default", None) - if len(subtypes) == 1: - return _field_for_schema(subtypes[0], metadata=metadata) +def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: + for name, attr in inspect.getmembers(clazz): + if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST: + yield name, attr - from . import union_field - return union_field.Union( - [ - (typ, _field_for_schema(typ, metadata={"required": True})) - for typ in subtypes - ], - **metadata, - ) +def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]: + def constructor(**kwargs: Any) -> _U: + obj = clazz.__new__(clazz) + for k, v in kwargs.items(): + setattr(obj, k, v) + return obj + return constructor -def _field_for_literal_type( - typ: object, metadata: Dict[str, Any] -) -> marshmallow.fields.Field: - """ - Construct the appropriate Field for a Literal type. - """ - validate: marshmallow.validate.Validator - assert typing_inspect.is_literal_type(typ) - arguments = typing_inspect.get_args(typ) - if len(arguments) == 1: - validate = marshmallow.validate.Equal(arguments[0]) - else: - validate = marshmallow.validate.OneOf(arguments) - return marshmallow.fields.Raw(validate=validate, **metadata) +def _is_builtin_collection_type(typ: object) -> bool: + origin = get_origin(typ) + if origin is None: + origin = typ + + return origin in { + list, + collections.abc.Sequence, + set, + frozenset, + tuple, + dict, + collections.abc.Mapping, + } def _get_subtype_for_final_type(typ: object, default: Any) -> object: @@ -994,95 +1153,6 @@ def _get_subtype_for_final_type(typ: object, default: Any) -> object: return type(default) -def _field_for_new_type( - typ: object, default: Any, metadata: Dict[str, Any] -) -> marshmallow.fields.Field: - """ - Return a new field for fields based on a NewType. - """ - # Add the information coming our custom NewType implementation - typ_args = getattr(typ, "_marshmallow_args", {}) - - # Handle multiple validators from both `typ` and `metadata`. - # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 - validators: List[Callable[[Any], Any]] = [] - for args in (typ_args, metadata): - validate = args.get("validate") - if marshmallow.utils.is_iterable_but_not_string(validate): - validators.extend(validate) # type: ignore[arg-type] - elif validate is not None: - validators.append(validate) - - metadata = { - **typ_args, - **metadata, - "validate": validators if validators else None, - } - type_name = getattr(typ, "__name__", repr(typ)) - metadata.setdefault("metadata", {}).setdefault("description", type_name) - - field: Optional[Type[marshmallow.fields.Field]] = getattr( - typ, "_marshmallow_field", None - ) - if field is not None: - return field(**metadata) - return _field_for_schema( - typ.__supertype__, # type: ignore[attr-defined] - default=default, - metadata=metadata, - ) - - -def _field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.Field: - """ - Return a new field for an Enum field. - """ - if sys.version_info >= (3, 7): - return marshmallow.fields.Enum(typ, **metadata) - else: - # Remove this once support for python 3.6 is dropped. - import marshmallow_enum - - return marshmallow_enum.EnumField(typ, **metadata) - - -def _schema_for_nested( - typ: object, -) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]: - """ - Return a marshmallow.Schema for a nested dataclass (or simple annotated class) - """ - if isinstance(typ, type) and hasattr(typ, "Schema"): - # marshmallow_dataclass.dataclass - # Defer evaluation of .Schema attribute, to avoid forward reference issues - return partial(getattr, typ, "Schema") - - class_schema = _internal_class_schema(typ) # type: ignore[arg-type] # FIXME - if isinstance(class_schema, _Future): - return class_schema.result - return class_schema - - -def _is_simple_annotated_class(obj: object) -> bool: - """Determine whether obj is a "simple annotated class". - - The ```class_schema``` function can generate schemas for - simple annotated classes (as well as for dataclasses). - """ - if not isinstance(obj, type): - return False - if getattr(obj, "__init__", None) is not object.__init__: - return False - if getattr(obj, "__new__", None) is not object.__new__: - return False - - schema_ctx = _schema_ctx_stack.top - type_hints = get_type_hints( - obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) - return any(not typing_inspect.is_classvar(th) for th in type_hints.values()) - - def field_for_schema( typ: type, default: Any = marshmallow.missing, @@ -1114,97 +1184,8 @@ def field_for_schema( if base_schema is None: base_schema = marshmallow.Schema localns = typ_frame.f_locals if typ_frame is not None else None - with _SchemaContext(localns=localns, base_schema=base_schema): - return _field_for_schema(typ, default, metadata) - - -def _field_for_schema( - typ: object, - default: Any = marshmallow.missing, - metadata: Optional[Mapping[str, Any]] = None, -) -> marshmallow.fields.Field: - """ - Get a marshmallow Field corresponding to the given python type. - The metadata of the dataclass field is used as arguments to the marshmallow Field. - - This is an internal version of field_for_schema. It assumes a _SchemaContext - has been pushed onto the local stack. - - :param typ: The type for which a field should be generated - :param default: value to use for (de)serialization when the field is missing - :param metadata: Additional parameters to pass to the marshmallow field constructor - - """ - - metadata = {} if metadata is None else dict(metadata) - - # If the field was already defined by the user - predefined_field = metadata.get("marshmallow_field") - if predefined_field: - if not isinstance(predefined_field, marshmallow.fields.Field): - raise TypeError( - "metadata['marshmallow_field'] must be set to a Field instance, " - f"not {predefined_field}" - ) - return predefined_field - - if default is not marshmallow.missing: - metadata.setdefault("dump_default", default) - # 'missing' must not be set for required fields. - if not metadata.get("required"): - metadata.setdefault("load_default", default) - else: - metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) - - schema_ctx = _schema_ctx_stack.top - - if schema_ctx.generic_args is not None and isinstance(typ, TypeVar): - typ = schema_ctx.generic_args.resolve(typ) - - if _is_builtin_collection_type(typ): - return _field_for_builtin_collection_type(typ, metadata) - - # Base types - type_mapping = schema_ctx.get_type_mapping(use_mro=True) - field = type_mapping.get(typ) - if field is not None: - return field(**metadata) - - if typ is Any: - metadata.setdefault("allow_none", True) - return marshmallow.fields.Raw(**metadata) - - if typing_inspect.is_literal_type(typ): - return _field_for_literal_type(typ, metadata) - - if typing_inspect.is_final_type(typ): - return _field_for_schema( - _get_subtype_for_final_type(typ, default), - default=default, - metadata=metadata, - ) - - if typing_inspect.is_union_type(typ): - return _field_for_union_type(typ, metadata) - - if typing_inspect.is_new_type(typ): - return _field_for_new_type(typ, default, metadata) - - # enumerations - if isinstance(typ, type) and issubclass(typ, Enum): - return _field_for_enum(typ, metadata) - - # nested dataclasses - if ( - dataclasses.is_dataclass(typ) - or _is_generic_alias_of_dataclass(typ) - or _is_simple_annotated_class(typ) - ): - nested = _schema_for_nested(typ) - # type spec for Nested.__init__ is not correct - return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type] - - raise TypeError(f"can not deduce field type for {typ}") + schema_ctx = _SchemaContext(localns=localns, base_schema=base_schema) + return schema_ctx.field_for_schema(typ, default, metadata) def NewType( From 2a26a24ddfcdcd469815ead695a00b1d94d4b3e8 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Thu, 19 Jan 2023 22:09:18 -0800 Subject: [PATCH 29/32] fix: type annotations for class_schema and field_for_schema --- marshmallow_dataclass/__init__.py | 189 ++++++++++++++++++++---------- 1 file changed, 129 insertions(+), 60 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f7c1f9b..86c60ee 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -114,16 +114,16 @@ def get_origin(tp): TypeVar_ = type -if sys.version_info >= (3, 10): - from typing import TypeGuard -else: - from typing_extensions import TypeGuard - if sys.version_info >= (3, 8): from typing import Protocol else: from typing_extensions import Protocol +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + NoneType = type(None) _U = TypeVar("_U") @@ -175,8 +175,62 @@ class UnboundTypeVarError(Error): """ +################################################################ +# Type aliases and type guards (FIXME: move these) + +if sys.version_info >= (3, 7): + _TypeVarType = TypeVar +else: + # py36: type.TypeVar does not work as a type annotation + # (⇒ "AttributeError: type object 'TypeVar' has no attribute '_gorg'") + _TypeVarType = typing_NewType("_TypeVarType", type) + + +def _is_type_var(obj: object) -> TypeGuard[_TypeVarType]: + return isinstance(obj, TypeVar) + + +TypeSpec = object +GenericAlias = typing_NewType("GenericAlias", object) +GenericAliasOfDataclass = typing_NewType("GenericAliasOfDataclass", GenericAlias) + + +def _is_generic_alias_of_dataclass( + cls: object, +) -> TypeGuard[GenericAliasOfDataclass]: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return _is_dataclass_type(get_origin(cls)) + + +_DataclassType = typing_NewType("_DataclassType", type) + + +def _is_dataclass_type(obj: object) -> TypeGuard[_DataclassType]: + return isinstance(obj, type) and dataclasses.is_dataclass(obj) + + +class _NewType(Protocol): + def __call__(self, obj: _U) -> _U: + ... + + @property + def __name__(self) -> str: + ... + + @property + def __supertype__(self) -> type: + ... + + +def _is_new_type(obj: object) -> TypeGuard[_NewType]: + return bool(typing_inspect.is_new_type(obj)) + + def _maybe_get_callers_frame( - cls: type, stacklevel: int = 1 + cls: Union[type, GenericAliasOfDataclass], stacklevel: int = 1 ) -> Optional[types.FrameType]: """Return the caller's frame, but only if it will help resolve forward type references. @@ -317,7 +371,7 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: return decorator(_cls, stacklevel=stacklevel + 1) -class ClassDecorator(Protocol): +class _ClassDecorator(Protocol): def __call__(self, cls: Type[_U], stacklevel: int = 1) -> Type[_U]: ... @@ -328,7 +382,7 @@ def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, stacklevel: int = 1, -) -> ClassDecorator: +) -> _ClassDecorator: ... @@ -348,7 +402,7 @@ def add_schema( cls_frame: Optional[types.FrameType] = None, stacklevel: int = 1, attr_name: str = "Schema", -) -> Union[Type[_U], ClassDecorator]: +) -> Union[Type[_U], _ClassDecorator]: """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. @@ -407,7 +461,7 @@ def class_schema( def class_schema( - clazz: type, # FIXME: type | _GenericAlias + clazz: object, base_schema: Optional[Type[marshmallow.Schema]] = None, # FIXME: delete clazz_frame from API? clazz_frame: Optional[types.FrameType] = None, @@ -533,6 +587,9 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ + if not (_is_dataclass_type(clazz) or _is_generic_alias_of_dataclass(clazz)): + raise InvalidClassError(f"{clazz} is not a dataclass") + if localns is None: if clazz_frame is None: clazz_frame = _maybe_get_callers_frame(clazz) @@ -627,20 +684,6 @@ def set_result(self, result: _U) -> None: self._done = True -TypeSpec = typing_NewType("TypeSpec", object) -GenericAliasOfDataclass = typing_NewType("GenericAliasOfDataclass", object) - - -def _is_generic_alias_of_dataclass( - cls: object, -) -> TypeGuard[GenericAliasOfDataclass]: - """ - Check if given class is a generic alias of a dataclass, if the dataclass is - defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed - """ - return dataclasses.is_dataclass(get_origin(cls)) - - def _has_generic_base(cls: type) -> bool: """Return True if cls has any generic base classes.""" return any(typing_inspect.get_parameters(base) for base in cls.__mro__[1:]) @@ -697,7 +740,7 @@ class _SchemaContext: localns: Optional[Dict[str, Any]] = None base_schema: Type[marshmallow.Schema] = marshmallow.Schema generic_args: Optional[_GenericArgs] = None - seen_classes: Dict[type, _Future[Type[marshmallow.Schema]]] = dataclasses.field( + seen_classes: Dict[Hashable, _Future[Type[marshmallow.Schema]]] = dataclasses.field( default_factory=dict ) @@ -706,7 +749,7 @@ def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext": def get_type_mapping( self, use_mro: bool = False - ) -> Mapping[Any, Type[marshmallow.fields.Field]]: + ) -> Mapping[TypeSpec, Type[marshmallow.fields.Field]]: """Get base_schema.TYPE_MAPPING. If use_mro is true, then merges the TYPE_MAPPINGs from @@ -720,7 +763,7 @@ def get_type_mapping( return getattr(base_schema, "TYPE_MAPPING", {}) def class_schema( - self, clazz: type + self, clazz: Hashable ) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: if clazz in self.seen_classes: return self.seen_classes[clazz] @@ -742,12 +785,12 @@ def class_schema( attributes = self.schema_attrs_for_simple_class(clazz) elif _is_generic_alias_of_dataclass(clazz): origin = get_origin(clazz) - assert isinstance(origin, type) + assert _is_dataclass_type(origin) class_name = origin.__name__ constructor = origin ctx = self.replace(generic_args=_GenericArgs(clazz, self.generic_args)) attributes = ctx.schema_attrs_for_dataclass(origin) - elif dataclasses.is_dataclass(clazz): + elif _is_dataclass_type(clazz): class_name = clazz.__name__ constructor = clazz attributes = self.schema_attrs_for_dataclass(clazz) @@ -783,7 +826,7 @@ def load( _schema_cache[cache_key] = schema_class return schema_class - def schema_attrs_for_dataclass(self, clazz: type) -> Dict[str, Any]: + def schema_attrs_for_dataclass(self, clazz: _DataclassType) -> Dict[str, Any]: if _has_generic_base(clazz): raise InvalidClassError( "class_schema does not support dataclasses with generic base classes" @@ -804,7 +847,9 @@ def schema_attrs_for_dataclass(self, clazz: type) -> Dict[str, Any]: attrs[field.name] = self.field_for_schema(typ, default, field.metadata) return attrs - def is_simple_annotated_class(self, obj: object) -> bool: + _SimpleClass = typing_NewType("_SimpleClass", type) + + def is_simple_annotated_class(self, obj: object) -> TypeGuard[_SimpleClass]: """Determine whether obj is a "simple annotated class". The ```class_schema``` function can generate schemas for @@ -820,7 +865,7 @@ def is_simple_annotated_class(self, obj: object) -> bool: type_hints = get_type_hints(obj, globalns=self.globalns, localns=self.localns) return any(not typing_inspect.is_classvar(th) for th in type_hints.values()) - def schema_attrs_for_simple_class(self, clazz: type) -> Dict[str, Any]: + def schema_attrs_for_simple_class(self, clazz: _SimpleClass) -> Dict[str, Any]: type_hints = get_type_hints(clazz, globalns=self.globalns, localns=self.localns) attrs = dict(_marshmallow_hooks(clazz)) @@ -832,7 +877,7 @@ def schema_attrs_for_simple_class(self, clazz: type) -> Dict[str, Any]: def field_for_schema( self, - typ: Union[type, object], + typ: object, default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, ) -> marshmallow.fields.Field: @@ -897,7 +942,7 @@ def field_for_schema( if typing_inspect.is_union_type(typ): return self.field_for_union_type(typ, metadata) - if typing_inspect.is_new_type(typ): + if _is_new_type(typ): return self.field_for_new_type(typ, default, metadata) # enumerations @@ -906,7 +951,7 @@ def field_for_schema( # nested dataclasses if ( - dataclasses.is_dataclass(typ) + _is_dataclass_type(typ) or _is_generic_alias_of_dataclass(typ) or self.is_simple_annotated_class(typ) ): @@ -937,7 +982,7 @@ def field_for_builtin_collection_type( args = args[:1] # Override base_schema.TYPE_MAPPING to change the class used for generic types below - def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]: + def get_field_type(type_spec: TypeSpec, default: Type[_Field]) -> Type[_Field]: type_mapping = self.get_type_mapping(use_mro=False) return type_mapping.get(type_spec, default) # type: ignore[return-value] @@ -945,8 +990,9 @@ def get_field(i: int) -> marshmallow.fields.Field: return self.field_for_schema(args[i] if args else Any) if origin is tuple: + tuple_fields = tuple(self.field_for_schema(arg) for arg in args) tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) - return tuple_type(tuple(map(self.field_for_schema, args)), **metadata) + return tuple_type(tuple_fields, **metadata) if origin in (dict, collections.abc.Mapping): dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) @@ -1018,41 +1064,33 @@ def field_for_literal_type( return marshmallow.fields.Raw(validate=validate, **metadata) def field_for_new_type( - self, typ: object, default: Any, metadata: Dict[str, Any] + self, new_type: _NewType, default: Any, metadata: Mapping[str, Any] ) -> marshmallow.fields.Field: """ Return a new field for fields based on a NewType. """ # Add the information coming our custom NewType implementation - typ_args = getattr(typ, "_marshmallow_args", {}) # Handle multiple validators from both `typ` and `metadata`. # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 - validators: List[Callable[[Any], Any]] = [] - for args in (typ_args, metadata): - validate = args.get("validate") - if marshmallow.utils.is_iterable_but_not_string(validate): - validators.extend(validate) # type: ignore[arg-type] - elif validate is not None: - validators.append(validate) - - metadata = { - **typ_args, - **metadata, - "validate": validators if validators else None, - } - type_name = getattr(typ, "__name__", repr(typ)) - metadata.setdefault("metadata", {}).setdefault("description", type_name) + merged_metadata = _merge_metadata( + getattr(new_type, "_marshmallow_args", {}), + metadata, + ) + merged_metadata.setdefault("metadata", {}).setdefault( + "description", new_type.__name__ + ) field: Optional[Type[marshmallow.fields.Field]] = getattr( - typ, "_marshmallow_field", None + new_type, "_marshmallow_field", None ) if field is not None: - return field(**metadata) + return field(**merged_metadata) + return self.field_for_schema( - typ.__supertype__, # type: ignore[attr-defined] + new_type.__supertype__, default=default, - metadata=metadata, + metadata=merged_metadata, ) @staticmethod @@ -1079,12 +1117,43 @@ def schema_for_nested( # Defer evaluation of .Schema attribute, to avoid forward reference issues return partial(getattr, typ, "Schema") - class_schema = self.class_schema(typ) # type: ignore[arg-type] # FIXME + class_schema = self.class_schema(typ) if isinstance(class_schema, _Future): return class_schema.result return class_schema +def _merge_metadata(*args: Mapping[str, Any]) -> Dict[str, Any]: + """Merge mutiple metadata mappings into a single dict. + + This is a standard dict merge, except that the "validate" field + is handled specially: validators specified in any of the args + are combined. + + """ + merged: Dict[str, Any] = {} + validators: List[Callable[[Any], Any]] = [] + + for metadata in args: + merged.update(metadata) + validate = metadata.get("validate") + if callable(validate): + validators.append(validate) + elif marshmallow.utils.is_iterable_but_not_string(validate): + assert isinstance(validate, Iterable) + validators.extend(validate) + elif validate is not None: + validators.append(validate) + + if not all(callable(validate) for validate in validators): + raise ValueError( + "the 'validate' parameter must be a callable or a collection of callables." + ) + + merged["validate"] = validators if validators else None + return merged + + def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: for name, attr in inspect.getmembers(clazz): if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST: @@ -1154,7 +1223,7 @@ def _get_subtype_for_final_type(typ: object, default: Any) -> object: def field_for_schema( - typ: type, + typ: object, default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, From 5642fe2a582ba184124b5f1405942ef4b539ee2f Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Fri, 20 Jan 2023 02:05:57 -0800 Subject: [PATCH 30/32] refactor: clean up generic alias parameter mapping --- marshmallow_dataclass/__init__.py | 120 ++++++++++++++++++------------ tests/test_class_schema.py | 14 ++++ tests/test_typevar_bindings.py | 60 +++++++++++++++ 3 files changed, 146 insertions(+), 48 deletions(-) create mode 100644 tests/test_typevar_bindings.py diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 86c60ee..260c820 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -40,6 +40,7 @@ class User: import sys import types import warnings +from contextlib import contextmanager from enum import Enum from functools import partial from typing import ( @@ -107,17 +108,15 @@ def get_origin(tp): if sys.version_info >= (3, 7): from typing import OrderedDict - - TypeVar_ = TypeVar else: from typing_extensions import OrderedDict - TypeVar_ = type - if sys.version_info >= (3, 8): from typing import Protocol + from typing import final else: from typing_extensions import Protocol + from typing_extensions import final if sys.version_info >= (3, 10): from typing import TypeGuard @@ -689,47 +688,57 @@ def _has_generic_base(cls: type) -> bool: return any(typing_inspect.get_parameters(base) for base in cls.__mro__[1:]) -class _GenericArgs(Mapping[TypeVar_, TypeSpec], collections.abc.Hashable): - """A mapping of TypeVars to type specs""" +@final +@dataclasses.dataclass(frozen=True) +class _TypeVarBindings(Mapping[TypeSpec, TypeSpec]): + """A mapping of bindings of TypeVars to type specs.""" - def __init__( - self, - generic_alias: GenericAliasOfDataclass, - binding: Optional["_GenericArgs"] = None, - ): - origin = typing_inspect.get_origin(generic_alias) - parameters: Iterable[TypeVar_] = typing_inspect.get_parameters(origin) - arguments: Iterable[TypeSpec] = get_args(generic_alias) - if binding is not None: - arguments = map(binding.resolve, arguments) - - self._args = dict(zip(parameters, arguments)) - self._hashvalue = hash(tuple(self._args.items())) - - _args: Mapping[TypeVar_, TypeSpec] - _hashvalue: int - - def resolve(self, spec: Union[TypeVar_, TypeSpec]) -> TypeSpec: - if isinstance(spec, TypeVar): - try: - return self._args[spec] - except KeyError as exc: - raise UnboundTypeVarError( - f"generic type variable {spec.__name__} is not bound" - ) from exc - return spec + parameters: Sequence[_TypeVarType] = () + args: Sequence[TypeSpec] = () + + def __post_init__(self) -> None: + if len(self.parameters) != len(self.args): + raise ValueError("the 'parameters' and 'args' must be of the same length") + + @classmethod + def from_generic_alias(cls, generic_alias: GenericAlias) -> "_TypeVarBindings": + origin = get_origin(generic_alias) + parameters = typing_inspect.get_parameters(origin) + args = get_args(generic_alias) + return cls(parameters, args) - def __getitem__(self, param: TypeVar_) -> TypeSpec: - return self._args[param] + def __getitem__(self, key: TypeSpec) -> TypeSpec: + try: + i = self.parameters.index(key) + except ValueError: + raise KeyError(key) from None + return self.args[i] - def __iter__(self) -> Iterator[TypeVar_]: - return iter(self._args.keys()) + def __iter__(self) -> Iterator[_TypeVarType]: + return iter(self.parameters) def __len__(self) -> int: - return len(self._args) + return len(self.parameters) + + def compose(self, other: "_TypeVarBindings") -> "_TypeVarBindings": + """Compose TypeVar bindings. + + Given: - def __hash__(self) -> int: - return self._hashvalue + def map(bindings, spec): + return bindings.get(spec, spec) + + composed = outer.compose(inner) + + Then, for all values of spec: + + map(composed, spec) == map(outer, map(inner, spec)) + + """ + mapped_args = tuple( + self.get(arg, arg) if _is_type_var(arg) else arg for arg in other.args + ) + return _TypeVarBindings(other.parameters, mapped_args) @dataclasses.dataclass @@ -739,13 +748,23 @@ class _SchemaContext: globalns: Optional[Dict[str, Any]] = None localns: Optional[Dict[str, Any]] = None base_schema: Type[marshmallow.Schema] = marshmallow.Schema - generic_args: Optional[_GenericArgs] = None + + typevar_bindings: _TypeVarBindings = dataclasses.field( + init=False, default_factory=_TypeVarBindings + ) + seen_classes: Dict[Hashable, _Future[Type[marshmallow.Schema]]] = dataclasses.field( - default_factory=dict + init=False, default_factory=dict ) - def replace(self, generic_args: Optional[_GenericArgs]) -> "_SchemaContext": - return dataclasses.replace(self, generic_args=generic_args) + @contextmanager + def bind_type_vars(self, bindings: _TypeVarBindings) -> Iterator[None]: + outer_bindings = self.typevar_bindings + try: + self.typevar_bindings = outer_bindings.compose(bindings) + yield + finally: + self.typevar_bindings = outer_bindings def get_type_mapping( self, use_mro: bool = False @@ -788,8 +807,8 @@ def class_schema( assert _is_dataclass_type(origin) class_name = origin.__name__ constructor = origin - ctx = self.replace(generic_args=_GenericArgs(clazz, self.generic_args)) - attributes = ctx.schema_attrs_for_dataclass(origin) + with self.bind_type_vars(_TypeVarBindings.from_generic_alias(clazz)): + attributes = self.schema_attrs_for_dataclass(origin) elif _is_dataclass_type(clazz): class_name = clazz.__name__ constructor = clazz @@ -893,6 +912,14 @@ def field_for_schema( """ + if _is_type_var(typ): + type_spec = self.typevar_bindings.get(typ, typ) + if _is_type_var(type_spec): + raise UnboundTypeVarError( + f"can not resolve type variable {type_spec.__name__}" + ) + return self.field_for_schema(type_spec, default, metadata) + metadata = {} if metadata is None else dict(metadata) # If the field was already defined by the user @@ -913,9 +940,6 @@ def field_for_schema( else: metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) - if self.generic_args is not None and isinstance(typ, TypeVar): - typ = self.generic_args.resolve(typ) - if _is_builtin_collection_type(typ): return self.field_for_builtin_collection_type(typ, metadata) diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 02a1ba9..3ac8efb 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -19,6 +19,7 @@ add_schema, class_schema, NewType, + UnboundTypeVarError, _is_generic_alias_of_dataclass, ) @@ -549,6 +550,19 @@ class TestClass(Base1[int]): with self.assertRaisesRegex(TypeError, "generic base class"): class_schema(TestClass) + def test_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + def test_recursive_reference(self): @dataclasses.dataclass class Tree: diff --git a/tests/test_typevar_bindings.py b/tests/test_typevar_bindings.py new file mode 100644 index 0000000..7516bd6 --- /dev/null +++ b/tests/test_typevar_bindings.py @@ -0,0 +1,60 @@ +""" Tests for _TypeVarBindings """ +from dataclasses import dataclass +from typing import Generic +from typing import TypeVar + +import pytest + +from marshmallow_dataclass import _is_generic_alias_of_dataclass +from marshmallow_dataclass import _TypeVarBindings + + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") + + +def test_default_init() -> None: + bindings = _TypeVarBindings() + assert len(bindings) == 0 + assert list(bindings) == [] + + +def test_init_raises_on_mismatched_args(): + with pytest.raises(ValueError): + _TypeVarBindings((T, U), (int, str, bool)) + + +def test_from_generic_alias() -> None: + @dataclass + class Gen(Generic[T, U]): + a: T + b: U + + generic_alias = Gen[str, int] + assert _is_generic_alias_of_dataclass(generic_alias) + bindings = _TypeVarBindings.from_generic_alias(generic_alias) + assert dict(bindings) == {T: str, U: int} + + +def test_getitem(): + bindings = _TypeVarBindings((T, U), (int, str)) + assert bindings[U] is str + + with pytest.raises(KeyError): + bindings[V] + with pytest.raises(KeyError): + bindings[str] + with pytest.raises(KeyError): + bindings[0] + + +def test_compose(): + b1 = _TypeVarBindings((T, U), (int, V)) + b2 = _TypeVarBindings((V, W), (U, T)) + + assert dict(b1.compose(b2)) == {V: V, W: int} + assert dict(b2.compose(b1)) == {T: int, U: U} + assert dict(b1.compose(b1)) == {T: int, U: V} + assert dict(b2.compose(b2)) == {V: U, W: T} From 948c145e1e6eb8409e1a507beca61ed438e7329e Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Fri, 20 Jan 2023 23:30:49 -0800 Subject: [PATCH 31/32] refactor: disuse homegrown _LRUDict in favor of cachetools --- .pre-commit-config.yaml | 1 + marshmallow_dataclass/__init__.py | 117 +++++++++++------------------- setup.py | 2 + tests/test_lrudict.py | 35 --------- 4 files changed, 47 insertions(+), 108 deletions(-) delete mode 100644 tests/test_lrudict.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06433cf..65f7a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: - marshmallow-enum - pytest - typeguard + - types-cachetools - types-setuptools - typing-inspect args: [--show-error-codes] diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 260c820..17b94a6 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -38,6 +38,7 @@ class User: import dataclasses import inspect import sys +import threading import types import warnings from contextlib import contextmanager @@ -47,6 +48,7 @@ class User: Any, Callable, ChainMap, + ClassVar, Dict, Generic, Hashable, @@ -54,6 +56,7 @@ class User: Iterator, List, Mapping, + MutableMapping, NewType as typing_NewType, Optional, Set, @@ -67,6 +70,7 @@ class User: FrozenSet, ) +import cachetools import marshmallow import typing_inspect @@ -106,11 +110,6 @@ def get_origin(tp): return None -if sys.version_info >= (3, 7): - from typing import OrderedDict -else: - from typing_extensions import OrderedDict - if sys.version_info >= (3, 8): from typing import Protocol from typing import final @@ -599,51 +598,7 @@ def class_schema( base_schema = marshmallow.Schema schema_ctx = _SchemaContext(globalns, localns, base_schema) - schema = schema_ctx.class_schema(clazz) - assert not isinstance(schema, _Future) - return schema - - -class _LRUDict(OrderedDict[_U, _V]): - """Limited-length dict which discards LRU entries.""" - - def __init__(self, maxsize: int = 128): - self.maxsize = maxsize - super().__init__() - - def __setitem__(self, key: _U, value: _V) -> None: - super().__setitem__(key, value) - super().move_to_end(key) - - while len(self) > self.maxsize: - oldkey = next(iter(self)) - super().__delitem__(oldkey) - - def __getitem__(self, key: _U) -> _V: - val = super().__getitem__(key) - super().move_to_end(key) - return val - - _T = TypeVar("_T") - - @overload - def get(self, key: _U) -> Optional[_V]: - ... - - @overload - def get(self, key: _U, default: _T) -> Union[_V, _T]: - ... - - def get(self, key: _U, default: Any = None) -> Any: - try: - return self.__getitem__(key) - except KeyError: - return default - - -_schema_cache = _LRUDict[Hashable, Type[marshmallow.Schema]]( - MAX_CLASS_SCHEMA_CACHE_SIZE -) + return schema_ctx.class_schema(clazz).result() class InvalidStateError(Exception): @@ -743,7 +698,13 @@ def map(bindings, spec): @dataclasses.dataclass class _SchemaContext: - """Global context for an invocation of class_schema.""" + """Global context for an invocation of class_schema. + + The _SchemaContext is not thread-safe — methods on a given _SchemaContext + instance should only be invoked from a single thread. (Other threads + can safely work with their own _SchemaContext instances.) + + """ globalns: Optional[Dict[str, Any]] = None localns: Optional[Dict[str, Any]] = None @@ -753,10 +714,6 @@ class _SchemaContext: init=False, default_factory=_TypeVarBindings ) - seen_classes: Dict[Hashable, _Future[Type[marshmallow.Schema]]] = dataclasses.field( - init=False, default_factory=dict - ) - @contextmanager def bind_type_vars(self, bindings: _TypeVarBindings) -> Iterator[None]: outer_bindings = self.typevar_bindings @@ -781,20 +738,36 @@ def get_type_mapping( ) return getattr(base_schema, "TYPE_MAPPING", {}) - def class_schema( - self, clazz: Hashable - ) -> Union[Type[marshmallow.Schema], _Future[Type[marshmallow.Schema]]]: - if clazz in self.seen_classes: - return self.seen_classes[clazz] + # We use two caches: + # + # 1. A global LRU cache. This cache is solely for the sake of efficiency + # + # 2. A context-local cache. Note that a new context is created for each + # call to the public marshmallow_dataclass.class_schema function. + # This context-local cache exists in order to avoid infinite + # recursion when working on a cyclic dataclass. + # + _global_cache: ClassVar[MutableMapping[Hashable, Any]] + _global_cache = cachetools.LRUCache(MAX_CLASS_SCHEMA_CACHE_SIZE) + + def _global_cache_key(self, clazz: Hashable) -> Hashable: + return clazz, self.base_schema + + _local_cache: MutableMapping[Hashable, Any] = dataclasses.field( + init=False, default_factory=dict + ) - cache_key = clazz, self.base_schema - try: - return _schema_cache[cache_key] - except KeyError: - pass + def _get_local_cache(self) -> MutableMapping[Hashable, Any]: + return self._local_cache - future: _Future[Type[marshmallow.Schema]] = _Future() - self.seen_classes[clazz] = future + @cachetools.cached( + cache=_global_cache, key=_global_cache_key, lock=threading.Lock() + ) + @cachetools.cachedmethod(cache=_get_local_cache) + def class_schema(self, clazz: Hashable) -> _Future[Type[marshmallow.Schema]]: + # insert future result into cache to prevent recursion + future: _Future[Type[marshmallow.Schema]] + future = self._local_cache.setdefault((clazz,), _Future()) constructor: Callable[..., object] @@ -842,8 +815,7 @@ def load( ) future.set_result(schema_class) - _schema_cache[cache_key] = schema_class - return schema_class + return future def schema_attrs_for_dataclass(self, clazz: _DataclassType) -> Dict[str, Any]: if _has_generic_base(clazz): @@ -1141,10 +1113,9 @@ def schema_for_nested( # Defer evaluation of .Schema attribute, to avoid forward reference issues return partial(getattr, typ, "Schema") - class_schema = self.class_schema(typ) - if isinstance(class_schema, _Future): - return class_schema.result - return class_schema + future = self.class_schema(typ) + deferred = future.result + return deferred() if future.done() else deferred def _merge_metadata(*args: Mapping[str, Any]) -> Dict[str, Any]: diff --git a/setup.py b/setup.py index 6c00a6b..5f7a1bd 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "docs": ["sphinx"], "tests": [ "pytest>=5.4", + "types-cachetools", # re: pypy: typed-ast (a dependency of mypy) fails to install on pypy # https://github.com/python/typed_ast/issues/111 "pytest-mypy-plugins>=1.2.0; implementation_name != 'pypy'", @@ -58,6 +59,7 @@ license="MIT", python_requires=">=3.6", install_requires=[ + "cachetools>=4.2.4,<6.0", "marshmallow>=3.13.0,<4.0", "typing-inspect>=0.8.0", "typing-extensions>=3.10; python_version < '3.8'", diff --git a/tests/test_lrudict.py b/tests/test_lrudict.py deleted file mode 100644 index e323fe5..0000000 --- a/tests/test_lrudict.py +++ /dev/null @@ -1,35 +0,0 @@ -from marshmallow_dataclass import _LRUDict - - -def test_LRUDict_getitem_moves_to_end() -> None: - d = _LRUDict[str, str]() - d["a"] = "aval" - d["b"] = "bval" - assert list(d.items()) == [("a", "aval"), ("b", "bval")] - assert d["a"] == "aval" - assert list(d.items()) == [("b", "bval"), ("a", "aval")] - - -def test_LRUDict_get_moves_to_end() -> None: - d = _LRUDict[str, str]() - d["a"] = "aval" - d["b"] = "bval" - assert list(d.items()) == [("a", "aval"), ("b", "bval")] - assert d.get("a") == "aval" - assert list(d.items()) == [("b", "bval"), ("a", "aval")] - - -def test_LRUDict_setitem_moves_to_end() -> None: - d = _LRUDict[str, str]() - d["a"] = "aval" - d["b"] = "bval" - assert list(d.items()) == [("a", "aval"), ("b", "bval")] - d["a"] = "newval" - assert list(d.items()) == [("b", "bval"), ("a", "newval")] - - -def test_LRUDict_discards_oldest() -> None: - d = _LRUDict[str, str](maxsize=1) - d["a"] = "aval" - d["b"] = "bval" - assert list(d.items()) == [("b", "bval")] From bf5975c939aa2ad73fdfa9451fcc4cde5e2d5118 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Sat, 21 Jan 2023 09:24:45 -0800 Subject: [PATCH 32/32] test(coverage): adjust exclusions from coverage reports --- pyproject.toml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8503045..9364c51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,17 @@ filterwarnings = [ "error:::marshmallow_dataclass|test", ] +[tool.coverage.report] +omit = [ + # pytest-mypy-plugins run mypy plugin tests get run in a subprocess, + # so we don't get coverage data + "marshmallow_dataclass/mypy.py", +] +exclude_lines = [ + "pragma: no cover", + '^\s*\.\.\.\s*$', +] + [tool.mypy] packages = [ "marshmallow_dataclass",