Skip to content

Commit d4ba521

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Add diversion logic based on the auth provider resource name
PiperOrigin-RevId: 931487487
1 parent dd97e76 commit d4ba521

4 files changed

Lines changed: 162 additions & 26 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credentials Provider using the Agent Identity service."""
16+
17+
from __future__ import annotations
18+
19+
from google.adk.agents.callback_context import CallbackContext
20+
from google.adk.auth.auth_credential import AuthCredential
21+
22+
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
23+
24+
25+
class _AgentIdentityCredentialsProvider:
26+
"""Auth provider implementation using Agent Identity credentials service."""
27+
28+
async def get_auth_credential(
29+
self,
30+
auth_scheme: GcpAuthProviderScheme,
31+
context: CallbackContext | None = None,
32+
) -> AuthCredential:
33+
"""Retrieves credentials using the Agent Identity Credentials service.
34+
35+
Args:
36+
auth_scheme: The GcpAuthProviderScheme.
37+
context: Optional context for the callback.
38+
39+
Returns:
40+
An AuthCredential instance.
41+
42+
Raises:
43+
NotImplementedError: Auth provider using Agent Identity Credential service
44+
is not yet supported.
45+
"""
46+
raise NotImplementedError(
47+
"Auth provider using Agent Identity Credential service is not yet"
48+
" supported."
49+
)

src/google/adk/integrations/agent_identity/gcp_auth_provider.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@
1616

1717
from __future__ import annotations
1818

19+
20+
import re
21+
1922
from google.adk.agents.callback_context import CallbackContext
2023
from google.adk.auth.auth_credential import AuthCredential
2124
from google.adk.auth.auth_tool import AuthConfig
2225
from google.adk.auth.base_auth_provider import BaseAuthProvider
2326
from typing_extensions import override
2427

28+
from ._agent_identity_credentials_provider import _AgentIdentityCredentialsProvider
2529
from ._iam_connector_credentials_provider import _IamConnectorCredentialsProvider
2630
from .gcp_auth_provider_scheme import GcpAuthProviderScheme
2731

2832

2933
class GcpAuthProvider(BaseAuthProvider):
30-
"""An auth provider that uses the Agent Identity Credentials service to generate access tokens."""
34+
"""An auth provider that uses Credentials service to generate access tokens."""
3135

3236
def __init__(self):
3337
self._iam_connector_provider = _IamConnectorCredentialsProvider()
38+
self._agent_identity_provider = _AgentIdentityCredentialsProvider()
3439

3540
@property
3641
@override
@@ -43,7 +48,7 @@ async def get_auth_credential(
4348
auth_config: AuthConfig,
4449
context: CallbackContext | None = None,
4550
) -> AuthCredential:
46-
"""Retrieves credentials using the Agent Identity Credentials service.
51+
"""Retrieves credentials using the Credentials service.
4752
4853
Args:
4954
auth_config: The authentication configuration.
@@ -61,6 +66,13 @@ async def get_auth_credential(
6166
f"Expected GcpAuthProviderScheme, got {type(auth_scheme)}"
6267
)
6368

64-
return await self._iam_connector_provider.get_auth_credential(
69+
if re.match(
70+
r"^projects/[^/]+/locations/[^/]+/connectors/[^/]+$", auth_scheme.name
71+
):
72+
return await self._iam_connector_provider.get_auth_credential(
73+
auth_scheme=auth_scheme, context=context
74+
)
75+
76+
return await self._agent_identity_provider.get_auth_credential(
6577
auth_scheme=auth_scheme, context=context
6678
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import Mock
16+
17+
from google.adk.agents.callback_context import CallbackContext
18+
from google.adk.integrations.agent_identity import GcpAuthProviderScheme
19+
from google.adk.integrations.agent_identity._agent_identity_credentials_provider import _AgentIdentityCredentialsProvider
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def auth_scheme():
25+
scheme = GcpAuthProviderScheme(
26+
name="projects/test-project/locations/global/connectors/test-connector",
27+
scopes=["test-scope"],
28+
continue_uri="https://example.com/continue",
29+
)
30+
return scheme
31+
32+
33+
@pytest.fixture
34+
def context():
35+
context = Mock(spec=CallbackContext)
36+
context.user_id = "user"
37+
return context
38+
39+
40+
async def test_get_auth_credential_not_implemented(auth_scheme, context):
41+
"""Verify that get_auth_credential raises NotImplementedError initially."""
42+
provider = _AgentIdentityCredentialsProvider()
43+
with pytest.raises(
44+
NotImplementedError,
45+
match=(
46+
"Auth provider using Agent Identity Credential service is not yet"
47+
" supported."
48+
),
49+
):
50+
await provider.get_auth_credential(auth_scheme, context)

tests/unittests/integrations/agent_identity/test_gcp_auth_provider.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,26 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Unit tests for the GcpAuthProvider class."""
1415

1516
from unittest.mock import AsyncMock
1617
from unittest.mock import Mock
1718
from unittest.mock import patch
1819

