diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index 16dce5025..e9b9414ee 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -15,10 +15,14 @@ from __future__ import annotations import asyncio import google_crc32c +import grpc from google.api_core import exceptions -from google_crc32c import Checksum +from google.api_core.retry_async import AsyncRetry +from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError +from google.rpc import status_pb2 +from google.protobuf.any_pb2 import Any as AnyProto -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any, Dict from google.cloud.storage._experimental.asyncio.async_read_object_stream import ( _AsyncReadObjectStream, @@ -26,43 +30,71 @@ from google.cloud.storage._experimental.asyncio.async_grpc_client import ( AsyncGrpcClient, ) +from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import ( + _BidiStreamRetryManager, +) +from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import ( + _ReadResumptionStrategy, + _DownloadState, +) from io import BytesIO from google.cloud import _storage_v2 -from google.cloud.storage.exceptions import DataCorruption from google.cloud.storage._helpers import generate_random_56_bit_integer _MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 - - -class Result: - """An instance of this class will be populated and retured for each - `read_range` provided to ``download_ranges`` method. - - """ - - def __init__(self, bytes_requested: int): - # only while instantiation, should not be edited later. - # hence there's no setter, only getter is provided. - self._bytes_requested: int = bytes_requested - self._bytes_written: int = 0 - - @property - def bytes_requested(self) -> int: - return self._bytes_requested - - @property - def bytes_written(self) -> int: - return self._bytes_written - - @bytes_written.setter - def bytes_written(self, value: int): - self._bytes_written = value - - def __repr__(self): - return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}" - +_BIDI_READ_REDIRECTED_TYPE_URL = "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" + + +def _is_read_retryable(exc): + """Predicate to determine if a read operation should be retried.""" + print(f"--- Checking if retryable: {type(exc)}: {exc}") + if isinstance(exc, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, exceptions.TooManyRequests)): + return True + + grpc_error = None + if isinstance(exc, exceptions.GoogleAPICallError) and exc.errors: + if isinstance(exc.errors[0], grpc.aio.AioRpcError): + grpc_error = exc.errors[0] + + if grpc_error: + print(f"--- Wrapped grpc.aio.AioRpcError code: {grpc_error.code()}") + if grpc_error.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.INTERNAL, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.RESOURCE_EXHAUSTED, + ): + return True + if grpc_error.code() == grpc.StatusCode.ABORTED: + trailers = grpc_error.trailing_metadata() + if not trailers: + print("--- No trailers") + return False + + status_details_bin = None + # *** CORRECTED TRAILER ACCESS *** + for key, value in trailers: + if key == 'grpc-status-details-bin': + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL: + print("--- Found BidiReadObjectRedirectedError, is retryable") + return True + print("--- BidiReadObjectRedirectedError type URL not found in details") + except Exception as e: + print(f"--- Error parsing status_details_bin: {e}") + return False + else: + print("--- No grpc-status-details-bin in trailers") + return False class AsyncMultiRangeDownloader: """Provides an interface for downloading multiple ranges of a GCS ``Object`` @@ -104,6 +136,7 @@ async def create_mrd( object_name: str, generation_number: Optional[int] = None, read_handle: Optional[bytes] = None, + retry_policy: Optional[AsyncRetry] = None, ) -> AsyncMultiRangeDownloader: """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC object for reading. @@ -125,11 +158,14 @@ async def create_mrd( :param read_handle: (Optional) An existing handle for reading the object. If provided, opening the bidi-gRPC connection will be faster. + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the ``open`` operation. + :rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader` :returns: An initialized AsyncMultiRangeDownloader instance for reading. """ mrd = cls(client, bucket_name, object_name, generation_number, read_handle) - await mrd.open() + await mrd.open(retry_policy=retry_policy) return mrd def __init__( @@ -176,24 +212,71 @@ def __init__( self.read_handle = read_handle self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False + self._routing_token: Optional[str] = None + + async def _on_open_error(self, exc): + """Extracts routing token and read handle on redirect error during open.""" + print(f"--- _on_open_error called with {type(exc)}: {exc}") + grpc_error = None + if isinstance(exc, exceptions.GoogleAPICallError) and exc.errors: + if isinstance(exc.errors[0], grpc.aio.AioRpcError): + grpc_error = exc.errors[0] + + if grpc_error and grpc_error.code() == grpc.StatusCode.ABORTED: + trailers = grpc_error.trailing_metadata() + if not trailers: return + + status_details_bin = None + # *** CORRECTED TRAILER ACCESS *** + for key, value in trailers: + if key == 'grpc-status-details-bin': + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL: + redirect_proto = BidiReadObjectRedirectedError() + detail.Unpack(redirect_proto) + if redirect_proto.routing_token: + self._routing_token = redirect_proto.routing_token + if redirect_proto.read_handle and redirect_proto.read_handle.handle: + self.read_handle = redirect_proto.read_handle.handle + print(f"--- BidiReadObjectRedirectedError caught in open, new token: {self._routing_token}, handle: {self.read_handle}") + break + except Exception as e: + print(f"--- Error unpacking redirect in _on_open_error: {e}") + + if self.read_obj_str and self.read_obj_str._is_open: + try: + await self.read_obj_str.close() + except Exception: + pass + self._is_stream_open = False - self._read_id_to_writable_buffer_dict = {} - self._read_id_to_download_ranges_id = {} - self._download_ranges_id_to_pending_read_ids = {} - - async def open(self) -> None: - """Opens the bidi-gRPC connection to read from the object. - - This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to - for downloading ranges of data from GCS ``Object``. - - "Opening" constitutes fetching object metadata such as generation number - and read handle and sets them as attributes if not already set. - """ + async def open(self, retry_policy: Optional[AsyncRetry] = None) -> None: + """Opens the bidi-gRPC connection to read from the object.""" if self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is already open") - if self.read_obj_str is None: + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_read_retryable, on_error=self._on_open_error) + else: + original_on_error = retry_policy._on_error + async def combined_on_error(exc): + await self._on_open_error(exc) + if original_on_error: + await original_on_error(exc) + retry_policy = retry_policy.with_predicate(_is_read_retryable).with_on_error(combined_on_error) + + async def _do_open(): + print("--- Attempting _do_open") + if self._is_stream_open: + self._is_stream_open = False + self.read_obj_str = _AsyncReadObjectStream( client=self.client, bucket_name=self.bucket_name, @@ -201,18 +284,33 @@ async def open(self) -> None: generation_number=self.generation_number, read_handle=self.read_handle, ) - await self.read_obj_str.open() - self._is_stream_open = True - if self.generation_number is None: - self.generation_number = self.read_obj_str.generation_number - self.read_handle = self.read_obj_str.read_handle - return + + metadata = [] + if self._routing_token: + metadata.append(("x-goog-request-params", f"routing_token={self._routing_token}")) + print(f"--- Using routing_token for open: {self._routing_token}") + self._routing_token = None + + await self.read_obj_str.open(metadata=metadata if metadata else None) + + if self.read_obj_str.generation_number: + self.generation_number = self.read_obj_str.generation_number + if self.read_obj_str.read_handle: + self.read_handle = self.read_obj_str.read_handle + + self._is_stream_open = True + print("--- Stream opened successfully") + + await retry_policy(_do_open)() async def download_ranges( - self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None + self, + read_ranges: List[Tuple[int, int, BytesIO]], + lock: asyncio.Lock = None, + retry_policy: AsyncRetry = None ) -> None: """Downloads multiple byte ranges from the object into the buffers - provided by user. + provided by user with automatic retries. :type read_ranges: List[Tuple[int, int, "BytesIO"]] :param read_ranges: A list of tuples, where each tuple represents a @@ -246,6 +344,8 @@ async def download_ranges( ``` + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the operation. :raises ValueError: if the underlying bidi-GRPC stream is not open. :raises ValueError: if the length of read_ranges is more than 1000. @@ -264,72 +364,89 @@ async def download_ranges( if lock is None: lock = asyncio.Lock() - _func_id = generate_random_56_bit_integer() - read_ids_in_current_func = set() - for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): - read_ranges_segment = read_ranges[ - i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST - ] + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_read_retryable) + + # Initialize Global State for Retry Strategy + download_states = {} + for read_range in read_ranges: + read_id = generate_random_56_bit_integer() + download_states[read_id] = _DownloadState( + initial_offset=read_range[0], + initial_length=read_range[1], + user_buffer=read_range[2] + ) + + initial_state = { + "download_states": download_states, + "read_handle": self.read_handle, + "routing_token": None + } + + # Track attempts to manage stream reuse + is_first_attempt = True + + def stream_opener(requests: List[_storage_v2.ReadRange], state: Dict[str, Any]): + + async def generator(): + nonlocal is_first_attempt + + async with lock: + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + # We reopen if it's a redirect (token exists) OR if this is a retry + # (not first attempt). This prevents trying to send data on a dead + # stream from a previous failed attempt. + should_reopen = (not is_first_attempt) or (current_token is not None) - read_ranges_for_bidi_req = [] - for j, read_range in enumerate(read_ranges_segment): - read_id = generate_random_56_bit_integer() - read_ids_in_current_func.add(read_id) - self._read_id_to_download_ranges_id[read_id] = _func_id - self._read_id_to_writable_buffer_dict[read_id] = read_range[2] - bytes_requested = read_range[1] - read_ranges_for_bidi_req.append( - _storage_v2.ReadRange( - read_offset=read_range[0], - read_length=bytes_requested, - read_id=read_id, - ) - ) - async with lock: - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest( - read_ranges=read_ranges_for_bidi_req - ) - ) - self._download_ranges_id_to_pending_read_ids[ - _func_id - ] = read_ids_in_current_func - - while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0: - async with lock: - response = await self.read_obj_str.recv() - - if response is None: - raise Exception("None response received, something went wrong.") - - for object_data_range in response.object_data_ranges: - if object_data_range.read_range is None: - raise Exception("Invalid response, read_range is None") - - checksummed_data = object_data_range.checksummed_data - data = checksummed_data.content - server_checksum = checksummed_data.crc32c - - client_crc32c = Checksum(data).digest() - client_checksum = int.from_bytes(client_crc32c, "big") - - if server_checksum != client_checksum: - raise DataCorruption( - response, - f"Checksum mismatch for read_id {object_data_range.read_range.read_id}. " - f"Server sent {server_checksum}, client calculated {client_checksum}.", - ) - - read_id = object_data_range.read_range.read_id - buffer = self._read_id_to_writable_buffer_dict[read_id] - buffer.write(data) - - if object_data_range.range_end: - tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id] - self._download_ranges_id_to_pending_read_ids[ - tmp_dn_ranges_id - ].remove(read_id) - del self._read_id_to_download_ranges_id[read_id] + if should_reopen: + # Close existing stream if any + if self.read_obj_str: + await self.read_obj_str.close() + + # Re-initialize stream + self.read_obj_str = _AsyncReadObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation_number, + read_handle=current_handle, + ) + + # Inject routing_token into metadata if present + metadata = [] + if current_token: + metadata.append(("x-goog-request-params", f"routing_token={current_token}")) + + await self.read_obj_str.open(metadata=metadata if metadata else None) + self._is_stream_open = True + + # Mark first attempt as done; next time this runs it will be a retry + is_first_attempt = False + + # Send Requests + for i in range(0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): + batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] + await self.read_obj_str.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + + while True: + response = await self.read_obj_str.recv() + if response is None: + break + yield response + + return generator() + + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager(strategy, stream_opener) + + await retry_manager.execute(initial_state, retry_policy) + + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] async def close(self): """ @@ -337,7 +454,9 @@ async def close(self): """ if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - await self.read_obj_str.close() + + if self.read_obj_str: + await self.read_obj_str.close() self._is_stream_open = False @property diff --git a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py index ddaaf9a54..2055d8b26 100644 --- a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py @@ -22,7 +22,7 @@ """ -from typing import Optional +from typing import List, Optional, Tuple from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( @@ -93,23 +93,54 @@ def __init__( self.socket_like_rpc: Optional[AsyncBidiRpc] = None self._is_stream_open: bool = False - async def open(self) -> None: + async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: """Opens the bidi-gRPC connection to read from the object. This method sends an initial request to start the stream and receives the first response containing metadata and a read handle. + + Args: + metadata (Optional[List[Tuple[str, str]]]): Additional metadata + to send with the initial stream request, e.g., for routing tokens. """ if self._is_stream_open: raise ValueError("Stream is already open") + + read_object_spec = _storage_v2.BidiReadObjectSpec( + bucket=self._full_bucket_name, + object=self.object_name, + generation=self.generation_number if self.generation_number else None, + read_handle=self.read_handle if self.read_handle else None, + ) + initial_request = _storage_v2.BidiReadObjectRequest( + read_object_spec=read_object_spec + ) + + # Build the x-goog-request-params header + request_params = [f"bucket={self._full_bucket_name}"] + other_metadata = [] + if metadata: + for key, value in metadata: + if key == "x-goog-request-params": + request_params.append(value) + else: + other_metadata.append((key, value)) + + current_metadata = other_metadata + current_metadata.append(("x-goog-request-params", ",".join(request_params))) + self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_read_req, metadata=self.metadata + self.rpc, initial_request=initial_request, metadata=current_metadata ) - await self.socket_like_rpc.open() # this is actually 1 send + await self.socket_like_rpc.open() response = await self.socket_like_rpc.recv() - if self.generation_number is None: - self.generation_number = response.metadata.generation - self.read_handle = response.read_handle + if response and response.metadata: + if self.generation_number is None and response.metadata.generation: + self.generation_number = response.metadata.generation + + if response and response.read_handle and response.read_handle.handle: + self.read_handle = response.read_handle.handle self._is_stream_open = True diff --git a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py index e32125069..ff193f109 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc from typing import Any, Iterable diff --git a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py new file mode 100644 index 000000000..68abd1b21 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py @@ -0,0 +1,63 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, AsyncIterator, Callable + +from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( + _BaseResumptionStrategy, +) + + +class _BidiStreamRetryManager: + """Manages the generic retry loop for a bidi streaming operation.""" + + def __init__( + self, + strategy: _BaseResumptionStrategy, + stream_opener: Callable[..., AsyncIterator[Any]], + ): + """Initializes the retry manager. + Args: + strategy: The strategy for managing the state of a specific + bidi operation (e.g., reads or writes). + stream_opener: An async callable that opens a new gRPC stream. + """ + self._strategy = strategy + self._stream_opener = stream_opener + + async def execute(self, initial_state: Any, retry_policy): + """ + Executes the bidi operation with the configured retry policy. + Args: + initial_state: An object containing all state for the operation. + retry_policy: The `google.api_core.retry.AsyncRetry` object to + govern the retry behavior for this specific operation. + """ + state = initial_state + + async def attempt(): + requests = self._strategy.generate_requests(state) + stream = self._stream_opener(requests, state) + try: + async for response in stream: + self._strategy.update_state_from_response(response, state) + return + except Exception as e: + if retry_policy._predicate(e): + await self._strategy.recover_state_on_failure(e, state) + raise e + + wrapped_attempt = retry_policy(attempt) + + await wrapped_attempt() diff --git a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py index d5d080358..550d96368 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py @@ -1,5 +1,20 @@ -from typing import Any, List, IO +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, IO + +from google_crc32c import Checksum from google.cloud import _storage_v2 as storage_v2 from google.cloud.storage.exceptions import DataCorruption from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( @@ -25,7 +40,7 @@ def __init__( class _ReadResumptionStrategy(_BaseResumptionStrategy): """The concrete resumption strategy for bidi reads.""" - def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: + def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]: """Generates new ReadRange requests for all incomplete downloads. :type state: dict @@ -33,10 +48,17 @@ def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: _DownloadState object. """ pending_requests = [] - for read_id, read_state in state.items(): + download_states: Dict[int, _DownloadState] = state["download_states"] + + for read_id, read_state in download_states.items(): if not read_state.is_complete: new_offset = read_state.initial_offset + read_state.bytes_written - new_length = read_state.initial_length - read_state.bytes_written + + # Calculate remaining length. If initial_length is 0 (read to end), + # it stays 0. Otherwise, subtract bytes_written. + new_length = 0 + if read_state.initial_length > 0: + new_length = read_state.initial_length - read_state.bytes_written new_request = storage_v2.ReadRange( read_offset=new_offset, @@ -47,19 +69,52 @@ def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: return pending_requests def update_state_from_response( - self, response: storage_v2.BidiReadObjectResponse, state: dict + self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any] ) -> None: """Processes a server response, performs integrity checks, and updates state.""" + + # Capture read_handle if provided. + if response.read_handle and response.read_handle.handle: + state["read_handle"] = response.read_handle.handle + + download_states = state["download_states"] + for object_data_range in response.object_data_ranges: + # Ignore empty ranges or ranges for IDs not in our state + # (e.g., from a previously cancelled request on the same stream). + if not object_data_range.read_range: + continue + read_id = object_data_range.read_range.read_id - read_state = state[read_id] + if read_id not in download_states: + continue + + read_state = download_states[read_id] # Offset Verification chunk_offset = object_data_range.read_range.read_offset if chunk_offset != read_state.next_expected_offset: - raise DataCorruption(response, f"Offset mismatch for read_id {read_id}") + raise DataCorruption( + response, + f"Offset mismatch for read_id {read_id}. " + f"Expected {read_state.next_expected_offset}, got {chunk_offset}" + ) + # Checksum Verification + # We must validate data before updating state or writing to buffer. data = object_data_range.checksummed_data.content + server_checksum = object_data_range.checksummed_data.crc32c + + if server_checksum is not None: + client_checksum = int.from_bytes(Checksum(data).digest(), "big") + if server_checksum != client_checksum: + raise DataCorruption( + response, + f"Checksum mismatch for read_id {read_id}. " + f"Server sent {server_checksum}, client calculated {client_checksum}." + ) + + # Update State & Write Data chunk_size = len(data) read_state.bytes_written += chunk_size read_state.next_expected_offset += chunk_size @@ -73,7 +128,9 @@ def update_state_from_response( and read_state.bytes_written != read_state.initial_length ): raise DataCorruption( - response, f"Byte count mismatch for read_id {read_id}" + response, + f"Byte count mismatch for read_id {read_id}. " + f"Expected {read_state.initial_length}, got {read_state.bytes_written}" ) async def recover_state_on_failure(self, error: Exception, state: Any) -> None: @@ -83,3 +140,6 @@ async def recover_state_on_failure(self, error: Exception, state: Any) -> None: cause = getattr(error, "cause", error) if isinstance(cause, BidiReadObjectRedirectedError): state["routing_token"] = cause.routing_token + if cause.read_handle and cause.read_handle.handle: + state["read_handle"] = cause.read_handle.handle + print(f"Recover state: Updated read_handle from redirect: {state['read_handle']}") diff --git a/run_bidi_reads_integration_test.py b/run_bidi_reads_integration_test.py new file mode 100644 index 000000000..8bd2d68da --- /dev/null +++ b/run_bidi_reads_integration_test.py @@ -0,0 +1,223 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import hashlib +import logging +import os +import random +import subprocess +import time +import requests +import grpc +from io import BytesIO + +# Configure Logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger("bidi_integration_test") + +# --- Configuration --- +TESTBENCH_PORT = 9000 +TESTBENCH_HOST = f"localhost:{TESTBENCH_PORT}" +BUCKET_NAME = f"bidi-retry-bucket-{random.randint(1000, 9999)}" +OBJECT_NAME = "test-blob-10mb" +OBJECT_SIZE = 10 * 1024 * 1024 # 10 MiB + +# --- Imports from SDK --- +from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import AsyncMultiRangeDownloader +from google.cloud.storage._experimental.asyncio.async_read_object_stream import _AsyncReadObjectStream + +# --- Infrastructure Management --- + +def start_testbench(): + """Starts the storage-testbench using Docker.""" + logger.info("Starting Storage Testbench container...") + try: + # Check if already running + requests.get(f"http://{TESTBENCH_HOST}/") + logger.info("Testbench is already running.") + return None + except requests.ConnectionError: + pass + + cmd = [ + "docker", "run", "-d", "--rm", + "-p", f"{TESTBENCH_PORT}:{TESTBENCH_PORT}", + "gcr.io/google.com/cloudsdktool/cloud-sdk:latest", + "gcloud", "beta", "emulators", "storage", "start", + f"--host-port=0.0.0.0:{TESTBENCH_PORT}" + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # Wait for it to be ready + for _ in range(20): + try: + requests.get(f"http://{TESTBENCH_HOST}/") + logger.info("Testbench started successfully.") + return process + except requests.ConnectionError: + time.sleep(1) + + raise RuntimeError("Timed out waiting for Testbench to start.") + +def stop_testbench(process): + if process: + logger.info("Stopping Testbench container...") + subprocess.run(["docker", "stop", process.args[2]]) # Stop container ID (not robust, assumes simple run) + # Better: Since we used --rm, killing the python process might not kill docker immediately + # without capturing container ID. + # For simplicity in this script, we assume the user might manually clean up if this fails, + # or we just rely on standard docker commands. + # Actually, let's just kill the container by image name or port if needed later. + pass + +# --- Test Data Setup --- + +def setup_resources(): + """Creates bucket and object via HTTP.""" + logger.info(f"Creating resources on {TESTBENCH_HOST}...") + + # 1. Create Bucket + resp = requests.post( + f"http://{TESTBENCH_HOST}/storage/v1/b?project=test-project", + json={"name": BUCKET_NAME} + ) + if resp.status_code not in (200, 409): + raise RuntimeError(f"Bucket creation failed: {resp.text}") + + # 2. Upload Object + data = os.urandom(OBJECT_SIZE) + resp = requests.post( + f"http://{TESTBENCH_HOST}/upload/storage/v1/b/{BUCKET_NAME}/o?uploadType=media&name={OBJECT_NAME}", + data=data, + headers={"Content-Type": "application/octet-stream"} + ) + if resp.status_code != 200: + raise RuntimeError(f"Object upload failed: {resp.text}") + + return data + +# --- Fault Injection Logic --- + +def inject_failure_instruction(test_case): + """ + Monkeypatches _AsyncReadObjectStream.open to inject x-goog-testbench-instructions. + + Supported test_cases: + - 'broken-stream': Aborts stream mid-way. + - 'stall-always': Stalls immediately (timeout simulation). + - 'transient-error': Returns an error status code. + """ + real_open = _AsyncReadObjectStream.open + attempt_counter = 0 + + async def monkeypatched_open(self, metadata=None): + nonlocal attempt_counter + attempt_counter += 1 + + if metadata is None: + metadata = [] + else: + metadata = list(metadata) + + # Inject fault only on the first attempt + if attempt_counter == 1: + instruction = "" + if test_case == 'broken-stream': + instruction = "return-broken-stream" + elif test_case == 'transient-error': + instruction = "return-503-after-256K" # Simulate Service Unavailable later + + if instruction: + logger.info(f">>> INJECTING FAULT: '{instruction}' <<<") + metadata.append(("x-goog-testbench-instructions", instruction)) + else: + logger.info(f">>> Attempt {attempt_counter}: Clean retry <<<") + + await real_open(self, metadata=metadata) + + _AsyncReadObjectStream.open = monkeypatched_open + return real_open + +# --- Main Test Runner --- + +async def run_tests(): + # 1. Start Infrastructure + tb_process = start_testbench() + + try: + # 2. Setup Data + original_data = setup_resources() + + # 3. Setup Client + channel = grpc.aio.insecure_channel(TESTBENCH_HOST) + client = AsyncGrpcClient(channel=channel) + + # Test Scenarios + scenarios = ['broken-stream', 'transient-error'] + + for scenario in scenarios: + logger.info(f"\n--- Running Scenario: {scenario} ---") + + # Reset MRD state + mrd = await AsyncMultiRangeDownloader.create_mrd( + client=client.grpc_client, + bucket_name=BUCKET_NAME, + object_name=OBJECT_NAME + ) + + # Apply Fault Injection + original_open_method = inject_failure_instruction(scenario) + + # Buffers + b1 = BytesIO() + b2 = BytesIO() + + # Split ranges + mid = OBJECT_SIZE // 2 + ranges = [(0, mid, b1), (mid, OBJECT_SIZE - mid, b2)] + + try: + await mrd.download_ranges(ranges) + logger.info(f"Scenario {scenario}: Download call returned successfully.") + + # Verify Content + downloaded = b1.getvalue() + b2.getvalue() + if downloaded == original_data: + logger.info(f"Scenario {scenario}: PASSED - Data integrity verified.") + else: + logger.error(f"Scenario {scenario}: FAILED - Data mismatch.") + + except Exception as e: + logger.error(f"Scenario {scenario}: FAILED with exception: {e}") + finally: + # Cleanup and Restore + _AsyncReadObjectStream.open = original_open_method + await mrd.close() + + finally: + # Stop Infrastructure (if we started it) + # Note: In a real script, we'd be more rigorous about finding the PID/Container ID + if tb_process: + logger.info("Killing Testbench process...") + tb_process.kill() + +if __name__ == "__main__": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(run_tests()) + except KeyboardInterrupt: + pass diff --git a/test_bidi_reads.py b/test_bidi_reads.py new file mode 100644 index 000000000..b094c6be7 --- /dev/null +++ b/test_bidi_reads.py @@ -0,0 +1,57 @@ +import asyncio +from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) +from google.cloud.storage._experimental.asyncio.async_grpc_client import ( + AsyncGrpcClient, +) +from io import BytesIO +import os +import time +import uuid + + +async def write_appendable_object_and_read_using_mrd(): + + client = AsyncGrpcClient().grpc_client + bucket_name = "chandrasiri-rs" + object_name = f"11Dec.100.3" + # data_to_append = os.urandom(10 * 1024 * 1024 + 1) # 10 MiB + 1 of random data + + # # 1. Write to an appendable object + # writer = AsyncAppendableObjectWriter(client, bucket_name, object_name) + # await writer.open() + # print(f"Opened writer for object: {object_name}, generation: {writer.generation}") + + # start_write_time = time.monotonic_ns() + # await writer.append(data_to_append) + # end_write_time = time.monotonic_ns() + # print( + # f"Appended {len(data_to_append)} bytes in " + # f"{(end_write_time - start_write_time) / 1_000_000:.2f} ms" + # ) + + # await writer.close(finalize_on_close=False) + + # 2. Read the object using AsyncMultiRangeDownloader + mrd = AsyncMultiRangeDownloader(client, bucket_name, object_name) + await mrd.open() + print(f"Opened downloader for object: {object_name}") + + # Define a single range to download the entire object + output_buffer = BytesIO() + download_ranges = [(0, 100*1000*1000, output_buffer)] + + await mrd.download_ranges(download_ranges) + for _, buffer in mrd._read_id_to_writable_buffer_dict.items(): + print("*" * 80) + print(buffer.getbuffer().nbytes) + print("*" * 80) + await mrd.close() + + +if __name__ == "__main__": + asyncio.run(write_appendable_object_and_read_using_mrd()) diff --git a/test_retry.py b/test_retry.py new file mode 100644 index 000000000..cfef6cd50 --- /dev/null +++ b/test_retry.py @@ -0,0 +1,95 @@ +# test_retry.py (Minimal Diagnostic Version) + +import asyncio +import docker +import time +import uuid + +from google.api_core import exceptions +from google.cloud import _storage_v2 as storage_v2 +from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient + +# --- Configuration --- +TESTBENCH_IMAGE = "gcr.io/cloud-devrel-public-resources/storage-testbench:latest" +PROJECT_NUMBER = "30215529953" + +async def main(): + docker_client = docker.from_env() + container = None + bucket_name = f"minimal-test-bucket-{uuid.uuid4().hex[:8]}" + object_name = "minimal-object" + + print("--- Minimal Write/Read Integration Test ---") + + try: + # 1. Start Testbench + print("Starting storage-testbench container...") + container = docker_client.containers.run( + TESTBENCH_IMAGE, detach=True, ports={"9000/tcp": 9000} + ) + time.sleep(3) + print(f"Testbench container {container.short_id} is running.") + + # 2. Create Client + client_options = {"api_endpoint": "localhost:9000"} + grpc_client = AsyncGrpcClient(client_options=client_options) + gapic_client = grpc_client._grpc_client + + # 3. Create Bucket + print(f"Creating test bucket gs://{bucket_name}...") + bucket_resource = storage_v2.Bucket(project=f"projects/{PROJECT_NUMBER}") + create_bucket_request = storage_v2.CreateBucketRequest( + parent="projects/_", bucket_id=bucket_name, bucket=bucket_resource + ) + await gapic_client.create_bucket(request=create_bucket_request) + print("Bucket created successfully.") + + # 4. Write Object + print(f"Creating test object gs://{bucket_name}/{object_name}...") + write_spec = storage_v2.WriteObjectSpec( + resource=storage_v2.Object(bucket=f"projects/_/buckets/{bucket_name}", name=object_name) + ) + + async def write_request_generator(): + yield storage_v2.WriteObjectRequest(write_object_spec=write_spec) + yield storage_v2.WriteObjectRequest( + checksummed_data={"content": b"test data"}, + finish_write=True + ) + + # CRITICAL: Capture and inspect the response from the write operation. + write_response = await gapic_client.write_object(requests=write_request_generator()) + print(f"Write operation completed. Response from server: {write_response}") + + # The `write_object` RPC only returns a resource on the *final* message of a stream. + # If this is not present, the object was not finalized correctly. + if not write_response.resource: + print("\n!!! CRITICAL FAILURE: The write response did not contain a finalized resource. The object may not have been created correctly. !!!") + raise ValueError("Object creation failed silently on the server.") + + print("Test object appears to be finalized successfully.") + + # 5. Attempt to Read the Object Metadata + print("\nAttempting to read the object's metadata back immediately...") + get_object_request = storage_v2.GetObjectRequest( + bucket=f"projects/_/buckets/{bucket_name}", + object=object_name, + ) + read_object = await gapic_client.get_object(request=get_object_request) + print("--- SUCCESS: Object read back successfully. ---") + print(f"Read object metadata: {read_object}") + + except Exception as e: + import traceback + print("\n!!! TEST FAILED. The original error is below: !!!") + traceback.print_exc() + finally: + # 6. Cleanup + if container: + print("Stopping and removing testbench container...") + container.stop() + container.remove() + print("Cleanup complete.") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py new file mode 100644 index 000000000..56737f8a5 --- /dev/null +++ b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest +from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry + +from google.cloud.storage._experimental.asyncio.retry import ( + bidi_stream_retry_manager as manager, +) +from google.cloud.storage._experimental.asyncio.retry import base_strategy + + +def _is_retriable(exc): + return isinstance(exc, exceptions.ServiceUnavailable) + + +DEFAULT_TEST_RETRY = AsyncRetry(predicate=_is_retriable, deadline=1) + + +class TestBidiStreamRetryManager: + @pytest.mark.asyncio + async def test_execute_success_on_first_try(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_stream_opener(*args, **kwargs): + yield "response_1" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, stream_opener=mock_stream_opener + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_1", {} + ) + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_success_on_empty_stream(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_stream_opener(*args, **kwargs): + if False: + yield + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, stream_opener=mock_stream_opener + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_not_called() + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_retries_on_initial_failure_and_succeeds(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + + async def mock_stream_opener(*args, **kwargs): + nonlocal attempt_count + attempt_count += 1 + if attempt_count == 1: + raise exceptions.ServiceUnavailable("Service is down") + else: + yield "response_2" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, stream_opener=mock_stream_opener + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_2", {} + ) + + @pytest.mark.asyncio + async def test_execute_retries_and_succeeds_mid_stream(self): + """Test retry logic for a stream that fails after yielding some data.""" + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + # Use a list to simulate stream content for each attempt + stream_content = [ + ["response_1", exceptions.ServiceUnavailable("Service is down")], + ["response_2"], + ] + + async def mock_stream_opener(*args, **kwargs): + nonlocal attempt_count + content = stream_content[attempt_count] + attempt_count += 1 + for item in content: + if isinstance(item, Exception): + raise item + else: + yield item + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, stream_opener=mock_stream_opener + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as mock_sleep: + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + mock_sleep.assert_called_once() + + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + assert mock_strategy.update_state_from_response.call_count == 2 + mock_strategy.update_state_from_response.assert_has_calls( + [ + mock.call("response_1", {}), + mock.call("response_2", {}), + ] + ) + + @pytest.mark.asyncio + async def test_execute_fails_immediately_on_non_retriable_error(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_stream_opener(*args, **kwargs): + if False: + yield + raise exceptions.PermissionDenied("Auth error") + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, stream_opener=mock_stream_opener + ) + with pytest.raises(exceptions.PermissionDenied): + await retry_manager.execute( + initial_state={}, retry_policy=DEFAULT_TEST_RETRY + ) + + mock_strategy.recover_state_on_failure.assert_not_called() diff --git a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py index e6b343f86..1b2649527 100644 --- a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import io import unittest -import pytest +from google_crc32c import Checksum from google.cloud.storage.exceptions import DataCorruption from google.api_core import exceptions @@ -45,14 +46,63 @@ def test_initialization(self): class TestReadResumptionStrategy(unittest.TestCase): + + def setUp(self): + self.strategy = _ReadResumptionStrategy() + + self.state = { + "download_states": {}, + "read_handle": None, + "routing_token": None + } + + def _add_download(self, read_id, offset=0, length=100, buffer=None): + """Helper to inject a download state into the correct nested location.""" + if buffer is None: + buffer = io.BytesIO() + state = _DownloadState( + initial_offset=offset, initial_length=length, user_buffer=buffer + ) + self.state["download_states"][read_id] = state + return state + + def _create_response(self, content, read_id, offset, crc=None, range_end=False, handle=None, has_read_range=True): + """Helper to create a response object.""" + checksummed_data = None + if content is not None: + if crc is None: + c = Checksum(content) + crc = int.from_bytes(c.digest(), "big") + checksummed_data = storage_v2.ChecksummedData(content=content, crc32c=crc) + + read_range = None + if has_read_range: + read_range = storage_v2.ReadRange(read_id=read_id, read_offset=offset) + + read_handle_message = None + if handle: + read_handle_message = storage_v2.BidiReadHandle(handle=handle) + self.state["read_handle"] = handle + + return storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + storage_v2.ObjectRangeData( + checksummed_data=checksummed_data, + read_range=read_range, + range_end=range_end, + ) + ], + read_handle=read_handle_message, + ) + + # --- Request Generation Tests --- + def test_generate_requests_single_incomplete(self): """Test generating a request for a single incomplete download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0, length=100) read_state.bytes_written = 20 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 1) self.assertEqual(requests[0].read_offset, 20) @@ -62,173 +112,215 @@ def test_generate_requests_single_incomplete(self): def test_generate_requests_multiple_incomplete(self): """Test generating requests for multiple incomplete downloads.""" read_id2 = 2 - read_state1 = _DownloadState(0, 100, io.BytesIO()) - read_state1.bytes_written = 50 - read_state2 = _DownloadState(200, 100, io.BytesIO()) - state = {_READ_ID: read_state1, read_id2: read_state2} + rs1 = self._add_download(_READ_ID, offset=0, length=100) + rs1.bytes_written = 50 + + self._add_download(read_id2, offset=200, length=100) - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 2) - req1 = next(request for request in requests if request.read_id == _READ_ID) - req2 = next(request for request in requests if request.read_id == read_id2) + requests.sort(key=lambda r: r.read_id) + req1 = requests[0] + req2 = requests[1] + + self.assertEqual(req1.read_id, _READ_ID) self.assertEqual(req1.read_offset, 50) self.assertEqual(req1.read_length, 50) + + self.assertEqual(req2.read_id, read_id2) self.assertEqual(req2.read_offset, 200) self.assertEqual(req2.read_length, 100) + def test_generate_requests_read_to_end_resumption(self): + """Test resumption for 'read to end' (length=0) requests.""" + read_state = self._add_download(_READ_ID, offset=0, length=0) + read_state.bytes_written = 500 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 1) + self.assertEqual(requests[0].read_offset, 500) + self.assertEqual(requests[0].read_length, 0) + def test_generate_requests_with_complete(self): """Test that no request is generated for a completed download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID) read_state.is_complete = True - state = {_READ_ID: read_state} - - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + def test_generate_requests_multiple_mixed_states(self): + """Test generating requests with mixed complete, partial, and fresh states.""" + s1 = self._add_download(1, length=100) + s1.is_complete = True + + s2 = self._add_download(2, offset=0, length=100) + s2.bytes_written = 50 + + s3 = self._add_download(3, offset=200, length=100) + s3.bytes_written = 0 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 2) + requests.sort(key=lambda r: r.read_id) + + self.assertEqual(requests[0].read_id, 2) + self.assertEqual(requests[1].read_id, 3) + def test_generate_requests_empty_state(self): """Test generating requests with an empty state.""" - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests({}) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + # --- Update State and response processing Tests --- + def test_update_state_processes_single_chunk_successfully(self): """Test updating state from a successful response.""" - buffer = io.BytesIO() - read_state = _DownloadState(0, 100, buffer) - state = {_READ_ID: read_state} + read_state = self._add_download(_READ_ID, offset=0, length=100) data = b"test_data" - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertEqual(read_state.bytes_written, len(data)) self.assertEqual(read_state.next_expected_offset, len(data)) self.assertFalse(read_state.is_complete) - self.assertEqual(buffer.getvalue(), data) + self.assertEqual(read_state.user_buffer.getvalue(), data) + + def test_update_state_accumulates_chunks(self): + """Verify that state updates correctly over multiple chunks.""" + read_state = self._add_download(_READ_ID, offset=0, length=8) + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + self.assertEqual(read_state.bytes_written, 4) + self.assertEqual(read_state.user_buffer.getvalue(), b"test") + + resp2 = self._create_response(b"data", _READ_ID, offset=4, range_end=True) + self.strategy.update_state_from_response(resp2, self.state) + + self.assertEqual(read_state.bytes_written, 8) + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.user_buffer.getvalue(), b"testdata") - def test_update_state_from_response_offset_mismatch(self): + def test_update_state_captures_read_handle(self): + """Verify read_handle is extracted from the response.""" + self._add_download(_READ_ID) + + new_handle = b"optimized_handle" + response = self._create_response(b"data", _READ_ID, 0, handle=new_handle) + + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["read_handle"], new_handle) + + def test_update_state_unknown_id(self): + """Verify we ignore data for IDs not in our tracking state.""" + self._add_download(_READ_ID) + response = self._create_response(b"ghost", read_id=999, offset=0) + + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["download_states"][_READ_ID].bytes_written, 0) + + def test_update_state_missing_read_range(self): + """Verify we ignore ranges without read_range metadata.""" + response = self._create_response(b"data", _READ_ID, 0, has_read_range=False) + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_offset_mismatch(self): """Test that an offset mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0) read_state.next_expected_offset = 10 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - ) - ] - ) + response = self._create_response(b"data", _READ_ID, offset=0) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Offset mismatch" in str(exc_info.value) + with self.assertRaisesRegex(DataCorruption, "Offset mismatch"): + self.strategy.update_state_from_response(response, self.state) - def test_update_state_from_response_final_byte_count_mismatch(self): - """Test that a final byte count mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + def test_update_state_checksum_mismatch(self): + """Test that a CRC32C mismatch raises DataCorruption.""" + self._add_download(_READ_ID) + response = self._create_response(b"data", _READ_ID, offset=0, crc=999999) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - range_end=True, - ) - ] - ) + with self.assertRaisesRegex(DataCorruption, "Checksum mismatch"): + self.strategy.update_state_from_response(response, self.state) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Byte count mismatch" in str(exc_info.value) + def test_update_state_final_byte_count_mismatch(self): + """Test mismatch between expected length and actual bytes written on completion.""" + self._add_download(_READ_ID, length=100) - def test_update_state_from_response_completes_download(self): + response = self._create_response(b"data", _READ_ID, offset=0, range_end=True) + + with self.assertRaisesRegex(DataCorruption, "Byte count mismatch"): + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_completes_download(self): """Test that the download is marked complete on range_end.""" - buffer = io.BytesIO() data = b"test_data" - read_state = _DownloadState(0, len(data), buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + read_state = self._add_download(_READ_ID, length=len(data)) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - self.assertEqual(buffer.getvalue(), data) - def test_update_state_from_response_completes_download_zero_length(self): + def test_update_state_completes_download_zero_length(self): """Test completion for a download with initial_length of 0.""" - buffer = io.BytesIO() + read_state = self._add_download(_READ_ID, length=0) data = b"test_data" - read_state = _DownloadState(0, 0, buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - async def test_recover_state_on_failure_handles_redirect(self): - """Verify recover_state_on_failure correctly extracts routing_token.""" - strategy = _ReadResumptionStrategy() + def test_update_state_zero_byte_file(self): + """Test downloading a completely empty file.""" + read_state = self._add_download(_READ_ID, length=0) + + response = self._create_response(b"", _READ_ID, offset=0, range_end=True) - state = {} - self.assertIsNone(state.get("routing_token")) + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.bytes_written, 0) + self.assertEqual(read_state.user_buffer.getvalue(), b"") - dummy_token = "dummy-routing-token" - redirect_error = BidiReadObjectRedirectedError(routing_token=dummy_token) + # --- Recovery Tests --- + def test_recover_state_on_failure_handles_redirect(self): + """Verify recover_state_on_failure correctly extracts routing_token.""" + token = "dummy-routing-token" + redirect_error = BidiReadObjectRedirectedError(routing_token=token) final_error = exceptions.RetryError("Retry failed", cause=redirect_error) - await strategy.recover_state_on_failure(final_error, state) + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) + + asyncio.new_event_loop().run_until_complete(run()) + + self.assertEqual(self.state["routing_token"], token) + + def test_recover_state_ignores_standard_errors(self): + """Verify that non-redirect errors do not corrupt the routing token.""" + self.state["routing_token"] = "existing-token" + + std_error = exceptions.ServiceUnavailable("Maintenance") + final_error = exceptions.RetryError("Retry failed", cause=std_error) + + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) + + asyncio.new_event_loop().run_until_complete(run()) - self.assertEqual(state.get("routing_token"), dummy_token) + # Token should remain unchanged + self.assertEqual(self.state["routing_token"], "existing-token") diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 668006627..b16a1a64b 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -35,11 +35,12 @@ class TestAsyncMultiRangeDownloader: + def create_read_ranges(self, num_ranges): ranges = [] for i in range(num_ranges): ranges.append( - _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) + (i, 1, BytesIO()) ) return ranges @@ -89,16 +90,6 @@ async def test_create_mrd( read_handle=_TEST_READ_HANDLE, ) - mrd.read_obj_str.open.assert_called_once() - # Assert - mock_cls_async_read_object_stream.assert_called_once_with( - client=mock_grpc_client, - bucket_name=_TEST_BUCKET_NAME, - object_name=_TEST_OBJECT_NAME, - generation_number=_TEST_GENERATION_NUMBER, - read_handle=_TEST_READ_HANDLE, - ) - mrd.read_obj_str.open.assert_called_once() assert mrd.client == mock_grpc_client @@ -132,7 +123,9 @@ async def test_download_ranges_via_async_gather( mock_mrd = await self._make_mock_mrd( mock_grpc_client, mock_cls_async_read_object_stream ) - mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id + + mock_random_int.side_effect = [456, 91011] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() @@ -150,6 +143,7 @@ async def test_download_ranges_via_async_gather( ) ] ), + None, _storage_v2.BidiReadObjectResponse( object_data_ranges=[ _storage_v2.ObjectRangeData( @@ -164,12 +158,14 @@ async def test_download_ranges_via_async_gather( ) ], ), + None, ] # Act buffer = BytesIO() second_buffer = BytesIO() lock = asyncio.Lock() + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) task2 = asyncio.create_task( mock_mrd.download_ranges([(10, 6, second_buffer)], lock) @@ -177,18 +173,6 @@ async def test_download_ranges_via_async_gather( await asyncio.gather(task1, task2) # Assert - mock_mrd.read_obj_str.send.side_effect = [ - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) - ] - ), - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011) - ] - ), - ] assert buffer.getvalue() == data assert second_buffer.getvalue() == data[10:16] @@ -213,22 +197,27 @@ async def test_download_ranges( mock_mrd = await self._make_mock_mrd( mock_grpc_client, mock_cls_async_read_object_stream ) - mock_random_int.side_effect = [123, 456] # for _func_id and read_id + + mock_random_int.side_effect = [456] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() - mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ) + mock_mrd.read_obj_str.recv.side_effect = [ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None + ] # Act buffer = BytesIO() @@ -317,7 +306,6 @@ async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client): mrd = AsyncMultiRangeDownloader( mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME ) - # Act + Assert with pytest.raises(ValueError) as exc: await mrd.close() @@ -366,7 +354,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing( @pytest.mark.asyncio @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Checksum" + "google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy.Checksum" ) @mock.patch( "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" @@ -374,6 +362,8 @@ def test_init_raises_if_crc32c_c_extension_is_missing( async def test_download_ranges_raises_on_checksum_mismatch( self, mock_client, mock_checksum_class ): + from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import AsyncMultiRangeDownloader + mock_stream = mock.AsyncMock( spec=async_read_object_stream._AsyncReadObjectStream ) @@ -389,7 +379,7 @@ async def test_download_ranges_raises_on_checksum_mismatch( checksummed_data=_storage_v2.ChecksummedData( content=test_data, crc32c=server_checksum ), - read_range=_storage_v2.ReadRange(read_id=0), + read_range=_storage_v2.ReadRange(read_id=0, read_offset=0, read_length=len(test_data)), range_end=True, ) ] @@ -402,7 +392,8 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd._is_stream_open = True with pytest.raises(DataCorruption) as exc_info: - await mrd.download_ranges([(0, len(test_data), BytesIO())]) + with mock.patch("google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer", return_value=0): + await mrd.download_ranges([(0, len(test_data), BytesIO())]) assert "Checksum mismatch" in str(exc_info.value) mock_checksum_class.assert_called_once_with(test_data)