Skip to content

Commit 175ffd1

Browse files
committed
Allow generic requets, responses, fields, views
Allow Request, Response, Field, and GenericAPIView to be subscriptable. This allows the classes to be made generic for type checking. This is especially useful since monkey patching DRF can be problematic as seen in this [issue][1]. [1]: typeddjango/djangorestframework-stubs#299
1 parent 48a21aa commit 175ffd1

File tree

8 files changed

+73
-0
lines changed

8 files changed

+73
-0
lines changed

rest_framework/fields.py

+4
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ def __init__(self, *, read_only=False, write_only=False,
355355
messages.update(getattr(cls, 'default_error_messages', {}))
356356
messages.update(error_messages or {})
357357
self.error_messages = messages
358+
359+
# Allow generic typing checking for fields.
360+
def __class_getitem__(cls, *args, **kwargs):
361+
return cls
358362

359363
def bind(self, field_name, parent):
360364
"""

rest_framework/generics.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class GenericAPIView(views.APIView):
4545
# The style to use for queryset pagination.
4646
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
4747

48+
# Allow generic typing checking for generic views.
49+
def __class_getitem__(cls, *args, **kwargs):
50+
return cls
51+
4852
def get_queryset(self):
4953
"""
5054
Get the list of items for this view.

rest_framework/request.py

+4
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def __repr__(self):
185185
self.__class__.__name__,
186186
self.method,
187187
self.get_full_path())
188+
189+
# Allow generic typing checking for requests.
190+
def __class_getitem__(cls, *args, **kwargs):
191+
return cls
188192

189193
def _default_negotiator(self):
190194
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()

rest_framework/response.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(self, data=None, status=None,
4545
if headers:
4646
for name, value in headers.items():
4747
self[name] = value
48+
49+
# Allow generic typing checking for responses.
50+
def __class_getitem__(cls, *args, **kwargs):
51+
return cls
4852

4953
@property
5054
def rendered_content(self):

tests/test_fields.py

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import os
44
import re
5+
import sys
56
import uuid
67
from decimal import ROUND_DOWN, ROUND_UP, Decimal
78

@@ -624,6 +625,15 @@ def test_parent_binding(self):
624625
assert field.root is parent
625626

626627

628+
class TestTyping(TestCase):
629+
@pytest.mark.skipif(
630+
sys.version_info < (3, 7),
631+
reason="subscriptable classes requires Python 3.7 or higher",
632+
)
633+
def test_field_is_subscriptable(self):
634+
assert serializers.Field is serializers.Field["foo"]
635+
636+
627637
# Tests for field input and output values.
628638
# ----------------------------------------
629639

tests/test_generics.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
import pytest
24
from django.db import models
35
from django.http import Http404
@@ -698,3 +700,26 @@ def list(self, request):
698700
serializer = response.serializer
699701

700702
assert serializer.context is context
703+
704+
705+
class TestTyping(TestCase):
706+
@pytest.mark.skipif(
707+
sys.version_info < (3, 7),
708+
reason="subscriptable classes requires Python 3.7 or higher",
709+
)
710+
def test_genericview_is_subscriptable(self):
711+
assert generics.GenericAPIView is generics.GenericAPIView["foo"]
712+
713+
@pytest.mark.skipif(
714+
sys.version_info < (3, 7),
715+
reason="subscriptable classes requires Python 3.7 or higher",
716+
)
717+
def test_listview_is_subscriptable(self):
718+
assert generics.ListAPIView is generics.ListAPIView["foo"]
719+
720+
@pytest.mark.skipif(
721+
sys.version_info < (3, 7),
722+
reason="subscriptable classes requires Python 3.7 or higher",
723+
)
724+
def test_instanceview_is_subscriptable(self):
725+
assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"]

tests/test_request.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import copy
55
import os.path
6+
import sys
67
import tempfile
78

89
import pytest
@@ -352,3 +353,12 @@ class TestDeepcopy(TestCase):
352353
def test_deepcopy_works(self):
353354
request = Request(factory.get('/', secure=False))
354355
copy.deepcopy(request)
356+
357+
358+
class TestTyping(TestCase):
359+
@pytest.mark.skipif(
360+
sys.version_info < (3, 7),
361+
reason="subscriptable classes requires Python 3.7 or higher",
362+
)
363+
def test_request_is_subscriptable(self):
364+
assert Request is Request["foo"]

tests/test_response.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import sys
2+
3+
import pytest
14
from django.test import TestCase, override_settings
25
from django.urls import include, path, re_path
36

@@ -283,3 +286,12 @@ def test_form_has_label_and_help_text(self):
283286
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
284287
# self.assertContains(resp, 'Text comes here')
285288
# self.assertContains(resp, 'Text description.')
289+
290+
291+
class TestTyping(TestCase):
292+
@pytest.mark.skipif(
293+
sys.version_info < (3, 7),
294+
reason="subscriptable classes requires Python 3.7 or higher",
295+
)
296+
def test_response_is_subscriptable(self):
297+
assert Response is Response["foo"]

0 commit comments

Comments
 (0)