-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add response-based fallback support for FallbackModel
#3786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c317be6
327b11f
8a86efe
01953fa
c2ea827
c140148
2315a54
623e6f5
874d503
84e2c24
4de4c9d
43a7015
67da9dd
21468ed
6291863
bf50264
3d78fd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,22 +3,27 @@ | |
| from collections.abc import AsyncIterator, Callable | ||
| from contextlib import AsyncExitStack, asynccontextmanager, suppress | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime | ||
| from functools import cached_property | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, NoReturn | ||
|
|
||
| from opentelemetry.trace import get_current_span | ||
|
|
||
| from pydantic_ai._run_context import RunContext | ||
| from pydantic_ai.messages import ModelResponseStreamEvent, PartEndEvent | ||
| from pydantic_ai.models.instrumented import InstrumentedModel | ||
|
|
||
| from ..exceptions import FallbackExceptionGroup, ModelAPIError | ||
| from ..profiles import ModelProfile | ||
| from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ..messages import ModelMessage, ModelResponse | ||
| from ..messages import ModelMessage, ModelResponse, ModelResponsePart | ||
| from ..settings import ModelSettings | ||
|
|
||
| FallbackOnResponse = Callable[['ModelResponse', list['ModelMessage']], bool] | ||
| FallbackOnPart = Callable[['ModelResponsePart', list['ModelMessage']], bool] | ||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class FallbackModel(Model): | ||
|
|
@@ -31,19 +36,36 @@ class FallbackModel(Model): | |
|
|
||
| _model_name: str = field(repr=False) | ||
| _fallback_on: Callable[[Exception], bool] | ||
| _fallback_on_response: FallbackOnResponse | None | ||
| _fallback_on_part: FallbackOnPart | None | ||
|
|
||
| def __init__( | ||
| self, | ||
| default_model: Model | KnownModelName | str, | ||
| *fallback_models: Model | KnownModelName | str, | ||
| fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelAPIError,), | ||
| fallback_on_response: FallbackOnResponse | None = None, | ||
| fallback_on_part: FallbackOnPart | None = None, | ||
| ): | ||
| """Initialize a fallback model instance. | ||
|
|
||
| Args: | ||
| default_model: The name or instance of the default model to use. | ||
| fallback_models: The names or instances of the fallback models to use upon failure. | ||
| fallback_on: A callable or tuple of exceptions that should trigger a fallback. | ||
| For streaming requests, this only catches exceptions during stream initialization | ||
| (e.g., connection errors, authentication failures). Exceptions that occur mid-stream | ||
| after events have started flowing will propagate to the caller without triggering | ||
| fallback. Use `fallback_on_part` if you need to handle mid-stream failures. | ||
| fallback_on_response: A callable that inspects the model response and message history, | ||
| returning `True` if fallback should be triggered. This enables fallback based on | ||
| response content (e.g., a builtin tool indicating failure) rather than exceptions. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also mention the finish reason as an example |
||
| fallback_on_part: A callable that inspects each model response part during streaming, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do wonder if this is the best API, or if we should somehow make
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be ideal, if we can figure it out in a backward compatible way :) If we change That's not the end of the world, but a little awkward, also because the user would still need to check So perhaps it'd be better to support a list of exception types, exception handler functions, and response handler functions? And if the type of a function is not specified, we assume it's an exception handler function? That way it'd also be easier to reuse checks like your |
||
| returning `True` if fallback should be triggered. This enables early abort when | ||
| a failure condition is detected (e.g., a builtin tool failure), saving tokens by | ||
| not consuming the rest of a doomed response. Only applies to streaming requests. | ||
| Note: The response is buffered until validation completes, so the caller receives | ||
| events after the full response is validated, not progressively. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use case is valid but I'm a lot more hesitant to introduce 2 new fields than 1 😅 Note that the That seems like enough of a rabbit hole that it shouldn't block the release of the main feature, but feel free to look into it in a separate issue/PR. |
||
| """ | ||
| super().__init__() | ||
| self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]] | ||
|
|
@@ -53,6 +75,9 @@ def __init__( | |
| else: | ||
| self._fallback_on = fallback_on | ||
|
|
||
| self._fallback_on_response = fallback_on_response | ||
| self._fallback_on_part = fallback_on_part | ||
|
|
||
| @property | ||
| def model_name(self) -> str: | ||
| """The model name.""" | ||
|
|
@@ -77,6 +102,7 @@ async def request( | |
| In case of failure, raise a FallbackExceptionGroup with all exceptions. | ||
| """ | ||
| exceptions: list[Exception] = [] | ||
| response_rejections: int = 0 | ||
|
|
||
| for model in self.models: | ||
| try: | ||
|
|
@@ -88,10 +114,14 @@ async def request( | |
| continue | ||
| raise exc | ||
|
|
||
| if self._fallback_on_response is not None and self._fallback_on_response(response, messages): | ||
| response_rejections += 1 | ||
| continue | ||
|
|
||
| self._set_span_attributes(model, prepared_parameters) | ||
| return response | ||
|
|
||
| raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) | ||
| _raise_fallback_exception_group(exceptions, response_rejections) | ||
|
|
||
| @asynccontextmanager | ||
| async def request_stream( | ||
|
|
@@ -103,6 +133,8 @@ async def request_stream( | |
| ) -> AsyncIterator[StreamedResponse]: | ||
| """Try each model in sequence until one succeeds.""" | ||
| exceptions: list[Exception] = [] | ||
| response_rejections: int = 0 | ||
| part_rejections: int = 0 | ||
|
|
||
| for model in self.models: | ||
| async with AsyncExitStack() as stack: | ||
|
|
@@ -117,11 +149,46 @@ async def request_stream( | |
| continue | ||
| raise exc # pragma: no cover | ||
|
|
||
| if self._fallback_on_part is not None: | ||
| buffered_events: list[ModelResponseStreamEvent] = [] | ||
| should_fallback = False | ||
|
|
||
| async for event in response: | ||
| buffered_events.append(event) | ||
| if isinstance(event, PartEndEvent) and self._fallback_on_part(event.part, messages): | ||
| should_fallback = True | ||
| break | ||
|
|
||
| if should_fallback: | ||
| part_rejections += 1 | ||
| continue | ||
|
|
||
| if self._fallback_on_response is not None and self._fallback_on_response(response.get(), messages): | ||
| response_rejections += 1 | ||
| continue | ||
|
|
||
| self._set_span_attributes(model, prepared_parameters) | ||
| yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events) | ||
| return | ||
|
|
||
| elif self._fallback_on_response is not None: | ||
| buffered_events = [] | ||
| async for event in response: | ||
| buffered_events.append(event) | ||
|
|
||
| if self._fallback_on_response(response.get(), messages): | ||
| response_rejections += 1 | ||
| continue | ||
|
|
||
| self._set_span_attributes(model, prepared_parameters) | ||
| yield BufferedStreamedResponse(_wrapped=response, _buffered_events=buffered_events) | ||
| return | ||
|
|
||
| self._set_span_attributes(model, prepared_parameters) | ||
| yield response | ||
| return | ||
|
|
||
| raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) | ||
| _raise_fallback_exception_group(exceptions, response_rejections, part_rejections) | ||
|
|
||
| @cached_property | ||
| def profile(self) -> ModelProfile: | ||
|
|
@@ -135,7 +202,7 @@ def prepare_request( | |
| ) -> tuple[ModelSettings | None, ModelRequestParameters]: | ||
| return model_settings, model_request_parameters | ||
|
|
||
| def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters): | ||
| def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters) -> None: | ||
| with suppress(Exception): | ||
| span = get_current_span() | ||
| if span.is_recording(): | ||
|
|
@@ -156,3 +223,62 @@ def fallback_condition(exception: Exception) -> bool: | |
| return isinstance(exception, exceptions) | ||
|
|
||
| return fallback_condition | ||
|
|
||
|
|
||
| def _raise_fallback_exception_group( | ||
| exceptions: list[Exception], response_rejections: int, part_rejections: int = 0 | ||
| ) -> NoReturn: | ||
| """Raise a FallbackExceptionGroup combining exceptions and rejections.""" | ||
| all_errors: list[Exception] = list(exceptions) | ||
| if part_rejections > 0: | ||
| all_errors.append(RuntimeError(f'{part_rejections} model(s) rejected by fallback_on_part during streaming')) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add more detail here on which part failed? |
||
| if response_rejections > 0: | ||
| all_errors.append(RuntimeError(f'{response_rejections} model response(s) rejected by fallback_on_response')) | ||
|
|
||
| if all_errors: | ||
| raise FallbackExceptionGroup('All models from FallbackModel failed', all_errors) | ||
| else: | ||
| raise FallbackExceptionGroup( | ||
| 'All models from FallbackModel failed', | ||
| [RuntimeError('No models available')], | ||
| ) # pragma: no cover | ||
|
|
||
|
|
||
| @dataclass | ||
| class BufferedStreamedResponse(StreamedResponse): | ||
| """A StreamedResponse wrapper that replays buffered events.""" | ||
|
|
||
| _wrapped: StreamedResponse | ||
| _buffered_events: list[ModelResponseStreamEvent] | ||
|
|
||
| model_request_parameters: ModelRequestParameters = field(init=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| self.model_request_parameters = self._wrapped.model_request_parameters | ||
| self._parts_manager = self._wrapped._parts_manager | ||
| self._usage = self._wrapped._usage | ||
| self.final_result_event = self._wrapped.final_result_event | ||
| self.provider_response_id = self._wrapped.provider_response_id | ||
| self.provider_details = self._wrapped.provider_details | ||
| self.finish_reason = self._wrapped.finish_reason | ||
| self._event_iterator = None # reset so __aiter__ uses _get_event_iterator() | ||
|
|
||
| async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | ||
| for event in self._buffered_events: | ||
| yield event | ||
|
|
||
| @property | ||
| def model_name(self) -> str: | ||
| return self._wrapped.model_name | ||
|
|
||
| @property | ||
| def provider_name(self) -> str | None: | ||
| return self._wrapped.provider_name | ||
|
|
||
| @property | ||
| def provider_url(self) -> str | None: | ||
| return self._wrapped.provider_url | ||
|
|
||
| @property | ||
| def timestamp(self) -> datetime: | ||
| return self._wrapped.timestamp | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self to review docs later