Skip to content

Add a test for missing generics in stubs #2659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:

# Must match `shard` definition in the test matrix:
- name: Run pytest tests
run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests
run: PYTHONPATH='.' pytest --num-shards=4 --shard-id=${{ matrix.shard }} -n auto tests --durations=0
- name: Run mypy on the test cases
run: mypy --strict tests/assert_type

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ This happens because these Django classes do not support [`__class_getitem__`](h

You can add extra types to patch with `django_stubs_ext.monkeypatch(extra_classes=[YourDesiredType])`

**If you use generic symbols in `django.contrib.auth.forms`**, you will have to do the monkeypatching
again in your first [`AppConfig.ready`](https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready).
This is currently required because `django.contrib.auth.forms` cannot be imported until django is initialized.


2. You can use strings instead: `'QuerySet[MyModel]'` and `'Manager[MyModel]'`, this way it will work as a type for `mypy` and as a regular `str` in runtime.

### How can I create a HttpRequest that's guaranteed to have an authenticated user?
Expand Down
47 changes: 41 additions & 6 deletions ext/django_stubs_ext/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import builtins
import logging
from collections.abc import Iterable
from typing import Any, Generic, TypeVar

Expand All @@ -8,24 +9,33 @@
from django.contrib.messages.views import SuccessMessageMixin
from django.contrib.sitemaps import Sitemap
from django.contrib.syndication.views import Feed
from django.core.exceptions import AppRegistryNotReady, ImproperlyConfigured
from django.core.files.utils import FileProxyMixin
from django.core.paginator import Paginator
from django.db.models.expressions import ExpressionWrapper
from django.db.models.fields import Field
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ReverseManyToOneDescriptor,
ReverseOneToOneDescriptor,
)
from django.db.models.lookups import Lookup
from django.db.models.manager import BaseManager
from django.db.models.query import ModelIterable, QuerySet, RawQuerySet
from django.db.models.options import Options
from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet
from django.forms.formsets import BaseFormSet
from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField
from django.utils.connection import BaseConnectionHandler
from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions
from django.utils.connection import BaseConnectionHandler, ConnectionProxy
from django.utils.functional import classproperty
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.edit import DeletionMixin, FormMixin
from django.views.generic.list import MultipleObjectMixin

__all__ = ["monkeypatch"]

logger = logging.getLogger(__name__)

_T = TypeVar("_T")
_VersionSpec = tuple[int, int]

Expand Down Expand Up @@ -81,16 +91,41 @@ def __repr__(self) -> str:
# These types do have native `__class_getitem__` method since django 4.1:
MPGeneric(ForeignKey, (4, 1)),
MPGeneric(RawQuerySet),
MPGeneric(classproperty),
MPGeneric(ConnectionProxy),
MPGeneric(ModelFormOptions),
MPGeneric(Options),
MPGeneric(BaseIterable),
MPGeneric(ForwardManyToOneDescriptor),
MPGeneric(ReverseOneToOneDescriptor),
]


def _get_need_generic() -> list[MPGeneric[Any]]:
try:
if VERSION >= (5, 1):
from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin

return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *_need_generic]
else:
from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm

return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *_need_generic]

except (ImproperlyConfigured, AppRegistryNotReady):
# We cannot patch symbols in `django.contrib.auth.forms` if the `monkeypatch()` call
# is in the settings file because django is not initialized yet.
# To solve this, you'll have to call `monkeypatch()` again later, in an `AppConfig.ready` for ex.
# See https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready
return _need_generic


def monkeypatch(extra_classes: Iterable[type] | None = None, include_builtins: bool = True) -> None:
"""Monkey patch django as necessary to work properly with mypy."""

# Add the __class_getitem__ dunder.
suited_for_this_version = filter(
lambda spec: spec.version is None or VERSION[:2] <= spec.version,
_need_generic,
_get_need_generic(),
)
for el in suited_for_this_version:
el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls)
Expand Down
90 changes: 90 additions & 0 deletions tests/test_generic_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import ast
import glob
import importlib
import os
from typing import final
from unittest import mock

import django

# The root directory of the django-stubs package
STUBS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "django-stubs"))


@final
class GenericInheritanceVisitor(ast.NodeVisitor):
"""AST visitor to find classes inheriting from `typing.Generic` in stubs."""

def __init__(self) -> None:
self.generic_classes: set[str] = set()

def visit_ClassDef(self, node: ast.ClassDef) -> None:
for base in node.bases:
if (
isinstance(base, ast.Subscript)
and isinstance(base.value, ast.Name)
and base.value.id == "Generic"
and not any(dec.id == "type_check_only" for dec in node.decorator_list if isinstance(dec, ast.Name))
):
self.generic_classes.add(node.name)
break
self.generic_visit(node)


def test_find_classes_inheriting_from_generic() -> None:
"""
This test ensures that the `ext/django_stubs_ext/patch.py` stays up-to-date with the stubs.
It works as follows:
1. Parse the ast of each .pyi file, and collects classes inheriting from Generic.
2. For each Generic in the stubs, import the associated module and capture every class in the MRO
3. Ensure that at least one class in the mro is patched in `ext/django_stubs_ext/patch.py`.
"""
with mock.patch.dict(os.environ, {"DJANGO_SETTINGS_MODULE": "scripts.django_tests_settings"}):
# We need this to be able to do django import
django.setup()

# A dict of class_name -> [subclasses names] for each Generic in the stubs.
all_generic_classes: dict[str, list[str]] = {}

print(f"Searching for classes inheriting from Generic in: {STUBS_ROOT}")
pyi_files = glob.glob("**/*.pyi", root_dir=STUBS_ROOT, recursive=True)
for file_path in pyi_files:
with open(os.path.join(STUBS_ROOT, file_path)) as f:
source = f.read()

tree = ast.parse(source)
generic_visitor = GenericInheritanceVisitor()
generic_visitor.visit(tree)

# For each Generic in the stubs, import the associated module and capture every class in the MRO
if generic_visitor.generic_classes:
module_name = _get_module_from_pyi(file_path)
django_module = importlib.import_module(module_name)
all_generic_classes.update(
{
cls: [subcls.__name__ for subcls in getattr(django_module, cls).mro()[1:-1]]
for cls in generic_visitor.generic_classes
}
)

print(f"Processed {len(pyi_files)} .pyi files.")
print(f"Found {len(all_generic_classes)} unique classes inheriting from Generic in stubs")

# Class patched in `ext/django_stubs_ext/patch.py`
import django_stubs_ext

patched_classes = {mp_generic.cls.__name__ for mp_generic in django_stubs_ext.patch._get_need_generic()}

# Pretty-print missing patch in `ext/django_stubs_ext/patch.py`
errors = []
for cls_name, subcls_names in all_generic_classes.items():
if not any(name in patched_classes for name in [*subcls_names, cls_name]):
bases = f"({', '.join(subcls_names)})" if subcls_names else ""
errors.append(f"{cls_name}{bases} is not patched in `ext/django_stubs_ext/patch.py`")

assert not errors, "\n".join(errors)


def _get_module_from_pyi(pyi_path: str) -> str:
py_module = "django." + pyi_path.replace(".pyi", "").replace("/", ".")
return py_module.removesuffix(".__init__")
Loading