diff --git a/docs/models/overview.md b/docs/models/overview.md index c7fe46993b..b67534fd53 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -235,3 +235,126 @@ passing a custom `fallback_on` argument to the `FallbackModel` constructor. !!! note Validation errors (from [structured output](../output.md#structured-output) or [tool parameters](../tools.md)) do **not** trigger fallback. These errors use the [retry mechanism](../agents.md#reflection-and-self-correction) instead, which re-prompts the same model to try again. This is intentional: validation errors stem from the non-deterministic nature of LLMs and may succeed on retry, whereas API errors (4xx/5xx) generally indicate issues that won't resolve by retrying the same request. + +!!! note "Streaming limitation" + For streaming requests, exception-based fallback only catches errors during stream **initialization** (e.g., connection errors, authentication failures). If an exception occurs mid-stream after events have started flowing, it will propagate to the caller without triggering fallback. To handle mid-stream failures, use [`fallback_on_part`](#part-based-fallback-streaming) which buffers the stream and can cleanly switch to a fallback model. + +### Response-Based Fallback + +In addition to exception-based fallback, you can also trigger fallback based on the **content** of a model's response using the `fallback_on_response` parameter. This is useful when a model returns a successful HTTP response (no exception), but the response content indicates a semantic failure. + +A common use case is when using built-in tools like web search or URL fetching. For example, Google's `WebFetchTool` may return a successful response with a status indicating the URL fetch failed: + +```python {title="fallback_on_response.py" test="skip" lint="skip"} +from typing import Any + +from pydantic_ai import Agent +from pydantic_ai.messages import BuiltinToolCallPart, BuiltinToolReturnPart, ModelMessage, ModelResponse +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.models.fallback import FallbackModel +from pydantic_ai.models.google import GoogleModel + + +def web_fetch_failed(response: ModelResponse, messages: list[ModelMessage]) -> bool: + """Check if a web_fetch built-in tool failed to retrieve content.""" + call: BuiltinToolCallPart + result: BuiltinToolReturnPart + for call, result in response.builtin_tool_calls: + if call.tool_name != 'web_fetch': + continue + if not isinstance(result.content, dict): + continue + content: dict[str, Any] = result.content + status = content.get('url_retrieval_status', '') + if status and status != 'URL_RETRIEVAL_STATUS_SUCCESS': + return True + return False + + +google_model = GoogleModel('gemini-2.5-flash') +anthropic_model = AnthropicModel('claude-sonnet-4-5') + +fallback_model = FallbackModel( + google_model, + anthropic_model, + fallback_on_response=web_fetch_failed, +) + +agent = Agent(fallback_model) + +# If Google's web_fetch fails, automatically falls back to Anthropic +result = agent.run_sync('Summarize https://example.com') +print(result.output) +``` + +The `fallback_on_response` callback receives two arguments: + +- `response`: The [`ModelResponse`][pydantic_ai.messages.ModelResponse] returned by the model +- `messages`: The list of [`ModelMessage`][pydantic_ai.messages.ModelMessage] that were sent to the model + +The callback should return `True` to trigger fallback to the next model, or `False` to accept the response. + +!!! note + When using `fallback_on_response` with streaming (`run_stream`), the entire response is buffered before being returned. This means the caller won't receive partial results until the full response is ready and the fallback condition has been evaluated. This is necessary because the response content must be fully available to evaluate the fallback condition. + +### Part-Based Fallback (Streaming) + +For streaming requests, you can use `fallback_on_part` to check each response part as it arrives from the model. This enables **early abort** when failure conditions are detectable before the full response completes—saving tokens and reducing latency by starting the fallback sooner. + +This is particularly useful when built-in tool results (like `web_fetch`) arrive early in the stream: + +```python {title="fallback_on_part.py" test="skip" lint="skip"} +from pydantic_ai import Agent +from pydantic_ai.messages import BuiltinToolReturnPart, ModelMessage, ModelResponsePart +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.models.fallback import FallbackModel +from pydantic_ai.models.google import GoogleModel + + +def web_fetch_failed_part(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + """Check if a web_fetch built-in tool part indicates failure.""" + if not isinstance(part, BuiltinToolReturnPart) or part.tool_name != 'web_fetch': + return False + if not isinstance(part.content, dict): + return False + status = part.content.get('url_retrieval_status', '') + return bool(status) and status != 'URL_RETRIEVAL_STATUS_SUCCESS' + + +google_model = GoogleModel('gemini-2.5-flash') +anthropic_model = AnthropicModel('claude-sonnet-4-5') + +fallback_model = FallbackModel( + google_model, + anthropic_model, + fallback_on_part=web_fetch_failed_part, +) + +agent = Agent(fallback_model) + +# With streaming, fallback can occur as soon as the failed tool result arrives +async with agent.run_stream('Summarize https://example.com') as result: + output = await result.get_output() +print(output) +``` + +The `fallback_on_part` callback receives: + +- `part`: A [`ModelResponsePart`][pydantic_ai.messages.ModelResponsePart] that has completed streaming +- `messages`: The list of [`ModelMessage`][pydantic_ai.messages.ModelMessage] that were sent to the model + +You can use both `fallback_on_part` and `fallback_on_response` together. Parts are checked during streaming, and if the stream completes without part rejection, the full response is checked with `fallback_on_response`. + +!!! warning "Buffering trade-off" + Like `fallback_on_response`, using `fallback_on_part` buffers the response before returning it to the caller. This means the caller won't receive events progressively—they'll receive the complete response after all parts have been validated. + + The benefit of `fallback_on_part` over `fallback_on_response` is **not** live streaming to the caller, but rather: + + - **Token savings**: Stop consuming a response as soon as a failure is detected, rather than waiting for it to complete + - **Faster fallback**: Start the next model immediately instead of waiting for a doomed response to finish + - **Cost reduction**: Pay only for the tokens consumed before the failure was detected + + If you need progressive streaming to the caller and only want fallback for connection/initialization errors, you can omit `fallback_on_part` and `fallback_on_response`. However, be aware that exception-based fallback (`fallback_on`) only catches errors during stream **initialization**—if an exception occurs mid-stream after events have started flowing, it will propagate to the caller without triggering fallback. + +!!! note + `fallback_on_part` only applies to streaming requests (`run_stream`). For non-streaming requests, use `fallback_on_response` instead. diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 67151a07b9..a2ff0ce1e3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -3,12 +3,14 @@ from collections.abc import AsyncIterator, Callable from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field +from datetime import datetime from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from opentelemetry.trace import get_current_span from pydantic_ai._run_context import RunContext +from pydantic_ai.messages import ModelResponseStreamEvent, PartEndEvent from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelAPIError @@ -16,9 +18,12 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: - from ..messages import ModelMessage, ModelResponse + from ..messages import ModelMessage, ModelResponse, ModelResponsePart from ..settings import ModelSettings +FallbackOnResponse = Callable[['ModelResponse', list['ModelMessage']], bool] +FallbackOnPart = Callable[['ModelResponsePart', list['ModelMessage']], bool] + @dataclass(init=False) class FallbackModel(Model): @@ -31,12 +36,16 @@ class FallbackModel(Model): _model_name: str = field(repr=False) _fallback_on: Callable[[Exception], bool] + _fallback_on_response: FallbackOnResponse | None + _fallback_on_part: FallbackOnPart | None def __init__( self, default_model: Model | KnownModelName | str, *fallback_models: Model | KnownModelName | str, fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelAPIError,), + fallback_on_response: FallbackOnResponse | None = None, + fallback_on_part: FallbackOnPart | None = None, ): """Initialize a fallback model instance. @@ -44,6 +53,19 @@ def __init__( default_model: The name or instance of the default model to use. fallback_models: The names or instances of the fallback models to use upon failure. fallback_on: A callable or tuple of exceptions that should trigger a fallback. + For streaming requests, this only catches exceptions during stream initialization + (e.g., connection errors, authentication failures). Exceptions that occur mid-stream + after events have started flowing will propagate to the caller without triggering + fallback. Use `fallback_on_part` if you need to handle mid-stream failures. + fallback_on_response: A callable that inspects the model response and message history, + returning `True` if fallback should be triggered. This enables fallback based on + response content (e.g., a builtin tool indicating failure) rather than exceptions. + fallback_on_part: A callable that inspects each model response part during streaming, + returning `True` if fallback should be triggered. This enables early abort when + a failure condition is detected (e.g., a builtin tool failure), saving tokens by + not consuming the rest of a doomed response. Only applies to streaming requests. + Note: The response is buffered until validation completes, so the caller receives + events after the full response is validated, not progressively. """ super().__init__() self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]] @@ -53,6 +75,9 @@ def __init__( else: self._fallback_on = fallback_on + self._fallback_on_response = fallback_on_response + self._fallback_on_part = fallback_on_part + @property def model_name(self) -> str: """The model name.""" @@ -77,6 +102,7 @@ async def request( In case of failure, raise a FallbackExceptionGroup with all exceptions. """ exceptions: list[Exception] = [] + response_rejections: int = 0 for model in self.models: try: @@ -88,10 +114,14 @@ async def request( continue raise exc + if self._fallback_on_response is not None and self._fallback_on_response(response, messages): + response_rejections += 1 + continue + self._set_span_attributes(model, prepared_parameters) return response - raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) + _raise_fallback_exception_group(exceptions, response_rejections) @asynccontextmanager async def request_stream( @@ -103,6 +133,8 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] + response_rejections: int = 0 + part_rejections: int = 0 for model in self.models: async with AsyncExitStack() as stack: @@ -117,11 +149,46 @@ async def request_stream( continue raise exc # pragma: no cover + if self._fallback_on_part is not None: + buffered_events: list[ModelResponseStreamEvent] = [] + should_fallback = False + + async for event in response: + buffered_events.append(event) + if isinstance(event, PartEndEvent) and self._fallback_on_part(event.part, messages): + should_fallback = True + break + + if should_fallback: + part_rejections += 1 + continue + + if self._fallback_on_response is not None and self._fallback_on_response(response.get(), messages): + response_rejections += 1 + continue + + self._set_span_attributes(model, prepared_parameters) + yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events) + return + + elif self._fallback_on_response is not None: + buffered_events = [] + async for event in response: + buffered_events.append(event) + + if self._fallback_on_response(response.get(), messages): + response_rejections += 1 + continue + + self._set_span_attributes(model, prepared_parameters) + yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events) + return + self._set_span_attributes(model, prepared_parameters) yield response return - raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) + _raise_fallback_exception_group(exceptions, response_rejections, part_rejections) @cached_property def profile(self) -> ModelProfile: @@ -135,7 +202,7 @@ def prepare_request( ) -> tuple[ModelSettings | None, ModelRequestParameters]: return model_settings, model_request_parameters - def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters): + def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters) -> None: with suppress(Exception): span = get_current_span() if span.is_recording(): @@ -156,3 +223,62 @@ def fallback_condition(exception: Exception) -> bool: return isinstance(exception, exceptions) return fallback_condition + + +def _raise_fallback_exception_group( + exceptions: list[Exception], response_rejections: int, part_rejections: int = 0 +) -> NoReturn: + """Raise a FallbackExceptionGroup combining exceptions and rejections.""" + all_errors: list[Exception] = list(exceptions) + if part_rejections > 0: + all_errors.append(RuntimeError(f'{part_rejections} model(s) rejected by fallback_on_part during streaming')) + if response_rejections > 0: + all_errors.append(RuntimeError(f'{response_rejections} model response(s) rejected by fallback_on_response')) + + if all_errors: + raise FallbackExceptionGroup('All models from FallbackModel failed', all_errors) + else: + raise FallbackExceptionGroup( + 'All models from FallbackModel failed', + [RuntimeError('No models available')], + ) # pragma: no cover + + +@dataclass +class BufferedStreamedResponse(StreamedResponse): + """A StreamedResponse wrapper that replays buffered events.""" + + _wrapped: StreamedResponse + _buffered_events: list[ModelResponseStreamEvent] + + model_request_parameters: ModelRequestParameters = field(init=False) + + def __post_init__(self) -> None: + self.model_request_parameters = self._wrapped.model_request_parameters + self._parts_manager = self._wrapped._parts_manager + self._usage = self._wrapped._usage + self.final_result_event = self._wrapped.final_result_event + self.provider_response_id = self._wrapped.provider_response_id + self.provider_details = self._wrapped.provider_details + self.finish_reason = self._wrapped.finish_reason + self._event_iterator = None # reset so __aiter__ uses _get_event_iterator() + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + for event in self._buffered_events: + yield event + + @property + def model_name(self) -> str: + return self._wrapped.model_name + + @property + def provider_name(self) -> str | None: + return self._wrapped.provider_name + + @property + def provider_url(self) -> str | None: + return self._wrapped.provider_url + + @property + def timestamp(self) -> datetime: + return self._wrapped.timestamp diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 41c20d1e46..2d2c8c0dce 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -21,11 +21,13 @@ ModelProfile, ModelRequest, ModelResponse, + ModelResponsePart, TextPart, ToolCallPart, ToolDefinition, UserPromptPart, ) +from pydantic_ai.messages import BuiltinToolCallPart, BuiltinToolReturnPart from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, FunctionModel @@ -920,3 +922,593 @@ def prompted_output_func(_: list[ModelMessage], info: AgentInfo) -> ModelRespons }, ] ) + + +def primary_response(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('primary response')]) + + +def fallback_response(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('fallback response')]) + + +primary_model = FunctionModel(primary_response) +fallback_model_impl = FunctionModel(fallback_response) + + +async def test_fallback_on_response_triggered() -> None: + def should_fallback_on_primary(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'primary' in part.content + + fallback = FallbackModel( + primary_model, + fallback_model_impl, + fallback_on_response=should_fallback_on_primary, + ) + agent = Agent(model=fallback) + + result = await agent.run('hello') + assert result.output == snapshot('fallback response') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)), + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='fallback response')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:fallback_response:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ] + ) + + +async def test_fallback_on_response_not_triggered() -> None: + def never_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + return False + + fallback = FallbackModel( + primary_model, + fallback_model_impl, + fallback_on_response=never_fallback, + ) + agent = Agent(model=fallback) + + result = await agent.run('hello') + assert result.output == snapshot('primary response') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)), + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='primary response')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:primary_response:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ] + ) + + +async def test_fallback_on_response_all_fail() -> None: + def always_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + return True + + fallback = FallbackModel( + primary_model, + fallback_model_impl, + fallback_on_response=always_fallback, + ) + agent = Agent(model=fallback) + + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: + await agent.run('hello') + assert 'All models from FallbackModel failed' in exc_info.value.args[0] + assert len(exc_info.value.exceptions) == 1 + assert isinstance(exc_info.value.exceptions[0], RuntimeError) + assert 'rejected by fallback_on_response' in str(exc_info.value.exceptions[0]) + + +async def test_fallback_on_response_with_message_inspection() -> None: + inspected_messages: list[list[ModelMessage]] = [] + + def inspect_messages(response: ModelResponse, messages: list[ModelMessage]) -> bool: + inspected_messages.append(messages) + return False + + fallback = FallbackModel( + primary_model, + fallback_model_impl, + fallback_on_response=inspect_messages, + ) + agent = Agent(model=fallback) + + await agent.run('hello') + + assert len(inspected_messages) == 1 + assert len(inspected_messages[0]) == 1 + + +async def test_fallback_on_response_combined_with_exception_fallback() -> None: + call_order: list[str] = [] + + def first_fails_with_exception(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + call_order.append('first') + raise ModelHTTPError(status_code=500, model_name='first', body=None) + + def second_fails_response_check(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + call_order.append('second') + return ModelResponse(parts=[TextPart('bad response')]) + + def third_succeeds(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + call_order.append('third') + return ModelResponse(parts=[TextPart('good response')]) + + def reject_bad_response(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'bad' in part.content + + first_model = FunctionModel(first_fails_with_exception) + second_model = FunctionModel(second_fails_response_check) + third_model = FunctionModel(third_succeeds) + + fallback = FallbackModel( + first_model, + second_model, + third_model, + fallback_on_response=reject_bad_response, + ) + agent = Agent(model=fallback) + + result = await agent.run('hello') + + assert result.output == snapshot('good response') + assert call_order == snapshot(['first', 'second', 'third']) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)), + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='good response')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:third_succeeds:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ] + ) + + +async def test_fallback_on_response_mixed_failures_all_fail() -> None: + call_order: list[str] = [] + + def first_fails_with_exception(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + call_order.append('first') + raise ModelHTTPError(status_code=500, model_name='first', body=None) + + def second_fails_response_check(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + call_order.append('second') + return ModelResponse(parts=[TextPart('bad response')]) + + def reject_bad_response(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'bad' in part.content + + first_model = FunctionModel(first_fails_with_exception) + second_model = FunctionModel(second_fails_response_check) + + fallback = FallbackModel( + first_model, + second_model, + fallback_on_response=reject_bad_response, + ) + agent = Agent(model=fallback) + + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: + await agent.run('hello') + + assert 'All models from FallbackModel failed' in exc_info.value.args[0] + assert len(exc_info.value.exceptions) == 2 + assert isinstance(exc_info.value.exceptions[0], ModelHTTPError) + assert isinstance(exc_info.value.exceptions[1], RuntimeError) + assert 'rejected by fallback_on_response' in str(exc_info.value.exceptions[1]) + assert call_order == ['first', 'second'] + + +async def test_fallback_on_response_web_fetch_scenario() -> None: + def google_web_fetch_fails(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + return ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='web_fetch', args={'url': 'https://example.com'}, tool_call_id='1'), + BuiltinToolReturnPart( + tool_name='web_fetch', + tool_call_id='1', + content={'uri': 'https://example.com', 'url_retrieval_status': 'URL_RETRIEVAL_STATUS_FAILED'}, + ), + TextPart('Could not fetch URL'), + ] + ) + + def anthropic_succeeds(_: list[ModelMessage], __: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('Successfully fetched and summarized the content')]) + + def web_fetch_failed(response: ModelResponse, messages: list[ModelMessage]) -> bool: + for call, result in response.builtin_tool_calls: + if call.tool_name != 'web_fetch': + continue # pragma: no cover + content = result.content + if not isinstance(content, dict): + continue # pragma: no cover + content_dict = cast(dict[str, Any], content) + status = content_dict.get('url_retrieval_status', '') + if status and status != 'URL_RETRIEVAL_STATUS_SUCCESS': # pragma: no branch + return True + return False + + google_model = FunctionModel(google_web_fetch_fails) + anthropic_model = FunctionModel(anthropic_succeeds) + + fallback = FallbackModel( + google_model, + anthropic_model, + fallback_on_response=web_fetch_failed, + ) + agent = Agent(model=fallback) + + result = await agent.run('Summarize https://example.com') + assert result.output == 'Successfully fetched and summarized the content' + + +def test_fallback_on_response_sync() -> None: + def should_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'primary' in part.content + + fallback = FallbackModel( + primary_model, + fallback_model_impl, + fallback_on_response=should_fallback, + ) + agent = Agent(model=fallback) + + result = agent.run_sync('hello') + assert result.output == 'fallback response' + + +async def primary_response_stream(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> AsyncIterator[str]: + yield 'primary ' + yield 'response' + + +async def fallback_response_stream(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> AsyncIterator[str]: + yield 'fallback ' + yield 'response' + + +primary_model_stream = FunctionModel(stream_function=primary_response_stream) +fallback_model_stream_impl = FunctionModel(stream_function=fallback_response_stream) + + +async def test_fallback_on_response_streaming_triggered() -> None: + def should_fallback_on_primary(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'primary' in part.content + + fallback = FallbackModel( + primary_model_stream, + fallback_model_stream_impl, + fallback_on_response=should_fallback_on_primary, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'fallback response' + + +async def test_fallback_on_response_streaming_not_triggered() -> None: + def never_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + return False + + fallback = FallbackModel( + primary_model_stream, + fallback_model_stream_impl, + fallback_on_response=never_fallback, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'primary response' + + +async def test_fallback_on_response_streaming_all_fail() -> None: + def always_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + return True + + fallback = FallbackModel( + primary_model_stream, + fallback_model_stream_impl, + fallback_on_response=always_fallback, + ) + agent = Agent(model=fallback) + + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: + async with agent.run_stream('hello') as result: + await result.get_output() # pragma: no cover + + assert 'All models from FallbackModel failed' in exc_info.value.args[0] + assert len(exc_info.value.exceptions) == 1 + assert isinstance(exc_info.value.exceptions[0], RuntimeError) + assert 'rejected by fallback_on_response' in str(exc_info.value.exceptions[0]) + + +async def test_fallback_on_response_streaming_combined_with_exception() -> None: + call_order: list[str] = [] + + async def first_fails_with_exception(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('first') + raise ModelHTTPError(status_code=500, model_name='first', body=None) + yield 'never' # pragma: no cover + + async def second_fails_response_check(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('second') + yield 'bad ' + yield 'response' + + async def third_succeeds(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('third') + yield 'good ' + yield 'response' + + def reject_bad_response(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'bad' in part.content + + first_model = FunctionModel(stream_function=first_fails_with_exception) + second_model = FunctionModel(stream_function=second_fails_response_check) + third_model = FunctionModel(stream_function=third_succeeds) + + fallback = FallbackModel( + first_model, + second_model, + third_model, + fallback_on_response=reject_bad_response, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'good response' + assert call_order == ['first', 'second', 'third'] + + +async def test_fallback_on_response_streaming_replays_events() -> None: + def never_fallback(response: ModelResponse, messages: list[ModelMessage]) -> bool: + return False + + fallback = FallbackModel( + primary_model_stream, + fallback_model_stream_impl, + fallback_on_response=never_fallback, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)] + + assert len(responses) >= 2 + part = responses[-1].parts[0] + assert isinstance(part, TextPart) + assert part.content == 'primary response' + + +async def test_fallback_on_part_streaming_triggered() -> None: + models_tried: list[str] = [] + + async def bad_response_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + models_tried.append('bad_model') + yield 'bad content' + + async def good_response_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + models_tried.append('good_model') + yield 'good content' + + def reject_bad_part(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return isinstance(part, TextPart) and 'bad' in part.content + + bad_model = FunctionModel(stream_function=bad_response_stream) + good_model = FunctionModel(stream_function=good_response_stream) + + fallback = FallbackModel( + bad_model, + good_model, + fallback_on_part=reject_bad_part, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'good content' + assert models_tried == ['bad_model', 'good_model'] + + +async def test_fallback_on_part_streaming_not_triggered() -> None: + async def ok_response_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + yield 'ok content' + + def never_reject(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return False + + ok_model = FunctionModel(stream_function=ok_response_stream) + fallback_model_not_used = FunctionModel(stream_function=ok_response_stream) + + fallback = FallbackModel( + ok_model, + fallback_model_not_used, + fallback_on_part=never_reject, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'ok content' + + +async def test_fallback_on_part_streaming_all_fail() -> None: + async def bad_response_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + yield 'bad content' + + def always_reject(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return True + + bad_model1 = FunctionModel(stream_function=bad_response_stream) + bad_model2 = FunctionModel(stream_function=bad_response_stream) + + fallback = FallbackModel( + bad_model1, + bad_model2, + fallback_on_part=always_reject, + ) + agent = Agent(model=fallback) + + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: + async with agent.run_stream('hello') as result: + await result.get_output() # pragma: no cover + + assert 'All models from FallbackModel failed' in exc_info.value.args[0] + assert len(exc_info.value.exceptions) == 1 + assert isinstance(exc_info.value.exceptions[0], RuntimeError) + assert 'rejected by fallback_on_part' in str(exc_info.value.exceptions[0]) + + +async def test_fallback_on_part_streaming_combined_with_fallback_on_response() -> None: + call_order: list[str] = [] + + async def part_rejected_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('first_part_rejected') + yield 'reject_part' + + async def response_rejected_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('second_response_rejected') + yield 'reject_response' + + async def success_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('third_success') + yield 'success' + + def reject_part_with_keyword(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return isinstance(part, TextPart) and 'reject_part' in part.content + + def reject_response_with_keyword(response: ModelResponse, messages: list[ModelMessage]) -> bool: + part = response.parts[0] if response.parts else None + return isinstance(part, TextPart) and 'reject_response' in part.content + + first_model = FunctionModel(stream_function=part_rejected_stream) + second_model = FunctionModel(stream_function=response_rejected_stream) + third_model = FunctionModel(stream_function=success_stream) + + fallback = FallbackModel( + first_model, + second_model, + third_model, + fallback_on_part=reject_part_with_keyword, + fallback_on_response=reject_response_with_keyword, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'success' + assert call_order == ['first_part_rejected', 'second_response_rejected', 'third_success'] + + +async def test_fallback_on_part_streaming_with_exception_fallback() -> None: + call_order: list[str] = [] + + async def first_exception_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('first_exception') + raise ModelHTTPError(status_code=500, model_name='first', body=None) + yield 'never' # pragma: no cover + + async def second_part_rejected(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('second_part_rejected') + yield 'bad_part' + + async def third_success(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + call_order.append('third_success') + yield 'good' + + def reject_bad_part(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return isinstance(part, TextPart) and 'bad' in part.content + + first_model = FunctionModel(stream_function=first_exception_stream) + second_model = FunctionModel(stream_function=second_part_rejected) + third_model = FunctionModel(stream_function=third_success) + + fallback = FallbackModel( + first_model, + second_model, + third_model, + fallback_on_part=reject_bad_part, + ) + agent = Agent(model=fallback) + + async with agent.run_stream('hello') as result: + output = await result.get_output() + + assert output == 'good' + assert call_order == ['first_exception', 'second_part_rejected', 'third_success'] + + +async def test_fallback_on_part_streaming_mixed_failures_all_fail() -> None: + async def exception_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + raise ModelHTTPError(status_code=500, model_name='exception', body=None) + yield 'never' # pragma: no cover + + async def bad_part_stream(_: list[ModelMessage], __: AgentInfo) -> AsyncIterator[str]: + yield 'bad_part' + + def always_reject(part: ModelResponsePart, messages: list[ModelMessage]) -> bool: + return True + + first_model = FunctionModel(stream_function=exception_stream) + second_model = FunctionModel(stream_function=bad_part_stream) + + fallback = FallbackModel( + first_model, + second_model, + fallback_on_part=always_reject, + ) + agent = Agent(model=fallback) + + with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info: + async with agent.run_stream('hello') as result: + await result.get_output() # pragma: no cover + + assert 'All models from FallbackModel failed' in exc_info.value.args[0] + assert len(exc_info.value.exceptions) == 2 + assert isinstance(exc_info.value.exceptions[0], ModelHTTPError) + assert isinstance(exc_info.value.exceptions[1], RuntimeError) + assert 'rejected by fallback_on_part' in str(exc_info.value.exceptions[1])