1- # Copyright 2025 Google LLC
1+ # Copyright 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
3030import math
3131import os
3232import random
33+ import requests
3334import ssl
3435import sys
3536import threading
4748import httpx
4849from pydantic import BaseModel
4950from pydantic import ValidationError
51+ from requests .structures import CaseInsensitiveDict
5052import tenacity
5153
5254from . import _common
5961from .types import ResourceScope
6062
6163
64+ try :
65+ from google .auth .transport .requests import AuthorizedSession
66+ from google .auth .aio .credentials import StaticCredentials
67+ from google .auth .aio .transport .sessions import AsyncAuthorizedSession
68+ except ImportError :
69+ # This try/except is for TAP
70+ StaticCredentials = None
71+ AsyncAuthorizedSession = None
72+
6273try :
6374 from websockets .asyncio .client import connect as ws_connect
6475except ModuleNotFoundError :
@@ -235,7 +246,12 @@ class HttpResponse:
235246
236247 def __init__ (
237248 self ,
238- headers : Union [dict [str , str ], httpx .Headers , 'CIMultiDictProxy[str]' ],
249+ headers : Union [
250+ dict [str , str ],
251+ httpx .Headers ,
252+ 'CIMultiDictProxy[str]' ,
253+ CaseInsensitiveDict ,
254+ ],
239255 response_stream : Union [Any , str ] = None ,
240256 byte_stream : Union [Any , bytes ] = None ,
241257 ):
@@ -245,6 +261,10 @@ def __init__(
245261 self .headers = {
246262 key : ', ' .join (headers .get_list (key )) for key in headers .keys ()
247263 }
264+ elif isinstance (headers , CaseInsensitiveDict ):
265+ self .headers = {
266+ key : value for key ,value in headers .items ()
267+ }
248268 elif type (headers ).__name__ == 'CIMultiDictProxy' :
249269 self .headers = {
250270 key : ', ' .join (headers .getall (key )) for key in headers .keys ()
@@ -321,15 +341,22 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
321341
322342 def _iter_response_stream (self ) -> Iterator [str ]:
323343 """Iterates over chunks retrieved from the API."""
324- if not isinstance (self .response_stream , httpx .Response ):
344+ if not (
345+ isinstance (self .response_stream , httpx .Response )
346+ or isinstance (self .response_stream , requests .Response )
347+ ):
325348 raise TypeError (
326349 'Expected self.response_stream to be an httpx.Response object, '
327350 f'but got { type (self .response_stream ).__name__ } .'
328351 )
329352
330353 chunk = ''
331354 balance = 0
332- for line in self .response_stream .iter_lines ():
355+ if isinstance (self .response_stream , httpx .Response ):
356+ response_stream = self .response_stream .iter_lines ()
357+ else :
358+ response_stream = self .response_stream .iter_lines (decode_unicode = True )
359+ for line in response_stream :
333360 if not line :
334361 continue
335362
@@ -729,8 +756,11 @@ def __init__(
729756 self ._http_options
730757 )
731758 self ._async_httpx_client_args = async_client_args
759+ self .authorized_session : Optional [AuthorizedSession ] = None
732760
733- if self ._http_options .httpx_client :
761+ if self ._use_google_auth_sync ():
762+ self ._httpx_client = None
763+ elif self ._http_options .httpx_client :
734764 self ._httpx_client = self ._http_options .httpx_client
735765 else :
736766 self ._httpx_client = SyncHttpxClient (** client_args )
@@ -747,6 +777,7 @@ def __init__(
747777
748778 if self ._http_options .aiohttp_client :
749779 self ._aiohttp_session = self ._http_options .aiohttp_client
780+ self ._async_client_session_request_args = {}
750781 else :
751782 # Do it once at the genai.Client level. Share among all requests.
752783 self ._async_client_session_request_args = (
@@ -760,13 +791,36 @@ def __init__(
760791 self ._retry = tenacity .Retrying (** retry_kwargs )
761792 self ._async_retry = tenacity .AsyncRetrying (** retry_kwargs )
762793
763- async def _get_aiohttp_session (self ) -> 'aiohttp.ClientSession' :
794+ def _use_google_auth_sync (self ) -> bool :
795+ return self .vertexai and not (
796+ self ._http_options .httpx_client or self ._http_options .client_args
797+ )
798+
799+ def _use_google_auth_async (self ) -> bool :
800+ return (
801+ StaticCredentials
802+ and AsyncAuthorizedSession
803+ and self .vertexai
804+ and not (
805+ self ._http_options .aiohttp_client
806+ or self ._http_options .async_client_args
807+ )
808+ )
809+
810+ async def _get_aiohttp_session (
811+ self ,
812+ ) -> Union ['aiohttp.ClientSession' , 'AsyncAuthorizedSession' ]:
764813 """Returns the aiohttp client session."""
765- if (
766- self ._aiohttp_session is None
767- or self ._aiohttp_session .closed
768- or self ._aiohttp_session ._loop .is_closed () # pylint: disable=protected-access
769- ):
814+
815+ # Use aiohttp directly
816+ if self ._aiohttp_session is None or (
817+ isinstance (self ._aiohttp_session , aiohttp .ClientSession )
818+ and (
819+ self ._aiohttp_session .closed
820+ or self ._aiohttp_session ._loop .is_closed ()
821+ )
822+ ): # pylint: disable=protected-access
823+
770824 # Initialize the aiohttp client session if it's not set up or closed.
771825 class AiohttpClientSession (aiohttp .ClientSession ): # type: ignore[misc]
772826
@@ -803,6 +857,17 @@ def __del__(self, _warnings: Any = warnings) -> None:
803857 trust_env = True ,
804858 read_bufsize = READ_BUFFER_SIZE ,
805859 )
860+ # Use google.auth if available.
861+ if self ._use_google_auth_async ():
862+ token = await self ._async_access_token ()
863+ async_creds = StaticCredentials (token = token )
864+ auth_request = google .auth .aio .transport .aiohttp .Request (
865+ session = self ._aiohttp_session ,
866+ )
867+ self ._aiohttp_session = AsyncAuthorizedSession (
868+ async_creds , auth_request
869+ )
870+ return self ._aiohttp_session
806871 return self ._aiohttp_session
807872
808873 @staticmethod
@@ -1191,31 +1256,33 @@ def _request_once(
11911256 else :
11921257 data = http_request .data
11931258
1194- if stream :
1195- httpx_request = self ._httpx_client .build_request (
1196- method = http_request .method ,
1259+ if self ._use_google_auth_sync ():
1260+ if self .authorized_session is None :
1261+ self .authorized_session = AuthorizedSession (
1262+ self ._credentials ,
1263+ max_refresh_attempts = 1 ,
1264+ )
1265+ response = self .authorized_session .request (
1266+ method = http_request .method .upper (),
11971267 url = http_request .url ,
1198- content = data ,
1268+ data = data ,
11991269 headers = http_request .headers ,
12001270 timeout = http_request .timeout ,
1201- )
1202- response = self ._httpx_client .send (httpx_request , stream = stream )
1203- errors .APIError .raise_for_response (response )
1204- return HttpResponse (
1205- response .headers , response if stream else [response .text ]
1271+ stream = stream ,
12061272 )
12071273 else :
1208- response = self ._httpx_client .request (
1274+ httpx_request = self ._httpx_client .build_request (
12091275 method = http_request .method ,
12101276 url = http_request .url ,
1211- headers = http_request .headers ,
12121277 content = data ,
1278+ headers = http_request .headers ,
12131279 timeout = http_request .timeout ,
12141280 )
1215- errors .APIError .raise_for_response (response )
1216- return HttpResponse (
1217- response .headers , response if stream else [response .text ]
1218- )
1281+ response = self ._httpx_client .send (httpx_request , stream = stream )
1282+ errors .APIError .raise_for_response (response )
1283+ return HttpResponse (
1284+ response .headers , response if stream else [response .text ]
1285+ )
12191286
12201287 def _request (
12211288 self ,
@@ -1259,107 +1326,70 @@ async def _async_request_once(
12591326 else :
12601327 data = http_request .data
12611328
1262- if stream :
1263- if self ._use_aiohttp ():
1264- self ._aiohttp_session = await self ._get_aiohttp_session ()
1265- try :
1266- response = await self ._aiohttp_session .request (
1267- method = http_request .method ,
1268- url = http_request .url ,
1269- headers = http_request .headers ,
1270- data = data ,
1271- timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1272- ** self ._async_client_session_request_args ,
1273- )
1274- except (
1275- aiohttp .ClientConnectorError ,
1276- aiohttp .ClientConnectorDNSError ,
1277- aiohttp .ClientOSError ,
1278- aiohttp .ServerDisconnectedError ,
1279- ) as e :
1280- await asyncio .sleep (1 + random .randint (0 , 9 ))
1281- logger .info ('Retrying due to aiohttp error: %s' % e )
1282- # Retrieve the SSL context from the session.
1283- self ._async_client_session_request_args = (
1284- self ._ensure_aiohttp_ssl_ctx (self ._http_options )
1285- )
1286- # Instantiate a new session with the updated SSL context.
1287- self ._aiohttp_session = await self ._get_aiohttp_session ()
1288- response = await self ._aiohttp_session .request (
1289- method = http_request .method ,
1290- url = http_request .url ,
1291- headers = http_request .headers ,
1292- data = data ,
1293- timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1294- ** self ._async_client_session_request_args ,
1295- )
1296-
1297- await errors .APIError .raise_for_async_response (response )
1298- return HttpResponse (response .headers , response )
1299- else :
1300- # aiohttp is not available. Fall back to httpx.
1301- httpx_request = self ._async_httpx_client .build_request (
1329+ if self ._use_aiohttp ():
1330+ self ._aiohttp_session = await self ._get_aiohttp_session ()
1331+ if self ._use_google_auth_async ():
1332+ self ._async_client_session_request_args ['max_allowed_time' ] = float (
1333+ 'inf'
1334+ )
1335+ self ._async_client_session_request_args ['total_attempts' ] = 1
1336+ try :
1337+ response = await self ._aiohttp_session .request (
13021338 method = http_request .method ,
13031339 url = http_request .url ,
1304- content = data ,
13051340 headers = http_request .headers ,
1306- timeout = http_request .timeout ,
1341+ data = data ,
1342+ timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1343+ ** self ._async_client_session_request_args ,
13071344 )
1308- client_response = await self ._async_httpx_client .send (
1309- httpx_request ,
1310- stream = stream ,
1345+ except (
1346+ aiohttp .ClientConnectorError ,
1347+ aiohttp .ClientConnectorDNSError ,
1348+ aiohttp .ClientOSError ,
1349+ aiohttp .ServerDisconnectedError ,
1350+ ) as e :
1351+ await asyncio .sleep (1 + random .randint (0 , 9 ))
1352+ logger .info ('Retrying due to aiohttp error: %s' % e )
1353+ # Retrieve the SSL context from the session.
1354+ self ._async_client_session_request_args = self ._ensure_aiohttp_ssl_ctx (
1355+ self ._http_options
13111356 )
1312- await errors .APIError .raise_for_async_response (client_response )
1313- return HttpResponse (client_response .headers , client_response )
1314- else :
1315- if self ._use_aiohttp ():
1357+ # Instantiate a new session with the updated SSL context.
13161358 self ._aiohttp_session = await self ._get_aiohttp_session ()
1317- try :
1318- response = await self ._aiohttp_session .request (
1319- method = http_request .method ,
1320- url = http_request .url ,
1321- headers = http_request .headers ,
1322- data = data ,
1323- timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1324- ** self ._async_client_session_request_args ,
1325- )
1326- await errors .APIError .raise_for_async_response (response )
1327- return HttpResponse (response .headers , [await response .text ()])
1328- except (
1329- aiohttp .ClientConnectorError ,
1330- aiohttp .ClientConnectorDNSError ,
1331- aiohttp .ClientOSError ,
1332- aiohttp .ServerDisconnectedError ,
1333- ) as e :
1334- await asyncio .sleep (1 + random .randint (0 , 9 ))
1335- logger .info ('Retrying due to aiohttp error: %s' % e )
1336- # Retrieve the SSL context from the session.
1337- self ._async_client_session_request_args = (
1338- self ._ensure_aiohttp_ssl_ctx (self ._http_options )
1339- )
1340- # Instantiate a new session with the updated SSL context.
1341- self ._aiohttp_session = await self ._get_aiohttp_session ()
1342- response = await self ._aiohttp_session .request (
1343- method = http_request .method ,
1344- url = http_request .url ,
1345- headers = http_request .headers ,
1346- data = data ,
1347- timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1348- ** self ._async_client_session_request_args ,
1349- )
1350- await errors .APIError .raise_for_async_response (response )
1351- return HttpResponse (response .headers , [await response .text ()])
1352- else :
1353- # aiohttp is not available. Fall back to httpx.
1354- client_response = await self ._async_httpx_client .request (
1359+ response = await self ._aiohttp_session .request (
13551360 method = http_request .method ,
13561361 url = http_request .url ,
13571362 headers = http_request .headers ,
1358- content = data ,
1359- timeout = http_request .timeout ,
1363+ data = data ,
1364+ timeout = aiohttp .ClientTimeout (total = http_request .timeout ),
1365+ ** self ._async_client_session_request_args ,
13601366 )
1361- await errors .APIError .raise_for_async_response (client_response )
1362- return HttpResponse (client_response .headers , [client_response .text ])
1367+ await errors .APIError .raise_for_async_response (response )
1368+ if self ._use_google_auth_async () and response :
1369+ # Extract the underlying aiohttp.ClientResponse from the
1370+ # AsyncAuthorizedSession Response.
1371+ response = response ._response
1372+ return HttpResponse (
1373+ response .headers , response if stream else [await response .text ()]
1374+ )
1375+ else :
1376+ # aiohttp is not available. Fall back to httpx.
1377+ httpx_request = self ._async_httpx_client .build_request (
1378+ method = http_request .method ,
1379+ url = http_request .url ,
1380+ content = data ,
1381+ headers = http_request .headers ,
1382+ timeout = http_request .timeout ,
1383+ )
1384+ client_response = await self ._async_httpx_client .send (
1385+ httpx_request ,
1386+ stream = stream ,
1387+ )
1388+ await errors .APIError .raise_for_async_response (client_response )
1389+ return HttpResponse (
1390+ client_response .headers ,
1391+ client_response if stream else [client_response .text ],
1392+ )
13631393
13641394 async def _async_request (
13651395 self ,
@@ -1910,7 +1940,7 @@ def close(self) -> None:
19101940 """Closes the API client."""
19111941 # Let users close the custom client explicitly by themselves. Otherwise,
19121942 # close the client when the object is garbage collected.
1913- if not self ._http_options .httpx_client :
1943+ if not self ._http_options .httpx_client and self . _httpx_client :
19141944 self ._httpx_client .close ()
19151945
19161946 async def aclose (self ) -> None :
@@ -1952,6 +1982,7 @@ def get_token_from_credentials(
19521982 raise RuntimeError ('Could not resolve API token from the environment' )
19531983 return credentials .token # type: ignore[no-any-return]
19541984
1985+
19551986async def async_get_token_from_credentials (
19561987 client : 'BaseApiClient' ,
19571988 credentials : google .auth .credentials .Credentials
0 commit comments