Skip to content

Commit 412bec1

Browse files
committed
Added middleware to refresh access tokens
1 parent 3de576c commit 412bec1

File tree

7 files changed

+85
-14
lines changed

7 files changed

+85
-14
lines changed

django_auth_adfs/backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def authenticate(self, request=None, authorization_code=None, **kwargs):
396396
logger.debug("Authentication backend was called but no authorization code was received")
397397
return
398398

399+
# If there's no request object, we pass control to the next authentication backend
400+
if request is None:
401+
logger.debug("Authentication backend was called without request")
402+
return
403+
399404
# If loaded data is too old, reload it again
400405
provider_config.load_config()
401406

django_auth_adfs/middleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ def middleware(request):
6363
backend = auth.load_backend(backend_str)
6464
if isinstance(backend, AdfsAuthCodeBackend):
6565
backend.process_request(request)
66-
return get_response()
66+
return get_response(request)
6767
return middleware

tests/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
'django.contrib.messages.middleware.MessageMiddleware',
3636
'django.middleware.clickjacking.XFrameOptionsMiddleware',
3737

38+
'django_auth_adfs.middleware.adfs_refresh_middleware',
3839
'django_auth_adfs.middleware.LoginRequiredMiddleware',
3940
)
4041

tests/test_authentication.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import base64
22

3+
from datetime import datetime, timedelta
4+
35
from django.urls import reverse
46

57
from django_auth_adfs.exceptions import MFARequired
@@ -12,9 +14,9 @@
1214
from copy import deepcopy
1315

1416
from django.contrib.auth.models import Group, User
15-
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
17+
from django.core.exceptions import ObjectDoesNotExist
1618
from django.db.models.signals import post_save
17-
from django.test import RequestFactory, TestCase
19+
from django.test import TestCase
1820
from mock import Mock, patch
1921

2022
from django_auth_adfs import signals
@@ -29,13 +31,12 @@ def setUp(self):
2931
Group.objects.create(name='group1')
3032
Group.objects.create(name='group2')
3133
Group.objects.create(name='group3')
32-
self.request = RequestFactory().get('/oauth2/callback')
3334
self.signal_handler = Mock()
3435
signals.post_authenticate.connect(self.signal_handler)
3536

3637
@mock_adfs("2012")
3738
def test_post_authenticate_signal_send(self):
38-
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
39+
self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"})
3940
self.assertEqual(self.signal_handler.call_count, 1)
4041

4142
@mock_adfs("2012")
@@ -495,3 +496,34 @@ def test_nonexisting_user(self):
495496
patch("django_auth_adfs.backend.settings", Settings()):
496497
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
497498
self.assertEqual(response.status_code, 401)
499+
500+
@mock_adfs("2016")
501+
def test_access_token_unexpired(self):
502+
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
503+
self.assertFalse(response.wsgi_request.user.is_anonymous)
504+
response = self.client.get(reverse('test'))
505+
self.assertEqual(response.status_code, 200)
506+
507+
@mock_adfs("2016")
508+
def test_access_token_expired(self):
509+
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
510+
self.assertFalse(response.wsgi_request.user.is_anonymous)
511+
fromisoformat = datetime.fromisoformat
512+
with patch('django_auth_adfs.backend.datetime') as dt:
513+
dt.fromisoformat = fromisoformat
514+
dt.now.return_value = datetime.now() + timedelta(hours=1)
515+
response = self.client.get(reverse('test'))
516+
self.assertEqual(response.status_code, 200)
517+
518+
@mock_adfs("2016", refresh_token_expired=True)
519+
def test_refresh_token_expired(self):
520+
response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"})
521+
self.assertFalse(response.wsgi_request.user.is_anonymous)
522+
fromisoformat = datetime.fromisoformat
523+
with patch('django_auth_adfs.backend.datetime') as dt:
524+
dt.fromisoformat = fromisoformat
525+
dt.now.return_value = datetime.now() + timedelta(hours=1)
526+
response = self.client.get(reverse('test'))
527+
self.assertEqual(response.status_code, 302)
528+
self.assertEqual(response['Location'], f"{reverse('django_auth_adfs:login')}?next=/")
529+
self.assertTrue(response.wsgi_request.user.is_anonymous)

tests/urls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from django.urls import include, re_path
1+
from django.urls import include, re_path, path
2+
3+
from tests.views import TestView
24

35
urlpatterns = [
6+
path('', TestView.as_view(), name='test'),
47
re_path(r'^oauth2/', include('django_auth_adfs.urls')),
58
re_path(r'^oauth2/', include('django_auth_adfs.drf_urls')),
69
]

tests/utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
from datetime import datetime, tzinfo, timedelta
77
from functools import partial
8+
from urllib.parse import parse_qs
89

