Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..plugins.plugin_manager import PluginManager
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from ..tools.base_tool import BaseTool
from .active_streaming_tool import ActiveStreamingTool
from .base_agent import BaseAgent
from .base_agent import BaseAgentState
Expand Down Expand Up @@ -202,6 +203,9 @@ class InvocationContext(BaseModel):
plugin_manager: PluginManager = Field(default_factory=PluginManager)
"""The manager for keeping track of plugins in this invocation."""

canonical_tools_cache: Optional[list[BaseTool]] = None
"""The cache of canonical tools for this invocation."""

_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
default_factory=_InvocationCostManager
)
Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,10 @@ async def _maybe_add_grounding_metadata(
response: Optional[LlmResponse] = None,
) -> Optional[LlmResponse]:
readonly_context = ReadonlyContext(invocation_context)
tools = await agent.canonical_tools(readonly_context)
if (tools := invocation_context.canonical_tools_cache) is None:
tools = await agent.canonical_tools(readonly_context)
invocation_context.canonical_tools_cache = tools

if not any(tool.name == 'google_search_agent' for tool in tools):
return response
ground_metadata = invocation_context.session.state.get(
Expand Down
64 changes: 64 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,67 @@ def __init__(self):

assert result == plugin_response
plugin.after_model_callback.assert_called_once()

@pytest.mark.asyncio
async def test_handle_after_model_callback_caches_canonical_tools():
"""Test that canonical_tools is only called once per invocation_context."""
canonical_tools_call_count = 0

async def mock_canonical_tools(self, readonly_context=None):
nonlocal canonical_tools_call_count
canonical_tools_call_count += 1
from google.adk.tools.base_tool import BaseTool

class MockGoogleSearchTool(BaseTool):
def __init__(self):
super().__init__(name="google_search_agent", description="Mock search")

async def call(self, **kwargs):
return "mock result"

return [MockGoogleSearchTool()]

agent = Agent(name="test_agent", tools=[google_search, dummy_tool])

with mock.patch.object(type(agent), "canonical_tools", new=mock_canonical_tools):
invocation_context = await testing_utils.create_invocation_context(agent=agent)

assert invocation_context.canonical_tools_cache is None

invocation_context.session.state["temp:_adk_grounding_metadata"] = {
"foo": "bar"
}

llm_response = LlmResponse(
content=types.Content(parts=[types.Part.from_text(text="response")])
)
event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=agent.name,
)
flow = BaseLlmFlowForTesting()

# Call _handle_after_model_callback multiple times with the same context
result1 = await flow._handle_after_model_callback(
invocation_context, llm_response, event
)
result2 = await flow._handle_after_model_callback(
invocation_context, llm_response, event
)
result3 = await flow._handle_after_model_callback(
invocation_context, llm_response, event
)

assert canonical_tools_call_count == 1, (
f"canonical_tools should be called once, but was called "
f"{canonical_tools_call_count} times"
)

assert invocation_context.canonical_tools_cache is not None
assert len(invocation_context.canonical_tools_cache) == 1
assert invocation_context.canonical_tools_cache[0].name == "google_search_agent"

assert result1.grounding_metadata == {"foo": "bar"}
assert result2.grounding_metadata == {"foo": "bar"}
assert result3.grounding_metadata == {"foo": "bar"}
Loading