Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions google/cloud/storage/_experimental/asyncio/async_creds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Async Wrapper around Google Auth Credentials"""

import asyncio
from google.auth.transport.requests import Request

try:
from google.auth.aio import credentials as aio_creds_module
BaseCredentials = aio_creds_module.Credentials
_AIO_AVAILABLE = True
except ImportError:
BaseCredentials = object
_AIO_AVAILABLE = False

class AsyncCredsWrapper(BaseCredentials):
"""Wraps synchronous Google Auth credentials to provide an asynchronous interface.

Args:
sync_creds (google.auth.credentials.Credentials): The synchronous credentials
instance to wrap.

Raises:
ImportError: If instantiated in an environment where 'google.auth.aio'
is not available.
"""

def __init__(self, sync_creds):
if not _AIO_AVAILABLE:
raise ImportError(
"Failed to import 'google.auth.aio'. This module requires a newer version "
"of 'google-auth' which supports asyncio."
)

super().__init__()
self.creds = sync_creds

async def refresh(self, request):
"""Refreshes the access token."""
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.creds.refresh, Request()
)

@property
def valid(self):
"""Checks the validity of the credentials."""
return self.creds.valid

async def before_request(self, request, method, url, headers):
"""Performs credential-specific before request logic."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sentence's meaning is not coming out clearly.

if self.valid:
self.creds.apply(headers)
return

loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.creds.before_request, Request(), method, url, headers
)
108 changes: 108 additions & 0 deletions tests/unit/asyncio/test_async_creds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import sys
import unittest.mock
import pytest
from google.auth import credentials as google_creds
from google.cloud.storage._experimental.asyncio import async_creds

@pytest.fixture
def mock_aio_modules():
"""Patches sys.modules to simulate google.auth.aio existence."""
mock_creds_module = unittest.mock.MagicMock()
# We must set the base class to object so our wrapper can inherit safely in tests
mock_creds_module.Credentials = object

modules = {
'google.auth.aio': unittest.mock.MagicMock(),
'google.auth.aio.credentials': mock_creds_module,
}

with unittest.mock.patch.dict(sys.modules, modules):
# We also need to manually flip the flag in the module to True for the test context
# because the module was likely already imported with the flag set to False/True
# depending on the real environment.
with unittest.mock.patch.object(async_creds, '_AIO_AVAILABLE', True):
# We also need to ensure BaseCredentials in the module points to our mock
# if we want strictly correct inheritance, though duck typing usually suffices.
with unittest.mock.patch.object(async_creds, 'BaseCredentials', object):
yield

@pytest.fixture
def mock_sync_creds():
"""Creates a mock of the synchronous Google Credentials object."""
creds = unittest.mock.create_autospec(google_creds.Credentials, instance=True)
type(creds).valid = unittest.mock.PropertyMock(return_value=True)
return creds

@pytest.fixture
def async_wrapper(mock_aio_modules, mock_sync_creds):
"""Instantiates the wrapper with the mock credentials."""
# This instantiation would raise ImportError if mock_aio_modules didn't set _AIO_AVAILABLE=True
return async_creds.AsyncCredsWrapper(mock_sync_creds)

class TestAsyncCredsWrapper:

@pytest.mark.asyncio
async def test_init_sets_attributes(self, async_wrapper, mock_sync_creds):
"""Test that the wrapper initializes correctly."""
assert async_wrapper.creds == mock_sync_creds

@pytest.mark.asyncio
async def test_valid_property_delegates(self, async_wrapper, mock_sync_creds):
"""Test that the .valid property maps to the sync creds .valid property."""
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True)
assert async_wrapper.valid is True

type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False)
assert async_wrapper.valid is False

@pytest.mark.asyncio
async def test_refresh_offloads_to_executor(self, async_wrapper, mock_sync_creds):
"""Test that refresh() gets the running loop and calls sync refresh in executor."""
with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop:
mock_loop = unittest.mock.AsyncMock()
mock_get_loop.return_value = mock_loop

await async_wrapper.refresh(None)

mock_loop.run_in_executor.assert_called_once()
args, _ = mock_loop.run_in_executor.call_args
assert args[1] == mock_sync_creds.refresh

@pytest.mark.asyncio
async def test_before_request_valid_creds(self, async_wrapper, mock_sync_creds):
"""Test before_request when credentials are ALREADY valid."""
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True)

headers = {}
await async_wrapper.before_request(None, "GET", "http://example.com", headers)

mock_sync_creds.apply.assert_called_once_with(headers)
mock_sync_creds.before_request.assert_not_called()

@pytest.mark.asyncio
async def test_before_request_invalid_creds(self, async_wrapper, mock_sync_creds):
"""Test before_request when credentials are INVALID (refresh path)."""
type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False)

headers = {}
method = "GET"
url = "http://example.com"

with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop:
mock_loop = unittest.mock.AsyncMock()
mock_get_loop.return_value = mock_loop

await async_wrapper.before_request(None, method, url, headers)

mock_loop.run_in_executor.assert_called_once()
args, _ = mock_loop.run_in_executor.call_args
assert args[1] == mock_sync_creds.before_request

def test_missing_aio_raises_error(self, mock_sync_creds):
"""Ensure ImportError is raised if _AIO_AVAILABLE is False."""
# We manually simulate the environment where AIO is missing
with unittest.mock.patch.object(async_creds, '_AIO_AVAILABLE', False):
with pytest.raises(ImportError) as excinfo:
async_creds.AsyncCredsWrapper(mock_sync_creds)

assert "Failed to import 'google.auth.aio'" in str(excinfo.value)