Skip to content

Commit 7f4b4c7

Browse files
yinghsienwucopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 868902411
1 parent af94e53 commit 7f4b4c7

File tree

6 files changed

+219
-137
lines changed

6 files changed

+219
-137
lines changed

google/genai/_api_client.py

Lines changed: 149 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
3030
import math
3131
import os
3232
import random
33+
import requests
3334
import ssl
3435
import sys
3536
import threading
@@ -47,6 +48,7 @@
4748
import httpx
4849
from pydantic import BaseModel
4950
from pydantic import ValidationError
51+
from requests.structures import CaseInsensitiveDict
5052
import tenacity
5153

5254
from . import _common
@@ -59,6 +61,15 @@
5961
from .types import ResourceScope
6062

6163

64+
try:
65+
from google.auth.transport.requests import AuthorizedSession
66+
from google.auth.aio.credentials import StaticCredentials
67+
from google.auth.aio.transport.sessions import AsyncAuthorizedSession
68+
except ImportError:
69+
# This try/except is for TAP
70+
StaticCredentials = None
71+
AsyncAuthorizedSession = None
72+
6273
try:
6374
from websockets.asyncio.client import connect as ws_connect
6475
except ModuleNotFoundError:
@@ -235,7 +246,12 @@ class HttpResponse:
235246

