25
25
_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository ()
26
26
27
27
28
+ class ResponseKeysMaxRecurtionReached (AirbyteTracedException ):
29
+ """
30
+ Raised when the max level of recursion is reached, when trying to
31
+ find-and-get the target key, during the `_make_handled_request`
32
+ """
33
+
34
+
28
35
class AbstractOauth2Authenticator (AuthBase ):
29
36
"""
30
37
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
@@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
53
60
request .headers .update (self .get_auth_header ())
54
61
return request
55
62
63
+ @property
64
+ def _is_access_token_flow (self ) -> bool :
65
+ return self .get_token_refresh_endpoint () is None and self .access_token is not None
66
+
67
+ @property
68
+ def token_expiry_is_time_of_expiration (self ) -> bool :
69
+ """
70
+ Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
71
+ """
72
+
73
+ return False
74
+
75
+ @property
76
+ def token_expiry_date_format (self ) -> Optional [str ]:
77
+ """
78
+ Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
79
+ """
80
+
81
+ return None
82
+
56
83
def get_auth_header (self ) -> Mapping [str , Any ]:
57
84
"""HTTP header to set on the requests"""
58
85
token = self .access_token if self ._is_access_token_flow else self .get_access_token ()
59
86
return {"Authorization" : f"Bearer { token } " }
60
87
61
- @property
62
- def _is_access_token_flow (self ) -> bool :
63
- return self .get_token_refresh_endpoint () is None and self .access_token is not None
64
-
65
88
def get_access_token (self ) -> str :
66
89
"""Returns the access token"""
67
90
if self .token_has_expired ():
@@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
107
130
headers = self .get_refresh_request_headers ()
108
131
return headers if headers else None
109
132
133
+ def refresh_access_token (self ) -> Tuple [str , Union [str , int ]]:
134
+ """
135
+ Returns the refresh token and its expiration datetime
136
+
137
+ :return: a tuple of (access_token, token_lifespan)
138
+ """
139
+ response_json = self ._make_handled_request ()
140
+ self ._ensure_access_token_in_response (response_json )
141
+
142
+ return (
143
+ self ._extract_access_token (response_json ),
144
+ self ._extract_token_expiry_date (response_json ),
145
+ )
146
+
147
+ # ----------------
148
+ # PRIVATE METHODS
149
+ # ----------------
150
+
110
151
def _wrap_refresh_token_exception (
111
152
self , exception : requests .exceptions .RequestException
112
153
) -> bool :
154
+ """
155
+ Wraps and handles exceptions that occur during the refresh token process.
156
+
157
+ This method checks if the provided exception is related to a refresh token error
158
+ by examining the response status code and specific error content.
159
+
160
+ Args:
161
+ exception (requests.exceptions.RequestException): The exception raised during the request.
162
+
163
+ Returns:
164
+ bool: True if the exception is related to a refresh token error, False otherwise.
165
+ """
113
166
try :
114
167
if exception .response is not None :
115
168
exception_content = exception .response .json ()
@@ -131,30 +184,35 @@ def _wrap_refresh_token_exception(
131
184
),
132
185
max_time = 300 ,
133
186
)
134
- def _get_refresh_access_token_response (self ) -> Any :
187
+ def _make_handled_request (self ) -> Any :
188
+ """
189
+ Makes a handled HTTP request to refresh an OAuth token.
190
+
191
+ This method sends a POST request to the token refresh endpoint with the necessary
192
+ headers and body to obtain a new access token. It handles various exceptions that
193
+ may occur during the request and logs the response for troubleshooting purposes.
194
+
195
+ Returns:
196
+ Mapping[str, Any]: The JSON response from the token refresh endpoint.
197
+
198
+ Raises:
199
+ DefaultBackoffException: If the response status code is 429 (Too Many Requests)
200
+ or any 5xx server error.
201
+ AirbyteTracedException: If the refresh token is invalid or expired, prompting
202
+ re-authentication.
203
+ Exception: For any other exceptions that occur during the request.
204
+ """
135
205
try :
136
206
response = requests .request (
137
207
method = "POST" ,
138
208
url = self .get_token_refresh_endpoint (), # type: ignore # returns None, if not provided, but str | bytes is expected.
139
209
data = self .build_refresh_request_body (),
140
210
headers = self .build_refresh_request_headers (),
141
211
)
142
- if response .ok :
143
- response_json = response .json ()
144
- # Add the access token to the list of secrets so it is replaced before logging the response
145
- # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
146
- access_key = response_json .get (self .get_access_token_name ())
147
- if not access_key :
148
- raise Exception (
149
- "Token refresh API response was missing access token {self.get_access_token_name()}"
150
- )
151
- add_to_secrets (access_key )
152
- self ._log_response (response )
153
- return response_json
154
- else :
155
- # log the response even if the request failed for troubleshooting purposes
156
- self ._log_response (response )
157
- response .raise_for_status ()
212
+ # log the response even if the request failed for troubleshooting purposes
213
+ self ._log_response (response )
214
+ response .raise_for_status ()
215
+ return response .json ()
158
216
except requests .exceptions .RequestException as e :
159
217
if e .response is not None :
160
218
if e .response .status_code == 429 or e .response .status_code >= 500 :
@@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any:
168
226
except Exception as e :
169
227
raise Exception (f"Error while refreshing access token: { e } " ) from e
170
228
171
- def refresh_access_token (self ) -> Tuple [str , Union [ str , int ]] :
229
+ def _ensure_access_token_in_response (self , response_data : Mapping [str , Any ]) -> None :
172
230
"""
173
- Returns the refresh token and its expiration datetime
231
+ Ensures that the access token is present in the response data.
174
232
175
- :return: a tuple of (access_token, token_lifespan)
176
- """
177
- response_json = self ._get_refresh_access_token_response ()
233
+ This method attempts to extract the access token from the provided response data.
234
+ If the access token is not found, it raises an exception indicating that the token
235
+ refresh API response was missing the access token. If the access token is found,
236
+ it adds the token to the list of secrets to ensure it is replaced before logging
237
+ the response.
238
+
239
+ Args:
240
+ response_data (Mapping[str, Any]): The response data from which to extract the access token.
178
241
179
- return response_json [self .get_access_token_name ()], response_json [
180
- self .get_expires_in_name ()
181
- ]
242
+ Raises:
243
+ Exception: If the access token is not found in the response data.
244
+ ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
245
+ """
246
+ try :
247
+ access_key = self ._extract_access_token (response_data )
248
+ if not access_key :
249
+ raise Exception (
250
+ "Token refresh API response was missing access token {self.get_access_token_name()}"
251
+ )
252
+ # Add the access token to the list of secrets so it is replaced before logging the response
253
+ # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
254
+ add_to_secrets (access_key )
255
+ except ResponseKeysMaxRecurtionReached as e :
256
+ raise e
182
257
183
258
def _parse_token_expiration_date (self , value : Union [str , int ]) -> AirbyteDateTime :
184
259
"""
@@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
206
281
f"Invalid expires_in value: { value } . Expected number of seconds when no format specified."
207
282
)
208
283
209
- @property
210
- def token_expiry_is_time_of_expiration (self ) -> bool :
284
+ def _extract_access_token (self , response_data : Mapping [str , Any ]) -> Any :
211
285
"""
212
- Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
286
+ Extracts the access token from the given response data.
287
+
288
+ Args:
289
+ response_data (Mapping[str, Any]): The response data from which to extract the access token.
290
+
291
+ Returns:
292
+ str: The extracted access token.
213
293
"""
294
+ return self ._find_and_get_value_from_response (response_data , self .get_access_token_name ())
214
295
215
- return False
296
+ def _extract_refresh_token (self , response_data : Mapping [str , Any ]) -> Any :
297
+ """
298
+ Extracts the refresh token from the given response data.
216
299
217
- @property
218
- def token_expiry_date_format (self ) -> Optional [str ]:
300
+ Args:
301
+ response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
302
+
303
+ Returns:
304
+ str: The extracted refresh token.
219
305
"""
220
- Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
306
+ return self ._find_and_get_value_from_response (response_data , self .get_refresh_token_name ())
307
+
308
+ def _extract_token_expiry_date (self , response_data : Mapping [str , Any ]) -> Any :
309
+ """
310
+ Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
311
+
312
+ Args:
313
+ response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
314
+
315
+ Returns:
316
+ str: The extracted token_expiry_date.
221
317
"""
318
+ return self ._find_and_get_value_from_response (response_data , self .get_expires_in_name ())
319
+
320
+ def _find_and_get_value_from_response (
321
+ self ,
322
+ response_data : Mapping [str , Any ],
323
+ key_name : str ,
324
+ max_depth : int = 5 ,
325
+ current_depth : int = 0 ,
326
+ ) -> Any :
327
+ """
328
+ Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
329
+
330
+ Args:
331
+ response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
332
+ key_name (str): The key to search for in the response data.
333
+ max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
334
+ current_depth (int, optional): The current depth of the recursion. Defaults to 0.
335
+
336
+ Returns:
337
+ Any: The value associated with the specified key if found, otherwise None.
338
+
339
+ Raises:
340
+ AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
341
+ """
342
+ if current_depth > max_depth :
343
+ # this is needed to avoid an inf loop, possible with a very deep nesting observed.
344
+ message = f"The maximum level of recursion is reached. Couldn't find the speficied `{ key_name } ` in the response."
345
+ raise ResponseKeysMaxRecurtionReached (
346
+ internal_message = message , message = message , failure_type = FailureType .config_error
347
+ )
348
+
349
+ if isinstance (response_data , dict ):
350
+ # get from the root level
351
+ if key_name in response_data :
352
+ return response_data [key_name ]
353
+
354
+ # get from the nested object
355
+ for _ , value in response_data .items ():
356
+ result = self ._find_and_get_value_from_response (
357
+ value , key_name , max_depth , current_depth + 1
358
+ )
359
+ if result is not None :
360
+ return result
361
+
362
+ # get from the nested array object
363
+ elif isinstance (response_data , list ):
364
+ for item in response_data :
365
+ result = self ._find_and_get_value_from_response (
366
+ item , key_name , max_depth , current_depth + 1
367
+ )
368
+ if result is not None :
369
+ return result
222
370
223
371
return None
224
372
373
+ @property
374
+ def _message_repository (self ) -> Optional [MessageRepository ]:
375
+ """
376
+ The implementation can define a message_repository if it wants debugging logs for HTTP requests
377
+ """
378
+ return _NOOP_MESSAGE_REPOSITORY
379
+
380
+ def _log_response (self , response : requests .Response ) -> None :
381
+ """
382
+ Logs the HTTP response using the message repository if it is available.
383
+
384
+ Args:
385
+ response (requests.Response): The HTTP response to log.
386
+ """
387
+ if self ._message_repository :
388
+ self ._message_repository .log_message (
389
+ Level .DEBUG ,
390
+ lambda : format_http_message (
391
+ response ,
392
+ "Refresh token" ,
393
+ "Obtains access token" ,
394
+ self ._NO_STREAM_NAME ,
395
+ is_auxiliary = True ,
396
+ ),
397
+ )
398
+
399
+ # ----------------
400
+ # ABSTR METHODS
401
+ # ----------------
402
+
225
403
@abstractmethod
226
404
def get_token_refresh_endpoint (self ) -> Optional [str ]:
227
405
"""Returns the endpoint to refresh the access token"""
@@ -295,23 +473,3 @@ def access_token(self) -> str:
295
473
@abstractmethod
296
474
def access_token (self , value : str ) -> str :
297
475
"""Setter for the access token"""
298
-
299
- @property
300
- def _message_repository (self ) -> Optional [MessageRepository ]:
301
- """
302
- The implementation can define a message_repository if it wants debugging logs for HTTP requests
303
- """
304
- return _NOOP_MESSAGE_REPOSITORY
305
-
306
- def _log_response (self , response : requests .Response ) -> None :
307
- if self ._message_repository :
308
- self ._message_repository .log_message (
309
- Level .DEBUG ,
310
- lambda : format_http_message (
311
- response ,
312
- "Refresh token" ,
313
- "Obtains access token" ,
314
- self ._NO_STREAM_NAME ,
315
- is_auxiliary = True ,
316
- ),
317
- )
0 commit comments