From c8df67600b5c03c61a385a915461b8dec930ac36 Mon Sep 17 00:00:00 2001 From: Sylwester Kardziejonek Date: Sun, 24 Nov 2024 15:59:25 +0100 Subject: [PATCH 1/3] Allow field overrides via `Annotated` --- src/cattrs/gen/__init__.py | 8 +++- src/cattrs/gen/_shared.py | 29 ++++++++++++- tests/test_annotated_overrides.py | 72 +++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 tests/test_annotated_overrides.py diff --git a/src/cattrs/gen/__init__.py b/src/cattrs/gen/__init__.py index 7a562c47..a81e2561 100644 --- a/src/cattrs/gen/__init__.py +++ b/src/cattrs/gen/__init__.py @@ -33,7 +33,7 @@ from ._consts import AttributeOverride, already_generating, neutral from ._generics import generate_mapping from ._lc import generate_unique_filename -from ._shared import find_structure_handler +from ._shared import find_structure_handler, get_fields_annotated_by if TYPE_CHECKING: from ..converters import BaseConverter @@ -260,6 +260,10 @@ def make_dict_unstructure_fn( working_set.add(cl) + # Merge overrides provided via Annotated with kwargs + annotated_overrides = get_fields_annotated_by(cl, AttributeOverride) + annotated_overrides.update(kwargs) + try: return make_dict_unstructure_fn_from_attrs( attrs, @@ -270,7 +274,7 @@ def make_dict_unstructure_fn( _cattrs_use_linecache=_cattrs_use_linecache, _cattrs_use_alias=_cattrs_use_alias, _cattrs_include_init_false=_cattrs_include_init_false, - **kwargs, + **annotated_overrides, ) finally: working_set.remove(cl) diff --git a/src/cattrs/gen/_shared.py b/src/cattrs/gen/_shared.py index 904c7744..fca7c87a 100644 --- a/src/cattrs/gen/_shared.py +++ b/src/cattrs/gen/_shared.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar, get_type_hints from attrs import NOTHING, Attribute, Factory @@ -10,8 +10,10 @@ from ..fns import raise_error if TYPE_CHECKING: + from collections.abc import Mapping from ..converters import BaseConverter +T = TypeVar("T") def find_structure_handler( a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False @@ -62,3 +64,28 @@ def handler(v, _, _h=handler): except RecursionError: # This means we're dealing with a reference cycle, so use late binding. return c.structure + + +def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]: + type_hints = get_type_hints(cls, include_extras=True) + # Support for both AttributeOverride and AttributeOverride() + annotation_type_ = annotation_type if isinstance(annotation_type, type) else type(annotation_type) + + # First pass of filtering to get only fields with annotations + fields_with_annotations = ( + (field_name, param_spec.__metadata__) + for field_name, param_spec in type_hints.items() + if hasattr(param_spec, "__metadata__") + ) + + # Now that we have fields with ANY annotations, we need to remove unwanted annotations. + fields_with_specific_annotation = ( + ( + field_name, + next((a for a in annotations if isinstance(a, annotation_type_)), None), + ) + for field_name, annotations in fields_with_annotations + ) + + # We still might have some `None` values from previous filtering. + return {field_name: annotation for field_name, annotation in fields_with_specific_annotation if annotation} diff --git a/tests/test_annotated_overrides.py b/tests/test_annotated_overrides.py new file mode 100644 index 00000000..bbe4c195 --- /dev/null +++ b/tests/test_annotated_overrides.py @@ -0,0 +1,72 @@ +from typing import Annotated + +import attrs +import pytest + +from cattrs.gen._shared import get_fields_annotated_by + + +class NotThere: ... + + +class IgnoreMe: + def __init__(self, why: str | None = None): + self.why = why + + +class FindMe: + def __init__(self, taint: str): + self.taint = taint + + +class EmptyClassExample: + pass + + +class PureClassExample: + id: Annotated[int, FindMe("red")] + name: Annotated[str, FindMe] + + +class MultipleAnnotationsExample: + id: Annotated[int, FindMe("red"), IgnoreMe()] + name: Annotated[str, IgnoreMe()] + surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")] + + +@attrs.define +class AttrsClassExample: + id: int = attrs.field(default=0) + color: Annotated[str, FindMe("blue")] = attrs.field(default="red") + config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict) + + +class PureClassInheritanceExample(PureClassExample): + include: dict + exclude: Annotated[dict, FindMe("boring things")] + extras: Annotated[dict, FindMe] + + +@pytest.mark.parametrize( + "klass,expected", + [ + (EmptyClassExample, {}), + (PureClassExample, {"id": isinstance}), + (AttrsClassExample, {"color": isinstance, "config": isinstance}), + (MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}), + (PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}), + ], +) +@pytest.mark.parametrize("instantiate", [True, False]) +def test_gets_annotated_types(klass, expected, instantiate: bool): + annotated = get_fields_annotated_by(klass, FindMe("irrelevant") if instantiate else FindMe) + + assert set(annotated.keys()) == set(expected.keys()), "Too many or too few annotations" + assert all( + assertion_func(annotated[field_name], FindMe) for field_name, assertion_func in expected.items() + ), "Unexpected type of annotation" + + +def test_empty_result_for_missing_annotation(): + annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere) + assert not annotated, "No annotation should be found." From 089a3d14782548674bdc2b430a1d14d8f9fca056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sun, 12 Jan 2025 22:43:48 +0100 Subject: [PATCH 2/3] fix test --- tests/test_annotated_overrides.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_annotated_overrides.py b/tests/test_annotated_overrides.py index bbe4c195..7a9846e4 100644 --- a/tests/test_annotated_overrides.py +++ b/tests/test_annotated_overrides.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Union import attrs import pytest @@ -10,7 +10,7 @@ class NotThere: ... class IgnoreMe: - def __init__(self, why: str | None = None): + def __init__(self, why: Union[str, None] = None): self.why = why @@ -59,11 +59,16 @@ class PureClassInheritanceExample(PureClassExample): ) @pytest.mark.parametrize("instantiate", [True, False]) def test_gets_annotated_types(klass, expected, instantiate: bool): - annotated = get_fields_annotated_by(klass, FindMe("irrelevant") if instantiate else FindMe) + annotated = get_fields_annotated_by( + klass, FindMe("irrelevant") if instantiate else FindMe + ) - assert set(annotated.keys()) == set(expected.keys()), "Too many or too few annotations" + assert set(annotated.keys()) == set( + expected.keys() + ), "Too many or too few annotations" assert all( - assertion_func(annotated[field_name], FindMe) for field_name, assertion_func in expected.items() + assertion_func(annotated[field_name], FindMe) + for field_name, assertion_func in expected.items() ), "Unexpected type of annotation" From 972e6cbde40e1f3508b57e3a550e2286c2079331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sun, 12 Jan 2025 22:51:48 +0100 Subject: [PATCH 3/3] Fix lint --- src/cattrs/gen/_shared.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/cattrs/gen/_shared.py b/src/cattrs/gen/_shared.py index fca7c87a..809e9b45 100644 --- a/src/cattrs/gen/_shared.py +++ b/src/cattrs/gen/_shared.py @@ -10,11 +10,11 @@ from ..fns import raise_error if TYPE_CHECKING: - from collections.abc import Mapping from ..converters import BaseConverter T = TypeVar("T") + def find_structure_handler( a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False ) -> StructureHook | None: @@ -69,7 +69,9 @@ def handler(v, _, _h=handler): def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]: type_hints = get_type_hints(cls, include_extras=True) # Support for both AttributeOverride and AttributeOverride() - annotation_type_ = annotation_type if isinstance(annotation_type, type) else type(annotation_type) + annotation_type_ = ( + annotation_type if isinstance(annotation_type, type) else type(annotation_type) + ) # First pass of filtering to get only fields with annotations fields_with_annotations = ( @@ -88,4 +90,8 @@ def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str ) # We still might have some `None` values from previous filtering. - return {field_name: annotation for field_name, annotation in fields_with_specific_annotation if annotation} + return { + field_name: annotation + for field_name, annotation in fields_with_specific_annotation + if annotation + }