|
4 | 4 | from asyncio import gather, sleep |
5 | 5 |
|
6 | 6 | import pytest |
7 | | -from smithy_core.exceptions import CallError, RetryError |
| 7 | +from smithy_core.exceptions import CallError, ClientTimeoutError, RetryError |
8 | 8 | from smithy_core.interfaces import retries as retries_interface |
9 | 9 | from smithy_core.retries import ( |
10 | 10 | ExponentialBackoffJitterType, |
|
14 | 14 | ) |
15 | 15 |
|
16 | 16 |
|
| 17 | +# TODO: Refactor this to use a smithy-testing generated client |
17 | 18 | async def retry_operation( |
18 | 19 | strategy: retries_interface.RetryStrategy, |
19 | | - status_codes: list[int], |
| 20 | + responses: list[int | Exception], |
20 | 21 | ) -> tuple[str, int]: |
21 | 22 | token = strategy.acquire_initial_retry_token() |
22 | | - responses = iter(status_codes) |
| 23 | + response_iter = iter(responses) |
23 | 24 |
|
24 | 25 | while True: |
25 | 26 | if token.retry_delay: |
26 | 27 | await sleep(token.retry_delay) |
27 | 28 |
|
28 | | - status_code = next(responses) |
| 29 | + response = next(response_iter) |
29 | 30 | attempt = token.retry_count + 1 |
30 | 31 |
|
31 | | - if status_code == 200: |
| 32 | + # Success case |
| 33 | + if response == 200: |
32 | 34 | strategy.record_success(token=token) |
33 | 35 | return "success", attempt |
34 | 36 |
|
35 | | - error = CallError( |
36 | | - fault="server" if status_code >= 500 else "client", |
37 | | - message=f"HTTP {status_code}", |
38 | | - is_retry_safe=status_code >= 500, |
39 | | - ) |
| 37 | + # Error case - either status code or exception |
| 38 | + if isinstance(response, Exception): |
| 39 | + error = response |
| 40 | + else: |
| 41 | + error = CallError( |
| 42 | + fault="server" if response >= 500 else "client", |
| 43 | + message=f"HTTP {response}", |
| 44 | + is_retry_safe=response >= 500, |
| 45 | + ) |
40 | 46 |
|
41 | 47 | try: |
42 | 48 | token = strategy.refresh_retry_token_for_retry( |
@@ -131,3 +137,17 @@ async def test_retry_quota_shared_across_concurrent_operations(): |
131 | 137 | assert result1 == ("success", 3) |
132 | 138 | assert result2 == ("success", 2) |
133 | 139 | assert quota.available_capacity == 495 |
| 140 | + |
| 141 | + |
| 142 | +async def test_retry_quota_handles_timeout_errors(): |
| 143 | + quota = StandardRetryQuota(initial_capacity=500) |
| 144 | + strategy = StandardRetryStrategy(max_attempts=3, retry_quota=quota) |
| 145 | + |
| 146 | + timeout1 = ClientTimeoutError() |
| 147 | + timeout2 = ClientTimeoutError() |
| 148 | + |
| 149 | + result, attempts = await retry_operation(strategy, [timeout1, timeout2, 200]) |
| 150 | + |
| 151 | + assert result == "success" |
| 152 | + assert attempts == 3 |
| 153 | + assert quota.available_capacity == 490 |
0 commit comments