Skip to content

fix(auth): correct OAuth2 token URL access in refresh flow #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Credential fetcher for OpenID Connect."""

from typing import Optional
from typing import Optional, Dict
from authlib.integrations.requests_client import OAuth2Session
import logging

from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
Expand All @@ -24,6 +26,7 @@
from .....auth.auth_schemes import AuthSchemeType
from .base_credential_exchanger import BaseAuthCredentialExchanger

logger = logging.getLogger(__name__)

class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
"""Fetches credentials for OAuth2 and OpenID Connect."""
Expand Down Expand Up @@ -84,6 +87,51 @@ def generate_auth_token(
)
return updated_credential

def _refresh_token(
self,
auth_credential: AuthCredential,
auth_scheme: AuthScheme,
) -> Dict[str, str]:
"""Refreshes the OAuth2 access token using the refresh token.

Args:
auth_credential: The auth credential containing the refresh token.
auth_scheme: The auth scheme containing OAuth2 configuration.

Returns:
A dictionary containing the new access token and related information.

Raises:
ValueError: If refresh token is missing or refresh fails.
"""
if not auth_credential.oauth2.token or "refresh_token" not in auth_credential.oauth2.token:
raise ValueError("No refresh token available for token refresh")

# Get token URL from either OpenID Connect or OAuth2 configuration
token_url = None
if auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_scheme.openIdConnect:
token_url = auth_scheme.openIdConnect.get("tokenUrl")
elif auth_scheme.type_ == AuthSchemeType.oauth2 and auth_scheme.oauth2:
token_url = auth_scheme.oauth2.tokenUrl

if not token_url:
raise ValueError("No token URL available for token refresh")

try:
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
token=auth_credential.oauth2.token,
)
new_token = client.refresh_token(
token_url,
refresh_token=auth_credential.oauth2.token["refresh_token"],
)
return new_token
except Exception as e:
logger.error("Failed to refresh token: %s", str(e))
raise ValueError(f"Token refresh failed: {str(e)}")

