|
7 | 7 | from django.contrib.auth.models import User
|
8 | 8 | from django.core.cache import cache
|
9 | 9 | from django.core.exceptions import ImproperlyConfigured
|
| 10 | +from django.http import HttpRequest |
10 | 11 | from django.test import TestCase
|
11 | 12 |
|
| 13 | +from rest_framework.request import Request |
12 | 14 | from rest_framework.response import Response
|
13 | 15 | from rest_framework.settings import api_settings
|
14 |
| -from rest_framework.test import APIRequestFactory |
| 16 | +from rest_framework.test import APIRequestFactory, force_authenticate |
15 | 17 | from rest_framework.throttling import (
|
16 |
| - BaseThrottle, ScopedRateThrottle, SimpleRateThrottle, UserRateThrottle |
| 18 | + AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle, |
| 19 | + UserRateThrottle |
17 | 20 | )
|
18 | 21 | from rest_framework.views import APIView
|
19 | 22 |
|
@@ -414,3 +417,36 @@ def test_wait_returns_none_if_there_are_no_available_requests(self):
|
414 | 417 | throttle.now = throttle.timer()
|
415 | 418 | throttle.history = [throttle.timer() for _ in range(3)]
|
416 | 419 | assert throttle.wait() is None
|
| 420 | + |
| 421 | + |
| 422 | +class AnonRateThrottleTests(TestCase): |
| 423 | + |
| 424 | + def setUp(self): |
| 425 | + self.throttle = AnonRateThrottle() |
| 426 | + |
| 427 | + def test_authenticated_user_not_affected(self): |
| 428 | + request = Request(HttpRequest()) |
| 429 | + user = User.objects.create(username='test') |
| 430 | + force_authenticate(request, user) |
| 431 | + request.user = user |
| 432 | + assert self.throttle.get_cache_key(request, view={}) is None |
| 433 | + |
| 434 | + def test_get_cache_key_returns_correct_value(self): |
| 435 | + request = Request(HttpRequest()) |
| 436 | + cache_key = self.throttle.get_cache_key(request, view={}) |
| 437 | + assert cache_key == 'throttle_anon_None' |
| 438 | + |
| 439 | + |
| 440 | +class UserRateThrottleTests(TestCase): |
| 441 | + |
| 442 | + def setUp(self): |
| 443 | + self.throttle = UserRateThrottle() |
| 444 | + |
| 445 | + def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self): |
| 446 | + request = Request(HttpRequest()) |
| 447 | + user = User.objects.create(username='test') |
| 448 | + force_authenticate(request, user) |
| 449 | + request.user = user |
| 450 | + |
| 451 | + cache_key = self.throttle.get_cache_key(request, view={}) |
| 452 | + assert cache_key == 'throttle_user_%s' % user.pk |
0 commit comments