Skip to content

Commit 8d2b219

Browse files
committed
feat(storage): add async credential wrapper
1 parent 4e91c54 commit 8d2b219

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Async Wrapper around Google Auth Credentials"""
2+
3+
import asyncio
4+
from google.auth.aio import credentials as async_creds
5+
from google.auth.transport.requests import Request
6+
7+
class AsyncCredsWrapper(async_creds.Credentials):
8+
"""Wraps synchronous Google Auth credentials to provide an asynchronous interface.
9+
10+
This class adapts standard synchronous `google.auth.credentials.Credentials` for use
11+
in asynchronous contexts. It offloads blocking operations, such as token refreshes,
12+
to a separate thread using `asyncio.loop.run_in_executor`.
13+
14+
Args:
15+
sync_creds (google.auth.credentials.Credentials): The synchronous credentials
16+
instance to wrap.
17+
"""
18+
19+
def __init__(self, sync_creds):
20+
super().__init__()
21+
self.creds = sync_creds
22+
23+
async def refresh(self, _request):
24+
"""Refreshes the access token."""
25+
loop = asyncio.get_running_loop()
26+
await loop.run_in_executor(
27+
None, self.creds.refresh, Request()
28+
)
29+
30+
@property
31+
def valid(self):
32+
"""Checks the validity of the credentials."""
33+
return self.creds.valid
34+
35+
async def before_request(self, _request, method, url, headers):
36+
"""Performs credential-specific before request logic."""
37+
if self.valid:
38+
self.creds.apply(headers)
39+
return
40+
41+
loop = asyncio.get_running_loop()
42+
await loop.run_in_executor(
43+
None, self.creds.before_request, Request(), method, url, headers
44+
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest.mock
2+
import pytest
3+
from google.auth import credentials as google_creds
4+
from google.auth.transport.requests import Request
5+
from google.cloud.storage._experimental.asyncio.async_creds import AsyncCredsWrapper
6+
7+
8+
@pytest.fixture
9+
def mock_sync_creds():
10+
"""Creates a mock of the synchronous Google Credentials object."""
11+
creds = unittest.mock.create_autospec(google_creds.Credentials, instance=True)
12+
type(creds).valid = unittest.mock.PropertyMock(return_value=True)
13+
return creds
14+
15+
@pytest.fixture
16+
def async_wrapper(mock_sync_creds):
17+
"""Instantiates the wrapper with the mock credentials."""
18+
return AsyncCredsWrapper(mock_sync_creds)
19+
20+
21+
class TestAsyncCredsWrapper:
22+
23+
@pytest.mark.asyncio
24+
async def test_init_sets_attributes(self, async_wrapper, mock_sync_creds):
25+
"""Test that the wrapper initializes correctly."""
26+
assert async_wrapper.creds == mock_sync_creds
27+
28+
@pytest.mark.asyncio
29+
async def test_valid_property_delegates(self, async_wrapper, mock_sync_creds):
30+
"""Test that the .valid property maps to the sync creds .valid property."""
31+
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True)
32+
assert async_wrapper.valid is True
33+
34+
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False)
35+
assert async_wrapper.valid is False
36+
37+
@pytest.mark.asyncio
38+
async def test_refresh_offloads_to_executor(self, async_wrapper, mock_sync_creds):
39+
"""Test that refresh() gets the running loop and calls sync refresh in executor."""
40+
with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop:
41+
mock_loop = unittest.mock.AsyncMock()
42+
mock_get_loop.return_value = mock_loop
43+
44+
await async_wrapper.refresh(None)
45+
46+
mock_loop.run_in_executor.assert_called_once()
47+
48+
args, _ = mock_loop.run_in_executor.call_args
49+
assert args[1] == mock_sync_creds.refresh
50+
assert isinstance(args[2], Request)
51+
52+
@pytest.mark.asyncio
53+
async def test_before_request_valid_creds(self, async_wrapper, mock_sync_creds):
54+
"""Test before_request when credentials are ALREADY valid (fast path)."""
55+
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True)
56+
57+
headers = {}
58+
await async_wrapper.before_request(None, "GET", "http://example.com", headers)
59+
60+
# Should call apply() directly on sync creds
61+
mock_sync_creds.apply.assert_called_once_with(headers)
62+
63+
# Should NOT call before_request on sync creds
64+
mock_sync_creds.before_request.assert_not_called()
65+
66+
@pytest.mark.asyncio
67+
async def test_before_request_invalid_creds(self, async_wrapper, mock_sync_creds):
68+
"""Test before_request when credentials are INVALID (refresh path)."""
69+
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False)
70+
71+
headers = {}
72+
method = "GET"
73+
url = "http://example.com"
74+
75+
with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop:
76+
mock_loop = unittest.mock.AsyncMock()
77+
mock_get_loop.return_value = mock_loop
78+
79+
await async_wrapper.before_request(None, method, url, headers)
80+
81+
mock_loop.run_in_executor.assert_called_once()
82+
83+
args, _ = mock_loop.run_in_executor.call_args
84+
assert args[1] == mock_sync_creds.before_request
85+
assert isinstance(args[2], Request)
86+
assert args[3] == method
87+
assert args[4] == url
88+
assert args[5] == headers

0 commit comments

Comments
 (0)