diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b87a083ac..a4317de07 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -42,6 +42,7 @@ from ...telemetry import trace_call_llm from ...telemetry import trace_send_data from ...telemetry import tracer +from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext if TYPE_CHECKING: @@ -341,13 +342,25 @@ async def _preprocess_async( yield event # Run processors for tools. - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ): + for tool_union in agent.tools: tool_context = ToolContext(invocation_context) - await tool.process_llm_request( - tool_context=tool_context, llm_request=llm_request + + # If it's a toolset, process it first + if isinstance(tool_union, BaseToolset): + await tool_union.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + from ...agents.llm_agent import _convert_tool_union_to_tools + + # Then process all tools from this tool union + tools = await _convert_tool_union_to_tools( + tool_union, ReadonlyContext(invocation_context) ) + for tool in tools: + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) async def _postprocess_async( self, diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 6003f560b..706b4c42c 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -20,11 +20,16 @@ from typing import Optional from typing import Protocol from typing import runtime_checkable +from typing import TYPE_CHECKING from typing import Union from ..agents.readonly_context import ReadonlyContext from .base_tool import BaseTool +if TYPE_CHECKING: + from ..models.llm_request import LlmRequest + from .tool_context import ToolContext + @runtime_checkable class ToolPredicate(Protocol): @@ -96,3 +101,20 @@ def _is_tool_selected( return tool.name in self.tool_filter return False + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + """Processes the outgoing LLM request for this toolset. This method will be + called before each tool processes the llm request. + + Use cases: + - Instead of let each tool process the llm request, we can let the toolset + process the llm request. e.g. ComputerUseToolset can add computer use + tool to the llm request. + + Args: + tool_context: The context of the tool. + llm_request: The outgoing LLM request, mutable this method. + """ + pass diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py new file mode 100644 index 000000000..82333c45a --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -0,0 +1,150 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for BaseLlmFlow toolset integration.""" + +from unittest.mock import AsyncMock + +from google.adk.agents import Agent +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.tools.base_toolset import BaseToolset +from google.genai import types +import pytest + +from ... import testing_utils + + +class BaseLlmFlowForTesting(BaseLlmFlow): + """Test implementation of BaseLlmFlow for testing purposes.""" + + pass + + +@pytest.mark.asyncio +async def test_preprocess_calls_toolset_process_llm_request(): + """Test that _preprocess_async calls process_llm_request on toolsets.""" + + # Create a mock toolset that tracks if process_llm_request was called + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that process_llm_request was called on the toolset + assert mock_toolset.process_llm_request_called + + +@pytest.mark.asyncio +async def test_preprocess_handles_mixed_tools_and_toolsets(): + """Test that _preprocess_async properly handles both tools and toolsets.""" + from google.adk.tools.base_tool import BaseTool + from google.adk.tools.function_tool import FunctionTool + + # Create a mock tool + class _MockTool(BaseTool): + + def __init__(self): + super().__init__(name='mock_tool', description='Mock tool') + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def call(self, **kwargs): + return 'mock result' + + # Create a mock toolset + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + def _test_function(): + """Test function tool.""" + return 'function result' + + mock_tool = _MockTool() + mock_toolset = _MockToolset() + + # Create agent with mixed tools and toolsets + agent = Agent( + name='test_agent', tools=[mock_tool, _test_function, mock_toolset] + ) + + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that process_llm_request was called on both tools and toolsets + assert mock_tool.process_llm_request_called + assert mock_toolset.process_llm_request_called diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py new file mode 100644 index 000000000..5414bb3c8 --- /dev/null +++ b/tests/unittests/tools/test_base_toolset.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for BaseToolset.""" + +from typing import Optional + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.base_toolset import BaseToolset +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestingToolset(BaseToolset): + """A test implementation of BaseToolset.""" + + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> list[BaseTool]: + return [] + + async def close(self) -> None: + pass + + +@pytest.mark.asyncio +async def test_process_llm_request_default_implementation(): + """Test that the default process_llm_request implementation does nothing.""" + toolset = _TestingToolset() + + # Create test objects + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='test_id', + agent=agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context) + llm_request = LlmRequest() + + # The default implementation should not modify the request + original_request = LlmRequest.model_validate(llm_request.model_dump()) + + await toolset.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + # Verify the request was not modified + assert llm_request.model_dump() == original_request.model_dump() + + +@pytest.mark.asyncio +async def test_process_llm_request_can_be_overridden(): + """Test that process_llm_request can be overridden by subclasses.""" + + class _CustomToolset(_TestingToolset): + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + # Add some custom processing + if not llm_request.contents: + llm_request.contents = [] + llm_request.contents.append('Custom processing applied') + + toolset = _CustomToolset() + + # Create test objects + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='test_id', + agent=agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context) + llm_request = LlmRequest() + + await toolset.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + # Verify the custom processing was applied + assert llm_request.contents == ['Custom processing applied']