55import time
66from datetime import datetime , tzinfo , timedelta
77from functools import partial
8+ from urllib .parse import parse_qs
89
910import jwt
1011import responses
@@ -98,9 +99,14 @@ def build_access_token_azure_groups_in_claim_source(request):
9899 return do_build_access_token (request , issuer , groups_in_claim_names = True )
99100
100101
102+ def build_access_token_adfs_expired (request ):
103+ issuer = "http://adfs.example.com/adfs/services/trust"
104+ return do_build_access_token (request , issuer , refresh_token_expired = True )
105+
106+
101107def do_build_mfa_error (request ):
102108 response = {'error_description' : 'AADSTS50076' }
103- return 400 , [] , json .dumps (response )
109+ return 400 , {} , json .dumps (response )
104110
105111
106112def do_build_graph_response (request ):
@@ -111,7 +117,11 @@ def do_build_graph_response_no_group_perm(request):
111117 return do_build_ms_graph_groups (request , missing_group_names = True )
112118
113119
114- def do_build_access_token (request , issuer , schema = None , no_upn = False , idp = None , groups_in_claim_names = False ):
120+ def do_build_access_token (request , issuer , schema = None , no_upn = False , idp = None , groups_in_claim_names = False ,
121+ refresh_token_expired = False ):
122+ data = parse_qs (request .body )
123+ if data .get ('grant_type' ) == ['refresh_token' ] and data .get ('refresh_token' ) == ['expired_refresh_token' ]:
124+ return 401 , {}, None
115125 issued_at = int (time .time ())
116126 expires = issued_at + 3600
117127 auth_time = datetime .utcnow ()
@@ -159,16 +169,20 @@ def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None,
159169 }
160170 }
161171 token = jwt .encode (claims , signing_key_b , algorithm = "RS256" )
172+ if refresh_token_expired :
173+ refresh_token = 'expired_refresh_token'
174+ else :
175+ refresh_token = 'random_refresh_token'
162176 response = {
163177 'resource' : 'django_website.adfs.relying_party_id' ,
164178 'token_type' : 'bearer' ,
165179 'refresh_token_expires_in' : 28799 ,
166- 'refresh_token' : 'random_refresh_token' ,
180+ 'refresh_token' : refresh_token ,
167181 'expires_in' : 3600 ,
168182 'id_token' : 'not_used' ,
169183 'access_token' : token .decode () if isinstance (token , bytes ) else token # PyJWT>=2 returns a str instead of bytes
170184 }
171- return 200 , [] , json .dumps (response )
185+ return 200 , {} , json .dumps (response )
172186
173187
174188def do_build_obo_access_token (request ):
@@ -228,7 +242,7 @@ def do_build_obo_access_token(request):
228242 'refresh_token' : 'not_used' ,
229243 'access_token' : token .decode () if isinstance (token , bytes ) else token # PyJWT>=2 returns a str instead of bytes
230244 }
231- return 200 , [] , json .dumps (response )
245+ return 200 , {} , json .dumps (response )
232246
233247
234248def do_build_ms_graph_groups (request , missing_group_names = False ):
@@ -308,7 +322,7 @@ def do_build_ms_graph_groups(request, missing_group_names=False):
308322 if missing_group_names :
309323 for group in response ["value" ]:
310324 group ["displayName" ] = None
311- return 200 , [] , json .dumps (response )
325+ return 200 , {} , json .dumps (response )
312326
313327
314328def build_openid_keys (request , empty_keys = False ):
@@ -337,15 +351,15 @@ def build_openid_keys(request, empty_keys=False):
337351 },
338352 ]
339353 }
340- return 200 , [] , json .dumps (keys )
354+ return 200 , {} , json .dumps (keys )
341355
342356
343357def build_adfs_meta (request ):
344358 with open (os .path .join (os .path .dirname (__file__ ), "mock_files/FederationMetadata.xml" ), mode = "r" ) as f :
345359 data = "" .join (f .readlines ())
346360 data = data .replace ("REPLACE_WITH_CERT_A" , base64 .b64encode (signing_cert_a ).decode ())
347361 data = data .replace ("REPLACE_WITH_CERT_B" , base64 .b64encode (signing_cert_b ).decode ())
348- return 200 , [] , data
362+ return 200 , {} , data
349363
350364
351365def mock_adfs (
@@ -356,6 +370,7 @@ def mock_adfs(
356370 version = None ,
357371 requires_obo = False ,
358372 missing_graph_group_perm = False ,
373+ refresh_token_expired = False ,
359374):
360375 if adfs_version not in ["2012" , "2016" , "azure" ]:
361376 raise NotImplementedError ("This version of ADFS is not implemented" )
@@ -465,6 +480,12 @@ def wrapper(*original_args, **original_kwargs):
465480 callback = do_build_mfa_error ,
466481 content_type = 'application/json' ,
467482 )
483+ elif refresh_token_expired :
484+ rsps .add_callback (
485+ rsps .POST , token_endpoint ,
486+ callback = build_access_token_adfs_expired ,
487+ content_type = 'application/json' ,
488+ )
468489 else :
469490 rsps .add_callback (
470491 rsps .POST , token_endpoint ,
0 commit comments