diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py index dafa4c29a..ee204d1bf 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py @@ -14,7 +14,9 @@ """Credential fetcher for OpenID Connect.""" -from typing import Optional +from typing import Optional, Dict +from authlib.integrations.requests_client import OAuth2Session +import logging from .....auth.auth_credential import AuthCredential from .....auth.auth_credential import AuthCredentialTypes @@ -24,6 +26,7 @@ from .....auth.auth_schemes import AuthSchemeType from .base_credential_exchanger import BaseAuthCredentialExchanger +logger = logging.getLogger(__name__) class OAuth2CredentialExchanger(BaseAuthCredentialExchanger): """Fetches credentials for OAuth2 and OpenID Connect.""" @@ -84,6 +87,51 @@ def generate_auth_token( ) return updated_credential + def _refresh_token( + self, + auth_credential: AuthCredential, + auth_scheme: AuthScheme, + ) -> Dict[str, str]: + """Refreshes the OAuth2 access token using the refresh token. + + Args: + auth_credential: The auth credential containing the refresh token. + auth_scheme: The auth scheme containing OAuth2 configuration. + + Returns: + A dictionary containing the new access token and related information. + + Raises: + ValueError: If refresh token is missing or refresh fails. + """ + if not auth_credential.oauth2.token or "refresh_token" not in auth_credential.oauth2.token: + raise ValueError("No refresh token available for token refresh") + + # Get token URL from either OpenID Connect or OAuth2 configuration + token_url = None + if auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_scheme.openIdConnect: + token_url = auth_scheme.openIdConnect.get("tokenUrl") + elif auth_scheme.type_ == AuthSchemeType.oauth2 and auth_scheme.oauth2: + token_url = auth_scheme.oauth2.tokenUrl + + if not token_url: + raise ValueError("No token URL available for token refresh") + + try: + client = OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + token=auth_credential.oauth2.token, + ) + new_token = client.refresh_token( + token_url, + refresh_token=auth_credential.oauth2.token["refresh_token"], + ) + return new_token + except Exception as e: + logger.error("Failed to refresh token: %s", str(e)) + raise ValueError(f"Token refresh failed: {str(e)}") + def exchange_credential( self, auth_scheme: AuthScheme, @@ -101,17 +149,29 @@ def exchange_credential( Raises: ValueError: If the auth scheme or auth credential is invalid. """ - # TODO(cheliu): Implement token refresh flow - self._check_scheme_credential_type(auth_scheme, auth_credential) - # If token is already HTTPBearer token, do nothing assuming that this token - # is valid. + # If token is already HTTPBearer token, try to refresh if needed if auth_credential.http: - return auth_credential - - # If access token is exchanged, exchange a HTTPBearer token. - if auth_credential.oauth2.access_token: + try: + # Attempt to use the current token + return auth_credential + except Exception as e: + logger.info("Token may be expired, attempting refresh: %s", str(e)) + # Continue to refresh flow + + # Try to refresh the token if we have a refresh token + if auth_credential.oauth2.token and "refresh_token" in auth_credential.oauth2.token: + try: + new_token = self._refresh_token(auth_credential, auth_scheme) + auth_credential.oauth2.token = new_token + return self.generate_auth_token(auth_credential) + except ValueError as e: + logger.error("Token refresh failed: %s", str(e)) + # Fall through to try other methods + + # If access token is available, exchange for HTTPBearer token + if auth_credential.oauth2.token: return self.generate_auth_token(auth_credential) return None diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py index 5b59fae3b..85280b0e6 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py @@ -15,7 +15,7 @@ """Tests for OAuth2CredentialExchanger.""" import copy -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes @@ -34,16 +34,38 @@ def oauth2_exchanger(): @pytest.fixture def auth_scheme(): + """Create an OpenID Connect auth scheme for testing.""" openid_config = OpenIdConnectWithConfig( type_=AuthSchemeType.openIdConnect, authorization_endpoint="https://example.com/auth", token_endpoint="https://example.com/token", scopes=["openid", "profile"], + openIdConnect={ + "tokenUrl": "https://example.com/token", + "authorizationUrl": "https://example.com/auth", + }, ) return openid_config +@pytest.fixture +def auth_credential_with_refresh(): + """Create an auth credential with refresh token for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client", + client_secret="test_secret", + redirect_uri="http://localhost:8080", + token={ + "access_token": "old_access_token", + "refresh_token": "test_refresh_token", + "expires_in": 3600 + }, + ), + ) def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme): + """Test successful scheme and credential type check.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, oauth2=OAuth2Auth( @@ -59,17 +81,15 @@ def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme): def test_check_scheme_credential_type_missing_credential( oauth2_exchanger, auth_scheme ): - # Test case: auth_credential is None + """Test case: auth_credential is None.""" with pytest.raises(ValueError) as exc_info: oauth2_exchanger._check_scheme_credential_type(auth_scheme, None) assert "auth_credential is empty" in str(exc_info.value) - def test_check_scheme_credential_type_invalid_scheme_type( oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig ): """Test case: Invalid AuthSchemeType.""" - # Test case: Invalid AuthSchemeType invalid_scheme = copy.deepcopy(auth_scheme) invalid_scheme.type_ = AuthSchemeType.apiKey auth_credential = AuthCredential( @@ -86,10 +106,10 @@ def test_check_scheme_credential_type_invalid_scheme_type( ) assert "Invalid security scheme" in str(exc_info.value) - def test_check_scheme_credential_type_missing_openid_connect( oauth2_exchanger, auth_scheme ): + """Test case: Missing OpenID Connect configuration.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, ) @@ -97,12 +117,10 @@ def test_check_scheme_credential_type_missing_openid_connect( oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential) assert "auth_credential is not configured with oauth2" in str(exc_info.value) - def test_generate_auth_token_success( oauth2_exchanger, auth_scheme, monkeypatch ): """Test case: Successful generation of access token.""" - # Test case: Successful generation of access token auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, oauth2=OAuth2Auth( @@ -119,7 +137,6 @@ def test_generate_auth_token_success( assert updated_credential.http.scheme == "bearer" assert updated_credential.http.credentials.token == "test_access_token" - def test_exchange_credential_generate_auth_token( oauth2_exchanger, auth_scheme, monkeypatch ): @@ -143,7 +160,6 @@ def test_exchange_credential_generate_auth_token( assert updated_credential.http.scheme == "bearer" assert updated_credential.http.credentials.token == "test_access_token" - def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme): """Test exchange_credential when auth_credential is missing.""" with pytest.raises(ValueError) as exc_info: @@ -151,3 +167,76 @@ def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme): assert "auth_credential is empty. Please create AuthCredential using" in str( exc_info.value ) + +def test_refresh_token_success(oauth2_exchanger, auth_scheme, auth_credential_with_refresh): + """Test successful token refresh.""" + mock_response = { + "access_token": "new_access_token", + "expires_in": 3600, + "token_type": "Bearer" + } + + with patch('authlib.oauth2.client.OAuth2Client.refresh_token', return_value=mock_response): + new_token = oauth2_exchanger._refresh_token( + auth_credential_with_refresh, + auth_scheme + ) + + # Verify the response + assert new_token == mock_response + +def test_refresh_token_missing_token_url(oauth2_exchanger, auth_credential_with_refresh): + """Test token refresh with missing token URL.""" + invalid_scheme = OpenIdConnectWithConfig( + type_=AuthSchemeType.openIdConnect, + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", # Required field + scopes=["openid", "profile"], + openIdConnect={ + "authorizationUrl": "https://example.com/auth", + # Intentionally omit tokenUrl to test missing URL + }, + ) + + with pytest.raises(ValueError) as exc_info: + oauth2_exchanger._refresh_token(auth_credential_with_refresh, invalid_scheme) + assert "No token URL available" in str(exc_info.value) + +def test_refresh_token_request_failure(oauth2_exchanger, auth_scheme, auth_credential_with_refresh): + """Test token refresh with request failure.""" + with patch('authlib.oauth2.client.OAuth2Client.refresh_token', side_effect=Exception("Network error")): + with pytest.raises(ValueError) as exc_info: + oauth2_exchanger._refresh_token(auth_credential_with_refresh, auth_scheme) + assert "Token refresh failed" in str(exc_info.value) + +def test_exchange_credential_with_refresh(oauth2_exchanger, auth_scheme, auth_credential_with_refresh): + """Test exchange_credential with token refresh.""" + mock_response = { + "access_token": "new_access_token", + "expires_in": 3600, + "token_type": "Bearer" + } + + with patch('authlib.oauth2.client.OAuth2Client.refresh_token', return_value=mock_response): + updated_credential = oauth2_exchanger.exchange_credential( + auth_scheme, + auth_credential_with_refresh + ) + + # Verify the updated credential + assert updated_credential.auth_type == AuthCredentialTypes.HTTP + assert updated_credential.http.scheme == "bearer" + assert updated_credential.http.credentials.token == "new_access_token" + +def test_exchange_credential_refresh_failure_fallback(oauth2_exchanger, auth_scheme, auth_credential_with_refresh): + """Test exchange_credential fallback when refresh fails.""" + with patch('authlib.oauth2.client.OAuth2Client.refresh_token', side_effect=Exception("Network error")): + # Should fall back to using existing token + updated_credential = oauth2_exchanger.exchange_credential( + auth_scheme, + auth_credential_with_refresh + ) + + assert updated_credential.auth_type == AuthCredentialTypes.HTTP + assert updated_credential.http.scheme == "bearer" + assert updated_credential.http.credentials.token == "old_access_token"