Skip to content

Commit 3587822

Browse files
committed
integrate retry logic with the MRD
1 parent 1080bc1 commit 3587822

File tree

4 files changed

+419
-278
lines changed

4 files changed

+419
-278
lines changed

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 127 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,32 @@
1616
import asyncio
1717
import google_crc32c
1818
from google.api_core import exceptions
19-
from google_crc32c import Checksum
19+
from google.api_core.retry_async import AsyncRetry
2020

21-
from typing import List, Optional, Tuple
21+
from typing import List, Optional, Tuple, Any, Dict
2222

2323
from google.cloud.storage._experimental.asyncio.async_read_object_stream import (
2424
_AsyncReadObjectStream,
2525
)
2626
from google.cloud.storage._experimental.asyncio.async_grpc_client import (
2727
AsyncGrpcClient,
2828
)
29+
from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import (
30+
_BidiStreamRetryManager,
31+
)
32+
from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import (
33+
_ReadResumptionStrategy,
34+
_DownloadState,
35+
)
2936

3037
from io import BytesIO
3138
from google.cloud import _storage_v2
32-
from google.cloud.storage.exceptions import DataCorruption
3339
from google.cloud.storage._helpers import generate_random_56_bit_integer
3440

3541

3642
_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
3743

3844

39-
class Result:
40-
"""An instance of this class will be populated and retured for each
41-
`read_range` provided to ``download_ranges`` method.
42-
43-
"""
44-
45-
def __init__(self, bytes_requested: int):
46-
# only while instantiation, should not be edited later.
47-
# hence there's no setter, only getter is provided.
48-
self._bytes_requested: int = bytes_requested
49-
self._bytes_written: int = 0
50-
51-
@property
52-
def bytes_requested(self) -> int:
53-
return self._bytes_requested
54-
55-
@property
56-
def bytes_written(self) -> int:
57-
return self._bytes_written
58-
59-
@bytes_written.setter
60-
def bytes_written(self, value: int):
61-
self._bytes_written = value
62-
63-
def __repr__(self):
64-
return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}"
65-
66-
6745
class AsyncMultiRangeDownloader:
6846
"""Provides an interface for downloading multiple ranges of a GCS ``Object``
6947
concurrently.
@@ -104,6 +82,7 @@ async def create_mrd(
10482
object_name: str,
10583
generation_number: Optional[int] = None,
10684
read_handle: Optional[bytes] = None,
85+
retry_policy: Optional[AsyncRetry] = None,
10786
) -> AsyncMultiRangeDownloader:
10887
"""Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC
10988
object for reading.
@@ -125,11 +104,14 @@ async def create_mrd(
125104
:param read_handle: (Optional) An existing handle for reading the object.
126105
If provided, opening the bidi-gRPC connection will be faster.
127106
107+
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
108+
:param retry_policy: (Optional) The retry policy to use for the ``open`` operation.
109+
128110
:rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader`
129111
:returns: An initialized AsyncMultiRangeDownloader instance for reading.
130112
"""
131113
mrd = cls(client, bucket_name, object_name, generation_number, read_handle)
132-
await mrd.open()
114+
await mrd.open(retry_policy=retry_policy)
133115
return mrd
134116

135117
def __init__(
@@ -177,11 +159,7 @@ def __init__(
177159
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
178160
self._is_stream_open: bool = False
179161

180-
self._read_id_to_writable_buffer_dict = {}
181-
self._read_id_to_download_ranges_id = {}
182-
self._download_ranges_id_to_pending_read_ids = {}
183-
184-
async def open(self) -> None:
162+
async def open(self, retry_policy: Optional[AsyncRetry] = None) -> None:
185163
"""Opens the bidi-gRPC connection to read from the object.
186164
187165
This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to
@@ -193,26 +171,40 @@ async def open(self) -> None:
193171
if self._is_stream_open:
194172
raise ValueError("Underlying bidi-gRPC stream is already open")
195173

196-
if self.read_obj_str is None:
174+
if retry_policy is None:
175+
# Default policy: retry generic transient errors
176+
retry_policy = AsyncRetry(
177+
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
178+
)
179+
180+
async def _do_open():
197181
self.read_obj_str = _AsyncReadObjectStream(
198182
client=self.client,
199183
bucket_name=self.bucket_name,
200184
object_name=self.object_name,
201185
generation_number=self.generation_number,
202186
read_handle=self.read_handle,
203187
)
204-
await self.read_obj_str.open()
205-
self._is_stream_open = True
206-
if self.generation_number is None:
207-
self.generation_number = self.read_obj_str.generation_number
208-
self.read_handle = self.read_obj_str.read_handle
209-
return
188+
await self.read_obj_str.open()
189+
190+
if self.read_obj_str.generation_number:
191+
self.generation_number = self.read_obj_str.generation_number
192+
if self.read_obj_str.read_handle:
193+
self.read_handle = self.read_obj_str.read_handle
194+
195+
self._is_stream_open = True
196+
197+
# Execute open with retry policy
198+
await retry_policy(_do_open)()
210199

211200
async def download_ranges(
212-
self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None
201+
self,
202+
read_ranges: List[Tuple[int, int, BytesIO]],
203+
lock: asyncio.Lock = None,
204+
retry_policy: AsyncRetry = None
213205
) -> None:
214206
"""Downloads multiple byte ranges from the object into the buffers
215-
provided by user.
207+
provided by user with automatic retries.
216208
217209
:type read_ranges: List[Tuple[int, int, "BytesIO"]]
218210
:param read_ranges: A list of tuples, where each tuple represents a
@@ -246,6 +238,8 @@ async def download_ranges(
246238
247239
```
248240
241+
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
242+
:param retry_policy: (Optional) The retry policy to use for the operation.
249243
250244
:raises ValueError: if the underlying bidi-GRPC stream is not open.
251245
:raises ValueError: if the length of read_ranges is more than 1000.
@@ -264,80 +258,101 @@ async def download_ranges(
264258
if lock is None:
265259
lock = asyncio.Lock()
266260

267-
_func_id = generate_random_56_bit_integer()
268-
read_ids_in_current_func = set()
269-
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
270-
read_ranges_segment = read_ranges[
271-
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
272-
]
261+
if retry_policy is None:
262+
retry_policy = AsyncRetry(
263+
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
264+
)
265+
266+
# Initialize Global State for Retry Strategy
267+
download_states = {}
268+
for read_range in read_ranges:
269+
read_id = generate_random_56_bit_integer()
270+
download_states[read_id] = _DownloadState(
271+
initial_offset=read_range[0],
272+
initial_length=read_range[1],
273+
user_buffer=read_range[2]
274+
)
275+
276+
initial_state = {
277+
"download_states": download_states,
278+
"read_handle": self.read_handle,
279+
"routing_token": None
280+
}
281+
282+
# Track attempts to manage stream reuse
283+
is_first_attempt = True
284+
285+
def stream_opener(requests: List[_storage_v2.ReadRange], state: Dict[str, Any]):
286+
287+
async def generator():
288+
nonlocal is_first_attempt
289+
290+
async with lock:
291+
current_handle = state.get("read_handle")
292+
current_token = state.get("routing_token")
293+
294+
# We reopen if it's a redirect (token exists) OR if this is a retry
295+
# (not first attempt). This prevents trying to send data on a dead
296+
# stream from a previous failed attempt.
297+
should_reopen = (not is_first_attempt) or (current_token is not None)
298+
299+
if should_reopen:
300+
# Close existing stream if any
301+
if self.read_obj_str:
302+
await self.read_obj_str.close()
273303

274-
read_ranges_for_bidi_req = []
275-
for j, read_range in enumerate(read_ranges_segment):
276-
read_id = generate_random_56_bit_integer()
277-
read_ids_in_current_func.add(read_id)
278-
self._read_id_to_download_ranges_id[read_id] = _func_id
279-
self._read_id_to_writable_buffer_dict[read_id] = read_range[2]
280-
bytes_requested = read_range[1]
281-
read_ranges_for_bidi_req.append(
282-
_storage_v2.ReadRange(
283-
read_offset=read_range[0],
284-
read_length=bytes_requested,
285-
read_id=read_id,
286-
)
287-
)
288-
async with lock:
289-
await self.read_obj_str.send(
290-
_storage_v2.BidiReadObjectRequest(
291-
read_ranges=read_ranges_for_bidi_req
292-
)
293-
)
294-
self._download_ranges_id_to_pending_read_ids[
295-
_func_id
296-
] = read_ids_in_current_func
297-
298-
while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0:
299-
async with lock:
300-
response = await self.read_obj_str.recv()
301-
302-
if response is None:
303-
raise Exception("None response received, something went wrong.")
304-
305-
for object_data_range in response.object_data_ranges:
306-
if object_data_range.read_range is None:
307-
raise Exception("Invalid response, read_range is None")
308-
309-
checksummed_data = object_data_range.checksummed_data
310-
data = checksummed_data.content
311-
server_checksum = checksummed_data.crc32c
312-
313-
client_crc32c = Checksum(data).digest()
314-
client_checksum = int.from_bytes(client_crc32c, "big")
315-
316-
if server_checksum != client_checksum:
317-
raise DataCorruption(
318-
response,
319-
f"Checksum mismatch for read_id {object_data_range.read_range.read_id}. "
320-
f"Server sent {server_checksum}, client calculated {client_checksum}.",
321-
)
322-
323-
read_id = object_data_range.read_range.read_id
324-
buffer = self._read_id_to_writable_buffer_dict[read_id]
325-
buffer.write(data)
326-
327-
if object_data_range.range_end:
328-
tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id]
329-
self._download_ranges_id_to_pending_read_ids[
330-
tmp_dn_ranges_id
331-
].remove(read_id)
332-
del self._read_id_to_download_ranges_id[read_id]
304+
# Re-initialize stream
305+
self.read_obj_str = _AsyncReadObjectStream(
306+
client=self.client,
307+
bucket_name=self.bucket_name,
308+
object_name=self.object_name,
309+
generation_number=self.generation_number,
310+
read_handle=current_handle,
311+
)
312+
313+
# Inject routing_token into metadata if present
314+
metadata = []
315+
if current_token:
316+
metadata.append(("x-goog-request-params", f"routing_token={current_token}"))
317+
318+
await self.read_obj_str.open(metadata=metadata if metadata else None)
319+
self._is_stream_open = True
320+
321+
# Mark first attempt as done; next time this runs it will be a retry
322+
is_first_attempt = False
323+
324+
# Send Requests
325+
for i in range(0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
326+
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
327+
await self.read_obj_str.send(
328+
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
329+
)
330+
331+
while True:
332+
response = await self.read_obj_str.recv()
333+
if response is None:
334+
break
335+
yield response
336+
337+
return generator()
338+
339+
strategy = _ReadResumptionStrategy()
340+
retry_manager = _BidiStreamRetryManager(strategy, stream_opener)
341+
342+
await retry_manager.execute(initial_state, retry_policy)
343+
344+
if initial_state.get("read_handle"):
345+
self.read_handle = initial_state["read_handle"]
333346

334347
async def close(self):
335348
"""
336349
Closes the underlying bidi-gRPC connection.
337350
"""
338351
if not self._is_stream_open:
339352
raise ValueError("Underlying bidi-gRPC stream is not open")
340-
await self.read_obj_str.close()
353+
354+
if self.read_obj_str:
355+
await self.read_obj_str.close()
341356
self._is_stream_open = False
342357

343358
@property

0 commit comments

Comments
 (0)