diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 8d45f1c35..530fab5b5 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -9,12 +9,13 @@ from django import forms from django.core import validators # due to weird mypy.stubtest error from django.core.checks import CheckMessage from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.models import Model +from django.db.models import Choices, Model from django.db.models.expressions import Col, Combinable, Expression, Func from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.query_utils import Q, RegisterLookupMixin from django.forms import Widget -from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _Choices, _ChoicesCallable +from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _ChoicesCallable, _ChoicesMapping +from django.utils.choices import _Choices as _ChoicesSequence from django.utils.datastructures import DictWrapper from django.utils.functional import _Getter, _StrOrPromise, cached_property from typing_extensions import Self, TypeAlias @@ -27,6 +28,9 @@ BLANK_CHOICE_DASH: list[tuple[str, str]] _ChoicesList: TypeAlias = Sequence[_Choice] | Sequence[_ChoiceNamedGroup] _LimitChoicesTo: TypeAlias = Q | dict[str, Any] _LimitChoicesToCallable: TypeAlias = Callable[[], _LimitChoicesTo] +_Choices: TypeAlias = ( + _ChoicesSequence | _ChoicesMapping | type[Choices] | Callable[[], _ChoicesSequence | _ChoicesMapping] +) _F = TypeVar("_F", bound=Field, covariant=True) diff --git a/django-stubs/utils/choices.pyi b/django-stubs/utils/choices.pyi index dbbca5490..ab4b44d70 100644 --- a/django-stubs/utils/choices.pyi +++ b/django-stubs/utils/choices.pyi @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Mapping from typing import Any, Protocol, TypeVar, type_check_only from typing_extensions import TypeAlias @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias _Choice: TypeAlias = tuple[Any, Any] _ChoiceNamedGroup: TypeAlias = tuple[str, Iterable[_Choice]] _Choices: TypeAlias = Iterable[_Choice | _ChoiceNamedGroup] +_ChoicesMapping: TypeAlias = Mapping[Any, Any] # noqa: PYI047 @type_check_only class _ChoicesCallable(Protocol): diff --git a/tests/assert_type/db/models/fields/test_choices.py b/tests/assert_type/db/models/fields/test_choices.py new file mode 100644 index 000000000..1f9f773c8 --- /dev/null +++ b/tests/assert_type/db/models/fields/test_choices.py @@ -0,0 +1,89 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import TypeVar + +from django.db import models +from typing_extensions import assert_type + +_T = TypeVar("_T") + + +def to_named_seq(func: Callable[[], _T]) -> Callable[[], Sequence[tuple[str, _T]]]: + def inner() -> Sequence[tuple[str, _T]]: + return [("title", func())] + + return inner + + +def to_named_mapping(func: Callable[[], _T]) -> Callable[[], Mapping[str, _T]]: + def inner() -> Mapping[str, _T]: + return {"title": func()} + + return inner + + +def str_tuple() -> Sequence[tuple[str, str]]: + return (("foo", "bar"), ("fuzz", "bazz")) + + +def str_mapping() -> Mapping[str, str]: + return {"foo": "bar", "fuzz": "bazz"} + + +def int_tuple() -> Sequence[tuple[int, str]]: + return ((1, "bar"), (2, "bazz")) + + +def int_mapping() -> Mapping[int, str]: + return {3: "bar", 4: "bazz"} + + +class TestModel(models.Model): + class TextChoices(models.TextChoices): + FIRST = "foo", "bar" + SECOND = "foo2", "bar" + + class IntegerChoices(models.IntegerChoices): + FIRST = 1, "bar" + SECOND = 2, "bar" + + char1 = models.CharField[str, str](max_length=5, choices=TextChoices, default="foo") + char2 = models.CharField[str, str](max_length=5, choices=str_tuple, default="foo") + char3 = models.CharField[str, str](max_length=5, choices=str_mapping, default="foo") + char4 = models.CharField[str, str](max_length=5, choices=str_tuple(), default="foo") + char5 = models.CharField[str, str](max_length=5, choices=str_mapping(), default="foo") + char6 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple), default="foo") + char7 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping), default="foo") + char8 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple)(), default="foo") + char9 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping)(), default="foo") + + int1 = models.IntegerField[int, int](choices=IntegerChoices, default=1) + int2 = models.IntegerField[int, int](choices=int_tuple, default=1) + int3 = models.IntegerField[int, int](choices=int_mapping, default=1) + int4 = models.IntegerField[int, int](choices=int_tuple(), default=1) + int5 = models.IntegerField[int, int](choices=int_mapping(), default=1) + int6 = models.IntegerField[int, int](choices=to_named_seq(int_tuple), default=1) + int7 = models.IntegerField[int, int](choices=to_named_seq(int_mapping), default=1) + int8 = models.IntegerField[int, int](choices=to_named_seq(int_tuple)(), default=1) + int9 = models.IntegerField[int, int](choices=to_named_seq(int_mapping)(), default=1) + + +instance = TestModel() +assert_type(instance.char1, str) +assert_type(instance.char2, str) +assert_type(instance.char3, str) +assert_type(instance.char4, str) +assert_type(instance.char5, str) +assert_type(instance.char6, str) +assert_type(instance.char7, str) +assert_type(instance.char8, str) +assert_type(instance.char9, str) + +assert_type(instance.int1, int) +assert_type(instance.int2, int) +assert_type(instance.int3, int) +assert_type(instance.int4, int) +assert_type(instance.int5, int) +assert_type(instance.int6, int) +assert_type(instance.int7, int) +assert_type(instance.int8, int) +assert_type(instance.int9, int) diff --git a/tests/typecheck/db/models/test_fields.yml b/tests/typecheck/db/models/test_fields.yml new file mode 100644 index 000000000..3a8e32350 --- /dev/null +++ b/tests/typecheck/db/models/test_fields.yml @@ -0,0 +1,13 @@ +- case: db_models_fields_choices + main: | + from django.db import models + + class MyModel(models.Model): + char1 = models.CharField[str, str](max_length=200, choices='test') + out: | + main:4: error: Argument "choices" to "CharField" has incompatible type "str"; expected "Union[Iterable[Union[Tuple[Any, Any], Tuple[str, Iterable[Tuple[Any, Any]]]]], Mapping[Any, Any], Type[Choices], Callable[[], Union[Iterable[Union[Tuple[Any, Any], Tuple[str, Iterable[Tuple[Any, Any]]]]], Mapping[Any, Any]]], None]" [arg-type] + main:4: note: Following member(s) of "str" have conflicts: + main:4: note: Expected: + main:4: note: def __iter__(self) -> Iterator[Union[Tuple[Any, Any], Tuple[str, Iterable[Tuple[Any, Any]]]]] + main:4: note: Got: + main:4: note: def __iter__(self) -> Iterator[str]