diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index aa47ab0..923cb4d 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -356,7 +356,9 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): + if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass( + clazz + ): clazz = dataclasses.dataclass(clazz) if not clazz_frame: current_frame = inspect.currentframe() @@ -366,21 +368,30 @@ def class_schema( del current_frame _RECURSION_GUARD.seen_classes = {} try: - return _internal_class_schema(clazz, base_schema, clazz_frame) + return _internal_class_schema(clazz, base_schema, clazz_frame, None) finally: _RECURSION_GUARD.seen_classes.clear() +def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + if _is_generic_alias_of_dataclass(clazz): + clazz = typing_inspect.get_origin(clazz) + return dataclasses.fields(clazz) + + @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, clazz_frame: Optional[types.FrameType] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ + # generic aliases do not have a __name__ prior python 3.10 + _name = getattr(clazz, "__name__", repr(clazz)) + + _RECURSION_GUARD.seen_classes[clazz] = _name try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) + fields = _dataclass_fields(clazz) except TypeError: # Not a dataclass try: warnings.warn( @@ -395,7 +406,9 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) + return _internal_class_schema( + created_dataclass, base_schema, clazz_frame, generic_params_to_args + ) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -408,10 +421,11 @@ def _internal_class_schema( if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST } + if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None: + generic_params_to_args = _generic_params_to_args(clazz) + + type_hints = _dataclass_type_hints(clazz, clazz_frame, generic_params_to_args) # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None - ) attributes.update( ( field.name, @@ -421,13 +435,14 @@ def _internal_class_schema( field.metadata, base_schema, clazz_frame, + generic_params_to_args, ), ) for field in fields if field.init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -446,6 +461,7 @@ def _field_by_supertype( metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], typ_frame: Optional[types.FrameType], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -477,6 +493,7 @@ def _field_by_supertype( default=default, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) @@ -501,6 +518,7 @@ def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], typ_frame: Optional[types.FrameType], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -514,7 +532,10 @@ def _field_for_generic_type( if origin in (list, List): child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) list_type = cast( Type[marshmallow.fields.List], @@ -529,14 +550,20 @@ def _field_for_generic_type( from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata @@ -545,14 +572,22 @@ def _field_for_generic_type( from . import collection_field child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) + field_for_schema( + arg, + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, + ) for arg in arguments ) tuple_type = cast( @@ -562,14 +597,20 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame + arguments[0], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame + arguments[1], + base_schema=base_schema, + typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), **metadata, ) @@ -587,6 +628,7 @@ def _field_for_generic_type( metadata=metadata, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) from . import union_field @@ -599,6 +641,7 @@ def _field_for_generic_type( metadata={"required": True}, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ), ) for subtyp in subtypes @@ -614,6 +657,7 @@ def field_for_schema( metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, typ_frame: Optional[types.FrameType] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -706,7 +750,9 @@ def field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) + return field_for_schema( + subtyp, default, metadata, base_schema, typ_frame, generic_params_to_args + ) # Generic types generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) @@ -724,6 +770,7 @@ def field_for_schema( metadata=metadata, base_schema=base_schema, typ_frame=typ_frame, + generic_params_to_args=generic_params_to_args, ) # enumerations @@ -746,8 +793,7 @@ def field_for_schema( nested = ( nested_schema or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type] + or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args) # type: ignore[arg-type] ) return marshmallow.fields.Nested(nested, **metadata) @@ -791,6 +837,50 @@ def _get_field_default(field: dataclasses.Field): return field.default +def _is_generic_alias_of_dataclass(clazz: type) -> bool: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + is_generic = typing_inspect.is_generic_type(clazz) + type_arguments = typing_inspect.get_args(clazz) + origin_class = typing_inspect.get_origin(clazz) + return ( + is_generic + and len(type_arguments) > 0 + and dataclasses.is_dataclass(origin_class) + ) + + +def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]: + base_dataclass = typing_inspect.get_origin(clazz) + base_parameters = typing_inspect.get_parameters(base_dataclass) + type_arguments = typing_inspect.get_args(clazz) + return tuple(zip(base_parameters, type_arguments)) + + +def _dataclass_type_hints( + clazz: type, + clazz_frame: Optional[types.FrameType] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, +) -> Mapping[str, type]: + localns = clazz_frame.f_locals if clazz_frame else None + if not _is_generic_alias_of_dataclass(clazz): + return get_type_hints(clazz, localns=localns) + # dataclass is generic + generic_type_hints = get_type_hints(typing_inspect.get_origin(clazz), localns) + generic_params_map = dict(generic_params_to_args if generic_params_to_args else {}) + + def _get_hint(_t: type) -> type: + if isinstance(_t, TypeVar): + return generic_params_map[_t] + return _t + + return { + field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items() + } + + def NewType( name: str, typ: Type[_U], diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index aa82975..10487e1 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -14,7 +14,7 @@ from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass class TestClassSchema(unittest.TestCase): @@ -401,6 +401,79 @@ class J: [validator_a, validator_b, validator_c, validator_d], ) + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + def test_recursive_reference(self): @dataclasses.dataclass class Tree: