diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 759c00969..2c5426cea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,7 +68,7 @@ jobs: # Must match `shard` definition in the test matrix: - name: Run pytest tests - run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests + run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests --durations=0 - name: Run mypy on the test cases run: mypy --strict tests/assert_type diff --git a/README.md b/README.md index dbdab0b42..06b86085f 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,11 @@ This happens because these Django classes do not support [`__class_getitem__`](h You can add extra types to patch with `django_stubs_ext.monkeypatch(extra_classes=[YourDesiredType])` + **If you use generic symbols in `django.contrib.auth.forms`**, you will have to do the monkeypatching + again in your first [`AppConfig.ready`](https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready). + This is currently required because `django.contrib.auth.forms` cannot be imported until django is initialized. + + 2. You can use strings instead: `'QuerySet[MyModel]'` and `'Manager[MyModel]'`, this way it will work as a type for `mypy` and as a regular `str` in runtime. ### How can I create a HttpRequest that's guaranteed to have an authenticated user? diff --git a/ext/django_stubs_ext/patch.py b/ext/django_stubs_ext/patch.py index e8e85ebd7..f5e63e360 100644 --- a/ext/django_stubs_ext/patch.py +++ b/ext/django_stubs_ext/patch.py @@ -1,4 +1,5 @@ import builtins +import logging from collections.abc import Iterable from typing import Any, Generic, TypeVar @@ -8,24 +9,33 @@ from django.contrib.messages.views import SuccessMessageMixin from django.contrib.sitemaps import Sitemap from django.contrib.syndication.views import Feed +from django.core.exceptions import AppRegistryNotReady, ImproperlyConfigured from django.core.files.utils import FileProxyMixin from django.core.paginator import Paginator from django.db.models.expressions import ExpressionWrapper from django.db.models.fields import Field from django.db.models.fields.related import ForeignKey -from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor +from django.db.models.fields.related_descriptors import ( + ForwardManyToOneDescriptor, + ReverseManyToOneDescriptor, + ReverseOneToOneDescriptor, +) from django.db.models.lookups import Lookup from django.db.models.manager import BaseManager -from django.db.models.query import ModelIterable, QuerySet, RawQuerySet +from django.db.models.options import Options +from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet from django.forms.formsets import BaseFormSet -from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField -from django.utils.connection import BaseConnectionHandler +from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions +from django.utils.connection import BaseConnectionHandler, ConnectionProxy +from django.utils.functional import classproperty from django.views.generic.detail import SingleObjectMixin from django.views.generic.edit import DeletionMixin, FormMixin from django.views.generic.list import MultipleObjectMixin __all__ = ["monkeypatch"] +logger = logging.getLogger(__name__) + _T = TypeVar("_T") _VersionSpec = tuple[int, int] @@ -81,16 +91,41 @@ def __repr__(self) -> str: # These types do have native `__class_getitem__` method since django 4.1: MPGeneric(ForeignKey, (4, 1)), MPGeneric(RawQuerySet), + MPGeneric(classproperty), + MPGeneric(ConnectionProxy), + MPGeneric(ModelFormOptions), + MPGeneric(Options), + MPGeneric(BaseIterable), + MPGeneric(ForwardManyToOneDescriptor), + MPGeneric(ReverseOneToOneDescriptor), ] +def _get_need_generic() -> list[MPGeneric[Any]]: + try: + if VERSION >= (5, 1): + from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin + + return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *_need_generic] + else: + from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm + + return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *_need_generic] + + except (ImproperlyConfigured, AppRegistryNotReady): + # We cannot patch symbols in `django.contrib.auth.forms` if the `monkeypatch()` call + # is in the settings file because django is not initialized yet. + # To solve this, you'll have to call `monkeypatch()` again later, in an `AppConfig.ready` for ex. + # See https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready + return _need_generic + + def monkeypatch(extra_classes: Iterable[type] | None = None, include_builtins: bool = True) -> None: """Monkey patch django as necessary to work properly with mypy.""" - # Add the __class_getitem__ dunder. suited_for_this_version = filter( lambda spec: spec.version is None or VERSION[:2] <= spec.version, - _need_generic, + _get_need_generic(), ) for el in suited_for_this_version: el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) diff --git a/tests/test_generic_consistency.py b/tests/test_generic_consistency.py new file mode 100644 index 000000000..5fa3c5bbf --- /dev/null +++ b/tests/test_generic_consistency.py @@ -0,0 +1,90 @@ +import ast +import glob +import importlib +import os +from typing import final +from unittest import mock + +import django + +# The root directory of the django-stubs package +STUBS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "django-stubs")) + + +@final +class GenericInheritanceVisitor(ast.NodeVisitor): + """AST visitor to find classes inheriting from `typing.Generic` in stubs.""" + + def __init__(self) -> None: + self.generic_classes: set[str] = set() + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + for base in node.bases: + if ( + isinstance(base, ast.Subscript) + and isinstance(base.value, ast.Name) + and base.value.id == "Generic" + and not any(dec.id == "type_check_only" for dec in node.decorator_list if isinstance(dec, ast.Name)) + ): + self.generic_classes.add(node.name) + break + self.generic_visit(node) + + +def test_find_classes_inheriting_from_generic() -> None: + """ + This test ensures that the `ext/django_stubs_ext/patch.py` stays up-to-date with the stubs. + It works as follows: + 1. Parse the ast of each .pyi file, and collects classes inheriting from Generic. + 2. For each Generic in the stubs, import the associated module and capture every class in the MRO + 3. Ensure that at least one class in the mro is patched in `ext/django_stubs_ext/patch.py`. + """ + with mock.patch.dict(os.environ, {"DJANGO_SETTINGS_MODULE": "scripts.django_tests_settings"}): + # We need this to be able to do django import + django.setup() + + # A dict of class_name -> [subclasses names] for each Generic in the stubs. + all_generic_classes: dict[str, list[str]] = {} + + print(f"Searching for classes inheriting from Generic in: {STUBS_ROOT}") + pyi_files = glob.glob("**/*.pyi", root_dir=STUBS_ROOT, recursive=True) + for file_path in pyi_files: + with open(os.path.join(STUBS_ROOT, file_path)) as f: + source = f.read() + + tree = ast.parse(source) + generic_visitor = GenericInheritanceVisitor() + generic_visitor.visit(tree) + + # For each Generic in the stubs, import the associated module and capture every class in the MRO + if generic_visitor.generic_classes: + module_name = _get_module_from_pyi(file_path) + django_module = importlib.import_module(module_name) + all_generic_classes.update( + { + cls: [subcls.__name__ for subcls in getattr(django_module, cls).mro()[1:-1]] + for cls in generic_visitor.generic_classes + } + ) + + print(f"Processed {len(pyi_files)} .pyi files.") + print(f"Found {len(all_generic_classes)} unique classes inheriting from Generic in stubs") + + # Class patched in `ext/django_stubs_ext/patch.py` + import django_stubs_ext + + patched_classes = {mp_generic.cls.__name__ for mp_generic in django_stubs_ext.patch._get_need_generic()} + + # Pretty-print missing patch in `ext/django_stubs_ext/patch.py` + errors = [] + for cls_name, subcls_names in all_generic_classes.items(): + if not any(name in patched_classes for name in [*subcls_names, cls_name]): + bases = f"({', '.join(subcls_names)})" if subcls_names else "" + errors.append(f"{cls_name}{bases} is not patched in `ext/django_stubs_ext/patch.py`") + + assert not errors, "\n".join(errors) + + +def _get_module_from_pyi(pyi_path: str) -> str: + py_module = "django." + pyi_path.replace(".pyi", "").replace("/", ".") + return py_module.removesuffix(".__init__")