20+
1921
from google.adk.agents.callback_context import CallbackContext
2022
from google.adk.auth.auth_credential import AuthCredential
2123
from google.adk.auth.auth_tool import AuthConfig
2224
from google.adk.integrations.agent_identity import GcpAuthProvider
2325
from google.adk.integrations.agent_identity import GcpAuthProviderScheme
24-
from google.adk.integrations.agent_identity._iam_connector_credentials_provider import _IamConnectorCredentialsProvider
2526
import pytest
2627

2728

2829
@pytest.fixture
2930
def auth_config():
30-
scheme = GcpAuthProviderScheme(
31-
name="projects/test-project/locations/global/connectors/test-connector",
32-
scopes=["test-scope"],
33-
continue_uri="https://example.com/continue",
34-
)
35-
return Mock(spec=AuthConfig, auth_scheme=scheme)
31+
config = Mock(spec=AuthConfig)
32+
config.auth_scheme = Mock(spec=GcpAuthProviderScheme)
33+
return config
3634

3735

3836
@pytest.fixture
@@ -43,45 +41,72 @@ def context():
4341

4442

4543
@pytest.fixture
46-
def provider():
44+
def gcp_auth_provider():
4745
return GcpAuthProvider()
4846

4947

50-
def test_supported_auth_schemes(provider):
48+
def test_supported_auth_schemes(gcp_auth_provider):
5149
"""Verify the provider supports the correct auth scheme."""
52-
assert GcpAuthProviderScheme in provider.supported_auth_schemes
50+
assert GcpAuthProviderScheme in gcp_auth_provider.supported_auth_schemes
51+
52+
53+
async def test_get_auth_credential_raises_error_for_invalid_auth_scheme(
54+
context,
55+
):
56+
"""Test get_auth_credential raises ValueError for invalid auth scheme."""
57+
provider = GcpAuthProvider()
58+
invalid_auth_config = Mock(spec=AuthConfig)
59+
invalid_auth_config.auth_scheme = Mock() # Not GcpAuthProviderScheme
60+
61+
with pytest.raises(ValueError, match="Expected GcpAuthProviderScheme, got"):
62+
await provider.get_auth_credential(invalid_auth_config, context)
5363

5464

5565
@patch(
5666
"google.adk.integrations.agent_identity.gcp_auth_provider._IamConnectorCredentialsProvider"
5767
)
58-
async def test_gcp_auth_provider_delegates_get_auth_credential(
59-
mock_provider_class, auth_config, context
68+
async def test_get_auth_credential_routes_to_iam_connector_service_provider(
69+
mock_iam_cls, auth_config, context
6070
):
61-
"""Test that get_auth_credential delegates to the internal provider."""
71+
"""Test routing to IAM Connector Credentials service for legacy auth provider resource names."""
72+
auth_config.auth_scheme.name = (
73+
"projects/test-project/locations/test-location/connectors/test-connector"
74+
)
6275
provider = GcpAuthProvider()
6376

6477
mock_credential = Mock(spec=AuthCredential)
65-
mock_provider_instance = mock_provider_class.return_value
66-
mock_provider_instance.get_auth_credential = AsyncMock(
78+
mock_iam_provider = mock_iam_cls.return_value
79+
mock_iam_provider.get_auth_credential = AsyncMock(
6780
return_value=mock_credential
6881
)
6982

7083
result = await provider.get_auth_credential(auth_config, context)
7184

7285
assert result == mock_credential
73-
mock_provider_instance.get_auth_credential.assert_awaited_once_with(
86+
mock_iam_provider.get_auth_credential.assert_awaited_once_with(
7487
auth_scheme=auth_config.auth_scheme, context=context
7588
)
7689

7790

78-
async def test_get_auth_credential_raises_error_for_invalid_auth_scheme(
79-
context,
91+
@patch(
92+
"google.adk.integrations.agent_identity.gcp_auth_provider._AgentIdentityCredentialsProvider"
93+
)
94+
async def test_get_auth_credential_routes_to_agent_identity_service_provider(
95+
mock_agent_cls, auth_config, context
8096
):
81-
"""Test get_auth_credential raises ValueError for invalid auth scheme."""
97+
"""Test routing to Agent Identity Credentials service for new auth provider resource names."""
98+
auth_config.auth_scheme.name = "projects/test-project/locations/test-location/authProviders/test-provider"
8299
provider = GcpAuthProvider()
83-
invalid_auth_config = Mock(spec=AuthConfig)
84-
invalid_auth_config.auth_scheme = Mock() # Not GcpAuthProviderScheme
85100

86-
with pytest.raises(ValueError, match="Expected GcpAuthProviderScheme, got"):
87-
await provider.get_auth_credential(invalid_auth_config, context)
101+
mock_credential = Mock(spec=AuthCredential)
102+
mock_agent_provider = mock_agent_cls.return_value
103+
mock_agent_provider.get_auth_credential = AsyncMock(
104+
return_value=mock_credential
105+
)
106+
107+
result = await provider.get_auth_credential(auth_config, context)
108+
109+
assert result == mock_credential
110+
mock_agent_provider.get_auth_credential.assert_awaited_once_with(
111+
auth_scheme=auth_config.auth_scheme, context=context
112+
)

0 commit comments

Comments
 (0)