Skip to content

Commit b68cfe9

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Support retry when model invocation fails
PiperOrigin-RevId: 783974317
1 parent 33ac838 commit b68cfe9

File tree

10 files changed

+412
-6
lines changed

10 files changed

+412
-6
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from abc import ABC
1818
import asyncio
1919
import datetime
20+
from enum import Enum
2021
import inspect
2122
import logging
2223
from typing import AsyncGenerator
24+
from typing import Callable
2325
from typing import cast
2426
from typing import Optional
2527
from typing import TYPE_CHECKING
@@ -36,6 +38,7 @@
3638
from ...agents.run_config import StreamingMode
3739
from ...agents.transcription_entry import TranscriptionEntry
3840
from ...events.event import Event
41+
from ...models.base_llm import ModelErrorStrategy
3942
from ...models.base_llm_connection import BaseLlmConnection
4043
from ...models.llm_request import LlmRequest
4144
from ...models.llm_response import LlmResponse
@@ -521,7 +524,13 @@ async def _call_llm_async(
521524
with tracer.start_as_current_span('call_llm'):
522525
if invocation_context.run_config.support_cfc:
523526
invocation_context.live_request_queue = LiveRequestQueue()
524-
async for llm_response in self.run_live(invocation_context):
527+
responses_generator = lambda: self.run_live(invocation_context)
528+
async for llm_response in self._run_and_handle_error(
529+
responses_generator,
530+
invocation_context,
531+
llm_request,
532+
model_response_event,
533+
):
525534
# Runs after_model_callback if it exists.
526535
if altered_llm_response := await self._handle_after_model_callback(
527536
invocation_context, llm_response, model_response_event
@@ -540,10 +549,16 @@ async def _call_llm_async(
540549
# the counter beyond the max set value, then the execution is stopped
541550
# right here, and exception is thrown.
542551
invocation_context.increment_llm_call_count()
543-
async for llm_response in llm.generate_content_async(
552+
responses_generator = lambda: llm.generate_content_async(
544553
llm_request,
545554
stream=invocation_context.run_config.streaming_mode
546555
== StreamingMode.SSE,
556+
)
557+
async for llm_response in self._run_and_handle_error(
558+
responses_generator,
559+
invocation_context,
560+
llm_request,
561+
model_response_event,
547562
):
548563
trace_call_llm(
549564
invocation_context,
@@ -660,6 +675,54 @@ def _finalize_model_response_event(
660675

661676
return model_response_event
662677

678+
async def _run_and_handle_error(
679+
self,
680+
response_generator: Callable[..., AsyncGenerator[LlmResponse, None]],
681+
invocation_context: InvocationContext,
682+
llm_request: LlmRequest,
683+
model_response_event: Event,
684+
) -> AsyncGenerator[LlmResponse, None]:
685+
"""Runs the response generator and processes the error with plugins.
686+
687+
Args:
688+
response_generator: The response generator to run.
689+
invocation_context: The invocation context.
690+
llm_request: The LLM request.
691+
model_response_event: The model response event.
692+
693+
Yields:
694+
A generator of LlmResponse.
695+
"""
696+
while True:
697+
try:
698+
responses_generator_instance = response_generator()
699+
async for response in responses_generator_instance:
700+
yield response
701+
break
702+
except Exception as model_error:
703+
callback_context = CallbackContext(
704+
invocation_context, event_actions=model_response_event.actions
705+
)
706+
outcome = (
707+
await invocation_context.plugin_manager.run_on_model_error_callback(
708+
callback_context=callback_context,
709+
llm_request=llm_request,
710+
error=model_error,
711+
)
712+
)
713+
# Retry the LLM call if the plugin outcome is RETRY.
714+
if outcome == ModelErrorStrategy.RETRY:
715+
continue
716+
717+
# If the plugin outcome is PASS, we can break the loop.
718+
if outcome == ModelErrorStrategy.PASS:
719+
break
720+
if outcome is not None:
721+
yield outcome
722+
break
723+
else:
724+
raise model_error
725+
663726
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
664727
from ...agents.llm_agent import LlmAgent
665728

src/google/adk/flows/llm_flows/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,21 @@ async def handle_function_calls_async(
176176

177177
# Step 3: Otherwise, proceed calling the tool normally.
178178
if function_response is None:
179-
function_response = await __call_tool_async(
180-
tool, args=function_args, tool_context=tool_context
181-
)
179+
try:
180+
function_response = await __call_tool_async(
181+
tool, args=function_args, tool_context=tool_context
182+
)
183+
except Exception as tool_error:
184+
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
185+
tool=tool,
186+
tool_args=function_args,
187+
tool_context=tool_context,
188+
error=tool_error,
189+
)
190+
if error_response is not None:
191+
function_response = error_response
192+
else:
193+
raise tool_error
182194

183195
# Step 4: Check if plugin after_tool_callback overrides the function
184196
# response.

src/google/adk/models/base_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from .llm_response import LlmResponse
2929

3030

31+
class ModelErrorStrategy:
32+
RETRY = 'RETRY'
33+
PASS = 'PASS'
34+
35+
3136
class BaseLlm(BaseModel):
3237
"""The BaseLLM class.
3338

src/google/adk/plugins/base_plugin.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from ..agents.base_agent import BaseAgent
2626
from ..agents.callback_context import CallbackContext
2727
from ..events.event import Event
28+
from ..models.base_llm import ModelErrorStrategy
2829
from ..models.llm_request import LlmRequest
2930
from ..models.llm_response import LlmResponse
3031
from ..tools.base_tool import BaseTool
31-
from ..utils.feature_decorator import working_in_progress
3232

3333
if TYPE_CHECKING:
3434
from ..agents.invocation_context import InvocationContext
@@ -265,6 +265,34 @@ async def after_model_callback(
265265
"""
266266
pass
267267

268+
async def on_model_error_callback(
269+
self,
270+
*,
271+
callback_context: CallbackContext,
272+
llm_request: LlmRequest,
273+
error: Exception,
274+
) -> Optional[LlmResponse | ModelErrorStrategy]:
275+
"""Callback executed when a model call encounters an error.
276+
277+
This callback provides an opportunity to handle model errors gracefully,
278+
potentially providing alternative responses or recovery mechanisms.
279+
280+
Args:
281+
callback_context: The context for the current agent call.
282+
llm_request: The request that was sent to the model when the error
283+
occurred.
284+
error: The exception that was raised during model execution.
285+
286+
Returns:
287+
An optional LlmResponse. If an LlmResponse is returned, it will be used
288+
instead of propagating the error.
289+
Returning `ModelErrorStrategy.RETRY` will retry the LLM call.
290+
Returning `ModelErrorStrategy.PASS` will allow the LLM call to
291+
proceed normally and ignore the error.
292+
Returning `None` allows the original error to be raised.
293+
"""
294+
pass
295+
268296
async def before_tool_callback(
269297
self,
270298
*,
@@ -315,3 +343,29 @@ async def after_tool_callback(
315343
result.
316344
"""
317345
pass
346+
347+
async def on_tool_error_callback(
348+
self,
349+
*,
350+
tool: BaseTool,
351+
tool_args: dict[str, Any],
352+
tool_context: ToolContext,
353+
error: Exception,
354+
) -> Optional[dict]:
355+
"""Callback executed when a tool call encounters an error.
356+
357+
This callback provides an opportunity to handle tool errors gracefully,
358+
potentially providing alternative responses or recovery mechanisms.
359+
360+
Args:
361+
tool: The tool instance that encountered an error.
362+
tool_args: The arguments that were passed to the tool.
363+
tool_context: The context specific to the tool execution.
364+
error: The exception that was raised during tool execution.
365+
366+
Returns:
367+
An optional dictionary. If a dictionary is returned, it will be used as
368+
the tool response instead of propagating the error. Returning `None`
369+
allows the original error to be raised.
370+
"""
371+
pass

src/google/adk/plugins/plugin_manager.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..agents.callback_context import CallbackContext
3131
from ..agents.invocation_context import InvocationContext
3232
from ..events.event import Event
33+
from ..models.base_llm import ModelErrorStrategy
3334
from ..models.llm_request import LlmRequest
3435
from ..models.llm_response import LlmResponse
3536
from ..tools.base_tool import BaseTool
@@ -48,6 +49,8 @@
4849
"after_tool_callback",
4950
"before_model_callback",
5051
"after_model_callback",
52+
"on_tool_error_callback",
53+
"on_model_error_callback",
5154
]
5255

5356
logger = logging.getLogger("google_adk." + __name__)
@@ -195,6 +198,21 @@ async def run_after_tool_callback(
195198
result=result,
196199
)
197200

201+
async def run_on_model_error_callback(
202+
self,
203+
*,
204+
callback_context: CallbackContext,
205+
llm_request: LlmRequest,
206+
error: Exception,
207+
) -> Optional[LlmResponse | ModelErrorStrategy]:
208+
"""Runs the `on_model_error_callback` for all plugins."""
209+
return await self._run_callbacks(
210+
"on_model_error_callback",
211+
callback_context=callback_context,
212+
llm_request=llm_request,
213+
error=error,
214+
)
215+
198216
async def run_before_model_callback(
199217
self, *, callback_context: CallbackContext, llm_request: LlmRequest
200218
) -> Optional[LlmResponse]:
@@ -215,6 +233,23 @@ async def run_after_model_callback(
215233
llm_response=llm_response,
216234
)
217235

236+
async def run_on_tool_error_callback(
237+
self,
238+
*,
239+
tool: BaseTool,
240+
tool_args: dict[str, Any],
241+
tool_context: ToolContext,
242+
error: Exception,
243+
) -> Optional[dict]:
244+
"""Runs the `on_tool_error_callback` for all plugins."""
245+
return await self._run_callbacks(
246+
"on_tool_error_callback",
247+
tool=tool,
248+
tool_args=tool_args,
249+
tool_context=tool_context,
250+
error=error,
251+
)
252+
218253
async def _run_callbacks(
219254
self, callback_name: PluginCallbackName, **kwargs: Any
220255
) -> Optional[Any]:

0 commit comments

Comments
 (0)