Skip to content

Commit 4c51b60

Browse files
committed
Allow generic requets, responses, fields, views
Allow Request, Response, Field, and GenericAPIView to be subscriptable. This allows the classes to be made generic for type checking. This is especially useful since monkey patching DRF can be problematic as seen in this [issue][1]. [1]: typeddjango/djangorestframework-stubs#299
1 parent 48a21aa commit 4c51b60

File tree

8 files changed

+73
-0
lines changed

8 files changed

+73
-0
lines changed

rest_framework/fields.py

+4
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ def __init__(self, *, read_only=False, write_only=False,
356356
messages.update(error_messages or {})
357357
self.error_messages = messages
358358

359+
# Allow generic typing checking for fields.
360+
def __class_getitem__(cls, *args, **kwargs):
361+
return cls
362+
359363
def bind(self, field_name, parent):
360364
"""
361365
Initializes the field name and parent for the field instance.

rest_framework/generics.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class GenericAPIView(views.APIView):
4545
# The style to use for queryset pagination.
4646
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
4747

48+
# Allow generic typing checking for generic views.
49+
def __class_getitem__(cls, *args, **kwargs):
50+
return cls
51+
4852
def get_queryset(self):
4953
"""
5054
Get the list of items for this view.

rest_framework/request.py

+4
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def __repr__(self):
186186
self.method,
187187
self.get_full_path())
188188

189+
# Allow generic typing checking for requests.
190+
def __class_getitem__(cls, *args, **kwargs):
191+
return cls
192+
189193
def _default_negotiator(self):
190194
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
191195

rest_framework/response.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def __init__(self, data=None, status=None,
4646
for name, value in headers.items():
4747
self[name] = value
4848

49+
# Allow generic typing checking for responses.
50+
def __class_getitem__(cls, *args, **kwargs):
51+
return cls
52+
4953
@property
5054
def rendered_content(self):
5155
renderer = getattr(self, 'accepted_renderer', None)

tests/test_fields.py

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import os
44
import re
5+
import sys
56
import uuid
67
from decimal import ROUND_DOWN, ROUND_UP, Decimal
78

@@ -624,6 +625,15 @@ def test_parent_binding(self):
624625
assert field.root is parent
625626

626627

628+
class TestTyping(TestCase):
629+
@pytest.mark.skipif(
630+
sys.version_info < (3, 7),
631+
reason="subscriptable classes requires Python 3.7 or higher",
632+
)
633+
def test_field_is_subscriptable(self):
634+
assert serializers.Field is serializers.Field["foo"]
635+
636+
627637
# Tests for field input and output values.
628638
# ----------------------------------------
629639

tests/test_generics.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
import pytest
24
from django.db import models
35
from django.http import Http404
@@ -698,3 +700,26 @@ def list(self, request):
698700
serializer = response.serializer
699701

700702
assert serializer.context is context
703+
704+
705+
class TestTyping(TestCase):
706+
@pytest.mark.skipif(
707+
sys.version_info < (3, 7),
708+
reason="subscriptable classes requires Python 3.7 or higher",
709+
)
710+
def test_genericview_is_subscriptable(self):
711+
assert generics.GenericAPIView is generics.GenericAPIView["foo"]
712+
713+
@pytest.mark.skipif(
714+
sys.version_info < (3, 7),
715+
reason="subscriptable classes requires Python 3.7 or higher",
716+
)
717+
def test_listview_is_subscriptable(self):
718+
assert generics.ListAPIView is generics.ListAPIView["foo"]
719+
720+
@pytest.mark.skipif(
721+
sys.version_info < (3, 7),
722+
reason="subscriptable classes requires Python 3.7 or higher",
723+
)
724+
def test_instanceview_is_subscriptable(self):
725+
assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"]

tests/test_request.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import copy
55
import os.path
6+
import sys
67
import tempfile
78

89
import pytest
@@ -352,3 +353,12 @@ class TestDeepcopy(TestCase):
352353
def test_deepcopy_works(self):
353354
request = Request(factory.get('/', secure=False))
354355
copy.deepcopy(request)
356+
357+
358+
class TestTyping(TestCase):
359+
@pytest.mark.skipif(
360+
sys.version_info < (3, 7),
361+
reason="subscriptable classes requires Python 3.7 or higher",
362+
)
363+
def test_request_is_subscriptable(self):
364+
assert Request is Request["foo"]

tests/test_response.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import sys
2+
3+
import pytest
14
from django.test import TestCase, override_settings
25
from django.urls import include, path, re_path
36

@@ -283,3 +286,12 @@ def test_form_has_label_and_help_text(self):
283286
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
284287
# self.assertContains(resp, 'Text comes here')
285288
# self.assertContains(resp, 'Text description.')
289+
290+
291+
class TestTyping(TestCase):
292+
@pytest.mark.skipif(
293+
sys.version_info < (3, 7),
294+
reason="subscriptable classes requires Python 3.7 or higher",
295+
)
296+
def test_response_is_subscriptable(self):
297+
assert Response is Response["foo"]

0 commit comments

Comments
 (0)