Skip to content

Commit 65e6a0d

Browse files
authored
fix: (OAuthAuthenticator) - get the access_token, refresh_token, expires_in recursively from response (#285)
1 parent ee537af commit 65e6a0d

File tree

4 files changed

+442
-133
lines changed

4 files changed

+442
-133
lines changed

Diff for: airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

+214-56
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
2626

2727

28+
class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
29+
"""
30+
Raised when the max level of recursion is reached, when trying to
31+
find-and-get the target key, during the `_make_handled_request`
32+
"""
33+
34+
2835
class AbstractOauth2Authenticator(AuthBase):
2936
"""
3037
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
@@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
5360
request.headers.update(self.get_auth_header())
5461
return request
5562

63+
@property
64+
def _is_access_token_flow(self) -> bool:
65+
return self.get_token_refresh_endpoint() is None and self.access_token is not None
66+
67+
@property
68+
def token_expiry_is_time_of_expiration(self) -> bool:
69+
"""
70+
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
71+
"""
72+
73+
return False
74+
75+
@property
76+
def token_expiry_date_format(self) -> Optional[str]:
77+
"""
78+
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
79+
"""
80+
81+
return None
82+
5683
def get_auth_header(self) -> Mapping[str, Any]:
5784
"""HTTP header to set on the requests"""
5885
token = self.access_token if self._is_access_token_flow else self.get_access_token()
5986
return {"Authorization": f"Bearer {token}"}
6087

61-
@property
62-
def _is_access_token_flow(self) -> bool:
63-
return self.get_token_refresh_endpoint() is None and self.access_token is not None
64-
6588
def get_access_token(self) -> str:
6689
"""Returns the access token"""
6790
if self.token_has_expired():
@@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
107130
headers = self.get_refresh_request_headers()
108131
return headers if headers else None
109132

133+
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
134+
"""
135+
Returns the refresh token and its expiration datetime
136+
137+
:return: a tuple of (access_token, token_lifespan)
138+
"""
139+
response_json = self._make_handled_request()
140+
self._ensure_access_token_in_response(response_json)
141+
142+
return (
143+
self._extract_access_token(response_json),
144+
self._extract_token_expiry_date(response_json),
145+
)
146+
147+
# ----------------
148+
# PRIVATE METHODS
149+
# ----------------
150+
110151
def _wrap_refresh_token_exception(
111152
self, exception: requests.exceptions.RequestException
112153
) -> bool:
154+
"""
155+
Wraps and handles exceptions that occur during the refresh token process.
156+
157+
This method checks if the provided exception is related to a refresh token error
158+
by examining the response status code and specific error content.
159+
160+
Args:
161+
exception (requests.exceptions.RequestException): The exception raised during the request.
162+
163+
Returns:
164+
bool: True if the exception is related to a refresh token error, False otherwise.
165+
"""
113166
try:
114167
if exception.response is not None:
115168
exception_content = exception.response.json()
@@ -131,30 +184,35 @@ def _wrap_refresh_token_exception(
131184
),
132185
max_time=300,
133186
)
134-
def _get_refresh_access_token_response(self) -> Any:
187+
def _make_handled_request(self) -> Any:
188+
"""
189+
Makes a handled HTTP request to refresh an OAuth token.
190+
191+
This method sends a POST request to the token refresh endpoint with the necessary
192+
headers and body to obtain a new access token. It handles various exceptions that
193+
may occur during the request and logs the response for troubleshooting purposes.
194+
195+
Returns:
196+
Mapping[str, Any]: The JSON response from the token refresh endpoint.
197+
198+
Raises:
199+
DefaultBackoffException: If the response status code is 429 (Too Many Requests)
200+
or any 5xx server error.
201+
AirbyteTracedException: If the refresh token is invalid or expired, prompting
202+
re-authentication.
203+
Exception: For any other exceptions that occur during the request.
204+
"""
135205
try:
136206
response = requests.request(
137207
method="POST",
138208
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
139209
data=self.build_refresh_request_body(),
140210
headers=self.build_refresh_request_headers(),
141211
)
142-
if response.ok:
143-
response_json = response.json()
144-
# Add the access token to the list of secrets so it is replaced before logging the response
145-
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
146-
access_key = response_json.get(self.get_access_token_name())
147-
if not access_key:
148-
raise Exception(
149-
"Token refresh API response was missing access token {self.get_access_token_name()}"
150-
)
151-
add_to_secrets(access_key)
152-
self._log_response(response)
153-
return response_json
154-
else:
155-
# log the response even if the request failed for troubleshooting purposes
156-
self._log_response(response)
157-
response.raise_for_status()
212+
# log the response even if the request failed for troubleshooting purposes
213+
self._log_response(response)
214+
response.raise_for_status()
215+
return response.json()
158216
except requests.exceptions.RequestException as e:
159217
if e.response is not None:
160218
if e.response.status_code == 429 or e.response.status_code >= 500:
@@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any:
168226
except Exception as e:
169227
raise Exception(f"Error while refreshing access token: {e}") from e
170228

171-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
229+
def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
172230
"""
173-
Returns the refresh token and its expiration datetime
231+
Ensures that the access token is present in the response data.
174232
175-
:return: a tuple of (access_token, token_lifespan)
176-
"""
177-
response_json = self._get_refresh_access_token_response()
233+
This method attempts to extract the access token from the provided response data.
234+
If the access token is not found, it raises an exception indicating that the token
235+
refresh API response was missing the access token. If the access token is found,
236+
it adds the token to the list of secrets to ensure it is replaced before logging
237+
the response.
238+
239+
Args:
240+
response_data (Mapping[str, Any]): The response data from which to extract the access token.
178241
179-
return response_json[self.get_access_token_name()], response_json[
180-
self.get_expires_in_name()
181-
]
242+
Raises:
243+
Exception: If the access token is not found in the response data.
244+
ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
245+
"""
246+
try:
247+
access_key = self._extract_access_token(response_data)
248+
if not access_key:
249+
raise Exception(
250+
"Token refresh API response was missing access token {self.get_access_token_name()}"
251+
)
252+
# Add the access token to the list of secrets so it is replaced before logging the response
253+
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
254+
add_to_secrets(access_key)
255+
except ResponseKeysMaxRecurtionReached as e:
256+
raise e
182257

183258
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
184259
"""
@@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
206281
f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
207282
)
208283

209-
@property
210-
def token_expiry_is_time_of_expiration(self) -> bool:
284+
def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
211285
"""
212-
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
286+
Extracts the access token from the given response data.
287+
288+
Args:
289+
response_data (Mapping[str, Any]): The response data from which to extract the access token.
290+
291+
Returns:
292+
str: The extracted access token.
213293
"""
294+
return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
214295

215-
return False
296+
def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
297+
"""
298+
Extracts the refresh token from the given response data.
216299
217-
@property
218-
def token_expiry_date_format(self) -> Optional[str]:
300+
Args:
301+
response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
302+
303+
Returns:
304+
str: The extracted refresh token.
219305
"""
220-
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
306+
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
307+
308+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
309+
"""
310+
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
311+
312+
Args:
313+
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
314+
315+
Returns:
316+
str: The extracted token_expiry_date.
221317
"""
318+
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
319+
320+
def _find_and_get_value_from_response(
321+
self,
322+
response_data: Mapping[str, Any],
323+
key_name: str,
324+
max_depth: int = 5,
325+
current_depth: int = 0,
326+
) -> Any:
327+
"""
328+
Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
329+
330+
Args:
331+
response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
332+
key_name (str): The key to search for in the response data.
333+
max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
334+
current_depth (int, optional): The current depth of the recursion. Defaults to 0.
335+
336+
Returns:
337+
Any: The value associated with the specified key if found, otherwise None.
338+
339+
Raises:
340+
AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
341+
"""
342+
if current_depth > max_depth:
343+
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
344+
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
345+
raise ResponseKeysMaxRecurtionReached(
346+
internal_message=message, message=message, failure_type=FailureType.config_error
347+
)
348+
349+
if isinstance(response_data, dict):
350+
# get from the root level
351+
if key_name in response_data:
352+
return response_data[key_name]
353+
354+
# get from the nested object
355+
for _, value in response_data.items():
356+
result = self._find_and_get_value_from_response(
357+
value, key_name, max_depth, current_depth + 1
358+
)
359+
if result is not None:
360+
return result
361+
362+
# get from the nested array object
363+
elif isinstance(response_data, list):
364+
for item in response_data:
365+
result = self._find_and_get_value_from_response(
366+
item, key_name, max_depth, current_depth + 1
367+
)
368+
if result is not None:
369+
return result
222370

223371
return None
224372

373+
@property
374+
def _message_repository(self) -> Optional[MessageRepository]:
375+
"""
376+
The implementation can define a message_repository if it wants debugging logs for HTTP requests
377+
"""
378+
return _NOOP_MESSAGE_REPOSITORY
379+
380+
def _log_response(self, response: requests.Response) -> None:
381+
"""
382+
Logs the HTTP response using the message repository if it is available.
383+
384+
Args:
385+
response (requests.Response): The HTTP response to log.
386+
"""
387+
if self._message_repository:
388+
self._message_repository.log_message(
389+
Level.DEBUG,
390+
lambda: format_http_message(
391+
response,
392+
"Refresh token",
393+
"Obtains access token",
394+
self._NO_STREAM_NAME,
395+
is_auxiliary=True,
396+
),
397+
)
398+
399+
# ----------------
400+
# ABSTR METHODS
401+
# ----------------
402+
225403
@abstractmethod
226404
def get_token_refresh_endpoint(self) -> Optional[str]:
227405
"""Returns the endpoint to refresh the access token"""
@@ -295,23 +473,3 @@ def access_token(self) -> str:
295473
@abstractmethod
296474
def access_token(self, value: str) -> str:
297475
"""Setter for the access token"""
298-
299-
@property
300-
def _message_repository(self) -> Optional[MessageRepository]:
301-
"""
302-
The implementation can define a message_repository if it wants debugging logs for HTTP requests
303-
"""
304-
return _NOOP_MESSAGE_REPOSITORY
305-
306-
def _log_response(self, response: requests.Response) -> None:
307-
if self._message_repository:
308-
self._message_repository.log_message(
309-
Level.DEBUG,
310-
lambda: format_http_message(
311-
response,
312-
"Refresh token",
313-
"Obtains access token",
314-
self._NO_STREAM_NAME,
315-
is_auxiliary=True,
316-
),
317-
)

0 commit comments

Comments
 (0)