Skip to content

Commit 8f3c3bf

Browse files
hcadiolixuanyang15
authored andcommitted
fix: cache canonical tools to avoid multiple calls when streaming
Merge #3299 Fixes #3237 Co-authored-by: Xuan Yang <[email protected]> COPYBARA_INTEGRATE_REVIEW=#3299 from hcadioli:fix/cache-tools de02bd3 PiperOrigin-RevId: 829499299
1 parent 9761fc6 commit 8f3c3bf

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..plugins.plugin_manager import PluginManager
3333
from ..sessions.base_session_service import BaseSessionService
3434
from ..sessions.session import Session
35+
from ..tools.base_tool import BaseTool
3536
from .active_streaming_tool import ActiveStreamingTool
3637
from .base_agent import BaseAgent
3738
from .base_agent import BaseAgentState
@@ -202,6 +203,9 @@ class InvocationContext(BaseModel):
202203
plugin_manager: PluginManager = Field(default_factory=PluginManager)
203204
"""The manager for keeping track of plugins in this invocation."""
204205

206+
canonical_tools_cache: Optional[list[BaseTool]] = None
207+
"""The cache of canonical tools for this invocation."""
208+
205209
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
206210
default_factory=_InvocationCostManager
207211
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,10 @@ async def _maybe_add_grounding_metadata(
855855
response: Optional[LlmResponse] = None,
856856
) -> Optional[LlmResponse]:
857857
readonly_context = ReadonlyContext(invocation_context)
858-
tools = await agent.canonical_tools(readonly_context)
858+
if (tools := invocation_context.canonical_tools_cache) is None:
859+
tools = await agent.canonical_tools(readonly_context)
860+
invocation_context.canonical_tools_cache = tools
861+
859862
if not any(tool.name == 'google_search_agent' for tool in tools):
860863
return response
861864
ground_metadata = invocation_context.session.state.get(

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,76 @@ def __init__(self):
413413

414414
assert result == plugin_response
415415
plugin.after_model_callback.assert_called_once()
416+
417+
418+
@pytest.mark.asyncio
419+
async def test_handle_after_model_callback_caches_canonical_tools():
420+
"""Test that canonical_tools is only called once per invocation_context."""
421+
canonical_tools_call_count = 0
422+
423+
async def mock_canonical_tools(self, readonly_context=None):
424+
nonlocal canonical_tools_call_count
425+
canonical_tools_call_count += 1
426+
from google.adk.tools.base_tool import BaseTool
427+
428+
class MockGoogleSearchTool(BaseTool):
429+
430+
def __init__(self):
431+
super().__init__(name='google_search_agent', description='Mock search')
432+
433+
async def call(self, **kwargs):
434+
return 'mock result'
435+
436+
return [MockGoogleSearchTool()]
437+
438+
agent = Agent(name='test_agent', tools=[google_search, dummy_tool])
439+
440+
with mock.patch.object(
441+
type(agent), 'canonical_tools', new=mock_canonical_tools
442+
):
443+
invocation_context = await testing_utils.create_invocation_context(
444+
agent=agent
445+
)
446+
447+
assert invocation_context.canonical_tools_cache is None
448+
449+
invocation_context.session.state['temp:_adk_grounding_metadata'] = {
450+
'foo': 'bar'
451+
}
452+
453+
llm_response = LlmResponse(
454+
content=types.Content(parts=[types.Part.from_text(text='response')])
455+
)
456+
event = Event(
457+
id=Event.new_id(),
458+
invocation_id=invocation_context.invocation_id,
459+
author=agent.name,
460+
)
461+
flow = BaseLlmFlowForTesting()
462+
463+
# Call _handle_after_model_callback multiple times with the same context
464+
result1 = await flow._handle_after_model_callback(
465+
invocation_context, llm_response, event
466+
)
467+
result2 = await flow._handle_after_model_callback(
468+
invocation_context, llm_response, event
469+
)
470+
result3 = await flow._handle_after_model_callback(
471+
invocation_context, llm_response, event
472+
)
473+
474+
assert canonical_tools_call_count == 1, (
475+
'canonical_tools should be called once, but was called '
476+
f'{canonical_tools_call_count} times'
477+
)
478+
479+
assert invocation_context.canonical_tools_cache is not None
480+
assert len(invocation_context.canonical_tools_cache) == 1
481+
assert (
482+
invocation_context.canonical_tools_cache[0].name
483+
== 'google_search_agent'
484+
)
485+
486+
assert result1.grounding_metadata == {'foo': 'bar'}
487+
assert result2.grounding_metadata == {'foo': 'bar'}
488+
assert result3.grounding_metadata == {'foo': 'bar'}

0 commit comments

Comments
 (0)