Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

"""

from typing import Optional
from typing import List, Optional, Tuple
from google.cloud import _storage_v2
from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient
from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import (
Expand Down Expand Up @@ -93,23 +93,54 @@ def __init__(
self.socket_like_rpc: Optional[AsyncBidiRpc] = None
self._is_stream_open: bool = False

async def open(self) -> None:
async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None:
"""Opens the bidi-gRPC connection to read from the object.

This method sends an initial request to start the stream and receives
the first response containing metadata and a read handle.

Args:
metadata (Optional[List[Tuple[str, str]]]): Additional metadata
to send with the initial stream request, e.g., for routing tokens.
"""
if self._is_stream_open:
raise ValueError("Stream is already open")

read_object_spec = _storage_v2.BidiReadObjectSpec(
bucket=self._full_bucket_name,
object=self.object_name,
generation=self.generation_number if self.generation_number else None,
read_handle=self.read_handle if self.read_handle else None,
)
initial_request = _storage_v2.BidiReadObjectRequest(
read_object_spec=read_object_spec
)

# Build the x-goog-request-params header
request_params = [f"bucket={self._full_bucket_name}"]
other_metadata = []
if metadata:
for key, value in metadata:
if key == "x-goog-request-params":
request_params.append(value)
else:
other_metadata.append((key, value))

current_metadata = other_metadata
current_metadata.append(("x-goog-request-params", ",".join(request_params)))

self.socket_like_rpc = AsyncBidiRpc(
self.rpc, initial_request=self.first_bidi_read_req, metadata=self.metadata
self.rpc, initial_request=initial_request, metadata=current_metadata
)
await self.socket_like_rpc.open() # this is actually 1 send
await self.socket_like_rpc.open()
response = await self.socket_like_rpc.recv()
if self.generation_number is None:
self.generation_number = response.metadata.generation

self.read_handle = response.read_handle
if response and response.metadata:
if self.generation_number is None and response.metadata.generation:
self.generation_number = response.metadata.generation

if response and response.read_handle and response.read_handle.handle:
self.read_handle = response.read_handle.handle

self._is_stream_open = True

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from typing import Any, Iterable

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, AsyncIterator, Callable

from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
_BaseResumptionStrategy,
)


class _BidiStreamRetryManager:
"""Manages the generic retry loop for a bidi streaming operation."""

def __init__(
self,
strategy: _BaseResumptionStrategy,
stream_opener: Callable[..., AsyncIterator[Any]],
):
"""Initializes the retry manager.
Args:
strategy: The strategy for managing the state of a specific
bidi operation (e.g., reads or writes).
stream_opener: An async callable that opens a new gRPC stream.
"""
self._strategy = strategy
self._stream_opener = stream_opener

async def execute(self, initial_state: Any, retry_policy):
"""
Executes the bidi operation with the configured retry policy.
Args:
initial_state: An object containing all state for the operation.
retry_policy: The `google.api_core.retry.AsyncRetry` object to
govern the retry behavior for this specific operation.
"""
state = initial_state

async def attempt():
requests = self._strategy.generate_requests(state)
stream = self._stream_opener(requests, state)
try:
async for response in stream:
self._strategy.update_state_from_response(response, state)
return
except Exception as e:
if retry_policy._predicate(e):
await self._strategy.recover_state_on_failure(e, state)
raise e

wrapped_attempt = retry_policy(attempt)

await wrapped_attempt()
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from typing import Any, List, IO
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, IO

from google_crc32c import Checksum
from google.cloud import _storage_v2 as storage_v2
from google.cloud.storage.exceptions import DataCorruption
from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
Expand All @@ -25,18 +40,25 @@ def __init__(
class _ReadResumptionStrategy(_BaseResumptionStrategy):
"""The concrete resumption strategy for bidi reads."""

def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]:
def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]:
"""Generates new ReadRange requests for all incomplete downloads.

:type state: dict
:param state: A dictionary mapping a read_id to its corresponding
_DownloadState object.
"""
pending_requests = []
for read_id, read_state in state.items():
download_states: Dict[int, _DownloadState] = state["download_states"]

for read_id, read_state in download_states.items():
if not read_state.is_complete:
new_offset = read_state.initial_offset + read_state.bytes_written
new_length = read_state.initial_length - read_state.bytes_written

# Calculate remaining length. If initial_length is 0 (read to end),
# it stays 0. Otherwise, subtract bytes_written.
new_length = 0
if read_state.initial_length > 0:
new_length = read_state.initial_length - read_state.bytes_written

new_request = storage_v2.ReadRange(
read_offset=new_offset,
Expand All @@ -47,19 +69,52 @@ def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]:
return pending_requests

def update_state_from_response(
self, response: storage_v2.BidiReadObjectResponse, state: dict
self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any]
) -> None:
"""Processes a server response, performs integrity checks, and updates state."""

# Capture read_handle if provided.
if response.read_handle and response.read_handle.handle:
state["read_handle"] = response.read_handle.handle

download_states = state["download_states"]

for object_data_range in response.object_data_ranges:
# Ignore empty ranges or ranges for IDs not in our state
# (e.g., from a previously cancelled request on the same stream).
if not object_data_range.read_range:
continue

read_id = object_data_range.read_range.read_id
read_state = state[read_id]
if read_id not in download_states:
continue

read_state = download_states[read_id]

# Offset Verification
chunk_offset = object_data_range.read_range.read_offset
if chunk_offset != read_state.next_expected_offset:
raise DataCorruption(response, f"Offset mismatch for read_id {read_id}")
raise DataCorruption(
response,
f"Offset mismatch for read_id {read_id}. "
f"Expected {read_state.next_expected_offset}, got {chunk_offset}"
)

# Checksum Verification
# We must validate data before updating state or writing to buffer.
data = object_data_range.checksummed_data.content
server_checksum = object_data_range.checksummed_data.crc32c

if server_checksum is not None:
client_checksum = int.from_bytes(Checksum(data).digest(), "big")
if server_checksum != client_checksum:
raise DataCorruption(
response,
f"Checksum mismatch for read_id {read_id}. "
f"Server sent {server_checksum}, client calculated {client_checksum}."
)

# Update State & Write Data
chunk_size = len(data)
read_state.bytes_written += chunk_size
read_state.next_expected_offset += chunk_size
Expand All @@ -73,7 +128,9 @@ def update_state_from_response(
and read_state.bytes_written != read_state.initial_length
):
raise DataCorruption(
response, f"Byte count mismatch for read_id {read_id}"
response,
f"Byte count mismatch for read_id {read_id}. "
f"Expected {read_state.initial_length}, got {read_state.bytes_written}"
)

async def recover_state_on_failure(self, error: Exception, state: Any) -> None:
Expand All @@ -83,3 +140,6 @@ async def recover_state_on_failure(self, error: Exception, state: Any) -> None:
cause = getattr(error, "cause", error)
if isinstance(cause, BidiReadObjectRedirectedError):
state["routing_token"] = cause.routing_token
if cause.read_handle and cause.read_handle.handle:
state["read_handle"] = cause.read_handle.handle
print(f"Recover state: Updated read_handle from redirect: {state['read_handle']}")
Loading