910
import jwt
1011
import responses
@@ -98,9 +99,14 @@ def build_access_token_azure_groups_in_claim_source(request):
9899
return do_build_access_token(request, issuer, groups_in_claim_names=True)
99100

100101

102+
def build_access_token_adfs_expired(request):
103+
issuer = "http://adfs.example.com/adfs/services/trust"
104+
return do_build_access_token(request, issuer, refresh_token_expired=True)
105+
106+
101107
def do_build_mfa_error(request):
102108
response = {'error_description': 'AADSTS50076'}
103-
return 400, [], json.dumps(response)
109+
return 400, {}, json.dumps(response)
104110

105111

106112
def do_build_graph_response(request):
@@ -111,7 +117,11 @@ def do_build_graph_response_no_group_perm(request):
111117
return do_build_ms_graph_groups(request, missing_group_names=True)
112118

113119

114-
def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False):
120+
def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False,
121+
refresh_token_expired=False):
122+
data = parse_qs(request.body)
123+
if data.get('grant_type') == ['refresh_token'] and data.get('refresh_token') == ['expired_refresh_token']:
124+
return 401, {}, None
115125
issued_at = int(time.time())
116126
expires = issued_at + 3600
117127
auth_time = datetime.utcnow()
@@ -159,16 +169,20 @@ def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None,
159169
}
160170
}
161171
token = jwt.encode(claims, signing_key_b, algorithm="RS256")
172+
if refresh_token_expired:
173+
refresh_token = 'expired_refresh_token'
174+
else:
175+
refresh_token = 'random_refresh_token'
162176
response = {
163177
'resource': 'django_website.adfs.relying_party_id',
164178
'token_type': 'bearer',
165179
'refresh_token_expires_in': 28799,
166-
'refresh_token': 'random_refresh_token',
180+
'refresh_token': refresh_token,
167181
'expires_in': 3600,
168182
'id_token': 'not_used',
169183
'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes
170184
}
171-
return 200, [], json.dumps(response)
185+
return 200, {}, json.dumps(response)
172186

173187

174188
def do_build_obo_access_token(request):
@@ -228,7 +242,7 @@ def do_build_obo_access_token(request):
228242
'refresh_token': 'not_used',
229243
'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes
230244
}
231-
return 200, [], json.dumps(response)
245+
return 200, {}, json.dumps(response)
232246

233247

234248
def do_build_ms_graph_groups(request, missing_group_names=False):
@@ -308,7 +322,7 @@ def do_build_ms_graph_groups(request, missing_group_names=False):
308322
if missing_group_names:
309323
for group in response["value"]:
310324
group["displayName"] = None
311-
return 200, [], json.dumps(response)
325+
return 200, {}, json.dumps(response)
312326

313327

314328
def build_openid_keys(request, empty_keys=False):
@@ -337,15 +351,15 @@ def build_openid_keys(request, empty_keys=False):
337351
},
338352
]
339353
}
340-
return 200, [], json.dumps(keys)
354+
return 200, {}, json.dumps(keys)
341355

342356

343357
def build_adfs_meta(request):
344358
with open(os.path.join(os.path.dirname(__file__), "mock_files/FederationMetadata.xml"), mode="r") as f:
345359
data = "".join(f.readlines())
346360
data = data.replace("REPLACE_WITH_CERT_A", base64.b64encode(signing_cert_a).decode())
347361
data = data.replace("REPLACE_WITH_CERT_B", base64.b64encode(signing_cert_b).decode())
348-
return 200, [], data
362+
return 200, {}, data
349363

350364

351365
def mock_adfs(
@@ -356,6 +370,7 @@ def mock_adfs(
356370
version=None,
357371
requires_obo=False,
358372
missing_graph_group_perm=False,
373+
refresh_token_expired=False,
359374
):
360375
if adfs_version not in ["2012", "2016", "azure"]:
361376
raise NotImplementedError("This version of ADFS is not implemented")
@@ -465,6 +480,12 @@ def wrapper(*original_args, **original_kwargs):
465480
callback=do_build_mfa_error,
466481
content_type='application/json',
467482
)
483+
elif refresh_token_expired:
484+
rsps.add_callback(
485+
rsps.POST, token_endpoint,
486+
callback=build_access_token_adfs_expired,
487+
content_type='application/json',
488+
)
468489
else:
469490
rsps.add_callback(
470491
rsps.POST, token_endpoint,

tests/views.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,11 @@
1+
from django.http import HttpResponse
2+
from django.views import View
3+
4+
15
def test_failed_response(request, error_message, status):
26
pass
7+
8+
9+
class TestView(View):
10+
def get(self, request):
11+
return HttpResponse('okay')

0 commit comments

Comments
 (0)