1616import asyncio
1717import google_crc32c
1818from google .api_core import exceptions
19- from google_crc32c import Checksum
19+ from google . api_core . retry_async import AsyncRetry
2020
21- from typing import List , Optional , Tuple
21+ from typing import List , Optional , Tuple , Any , Dict
2222
2323from google .cloud .storage ._experimental .asyncio .async_read_object_stream import (
2424 _AsyncReadObjectStream ,
2525)
2626from google .cloud .storage ._experimental .asyncio .async_grpc_client import (
2727 AsyncGrpcClient ,
2828)
29+ from google .cloud .storage ._experimental .asyncio .retry .bidi_stream_retry_manager import (
30+ _BidiStreamRetryManager ,
31+ )
32+ from google .cloud .storage ._experimental .asyncio .retry .reads_resumption_strategy import (
33+ _ReadResumptionStrategy ,
34+ _DownloadState ,
35+ )
2936
3037from io import BytesIO
3138from google .cloud import _storage_v2
32- from google .cloud .storage .exceptions import DataCorruption
3339from google .cloud .storage ._helpers import generate_random_56_bit_integer
3440
3541
3642_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
3743
3844
39- class Result :
40- """An instance of this class will be populated and retured for each
41- `read_range` provided to ``download_ranges`` method.
42-
43- """
44-
45- def __init__ (self , bytes_requested : int ):
46- # only while instantiation, should not be edited later.
47- # hence there's no setter, only getter is provided.
48- self ._bytes_requested : int = bytes_requested
49- self ._bytes_written : int = 0
50-
51- @property
52- def bytes_requested (self ) -> int :
53- return self ._bytes_requested
54-
55- @property
56- def bytes_written (self ) -> int :
57- return self ._bytes_written
58-
59- @bytes_written .setter
60- def bytes_written (self , value : int ):
61- self ._bytes_written = value
62-
63- def __repr__ (self ):
64- return f"bytes_requested: { self ._bytes_requested } , bytes_written: { self ._bytes_written } "
65-
66-
6745class AsyncMultiRangeDownloader :
6846 """Provides an interface for downloading multiple ranges of a GCS ``Object``
6947 concurrently.
@@ -104,6 +82,7 @@ async def create_mrd(
10482 object_name : str ,
10583 generation_number : Optional [int ] = None ,
10684 read_handle : Optional [bytes ] = None ,
85+ retry_policy : Optional [AsyncRetry ] = None ,
10786 ) -> AsyncMultiRangeDownloader :
10887 """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC
10988 object for reading.
@@ -125,11 +104,14 @@ async def create_mrd(
125104 :param read_handle: (Optional) An existing handle for reading the object.
126105 If provided, opening the bidi-gRPC connection will be faster.
127106
107+ :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
108+ :param retry_policy: (Optional) The retry policy to use for the ``open`` operation.
109+
128110 :rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader`
129111 :returns: An initialized AsyncMultiRangeDownloader instance for reading.
130112 """
131113 mrd = cls (client , bucket_name , object_name , generation_number , read_handle )
132- await mrd .open ()
114+ await mrd .open (retry_policy = retry_policy )
133115 return mrd
134116
135117 def __init__ (
@@ -177,11 +159,7 @@ def __init__(
177159 self .read_obj_str : Optional [_AsyncReadObjectStream ] = None
178160 self ._is_stream_open : bool = False
179161
180- self ._read_id_to_writable_buffer_dict = {}
181- self ._read_id_to_download_ranges_id = {}
182- self ._download_ranges_id_to_pending_read_ids = {}
183-
184- async def open (self ) -> None :
162+ async def open (self , retry_policy : Optional [AsyncRetry ] = None ) -> None :
185163 """Opens the bidi-gRPC connection to read from the object.
186164
187165 This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to
@@ -193,26 +171,40 @@ async def open(self) -> None:
193171 if self ._is_stream_open :
194172 raise ValueError ("Underlying bidi-gRPC stream is already open" )
195173
196- if self .read_obj_str is None :
174+ if retry_policy is None :
175+ # Default policy: retry generic transient errors
176+ retry_policy = AsyncRetry (
177+ predicate = lambda e : isinstance (e , (exceptions .ServiceUnavailable , exceptions .DeadlineExceeded ))
178+ )
179+
180+ async def _do_open ():
197181 self .read_obj_str = _AsyncReadObjectStream (
198182 client = self .client ,
199183 bucket_name = self .bucket_name ,
200184 object_name = self .object_name ,
201185 generation_number = self .generation_number ,
202186 read_handle = self .read_handle ,
203187 )
204- await self .read_obj_str .open ()
205- self ._is_stream_open = True
206- if self .generation_number is None :
207- self .generation_number = self .read_obj_str .generation_number
208- self .read_handle = self .read_obj_str .read_handle
209- return
188+ await self .read_obj_str .open ()
189+
190+ if self .read_obj_str .generation_number :
191+ self .generation_number = self .read_obj_str .generation_number
192+ if self .read_obj_str .read_handle :
193+ self .read_handle = self .read_obj_str .read_handle
194+
195+ self ._is_stream_open = True
196+
197+ # Execute open with retry policy
198+ await retry_policy (_do_open )()
210199
211200 async def download_ranges (
212- self , read_ranges : List [Tuple [int , int , BytesIO ]], lock : asyncio .Lock = None
201+ self ,
202+ read_ranges : List [Tuple [int , int , BytesIO ]],
203+ lock : asyncio .Lock = None ,
204+ retry_policy : AsyncRetry = None
213205 ) -> None :
214206 """Downloads multiple byte ranges from the object into the buffers
215- provided by user.
207+ provided by user with automatic retries .
216208
217209 :type read_ranges: List[Tuple[int, int, "BytesIO"]]
218210 :param read_ranges: A list of tuples, where each tuple represents a
@@ -246,6 +238,8 @@ async def download_ranges(
246238
247239 ```
248240
241+ :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
242+ :param retry_policy: (Optional) The retry policy to use for the operation.
249243
250244 :raises ValueError: if the underlying bidi-GRPC stream is not open.
251245 :raises ValueError: if the length of read_ranges is more than 1000.
@@ -264,80 +258,101 @@ async def download_ranges(
264258 if lock is None :
265259 lock = asyncio .Lock ()
266260
267- _func_id = generate_random_56_bit_integer ()
268- read_ids_in_current_func = set ()
269- for i in range (0 , len (read_ranges ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ):
270- read_ranges_segment = read_ranges [
271- i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
272- ]
261+ if retry_policy is None :
262+ retry_policy = AsyncRetry (
263+ predicate = lambda e : isinstance (e , (exceptions .ServiceUnavailable , exceptions .DeadlineExceeded ))
264+ )
265+
266+ # Initialize Global State for Retry Strategy
267+ download_states = {}
268+ for read_range in read_ranges :
269+ read_id = generate_random_56_bit_integer ()
270+ download_states [read_id ] = _DownloadState (
271+ initial_offset = read_range [0 ],
272+ initial_length = read_range [1 ],
273+ user_buffer = read_range [2 ]
274+ )
275+
276+ initial_state = {
277+ "download_states" : download_states ,
278+ "read_handle" : self .read_handle ,
279+ "routing_token" : None
280+ }
281+
282+ # Track attempts to manage stream reuse
283+ is_first_attempt = True
284+
285+ def stream_opener (requests : List [_storage_v2 .ReadRange ], state : Dict [str , Any ]):
286+
287+ async def generator ():
288+ nonlocal is_first_attempt
289+
290+ async with lock :
291+ current_handle = state .get ("read_handle" )
292+ current_token = state .get ("routing_token" )
293+
294+ # We reopen if it's a redirect (token exists) OR if this is a retry
295+ # (not first attempt). This prevents trying to send data on a dead
296+ # stream from a previous failed attempt.
297+ should_reopen = (not is_first_attempt ) or (current_token is not None )
298+
299+ if should_reopen :
300+ # Close existing stream if any
301+ if self .read_obj_str :
302+ await self .read_obj_str .close ()
273303
274- read_ranges_for_bidi_req = []
275- for j , read_range in enumerate (read_ranges_segment ):
276- read_id = generate_random_56_bit_integer ()
277- read_ids_in_current_func .add (read_id )
278- self ._read_id_to_download_ranges_id [read_id ] = _func_id
279- self ._read_id_to_writable_buffer_dict [read_id ] = read_range [2 ]
280- bytes_requested = read_range [1 ]
281- read_ranges_for_bidi_req .append (
282- _storage_v2 .ReadRange (
283- read_offset = read_range [0 ],
284- read_length = bytes_requested ,
285- read_id = read_id ,
286- )
287- )
288- async with lock :
289- await self .read_obj_str .send (
290- _storage_v2 .BidiReadObjectRequest (
291- read_ranges = read_ranges_for_bidi_req
292- )
293- )
294- self ._download_ranges_id_to_pending_read_ids [
295- _func_id
296- ] = read_ids_in_current_func
297-
298- while len (self ._download_ranges_id_to_pending_read_ids [_func_id ]) > 0 :
299- async with lock :
300- response = await self .read_obj_str .recv ()
301-
302- if response is None :
303- raise Exception ("None response received, something went wrong." )
304-
305- for object_data_range in response .object_data_ranges :
306- if object_data_range .read_range is None :
307- raise Exception ("Invalid response, read_range is None" )
308-
309- checksummed_data = object_data_range .checksummed_data
310- data = checksummed_data .content
311- server_checksum = checksummed_data .crc32c
312-
313- client_crc32c = Checksum (data ).digest ()
314- client_checksum = int .from_bytes (client_crc32c , "big" )
315-
316- if server_checksum != client_checksum :
317- raise DataCorruption (
318- response ,
319- f"Checksum mismatch for read_id { object_data_range .read_range .read_id } . "
320- f"Server sent { server_checksum } , client calculated { client_checksum } ." ,
321- )
322-
323- read_id = object_data_range .read_range .read_id
324- buffer = self ._read_id_to_writable_buffer_dict [read_id ]
325- buffer .write (data )
326-
327- if object_data_range .range_end :
328- tmp_dn_ranges_id = self ._read_id_to_download_ranges_id [read_id ]
329- self ._download_ranges_id_to_pending_read_ids [
330- tmp_dn_ranges_id
331- ].remove (read_id )
332- del self ._read_id_to_download_ranges_id [read_id ]
304+ # Re-initialize stream
305+ self .read_obj_str = _AsyncReadObjectStream (
306+ client = self .client ,
307+ bucket_name = self .bucket_name ,
308+ object_name = self .object_name ,
309+ generation_number = self .generation_number ,
310+ read_handle = current_handle ,
311+ )
312+
313+ # Inject routing_token into metadata if present
314+ metadata = []
315+ if current_token :
316+ metadata .append (("x-goog-request-params" , f"routing_token={ current_token } " ))
317+
318+ await self .read_obj_str .open (metadata = metadata if metadata else None )
319+ self ._is_stream_open = True
320+
321+ # Mark first attempt as done; next time this runs it will be a retry
322+ is_first_attempt = False
323+
324+ # Send Requests
325+ for i in range (0 , len (requests ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ):
326+ batch = requests [i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ]
327+ await self .read_obj_str .send (
328+ _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
329+ )
330+
331+ while True :
332+ response = await self .read_obj_str .recv ()
333+ if response is None :
334+ break
335+ yield response
336+
337+ return generator ()
338+
339+ strategy = _ReadResumptionStrategy ()
340+ retry_manager = _BidiStreamRetryManager (strategy , stream_opener )
341+
342+ await retry_manager .execute (initial_state , retry_policy )
343+
344+ if initial_state .get ("read_handle" ):
345+ self .read_handle = initial_state ["read_handle" ]
333346
334347 async def close (self ):
335348 """
336349 Closes the underlying bidi-gRPC connection.
337350 """
338351 if not self ._is_stream_open :
339352 raise ValueError ("Underlying bidi-gRPC stream is not open" )
340- await self .read_obj_str .close ()
353+
354+ if self .read_obj_str :
355+ await self .read_obj_str .close ()
341356 self ._is_stream_open = False
342357
343358 @property
0 commit comments