Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions docs/models/overview.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self to review docs later

Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,126 @@ passing a custom `fallback_on` argument to the `FallbackModel` constructor.

!!! note
Validation errors (from [structured output](../output.md#structured-output) or [tool parameters](../tools.md)) do **not** trigger fallback. These errors use the [retry mechanism](../agents.md#reflection-and-self-correction) instead, which re-prompts the same model to try again. This is intentional: validation errors stem from the non-deterministic nature of LLMs and may succeed on retry, whereas API errors (4xx/5xx) generally indicate issues that won't resolve by retrying the same request.

!!! note "Streaming limitation"
For streaming requests, exception-based fallback only catches errors during stream **initialization** (e.g., connection errors, authentication failures). If an exception occurs mid-stream after events have started flowing, it will propagate to the caller without triggering fallback. To handle mid-stream failures, use [`fallback_on_part`](#part-based-fallback-streaming) which buffers the stream and can cleanly switch to a fallback model.

### Response-Based Fallback

In addition to exception-based fallback, you can also trigger fallback based on the **content** of a model's response using the `fallback_on_response` parameter. This is useful when a model returns a successful HTTP response (no exception), but the response content indicates a semantic failure.

A common use case is when using built-in tools like web search or URL fetching. For example, Google's `WebFetchTool` may return a successful response with a status indicating the URL fetch failed:

```python {title="fallback_on_response.py" test="skip" lint="skip"}
from typing import Any

from pydantic_ai import Agent
from pydantic_ai.messages import BuiltinToolCallPart, BuiltinToolReturnPart, ModelMessage, ModelResponse
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.google import GoogleModel


def web_fetch_failed(response: ModelResponse, messages: list[ModelMessage]) -> bool:
"""Check if a web_fetch built-in tool failed to retrieve content."""
call: BuiltinToolCallPart
result: BuiltinToolReturnPart
for call, result in response.builtin_tool_calls:
if call.tool_name != 'web_fetch':
continue
if not isinstance(result.content, dict):
continue
content: dict[str, Any] = result.content
status = content.get('url_retrieval_status', '')
if status and status != 'URL_RETRIEVAL_STATUS_SUCCESS':
return True
return False


google_model = GoogleModel('gemini-2.5-flash')
anthropic_model = AnthropicModel('claude-sonnet-4-5')

fallback_model = FallbackModel(
google_model,
anthropic_model,
fallback_on_response=web_fetch_failed,
)

agent = Agent(fallback_model)

# If Google's web_fetch fails, automatically falls back to Anthropic
result = agent.run_sync('Summarize https://example.com')
print(result.output)
```

The `fallback_on_response` callback receives two arguments:

- `response`: The [`ModelResponse`][pydantic_ai.messages.ModelResponse] returned by the model
- `messages`: The list of [`ModelMessage`][pydantic_ai.messages.ModelMessage] that were sent to the model

The callback should return `True` to trigger fallback to the next model, or `False` to accept the response.

!!! note
When using `fallback_on_response` with streaming (`run_stream`), the entire response is buffered before being returned. This means the caller won't receive partial results until the full response is ready and the fallback condition has been evaluated. This is necessary because the response content must be fully available to evaluate the fallback condition.

### Part-Based Fallback (Streaming)

For streaming requests, you can use `fallback_on_part` to check each response part as it arrives from the model. This enables **early abort** when failure conditions are detectable before the full response completes—saving tokens and reducing latency by starting the fallback sooner.

This is particularly useful when built-in tool results (like `web_fetch`) arrive early in the stream:

```python {title="fallback_on_part.py" test="skip" lint="skip"}
from pydantic_ai import Agent
from pydantic_ai.messages import BuiltinToolReturnPart, ModelMessage, ModelResponsePart
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.google import GoogleModel


def web_fetch_failed_part(part: ModelResponsePart, messages: list[ModelMessage]) -> bool:
"""Check if a web_fetch built-in tool part indicates failure."""
if not isinstance(part, BuiltinToolReturnPart) or part.tool_name != 'web_fetch':
return False
if not isinstance(part.content, dict):
return False
status = part.content.get('url_retrieval_status', '')
return bool(status) and status != 'URL_RETRIEVAL_STATUS_SUCCESS'


google_model = GoogleModel('gemini-2.5-flash')
anthropic_model = AnthropicModel('claude-sonnet-4-5')

fallback_model = FallbackModel(
google_model,
anthropic_model,
fallback_on_part=web_fetch_failed_part,
)

agent = Agent(fallback_model)

# With streaming, fallback can occur as soon as the failed tool result arrives
async with agent.run_stream('Summarize https://example.com') as result:
output = await result.get_output()
print(output)
```

The `fallback_on_part` callback receives:

- `part`: A [`ModelResponsePart`][pydantic_ai.messages.ModelResponsePart] that has completed streaming
- `messages`: The list of [`ModelMessage`][pydantic_ai.messages.ModelMessage] that were sent to the model

You can use both `fallback_on_part` and `fallback_on_response` together. Parts are checked during streaming, and if the stream completes without part rejection, the full response is checked with `fallback_on_response`.

!!! warning "Buffering trade-off"
Like `fallback_on_response`, using `fallback_on_part` buffers the response before returning it to the caller. This means the caller won't receive events progressively—they'll receive the complete response after all parts have been validated.

The benefit of `fallback_on_part` over `fallback_on_response` is **not** live streaming to the caller, but rather:

- **Token savings**: Stop consuming a response as soon as a failure is detected, rather than waiting for it to complete
- **Faster fallback**: Start the next model immediately instead of waiting for a doomed response to finish
- **Cost reduction**: Pay only for the tokens consumed before the failure was detected

If you need progressive streaming to the caller and only want fallback for connection/initialization errors, you can omit `fallback_on_part` and `fallback_on_response`. However, be aware that exception-based fallback (`fallback_on`) only catches errors during stream **initialization**—if an exception occurs mid-stream after events have started flowing, it will propagate to the caller without triggering fallback.

!!! note
`fallback_on_part` only applies to streaming requests (`run_stream`). For non-streaming requests, use `fallback_on_response` instead.
136 changes: 131 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
from collections.abc import AsyncIterator, Callable
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass, field
from datetime import datetime
from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NoReturn

from opentelemetry.trace import get_current_span

from pydantic_ai._run_context import RunContext
from pydantic_ai.messages import ModelResponseStreamEvent, PartEndEvent
from pydantic_ai.models.instrumented import InstrumentedModel

from ..exceptions import FallbackExceptionGroup, ModelAPIError
from ..profiles import ModelProfile
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
from ..messages import ModelMessage, ModelResponse
from ..messages import ModelMessage, ModelResponse, ModelResponsePart
from ..settings import ModelSettings

FallbackOnResponse = Callable[['ModelResponse', list['ModelMessage']], bool]
FallbackOnPart = Callable[['ModelResponsePart', list['ModelMessage']], bool]


@dataclass(init=False)
class FallbackModel(Model):
Expand All @@ -31,19 +36,36 @@ class FallbackModel(Model):

_model_name: str = field(repr=False)
_fallback_on: Callable[[Exception], bool]
_fallback_on_response: FallbackOnResponse | None
_fallback_on_part: FallbackOnPart | None

def __init__(
self,
default_model: Model | KnownModelName | str,
*fallback_models: Model | KnownModelName | str,
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelAPIError,),
fallback_on_response: FallbackOnResponse | None = None,
fallback_on_part: FallbackOnPart | None = None,
):
"""Initialize a fallback model instance.

Args:
default_model: The name or instance of the default model to use.
fallback_models: The names or instances of the fallback models to use upon failure.
fallback_on: A callable or tuple of exceptions that should trigger a fallback.
For streaming requests, this only catches exceptions during stream initialization
(e.g., connection errors, authentication failures). Exceptions that occur mid-stream
after events have started flowing will propagate to the caller without triggering
fallback. Use `fallback_on_part` if you need to handle mid-stream failures.
fallback_on_response: A callable that inspects the model response and message history,
returning `True` if fallback should be triggered. This enables fallback based on
response content (e.g., a builtin tool indicating failure) rather than exceptions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also mention the finish reason as an example

fallback_on_part: A callable that inspects each model response part during streaming,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do wonder if this is the best API, or if we should somehow make fallback_on support exceptions, responses, and parts, all in one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be ideal, if we can figure it out in a backward compatible way :)

If we change Callable[[Exception], bool] to Callable[[Exception | ModelResponse], bool], users' existing handlers would likely start raising errors, so we'd have to check the function signature type hint. But that wouldn't work with lambdas, and type hints are of course not required, so basically we would only send in ModelResponse if it is an explicitly hinted function.

That's not the end of the world, but a little awkward, also because the user would still need to check isinstance(...) etc.

So perhaps it'd be better to support a list of exception types, exception handler functions, and response handler functions? And if the type of a function is not specified, we assume it's an exception handler function?

That way it'd also be easier to reuse checks like your builtin_tool_error.

returning `True` if fallback should be triggered. This enables early abort when
a failure condition is detected (e.g., a builtin tool failure), saving tokens by
not consuming the rest of a doomed response. Only applies to streaming requests.
Note: The response is buffered until validation completes, so the caller receives
events after the full response is validated, not progressively.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback_on_part isn't relevent to my use case but I felt it would be useful for users looking to save tokens when using response-based fallback with large responses. However, if it's undesirable I can remove it 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use case is valid but I'm a lot more hesitant to introduce 2 new fields than 1 😅

Note that the ModelResponsePartsManager.get_parts() method can already be used to build a ModelResponse at any stage of the stream, so it should be possible to make the same fallback_on_response work with streaming as well. We would want to add an incomplete field or something in that case, so that the handler can tell the difference.

That seems like enough of a rabbit hole that it shouldn't block the release of the main feature, but feel free to look into it in a separate issue/PR.

"""
super().__init__()
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
Expand All @@ -53,6 +75,9 @@ def __init__(
else:
self._fallback_on = fallback_on

self._fallback_on_response = fallback_on_response
self._fallback_on_part = fallback_on_part

@property
def model_name(self) -> str:
"""The model name."""
Expand All @@ -77,6 +102,7 @@ async def request(
In case of failure, raise a FallbackExceptionGroup with all exceptions.
"""
exceptions: list[Exception] = []
response_rejections: int = 0

for model in self.models:
try:
Expand All @@ -88,10 +114,14 @@ async def request(
continue
raise exc

if self._fallback_on_response is not None and self._fallback_on_response(response, messages):
response_rejections += 1
continue

self._set_span_attributes(model, prepared_parameters)
return response

raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
_raise_fallback_exception_group(exceptions, response_rejections)

@asynccontextmanager
async def request_stream(
Expand All @@ -103,6 +133,8 @@ async def request_stream(
) -> AsyncIterator[StreamedResponse]:
"""Try each model in sequence until one succeeds."""
exceptions: list[Exception] = []
response_rejections: int = 0
part_rejections: int = 0

for model in self.models:
async with AsyncExitStack() as stack:
Expand All @@ -117,11 +149,46 @@ async def request_stream(
continue
raise exc # pragma: no cover

if self._fallback_on_part is not None:
buffered_events: list[ModelResponseStreamEvent] = []
should_fallback = False

async for event in response:
buffered_events.append(event)
if isinstance(event, PartEndEvent) and self._fallback_on_part(event.part, messages):
should_fallback = True
break

if should_fallback:
part_rejections += 1
continue

if self._fallback_on_response is not None and self._fallback_on_response(response.get(), messages):
response_rejections += 1
continue

self._set_span_attributes(model, prepared_parameters)
yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events)
return

elif self._fallback_on_response is not None:
buffered_events = []
async for event in response:
buffered_events.append(event)

if self._fallback_on_response(response.get(), messages):
response_rejections += 1
continue

self._set_span_attributes(model, prepared_parameters)
yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events)
return

self._set_span_attributes(model, prepared_parameters)
yield response
return

raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
_raise_fallback_exception_group(exceptions, response_rejections, part_rejections)

@cached_property
def profile(self) -> ModelProfile:
Expand All @@ -135,7 +202,7 @@ def prepare_request(
) -> tuple[ModelSettings | None, ModelRequestParameters]:
return model_settings, model_request_parameters

def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters):
def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters) -> None:
with suppress(Exception):
span = get_current_span()
if span.is_recording():
Expand All @@ -156,3 +223,62 @@ def fallback_condition(exception: Exception) -> bool:
return isinstance(exception, exceptions)

return fallback_condition


def _raise_fallback_exception_group(
exceptions: list[Exception], response_rejections: int, part_rejections: int = 0
) -> NoReturn:
"""Raise a FallbackExceptionGroup combining exceptions and rejections."""
all_errors: list[Exception] = list(exceptions)
if part_rejections > 0:
all_errors.append(RuntimeError(f'{part_rejections} model(s) rejected by fallback_on_part during streaming'))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add more detail here on which part failed?

if response_rejections > 0:
all_errors.append(RuntimeError(f'{response_rejections} model response(s) rejected by fallback_on_response'))

if all_errors:
raise FallbackExceptionGroup('All models from FallbackModel failed', all_errors)
else:
raise FallbackExceptionGroup(
'All models from FallbackModel failed',
[RuntimeError('No models available')],
) # pragma: no cover


@dataclass
class BufferedStreamedResponse(StreamedResponse):
"""A StreamedResponse wrapper that replays buffered events."""

_wrapped: StreamedResponse
_buffered_events: list[ModelResponseStreamEvent]

model_request_parameters: ModelRequestParameters = field(init=False)

def __post_init__(self) -> None:
self.model_request_parameters = self._wrapped.model_request_parameters
self._parts_manager = self._wrapped._parts_manager
self._usage = self._wrapped._usage
self.final_result_event = self._wrapped.final_result_event
self.provider_response_id = self._wrapped.provider_response_id
self.provider_details = self._wrapped.provider_details
self.finish_reason = self._wrapped.finish_reason
self._event_iterator = None # reset so __aiter__ uses _get_event_iterator()

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
for event in self._buffered_events:
yield event

@property
def model_name(self) -> str:
return self._wrapped.model_name

@property
def provider_name(self) -> str | None:
return self._wrapped.provider_name

@property
def provider_url(self) -> str | None:
return self._wrapped.provider_url

@property
def timestamp(self) -> datetime:
return self._wrapped.timestamp
Loading