def exchange_credential(
self,
auth_scheme: AuthScheme,
Expand All @@ -101,17 +149,29 @@ def exchange_credential(
Raises:
ValueError: If the auth scheme or auth credential is invalid.
"""
# TODO(cheliu): Implement token refresh flow

self._check_scheme_credential_type(auth_scheme, auth_credential)

# If token is already HTTPBearer token, do nothing assuming that this token
# is valid.
# If token is already HTTPBearer token, try to refresh if needed
if auth_credential.http:
return auth_credential

# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.access_token:
try:
# Attempt to use the current token
return auth_credential
except Exception as e:
logger.info("Token may be expired, attempting refresh: %s", str(e))
# Continue to refresh flow

# Try to refresh the token if we have a refresh token
if auth_credential.oauth2.token and "refresh_token" in auth_credential.oauth2.token:
try:
new_token = self._refresh_token(auth_credential, auth_scheme)
auth_credential.oauth2.token = new_token
return self.generate_auth_token(auth_credential)
except ValueError as e:
logger.error("Token refresh failed: %s", str(e))
# Fall through to try other methods

# If access token is available, exchange for HTTPBearer token
if auth_credential.oauth2.token:
return self.generate_auth_token(auth_credential)

return None
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Tests for OAuth2CredentialExchanger."""

import copy
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
Expand All @@ -34,16 +34,38 @@ def oauth2_exchanger():

@pytest.fixture
def auth_scheme():
"""Create an OpenID Connect auth scheme for testing."""
openid_config = OpenIdConnectWithConfig(
type_=AuthSchemeType.openIdConnect,
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
openIdConnect={
"tokenUrl": "https://example.com/token",
"authorizationUrl": "https://example.com/auth",
},
)
return openid_config

@pytest.fixture
def auth_credential_with_refresh():
"""Create an auth credential with refresh token for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client",
client_secret="test_secret",
redirect_uri="http://localhost:8080",
token={
"access_token": "old_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600
},
),
)

def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
"""Test successful scheme and credential type check."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
Expand All @@ -59,17 +81,15 @@ def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
def test_check_scheme_credential_type_missing_credential(
oauth2_exchanger, auth_scheme
):
# Test case: auth_credential is None
"""Test case: auth_credential is None."""
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
assert "auth_credential is empty" in str(exc_info.value)


def test_check_scheme_credential_type_invalid_scheme_type(
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
):
"""Test case: Invalid AuthSchemeType."""
# Test case: Invalid AuthSchemeType
invalid_scheme = copy.deepcopy(auth_scheme)
invalid_scheme.type_ = AuthSchemeType.apiKey
auth_credential = AuthCredential(
Expand All @@ -86,23 +106,21 @@ def test_check_scheme_credential_type_invalid_scheme_type(
)
assert "Invalid security scheme" in str(exc_info.value)


def test_check_scheme_credential_type_missing_openid_connect(
oauth2_exchanger, auth_scheme
):
"""Test case: Missing OpenID Connect configuration."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
)
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
assert "auth_credential is not configured with oauth2" in str(exc_info.value)


def test_generate_auth_token_success(
oauth2_exchanger, auth_scheme, monkeypatch
):
"""Test case: Successful generation of access token."""
# Test case: Successful generation of access token
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
Expand All @@ -119,7 +137,6 @@ def test_generate_auth_token_success(
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"


def test_exchange_credential_generate_auth_token(
oauth2_exchanger, auth_scheme, monkeypatch
):
Expand All @@ -143,11 +160,83 @@ def test_exchange_credential_generate_auth_token(
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "test_access_token"


def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
"""Test exchange_credential when auth_credential is missing."""
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger.exchange_credential(auth_scheme, None)
assert "auth_credential is empty. Please create AuthCredential using" in str(
exc_info.value
)

def test_refresh_token_success(oauth2_exchanger, auth_scheme, auth_credential_with_refresh):
"""Test successful token refresh."""
mock_response = {
"access_token": "new_access_token",
"expires_in": 3600,
"token_type": "Bearer"
}

with patch('authlib.oauth2.client.OAuth2Client.refresh_token', return_value=mock_response):
new_token = oauth2_exchanger._refresh_token(
auth_credential_with_refresh,
auth_scheme
)

# Verify the response
assert new_token == mock_response

def test_refresh_token_missing_token_url(oauth2_exchanger, auth_credential_with_refresh):
"""Test token refresh with missing token URL."""
invalid_scheme = OpenIdConnectWithConfig(
type_=AuthSchemeType.openIdConnect,
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token", # Required field
scopes=["openid", "profile"],
openIdConnect={
"authorizationUrl": "https://example.com/auth",
# Intentionally omit tokenUrl to test missing URL
},
)

with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._refresh_token(auth_credential_with_refresh, invalid_scheme)
assert "No token URL available" in str(exc_info.value)

def test_refresh_token_request_failure(oauth2_exchanger, auth_scheme, auth_credential_with_refresh):
"""Test token refresh with request failure."""
with patch('authlib.oauth2.client.OAuth2Client.refresh_token', side_effect=Exception("Network error")):
with pytest.raises(ValueError) as exc_info:
oauth2_exchanger._refresh_token(auth_credential_with_refresh, auth_scheme)
assert "Token refresh failed" in str(exc_info.value)

def test_exchange_credential_with_refresh(oauth2_exchanger, auth_scheme, auth_credential_with_refresh):
"""Test exchange_credential with token refresh."""
mock_response = {
"access_token": "new_access_token",
"expires_in": 3600,
"token_type": "Bearer"
}

with patch('authlib.oauth2.client.OAuth2Client.refresh_token', return_value=mock_response):
updated_credential = oauth2_exchanger.exchange_credential(
auth_scheme,
auth_credential_with_refresh
)

# Verify the updated credential
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "new_access_token"

def test_exchange_credential_refresh_failure_fallback(oauth2_exchanger, auth_scheme, auth_credential_with_refresh):
"""Test exchange_credential fallback when refresh fails."""
with patch('authlib.oauth2.client.OAuth2Client.refresh_token', side_effect=Exception("Network error")):
# Should fall back to using existing token
updated_credential = oauth2_exchanger.exchange_credential(
auth_scheme,
auth_credential_with_refresh
)

assert updated_credential.auth_type == AuthCredentialTypes.HTTP
assert updated_credential.http.scheme == "bearer"
assert updated_credential.http.credentials.token == "old_access_token"