236247
def __init__(
237248
self,
238-
headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
249+
headers: Union[
250+
dict[str, str],
251+
httpx.Headers,
252+
'CIMultiDictProxy[str]',
253+
CaseInsensitiveDict,
254+
],
239255
response_stream: Union[Any, str] = None,
240256
byte_stream: Union[Any, bytes] = None,
241257
):
@@ -245,6 +261,10 @@ def __init__(
245261
self.headers = {
246262
key: ', '.join(headers.get_list(key)) for key in headers.keys()
247263
}
264+
elif isinstance(headers, CaseInsensitiveDict):
265+
self.headers = {
266+
key: value for key,value in headers.items()
267+
}
248268
elif type(headers).__name__ == 'CIMultiDictProxy':
249269
self.headers = {
250270
key: ', '.join(headers.getall(key)) for key in headers.keys()
@@ -321,15 +341,22 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
321341

322342
def _iter_response_stream(self) -> Iterator[str]:
323343
"""Iterates over chunks retrieved from the API."""
324-
if not isinstance(self.response_stream, httpx.Response):
344+
if not (
345+
isinstance(self.response_stream, httpx.Response)
346+
or isinstance(self.response_stream, requests.Response)
347+
):
325348
raise TypeError(
326349
'Expected self.response_stream to be an httpx.Response object, '
327350
f'but got {type(self.response_stream).__name__}.'
328351
)
329352

330353
chunk = ''
331354
balance = 0
332-
for line in self.response_stream.iter_lines():
355+
if isinstance(self.response_stream, httpx.Response):
356+
response_stream = self.response_stream.iter_lines()
357+
else:
358+
response_stream = self.response_stream.iter_lines(decode_unicode=True)
359+
for line in response_stream:
333360
if not line:
334361
continue
335362

@@ -729,8 +756,11 @@ def __init__(
729756
self._http_options
730757
)
731758
self._async_httpx_client_args = async_client_args
759+
self.authorized_session: Optional[AuthorizedSession] = None
732760

733-
if self._http_options.httpx_client:
761+
if self._use_google_auth_sync():
762+
self._httpx_client = None
763+
elif self._http_options.httpx_client:
734764
self._httpx_client = self._http_options.httpx_client
735765
else:
736766
self._httpx_client = SyncHttpxClient(**client_args)
@@ -747,6 +777,7 @@ def __init__(
747777

748778
if self._http_options.aiohttp_client:
749779
self._aiohttp_session = self._http_options.aiohttp_client
780+
self._async_client_session_request_args = {}
750781
else:
751782
# Do it once at the genai.Client level. Share among all requests.
752783
self._async_client_session_request_args = (
@@ -760,13 +791,36 @@ def __init__(
760791
self._retry = tenacity.Retrying(**retry_kwargs)
761792
self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
762793

763-
async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession':
794+
def _use_google_auth_sync(self) -> bool:
795+
return self.vertexai and not (
796+
self._http_options.httpx_client or self._http_options.client_args
797+
)
798+
799+
def _use_google_auth_async(self) -> bool:
800+
return (
801+
StaticCredentials
802+
and AsyncAuthorizedSession
803+
and self.vertexai
804+
and not (
805+
self._http_options.aiohttp_client
806+
or self._http_options.async_client_args
807+
)
808+
)
809+
810+
async def _get_aiohttp_session(
811+
self,
812+
) -> Union['aiohttp.ClientSession', 'AsyncAuthorizedSession']:
764813
"""Returns the aiohttp client session."""
765-
if (
766-
self._aiohttp_session is None
767-
or self._aiohttp_session.closed
768-
or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access
769-
):
814+
815+
# Use aiohttp directly
816+
if self._aiohttp_session is None or (
817+
isinstance(self._aiohttp_session, aiohttp.ClientSession)
818+
and (
819+
self._aiohttp_session.closed
820+
or self._aiohttp_session._loop.is_closed()
821+
)
822+
): # pylint: disable=protected-access
823+
770824
# Initialize the aiohttp client session if it's not set up or closed.
771825
class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc]
772826

@@ -803,6 +857,17 @@ def __del__(self, _warnings: Any = warnings) -> None:
803857
trust_env=True,
804858
read_bufsize=READ_BUFFER_SIZE,
805859
)
860+
# Use google.auth if available.
861+
if self._use_google_auth_async():
862+
token = await self._async_access_token()
863+
async_creds = StaticCredentials(token=token)
864+
auth_request = google.auth.aio.transport.aiohttp.Request(
865+
session=self._aiohttp_session,
866+
)
867+
self._aiohttp_session = AsyncAuthorizedSession(
868+
async_creds, auth_request
869+
)
870+
return self._aiohttp_session
806871
return self._aiohttp_session
807872

808873
@staticmethod
@@ -1191,31 +1256,33 @@ def _request_once(
11911256
else:
11921257
data = http_request.data
11931258

1194-
if stream:
1195-
httpx_request = self._httpx_client.build_request(
1196-
method=http_request.method,
1259+
if self._use_google_auth_sync():
1260+
if self.authorized_session is None:
1261+
self.authorized_session = AuthorizedSession(
1262+
self._credentials,
1263+
max_refresh_attempts=1,
1264+
)
1265+
response = self.authorized_session.request(
1266+
method=http_request.method.upper(),
11971267
url=http_request.url,
1198-
content=data,
1268+
data=data,
11991269
headers=http_request.headers,
12001270
timeout=http_request.timeout,
1201-
)
1202-
response = self._httpx_client.send(httpx_request, stream=stream)
1203-
errors.APIError.raise_for_response(response)
1204-
return HttpResponse(
1205-
response.headers, response if stream else [response.text]
1271+
stream=stream,
12061272
)
12071273
else:
1208-
response = self._httpx_client.request(
1274+
httpx_request = self._httpx_client.build_request(
12091275
method=http_request.method,
12101276
url=http_request.url,
1211-
headers=http_request.headers,
12121277
content=data,
1278+
headers=http_request.headers,
12131279
timeout=http_request.timeout,
12141280
)
1215-
errors.APIError.raise_for_response(response)
1216-
return HttpResponse(
1217-
response.headers, response if stream else [response.text]
1218-
)
1281+
response = self._httpx_client.send(httpx_request, stream=stream)
1282+
errors.APIError.raise_for_response(response)
1283+
return HttpResponse(
1284+
response.headers, response if stream else [response.text]
1285+
)
12191286

12201287
def _request(
12211288
self,
@@ -1259,107 +1326,70 @@ async def _async_request_once(
12591326
else:
12601327
data = http_request.data
12611328

1262-
if stream:
1263-
if self._use_aiohttp():
1264-
self._aiohttp_session = await self._get_aiohttp_session()
1265-
try:
1266-
response = await self._aiohttp_session.request(
1267-
method=http_request.method,
1268-
url=http_request.url,
1269-
headers=http_request.headers,
1270-
data=data,
1271-
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1272-
**self._async_client_session_request_args,
1273-
)
1274-
except (
1275-
aiohttp.ClientConnectorError,
1276-
aiohttp.ClientConnectorDNSError,
1277-
aiohttp.ClientOSError,
1278-
aiohttp.ServerDisconnectedError,
1279-
) as e:
1280-
await asyncio.sleep(1 + random.randint(0, 9))
1281-
logger.info('Retrying due to aiohttp error: %s' % e)
1282-
# Retrieve the SSL context from the session.
1283-
self._async_client_session_request_args = (
1284-
self._ensure_aiohttp_ssl_ctx(self._http_options)
1285-
)
1286-
# Instantiate a new session with the updated SSL context.
1287-
self._aiohttp_session = await self._get_aiohttp_session()
1288-
response = await self._aiohttp_session.request(
1289-
method=http_request.method,
1290-
url=http_request.url,
1291-
headers=http_request.headers,
1292-
data=data,
1293-
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1294-
**self._async_client_session_request_args,
1295-
)
1296-
1297-
await errors.APIError.raise_for_async_response(response)
1298-
return HttpResponse(response.headers, response)
1299-
else:
1300-
# aiohttp is not available. Fall back to httpx.
1301-
httpx_request = self._async_httpx_client.build_request(
1329+
if self._use_aiohttp():
1330+
self._aiohttp_session = await self._get_aiohttp_session()
1331+
if self._use_google_auth_async():
1332+
self._async_client_session_request_args['max_allowed_time'] = float(
1333+
'inf'
1334+
)
1335+
self._async_client_session_request_args['total_attempts'] = 1
1336+
try:
1337+
response = await self._aiohttp_session.request(
13021338
method=http_request.method,
13031339
url=http_request.url,
1304-
content=data,
13051340
headers=http_request.headers,
1306-
timeout=http_request.timeout,
1341+
data=data,
1342+
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1343+
**self._async_client_session_request_args,
13071344
)
1308-
client_response = await self._async_httpx_client.send(
1309-
httpx_request,
1310-
stream=stream,
1345+
except (
1346+
aiohttp.ClientConnectorError,
1347+
aiohttp.ClientConnectorDNSError,
1348+
aiohttp.ClientOSError,
1349+
aiohttp.ServerDisconnectedError,
1350+
) as e:
1351+
await asyncio.sleep(1 + random.randint(0, 9))
1352+
logger.info('Retrying due to aiohttp error: %s' % e)
1353+
# Retrieve the SSL context from the session.
1354+
self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
1355+
self._http_options
13111356
)
1312-
await errors.APIError.raise_for_async_response(client_response)
1313-
return HttpResponse(client_response.headers, client_response)
1314-
else:
1315-
if self._use_aiohttp():
1357+
# Instantiate a new session with the updated SSL context.
13161358
self._aiohttp_session = await self._get_aiohttp_session()
1317-
try:
1318-
response = await self._aiohttp_session.request(
1319-
method=http_request.method,
1320-
url=http_request.url,
1321-
headers=http_request.headers,
1322-
data=data,
1323-
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1324-
**self._async_client_session_request_args,
1325-
)
1326-
await errors.APIError.raise_for_async_response(response)
1327-
return HttpResponse(response.headers, [await response.text()])
1328-
except (
1329-
aiohttp.ClientConnectorError,
1330-
aiohttp.ClientConnectorDNSError,
1331-
aiohttp.ClientOSError,
1332-
aiohttp.ServerDisconnectedError,
1333-
) as e:
1334-
await asyncio.sleep(1 + random.randint(0, 9))
1335-
logger.info('Retrying due to aiohttp error: %s' % e)
1336-
# Retrieve the SSL context from the session.
1337-
self._async_client_session_request_args = (
1338-
self._ensure_aiohttp_ssl_ctx(self._http_options)
1339-
)
1340-
# Instantiate a new session with the updated SSL context.
1341-
self._aiohttp_session = await self._get_aiohttp_session()
1342-
response = await self._aiohttp_session.request(
1343-
method=http_request.method,
1344-
url=http_request.url,
1345-
headers=http_request.headers,
1346-
data=data,
1347-
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1348-
**self._async_client_session_request_args,
1349-
)
1350-
await errors.APIError.raise_for_async_response(response)
1351-
return HttpResponse(response.headers, [await response.text()])
1352-
else:
1353-
# aiohttp is not available. Fall back to httpx.
1354-
client_response = await self._async_httpx_client.request(
1359+
response = await self._aiohttp_session.request(
13551360
method=http_request.method,
13561361
url=http_request.url,
13571362
headers=http_request.headers,
1358-
content=data,
1359-
timeout=http_request.timeout,
1363+
data=data,
1364+
timeout=aiohttp.ClientTimeout(total=http_request.timeout),
1365+
**self._async_client_session_request_args,
13601366
)
1361-
await errors.APIError.raise_for_async_response(client_response)
1362-
return HttpResponse(client_response.headers, [client_response.text])
1367+
await errors.APIError.raise_for_async_response(response)
1368+
if self._use_google_auth_async() and response:
1369+
# Extract the underlying aiohttp.ClientResponse from the
1370+
# AsyncAuthorizedSession Response.
1371+
response = response._response
1372+
return HttpResponse(
1373+
response.headers, response if stream else [await response.text()]
1374+
)
1375+
else:
1376+
# aiohttp is not available. Fall back to httpx.
1377+
httpx_request = self._async_httpx_client.build_request(
1378+
method=http_request.method,
1379+
url=http_request.url,
1380+
content=data,
1381+
headers=http_request.headers,
1382+
timeout=http_request.timeout,
1383+
)
1384+
client_response = await self._async_httpx_client.send(
1385+
httpx_request,
1386+
stream=stream,
1387+
)
1388+
await errors.APIError.raise_for_async_response(client_response)
1389+
return HttpResponse(
1390+
client_response.headers,
1391+
client_response if stream else [client_response.text],
1392+
)
13631393

13641394
async def _async_request(
13651395
self,
@@ -1910,7 +1940,7 @@ def close(self) -> None:
19101940
"""Closes the API client."""
19111941
# Let users close the custom client explicitly by themselves. Otherwise,
19121942
# close the client when the object is garbage collected.
1913-
if not self._http_options.httpx_client:
1943+
if not self._http_options.httpx_client and self._httpx_client:
19141944
self._httpx_client.close()
19151945

19161946
async def aclose(self) -> None:
@@ -1952,6 +1982,7 @@ def get_token_from_credentials(
19521982
raise RuntimeError('Could not resolve API token from the environment')
19531983
return credentials.token # type: ignore[no-any-return]
19541984

1985+
19551986
async def async_get_token_from_credentials(
19561987
client: 'BaseApiClient',
19571988
credentials: google.auth.credentials.Credentials

0 commit comments

Comments
 (0)