diff --git a/examples/scaffolding/contrib/a2a/README.md b/examples/scaffolding/contrib/a2a/README.md new file mode 100644 index 000000000000..d1cda6d6d9b3 --- /dev/null +++ b/examples/scaffolding/contrib/a2a/README.md @@ -0,0 +1,57 @@ +# Scaffolding A2A (Agent2Agent) Example + +This example shows how a Scaffolding controller can delegate work to remote +agents that speak the [A2A (Agent2Agent) protocol](https://a2a-protocol.org/), +the agent-to-agent counterpart to the MCP tool-calling contrib. The generation +model decides which remote agent to call; the `A2AWorker` forwards the message +over A2A and feeds the reply back to the model for a final answer. + +This is a **client-side** integration: Scaffolding acts as an A2A client that +consumes other agents. (Exposing a Scaffolding pipeline *as* an A2A server is a +possible follow-up.) + +## Install + +```bash +pip install a2a-sdk httpx uvicorn +``` + +`a2a-sdk` is only needed when actually talking to a remote agent; the contrib +imports it lazily, and the unit tests do not require it. + +## Step 1: Start a remote A2A agent + +A minimal reference agent server is included: + +```bash +python weather_agent_server.py --port 9999 +``` + +This exposes a `weather_agent` whose agent card is discoverable at +`http://0.0.0.0:9999/.well-known/agent-card.json`. You can also point the +example at any other A2A-compatible agent server. + +## Step 2: Run the orchestrator + +```bash +python a2a_run.py \ + --API_KEY YOUR_API_KEY \ + --base_url https://your-openai-compatible-endpoint/v1 \ + --model your-model \ + --agent_urls http://0.0.0.0:9999 \ + --prompt "What is the weather in LA?" +``` + +The generation model receives the remote agents as callable tools, delegates to +`weather_agent`, and summarizes its reply. + +## Files + +| File | Role | +|------|------| +| `a2a_run.py` | Scaffolding A2A orchestrator runner (client) | +| `weather_agent_server.py` | Minimal reference A2A agent server for local testing | + +> `a2a-sdk` server/client APIs evolve across versions. The scripts target the +> SDK's published "helloworld" pattern; adjust imports if your installed version +> differs. diff --git a/examples/scaffolding/contrib/a2a/a2a_run.py b/examples/scaffolding/contrib/a2a/a2a_run.py new file mode 100644 index 000000000000..3abd49ec427d --- /dev/null +++ b/examples/scaffolding/contrib/a2a/a2a_run.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run a Scaffolding A2A orchestrator against one or more remote A2A agents. + +The generation side uses an OpenAI-compatible endpoint (any vendor, or a local +``trtllm-serve``); the orchestration side talks the Agent2Agent protocol via +``A2AWorker``. See README.md for how to start a sample remote agent server. +""" + +import argparse +import asyncio + +from openai import AsyncOpenAI + +from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm +from tensorrt_llm.scaffolding.contrib.a2a import A2AController, A2AWorker +from tensorrt_llm.scaffolding.contrib.mcp.chat_handler import chat_handler +from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_url", + type=str, + default="https://dashscope.aliyuncs.com/compatible-mode/v1", + help="OpenAI-compatible base URL for the generation model.", + ) + parser.add_argument("--model", type=str, default="qwen-plus-latest") + parser.add_argument("--API_KEY", type=str) + parser.add_argument( + "--agent_urls", + type=str, + nargs="+", + default=["http://0.0.0.0:9999"], + help="Base URLs of the remote A2A agents to orchestrate.", + ) + parser.add_argument("--prompt", type=str, default="What is the weather like today in LA?") + return parser.parse_args() + + +async def main(): + args = parse_arguments() + + client = AsyncOpenAI(api_key=args.API_KEY, base_url=args.base_url) + generation_worker = OpenaiWorker(client, args.model) + generation_worker.register_task_handler(ChatTask, chat_handler) + + a2a_worker = await A2AWorker.init_with_urls(args.agent_urls) + + controller = A2AController() + llm = ScaffoldingLlm( + controller, + { + A2AController.WorkerTag.GENERATION: generation_worker, + A2AController.WorkerTag.A2A: a2a_worker, + }, + ) + + future = llm.generate_async(args.prompt) + result = await future.aresult() + print(f"\nresult is {result.outputs[0].text}\n") + + print("shutting down...") + llm.shutdown() + generation_worker.shutdown() + await a2a_worker.async_shutdown() + print("shut down done") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/scaffolding/contrib/a2a/weather_agent_server.py b/examples/scaffolding/contrib/a2a/weather_agent_server.py new file mode 100644 index 000000000000..b42b25d44005 --- /dev/null +++ b/examples/scaffolding/contrib/a2a/weather_agent_server.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A minimal reference A2A agent server used to exercise the A2A contrib. + +This follows the ``a2a-sdk`` "helloworld" server pattern and exposes a single +``weather_agent`` that returns a canned reply. Run it with:: + + pip install a2a-sdk uvicorn + python weather_agent_server.py --port 9999 + +then point ``a2a_run.py --agent_urls http://0.0.0.0:9999`` at it. + +Note: ``a2a-sdk`` server APIs evolve; this script targets the published +helloworld example. Adjust imports if your installed SDK version differs. +""" + +import argparse + +import uvicorn +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.apps import A2AStarletteApplication +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from a2a.utils import new_agent_text_message + + +class WeatherAgentExecutor(AgentExecutor): + """Returns a canned weather reply regardless of the incoming message.""" + + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + await event_queue.enqueue_event(new_agent_text_message("It is sunny in LA, around 75F.")) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + raise NotImplementedError("cancel is not supported by this agent") + + +def build_app(host: str, port: int) -> A2AStarletteApplication: + skill = AgentSkill( + id="weather", + name="weather", + description="Returns the current weather for a location.", + tags=["weather"], + examples=["What is the weather in LA?"], + ) + agent_card = AgentCard( + name="weather_agent", + description="A demo agent that reports the weather.", + url=f"http://{host}:{port}/", + version="1.0.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(streaming=False), + skills=[skill], + ) + request_handler = DefaultRequestHandler( + agent_executor=WeatherAgentExecutor(), + task_store=InMemoryTaskStore(), + ) + return A2AStarletteApplication(agent_card=agent_card, http_handler=request_handler) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=9999) + args = parser.parse_args() + + app = build_app(args.host, args.port) + uvicorn.run(app.build(), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/scaffolding/contrib/a2a/__init__.py b/tensorrt_llm/scaffolding/contrib/a2a/__init__.py new file mode 100644 index 000000000000..2f3111fc3df0 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/a2a/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .a2a_controller import A2AController +from .a2a_task import A2AListTask, A2ASendTask +from .a2a_utils import A2AAgentConnection, AgentInfo +from .a2a_worker import A2AWorker + +__all__ = [ + "A2AController", + "A2AWorker", + "A2ASendTask", + "A2AListTask", + "A2AAgentConnection", + "AgentInfo", +] diff --git a/tensorrt_llm/scaffolding/contrib/a2a/a2a_controller.py b/tensorrt_llm/scaffolding/contrib/a2a/a2a_controller.py new file mode 100644 index 000000000000..0253a962482e --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/a2a/a2a_controller.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +from enum import Enum +from typing import List + +from tensorrt_llm.scaffolding import Controller, Task + +# Reuse the MCP contrib's portable chat task/handler (see contrib/README.md: +# "a project can import Controller/Task/Worker from other projects"). We import +# from the submodules directly to avoid pulling in the `mcp` package via the +# mcp package __init__. +from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask + +from .a2a_task import A2AListTask, A2ASendTask + + +def _agent_to_tool(agent) -> dict: + """Represent a remote agent as an OpenAI-style function the LLM can call. + + The LLM delegates to an agent by emitting a tool call whose name is the + agent name and whose single argument is the ``message`` to forward. + """ + description = agent.description or f"Remote agent '{agent.name}'." + if agent.skills: + description = f"{description} Skills: {', '.join(agent.skills)}." + return { + "type": "function", + "function": { + "name": agent.name, + "description": description, + "parameters": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The natural-language request to send to this agent.", + } + }, + "required": ["message"], + }, + }, + } + + +class A2AController(Controller): + """Route a user request to remote A2A agents, then summarize the replies. + + Flow (mirrors ``MCPController``): + 1. Discover reachable agents via an ``A2AListTask``. + 2. Ask the generation LLM which agent(s) to delegate to, exposing each + agent as a callable tool. + 3. Dispatch the chosen ``A2ASendTask`` requests to the remote agents. + 4. Feed the agents' replies back to the LLM for a final answer. + """ + + class WorkerTag(Enum): + GENERATION = "generation" + A2A = "a2a" + + def __init__(self, custom_sampling_params: dict = None): + super().__init__() + self.custom_sampling_params = ( + copy.deepcopy(custom_sampling_params) if custom_sampling_params else None + ) + + def _apply_sampling_params(self, task): + if not self.custom_sampling_params: + return + for key, value in self.custom_sampling_params.items(): + if hasattr(task, key) and getattr(task, key) is None: + setattr(task, key, value) + + def process(self, tasks: List[Task], **kwargs): + assert len(tasks) == 1, "A2AController handles a single task at a time." + result_task = tasks[0] + + # 1. Discover the remote agents reachable through the A2A worker. + list_task = A2AListTask.create_a2a_task(self.WorkerTag.A2A) + yield [list_task] + agents = list_task.result_agents or [] + available_tools = [_agent_to_tool(agent) for agent in agents] + + # 2. Let the LLM decide which agent (if any) should handle the request. + system_message = ( + "You are an orchestrator agent with access to the remote agents " + "exposed as tools below. When a remote agent is better suited to " + "answer, delegate by calling it with a concise `message`. " + "After receiving an agent's response, turn the raw result into a " + "clear, concise answer for the user. Only call agents that are " + "explicitly defined above." + ) + messages = [{"role": "system", "content": system_message}] + chat_task = ChatTask.create_from_prompt(messages, result_task.input_str, available_tools) + chat_task.worker_tag = self.WorkerTag.GENERATION + self._apply_sampling_params(chat_task) + yield [chat_task] + + # 3. If the LLM did not delegate, return its direct answer. + if chat_task.finish_reason != "tool_calls": + result_task.output_str = chat_task.output_str + return + + # 4. Dispatch each requested delegation to the remote agents. + send_tasks = [] + for tool_call in chat_task.tool_calls: + # Models occasionally emit malformed JSON (or a non-object payload) + # in the tool-call arguments. Degrade gracefully instead of aborting + # the whole controller flow with an unhandled exception. + raw_args = tool_call.function.arguments + try: + parsed_args = json.loads(raw_args) + except (json.JSONDecodeError, TypeError): + parsed_args = None + + if isinstance(parsed_args, dict): + message = parsed_args.get("message", "") + else: + # Forward the raw text so the remote agent still gets something + # actionable rather than an empty request. + message = raw_args if isinstance(raw_args, str) else "" + + send_tasks.append( + A2ASendTask.create_a2a_task(tool_call.function.name, message, self.WorkerTag.A2A) + ) + yield send_tasks + + agent_results = "\n".join(task.output_str for task in send_tasks if task.output_str) + + # 5. Summarize the agents' replies into a final answer. + messages.append({"role": "assistant", "content": chat_task.output_str or ""}) + final_chat_task = ChatTask.create_from_prompt(messages, agent_results, available_tools) + final_chat_task.worker_tag = self.WorkerTag.GENERATION + self._apply_sampling_params(final_chat_task) + yield [final_chat_task] + result_task.output_str = final_chat_task.output_str + return diff --git a/tensorrt_llm/scaffolding/contrib/a2a/a2a_task.py b/tensorrt_llm/scaffolding/contrib/a2a/a2a_task.py new file mode 100644 index 000000000000..3e5005d4f67a --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/a2a/a2a_task.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, List, Optional, Union + +from tensorrt_llm.scaffolding.task import Task + +if TYPE_CHECKING: + from tensorrt_llm.scaffolding.controller import Controller + + +@dataclass +class A2ASendTask(Task): + """Send a natural-language message to a remote A2A agent and collect its reply.""" + + # the name of the remote agent (matches the agent card's name) to route to + agent_name: Optional[str] = field(default=None) + # the natural-language message to send to the remote agent + message: Optional[str] = field(default=None) + + worker_tag: Union[str, "Controller.WorkerTag"] = None + + # result field, filled in by the worker with the agent's textual response + output_str: Optional[str] = None + + @staticmethod + def create_a2a_task( + agent_name: str, message: str, worker_tag: Union[str, "Controller.WorkerTag"] = None + ) -> "A2ASendTask": + task = A2ASendTask() + task.agent_name = agent_name + task.message = message + task.worker_tag = worker_tag + return task + + +@dataclass +class A2AListTask(Task): + """Discover the remote A2A agents reachable by the worker.""" + + worker_tag: Union[str, "Controller.WorkerTag"] = None + + # result field, filled in by the worker with a list of AgentInfo objects + result_agents: Optional[List[Any]] = None + + @staticmethod + def create_a2a_task(worker_tag: Union[str, "Controller.WorkerTag"] = None) -> "A2AListTask": + task = A2AListTask() + task.worker_tag = worker_tag + return task diff --git a/tensorrt_llm/scaffolding/contrib/a2a/a2a_utils.py b/tensorrt_llm/scaffolding/contrib/a2a/a2a_utils.py new file mode 100644 index 000000000000..b0c84db511ae --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/a2a/a2a_utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Thin client wrapper around the official ``a2a-sdk`` (Agent2Agent protocol). + +The wrapper keeps the rest of Scaffolding (worker/controller) free of any +hard dependency on ``a2a-sdk``: the SDK is imported lazily inside the methods +that actually talk to a remote agent, and the worker only ever sees the +normalized :class:`AgentInfo` view defined below. This mirrors how the MCP +contrib isolates the ``mcp`` package inside ``mcp_utils.MCPClient``. +""" + +import uuid +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class AgentInfo: + """Protocol-agnostic view of a remote agent, derived from its agent card. + + Decoupling the worker/controller from ``a2a-sdk`` types means tests can + inject fake connections without installing the SDK. + """ + + name: str + description: str = "" + skills: List[str] = field(default_factory=list) + + +def _extract_text_from_response(response) -> str: + """Best-effort extraction of textual content from an A2A send-message response. + + The ``a2a-sdk`` response schema has shifted across versions, so we walk the + structure defensively and fall back to ``str(response)`` if no text part is + found rather than raising on an unexpected layout. + """ + texts: List[str] = [] + + def _walk(obj): + # A ``TextPart`` exposes a ``text`` attribute directly. + text = getattr(obj, "text", None) + if isinstance(text, str) and text: + texts.append(text) + # ``Part`` wraps the concrete part under ``root`` in recent SDKs. + root = getattr(obj, "root", None) + if root is not None and root is not obj: + _walk(root) + # A ``Message``/``Task`` carries a list of ``parts``. + parts = getattr(obj, "parts", None) + if parts: + for part in parts: + _walk(part) + + result = getattr(response, "root", response) + result = getattr(result, "result", result) + _walk(result) + + return "\n".join(texts) if texts else str(response) + + +class A2AAgentConnection: + """A connection to a single remote agent that speaks the A2A protocol.""" + + def __init__(self): + self._httpx_client = None + self._client = None + self._agent_card = None + + async def connect(self, base_url: str): + """Resolve the remote agent card and create an A2A client for it.""" + try: + import httpx + from a2a.client import A2ACardResolver, A2AClient + except ImportError as e: + raise ImportError( + "The A2A contrib requires the 'a2a-sdk' and 'httpx' packages. " + "Install them with `pip install a2a-sdk httpx`." + ) from e + + self._httpx_client = httpx.AsyncClient() + resolver = A2ACardResolver(httpx_client=self._httpx_client, base_url=base_url.rstrip("/")) + self._agent_card = await resolver.get_agent_card() + self._client = A2AClient(httpx_client=self._httpx_client, agent_card=self._agent_card) + + def get_agent_info(self) -> AgentInfo: + """Normalize the resolved agent card into an :class:`AgentInfo`.""" + card = self._agent_card + skills = [] + for skill in getattr(card, "skills", None) or []: + name = getattr(skill, "name", None) or getattr(skill, "id", "") + if name: + skills.append(name) + return AgentInfo( + name=card.name, description=getattr(card, "description", "") or "", skills=skills + ) + + async def send_message(self, message: str) -> str: + """Send a text message to the remote agent and return its text reply.""" + from a2a.types import MessageSendParams, SendMessageRequest + + request = SendMessageRequest( + id=str(uuid.uuid4()), + params=MessageSendParams( + message={ + "role": "user", + "parts": [{"kind": "text", "text": message}], + "messageId": uuid.uuid4().hex, + } + ), + ) + response = await self._client.send_message(request) + return _extract_text_from_response(response) + + async def cleanup(self): + """Close the underlying HTTP client.""" + if self._httpx_client is not None: + await self._httpx_client.aclose() + self._httpx_client = None diff --git a/tensorrt_llm/scaffolding/contrib/a2a/a2a_worker.py b/tensorrt_llm/scaffolding/contrib/a2a/a2a_worker.py new file mode 100644 index 000000000000..09386ed41c72 --- /dev/null +++ b/tensorrt_llm/scaffolding/contrib/a2a/a2a_worker.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import List + +from tensorrt_llm.scaffolding import TaskStatus, Worker + +from .a2a_task import A2AListTask, A2ASendTask +from .a2a_utils import A2AAgentConnection + + +class A2AWorker(Worker): + """A Scaffolding worker that delegates work to remote A2A-protocol agents. + + The worker holds one connection per remote agent. It mirrors ``MCPWorker``: + the ``A2AListTask`` discovers the reachable agents (so a controller can let + the LLM decide where to route), and the ``A2ASendTask`` dispatches a message + to a named agent and collects its reply. + + ``connections`` is any object exposing ``get_agent_info()`` and an async + ``send_message(text)`` (plus an optional async ``cleanup()``). Production + code uses :class:`A2AAgentConnection`; tests can inject fakes so they need + neither ``a2a-sdk`` nor network access. + """ + + def __init__(self, connections: List): + self.connections = connections + + @classmethod + async def init_with_urls(cls, urls: List[str]) -> "A2AWorker": + connections = [] + for url in urls: + connection = A2AAgentConnection() + await connection.connect(url) + connections.append(connection) + return cls(connections) + + async def list_handler(self, task: A2AListTask) -> TaskStatus: + task.result_agents = [connection.get_agent_info() for connection in self.connections] + return TaskStatus.SUCCESS + + async def send_handler(self, task: A2ASendTask) -> TaskStatus: + for connection in self.connections: + if connection.get_agent_info().name != task.agent_name: + continue + task.output_str = await connection.send_message(task.message) + return TaskStatus.SUCCESS + + # No remote agent matched the requested name. Surface a clear message + # so the controller can still produce a final answer. + task.output_str = f"No remote A2A agent named '{task.agent_name}' is available." + return TaskStatus.WORKER_NOT_SUPPORTED + + async def async_shutdown(self): + """Close all remote-agent connections.""" + for connection in self.connections: + cleanup = getattr(connection, "cleanup", None) + if cleanup is not None: + await cleanup() + + def shutdown(self): + """Best-effort synchronous shutdown. + + Prefer :meth:`async_shutdown` from inside an event loop; this fallback + is provided to satisfy the :class:`Worker` interface. + """ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + loop.create_task(self.async_shutdown()) + else: + asyncio.run(self.async_shutdown()) + + task_handlers = {A2AListTask: list_handler, A2ASendTask: send_handler} diff --git a/tests/unittest/scaffolding/test_a2a_worker.py b/tests/unittest/scaffolding/test_a2a_worker.py new file mode 100644 index 000000000000..7dac56447abc --- /dev/null +++ b/tests/unittest/scaffolding/test_a2a_worker.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Offline unit tests for the A2A scaffolding contrib. + +These tests inject fake remote-agent connections into ``A2AWorker`` so they +require neither the ``a2a-sdk`` package nor any network access, mirroring the +``DummyWorker`` approach in ``test_mcp_worker.py``. +""" + +import json + +import pytest + +from tensorrt_llm.scaffolding import ScaffoldingLlm, TaskStatus, Worker +from tensorrt_llm.scaffolding.contrib.a2a import ( + A2AController, + A2AListTask, + A2ASendTask, + A2AWorker, + AgentInfo, +) +from tensorrt_llm.scaffolding.contrib.mcp.chat_task import ChatTask + +# ============================================================ +# Fakes +# ============================================================ + + +class FakeA2AConnection: + """Stand-in for A2AAgentConnection that echoes messages without network.""" + + def __init__(self, name: str, description: str = "", skills=None): + self._info = AgentInfo(name=name, description=description, skills=skills or []) + self.cleaned_up = False + + def get_agent_info(self) -> AgentInfo: + return self._info + + async def send_message(self, message: str) -> str: + return f"[{self._info.name}] handled: {message}" + + async def cleanup(self): + self.cleaned_up = True + + +class FunctionCall: + def __init__(self, name: str, arguments: str): + self.name = name + self.arguments = arguments + + +class ToolCall: + def __init__(self, name: str, arguments: str): + self.function = FunctionCall(name, arguments) + + +class DummyGenerationWorker(Worker): + """Deterministic generation worker: first delegates, then summarizes.""" + + async def dummy_handler(self, task: ChatTask): + # The controller seeds [system, user]; after delegation it appends an + # assistant message and a user message carrying the agent reply. + if len(task.messages) == 2: + task.output_str = "delegating to weather_agent" + task.tool_calls = [ToolCall("weather_agent", json.dumps({"message": "weather in LA?"}))] + task.finish_reason = "tool_calls" + else: + task.output_str = "It is sunny in LA." + task.tool_calls = None + task.finish_reason = "stop" + return TaskStatus.SUCCESS + + task_handlers = {ChatTask: dummy_handler} + + +# ============================================================ +# Worker-level tests +# ============================================================ + + +@pytest.mark.asyncio +async def test_a2a_worker_list_agents(): + worker = A2AWorker( + [ + FakeA2AConnection("weather_agent", "Provides weather", ["forecast"]), + FakeA2AConnection("math_agent", "Does arithmetic"), + ] + ) + task = A2AListTask.create_a2a_task() + status = await worker.run_task(task) + + assert status == TaskStatus.SUCCESS + names = [agent.name for agent in task.result_agents] + assert names == ["weather_agent", "math_agent"] + assert task.result_agents[0].skills == ["forecast"] + + +@pytest.mark.asyncio +async def test_a2a_worker_send_routes_to_named_agent(): + worker = A2AWorker( + [ + FakeA2AConnection("weather_agent"), + FakeA2AConnection("math_agent"), + ] + ) + task = A2ASendTask.create_a2a_task("math_agent", "1 + 1") + status = await worker.run_task(task) + + assert status == TaskStatus.SUCCESS + assert task.output_str == "[math_agent] handled: 1 + 1" + + +@pytest.mark.asyncio +async def test_a2a_worker_send_unknown_agent(): + worker = A2AWorker([FakeA2AConnection("weather_agent")]) + task = A2ASendTask.create_a2a_task("missing_agent", "hello") + status = await worker.run_task(task) + + assert status == TaskStatus.WORKER_NOT_SUPPORTED + assert "missing_agent" in task.output_str + + +@pytest.mark.asyncio +async def test_a2a_worker_async_shutdown(): + connections = [FakeA2AConnection("a"), FakeA2AConnection("b")] + worker = A2AWorker(connections) + await worker.async_shutdown() + assert all(connection.cleaned_up for connection in connections) + + +# ============================================================ +# End-to-end controller test through ScaffoldingLlm +# ============================================================ + + +@pytest.mark.asyncio +async def test_scaffolding_with_a2a_controller(): + a2a_worker = A2AWorker( + [ + FakeA2AConnection("weather_agent", "Provides weather", ["forecast"]), + ] + ) + generation_worker = DummyGenerationWorker() + + controller = A2AController() + scaffolding_llm = ScaffoldingLlm( + controller, + { + A2AController.WorkerTag.GENERATION: generation_worker, + A2AController.WorkerTag.A2A: a2a_worker, + }, + ) + + try: + future = scaffolding_llm.generate_async("What's the weather in LA?") + result = await future.aresult() + assert isinstance(result.outputs[0].text, str) + assert result.outputs[0].text == "It is sunny in LA." + finally: + await a2a_worker.async_shutdown() + scaffolding_llm.shutdown(shutdown_workers=False)