@@ -413,3 +413,76 @@ def __init__(self):
413413
414414 assert result == plugin_response
415415 plugin .after_model_callback .assert_called_once ()
416+
417+
418+ @pytest .mark .asyncio
419+ async def test_handle_after_model_callback_caches_canonical_tools ():
420+ """Test that canonical_tools is only called once per invocation_context."""
421+ canonical_tools_call_count = 0
422+
423+ async def mock_canonical_tools (self , readonly_context = None ):
424+ nonlocal canonical_tools_call_count
425+ canonical_tools_call_count += 1
426+ from google .adk .tools .base_tool import BaseTool
427+
428+ class MockGoogleSearchTool (BaseTool ):
429+
430+ def __init__ (self ):
431+ super ().__init__ (name = 'google_search_agent' , description = 'Mock search' )
432+
433+ async def call (self , ** kwargs ):
434+ return 'mock result'
435+
436+ return [MockGoogleSearchTool ()]
437+
438+ agent = Agent (name = 'test_agent' , tools = [google_search , dummy_tool ])
439+
440+ with mock .patch .object (
441+ type (agent ), 'canonical_tools' , new = mock_canonical_tools
442+ ):
443+ invocation_context = await testing_utils .create_invocation_context (
444+ agent = agent
445+ )
446+
447+ assert invocation_context .canonical_tools_cache is None
448+
449+ invocation_context .session .state ['temp:_adk_grounding_metadata' ] = {
450+ 'foo' : 'bar'
451+ }
452+
453+ llm_response = LlmResponse (
454+ content = types .Content (parts = [types .Part .from_text (text = 'response' )])
455+ )
456+ event = Event (
457+ id = Event .new_id (),
458+ invocation_id = invocation_context .invocation_id ,
459+ author = agent .name ,
460+ )
461+ flow = BaseLlmFlowForTesting ()
462+
463+ # Call _handle_after_model_callback multiple times with the same context
464+ result1 = await flow ._handle_after_model_callback (
465+ invocation_context , llm_response , event
466+ )
467+ result2 = await flow ._handle_after_model_callback (
468+ invocation_context , llm_response , event
469+ )
470+ result3 = await flow ._handle_after_model_callback (
471+ invocation_context , llm_response , event
472+ )
473+
474+ assert canonical_tools_call_count == 1 , (
475+ 'canonical_tools should be called once, but was called '
476+ f'{ canonical_tools_call_count } times'
477+ )
478+
479+ assert invocation_context .canonical_tools_cache is not None
480+ assert len (invocation_context .canonical_tools_cache ) == 1
481+ assert (
482+ invocation_context .canonical_tools_cache [0 ].name
483+ == 'google_search_agent'
484+ )
485+
486+ assert result1 .grounding_metadata == {'foo' : 'bar' }
487+ assert result2 .grounding_metadata == {'foo' : 'bar' }
488+ assert result3 .grounding_metadata == {'foo' : 'bar' }
0 commit comments