1313# limitations under the License.
1414
1515import asyncio
16- from typing import Any , AsyncIterator , Callable
16+ from typing import Any , AsyncIterator , Callable , Iterable , TYPE_CHECKING
1717
1818from google .api_core import exceptions
1919from google .cloud .storage ._experimental .asyncio .retry .base_strategy import (
2020 _BaseResumptionStrategy ,
2121)
2222
23+ if TYPE_CHECKING :
24+ from google .api_core .retry_async import AsyncRetry
25+
2326
2427class _BidiStreamRetryManager :
2528 """Manages the generic retry loop for a bidi streaming operation."""
2629
2730 def __init__ (
2831 self ,
2932 strategy : _BaseResumptionStrategy ,
30- stream_opener : Callable [... , AsyncIterator [Any ]],
33+ stream_opener : Callable [[ Iterable [ Any ], Any ] , AsyncIterator [Any ]],
3134 ):
3235 """Initializes the retry manager.
3336
@@ -39,13 +42,13 @@ def __init__(
3942 self ._strategy = strategy
4043 self ._stream_opener = stream_opener
4144
42- async def execute (self , initial_state : Any , retry_policy ):
45+ async def execute (self , initial_state : Any , retry_policy : "AsyncRetry" ):
4346 """
4447 Executes the bidi operation with the configured retry policy.
4548
4649 Args:
4750 initial_state: An object containing all state for the operation.
48- retry_policy: The `google.api_core.retry .AsyncRetry` object to
51+ retry_policy: The `google.api_core.retry_async .AsyncRetry` object to
4952 govern the retry behavior for this specific operation.
5053 """
5154 state = initial_state
@@ -56,12 +59,14 @@ async def attempt():
5659 try :
5760 async for response in stream :
5861 self ._strategy .update_state_from_response (response , state )
59- return
62+ return # Successful completion of the stream.
6063 except Exception as e :
6164 if retry_policy ._predicate (e ):
6265 await self ._strategy .recover_state_on_failure (e , state )
6366 raise e
6467
68+ # Wrap the attempt function with the retry policy.
6569 wrapped_attempt = retry_policy (attempt )
6670
71+ # Execute the operation with retry.
6772 await wrapped_attempt ()
0 commit comments