diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..a1ab8de 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -3,12 +3,17 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from .background_tools import BackgroundTools from .code_mode import CodeMode -__all__ = ['CodeMode'] +__all__ = ['BackgroundTools', 'CodeMode'] def __getattr__(name: str) -> object: + if name == 'BackgroundTools': + from .background_tools import BackgroundTools + + return BackgroundTools if name == 'CodeMode': from .code_mode import CodeMode diff --git a/pydantic_ai_harness/background_tools/README.md b/pydantic_ai_harness/background_tools/README.md new file mode 100644 index 0000000..c8314b9 --- /dev/null +++ b/pydantic_ai_harness/background_tools/README.md @@ -0,0 +1,111 @@ +# Background Tools + +Run selected tools as fire-and-forget asyncio tasks, so the agent can keep working while they finish. + +## The problem + +Some tools take seconds to minutes -- deep research, big aggregations, sub-agent delegation. With normal tool calls the agent is blocked: it makes the call, waits, then plans its next step. Over a long task the conversation effectively serializes. + +## The solution + +`BackgroundTools` spawns the matching tool calls as `asyncio.Task`s. The agent receives an immediate acknowledgment string and continues planning. When the task finishes, its result is enqueued as a follow-up message via [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue]; Pydantic AI's pending message queue redirects the agent into a fresh `ModelRequest` instead of ending, so the model sees the result and can use it. + +## Usage + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import BackgroundTools + +agent = Agent('openai:gpt-5', capabilities=[BackgroundTools()]) + +@agent.tool_plain(metadata={'background': True}) +async def slow_research(query: str) -> str: + """Research a topic thoroughly. Runs in the background.""" + return await do_expensive_research(query) +``` + +By default any tool with `metadata={'background': True}` runs in the background. The agent's instructions are augmented automatically so the model knows it shouldn't block waiting for the result. + +## Selecting which tools run in the background + +`BackgroundTools(tools=...)` accepts the standard [`ToolSelector`][pydantic_ai.tools.ToolSelector]: + +```python +# By metadata key (default) +BackgroundTools() # tools with metadata={'background': True} +BackgroundTools(tools={'background': True}) # explicit form +BackgroundTools(tools={'kind': 'research'}) # custom metadata key + +# By name +BackgroundTools(tools=['slow_research', 'deep_dig']) + +# By predicate +BackgroundTools(tools=lambda ctx, td: td.name.startswith('research_')) +``` + +### Marking a whole MCP server or toolset + +Combine with [`SetToolMetadata`][pydantic_ai.capabilities.SetToolMetadata] or `FunctionToolset.with_metadata(...)` to mark every tool from a source as background, without touching individual definitions: + +```python +from pydantic_ai import Agent +from pydantic_ai.capabilities import MCP, SetToolMetadata +from pydantic_ai_harness import BackgroundTools + +agent = Agent('openai:gpt-5', capabilities=[ + MCP('https://research.example/mcp/'), + SetToolMetadata(predicate=lambda td: td.name.startswith('mcp_'), background=True), + BackgroundTools(), +]) +``` + +## Result delivery + +Results are enqueued as `'follow_up'` priority messages on Pydantic AI's pending message queue. When the agent would otherwise produce a final result, the queue is drained and the agent continues with a fresh `ModelRequest` containing all completed background results. + +The follow-up message format is a `SystemPromptPart` containing: + +- On success: `Background tool 'X' (task ) completed.\nResult: ` +- On failure: `Background tool 'X' (task ) failed: ` + +The model sees the task ID alongside the result so it can correlate against the ack string it received earlier. + +## Lifecycle and cancellation + +- Each agent run gets fresh task state via the capability's `for_run` hook -- concurrent runs do not share tasks +- If the surrounding agent run is cancelled (e.g. via `asyncio.wait_for` timeout), all live background tasks are cancelled in the capability's `wrap_run` cleanup +- `asyncio.CancelledError` from a cancelled task does not produce a follow-up; it propagates as a normal task cancellation + +## Limitations + +- **Streaming**: follow-up delivery requires `agent.run()` or explicit `agent_run.next()` driving. A bare `async for node in agent_run:` loop does not run `after_node_run`, so background results won't be delivered. +- **Temporal / DBOS**: tools run inside durable activities and don't share state with the surrounding workflow. Tool-side `ctx.enqueue` calls do not currently propagate back, so background results from durable tools are lost. If you need this, file an issue. + +## API + +```python +BackgroundTools( + tools: ToolSelector = {'background': True}, +) +``` + +## Agent spec (YAML/JSON) + +```yaml +# agent.yaml +model: openai:gpt-5 +capabilities: + - BackgroundTools: {} +``` + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import BackgroundTools + +agent = Agent.from_file('agent.yaml', custom_capability_types=[BackgroundTools]) +``` + +## Further reading + +- [Pydantic AI message history -- injecting messages mid-run](https://ai.pydantic.dev/message-history/#injecting-messages-mid-run) -- the underlying primitive +- [Pydantic AI capabilities](https://ai.pydantic.dev/capabilities/) diff --git a/pydantic_ai_harness/background_tools/__init__.py b/pydantic_ai_harness/background_tools/__init__.py new file mode 100644 index 0000000..7af4c7b --- /dev/null +++ b/pydantic_ai_harness/background_tools/__init__.py @@ -0,0 +1,5 @@ +"""Run selected tools as background asyncio tasks with async result delivery.""" + +from ._capability import BackgroundTools + +__all__ = ['BackgroundTools'] diff --git a/pydantic_ai_harness/background_tools/_capability.py b/pydantic_ai_harness/background_tools/_capability.py new file mode 100644 index 0000000..58fcdc7 --- /dev/null +++ b/pydantic_ai_harness/background_tools/_capability.py @@ -0,0 +1,166 @@ +"""Background tools capability that spawns selected tools as fire-and-forget tasks.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pydantic_ai.capabilities import AbstractCapability, CapabilityOrdering +from pydantic_ai.capabilities._pending_messages import PendingMessageDrainCapability +from pydantic_ai.messages import SystemPromptPart, ToolCallPart +from pydantic_ai.tools import ( + AgentDepsT, + RunContext, + ToolDefinition, + ToolSelector, + matches_tool_selector, +) + +if TYPE_CHECKING: + from pydantic_ai import _agent_graph + from pydantic_ai.capabilities.abstract import WrapToolExecuteHandler + from pydantic_ai.result import FinalResult + from pydantic_graph import End + + +_DEFAULT_SELECTOR: dict[str, Any] = {'background': True} + +_INSTRUCTIONS = """\ +Some tools run in the background: when you call them you'll get an immediate \ +acknowledgment, and the real result will be delivered automatically as a follow-up \ +message when the task completes. Continue working on other things in the meantime; \ +do not block waiting for the result.\ +""" + + +@dataclass +class BackgroundTools(AbstractCapability[AgentDepsT]): + """Run selected tools as fire-and-forget asyncio tasks. + + When the model calls a tool that matches the selector, the capability spawns the + tool's handler in an `asyncio.Task` and immediately returns an acknowledgment + string to the agent. When the task completes, its result (or error) is enqueued + via [`RunContext.enqueue`][pydantic_ai.tools.RunContext.enqueue] as a `'follow_up'` + message — Pydantic AI's pending message queue redirects the agent to a fresh + `ModelRequest` instead of ending, so the model receives the result and can act on it. + + ```python + from pydantic_ai import Agent + from pydantic_ai_harness import BackgroundTools + + # Default: any tool with `metadata={'background': True}` runs in the background. + agent = Agent('openai:gpt-5', capabilities=[BackgroundTools()]) + + @agent.tool_plain(metadata={'background': True}) + async def slow_research(query: str) -> str: + return await do_expensive_research(query) + ``` + + Combine with [`SetToolMetadata`][pydantic_ai.capabilities.SetToolMetadata] to mark + every tool from a specific MCP server, or with `FunctionToolset.with_metadata(...)` + to mark a whole toolset. Or pass a name list / predicate via `tools=...` to ignore + metadata entirely. + """ + + tools: ToolSelector[AgentDepsT] = field(default_factory=lambda: dict(_DEFAULT_SELECTOR)) + """Which tools should run in the background. + + - `dict[str, Any]` (default `{'background': True}`): tools whose metadata deeply + includes the given key-value pairs. + - `'all'`: every tool in the agent's toolset (rarely what you want). + - `Sequence[str]`: tools with matching names. + - Callable `(ctx, tool_def) -> bool | Awaitable[bool]`: custom predicate. + """ + + _tasks: dict[str, asyncio.Task[None]] = field( + default_factory=dict[str, 'asyncio.Task[None]'], init=False, repr=False + ) + _completion_event: asyncio.Event = field(default_factory=asyncio.Event, init=False, repr=False) + + def get_ordering(self) -> CapabilityOrdering: + # `after_node_run` runs in reverse order (outermost runs last). We need to + # wait for at least one background task BEFORE the core + # `PendingMessageDrainCapability` checks the queue for follow-ups, so + # drain must be outermost relative to us. + return CapabilityOrdering(wrapped_by=[PendingMessageDrainCapability]) + + def get_instructions(self) -> str: + return _INSTRUCTIONS + + async def for_run(self, ctx: RunContext[AgentDepsT]) -> BackgroundTools[AgentDepsT]: + # Fresh per-run state so concurrent runs don't share tasks. + return BackgroundTools(tools=self.tools) + + async def wrap_tool_execute( + self, + ctx: RunContext[AgentDepsT], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + handler: WrapToolExecuteHandler, + ) -> Any: + if not await matches_tool_selector(self.tools, ctx, tool_def): + return await handler(args) + + task_id = call.tool_call_id + tool_name = call.tool_name + + async def _run() -> None: + try: + result = await handler(args) + ctx.enqueue( + SystemPromptPart(f"Background tool '{tool_name}' (task {task_id}) completed.\nResult: {result}"), + priority='follow_up', + ) + except asyncio.CancelledError: + # Run cleanup cancelled us; don't enqueue a spurious failure follow-up. + raise + except Exception as e: + ctx.enqueue( + SystemPromptPart(f"Background tool '{tool_name}' (task {task_id}) failed: {e}"), + priority='follow_up', + ) + finally: + self._tasks.pop(task_id, None) + self._completion_event.set() + + self._tasks[task_id] = asyncio.create_task(_run()) + return ( + f"Tool '{tool_name}' is running in background (task {task_id}). " + f'You will receive the result automatically when it completes. ' + f'Continue with other work in the meantime.' + ) + + async def after_node_run( + self, + ctx: RunContext[AgentDepsT], + *, + node: _agent_graph.AgentNode[AgentDepsT, Any], + result: _agent_graph.AgentNode[AgentDepsT, Any] | End[FinalResult[Any]], + ) -> _agent_graph.AgentNode[AgentDepsT, Any] | End[FinalResult[Any]]: + from pydantic_graph import End + + if not isinstance(result, End) or not self._tasks: + return result + + # Hold End until at least one task completes so the drain capability + # (which runs after us in reverse order) has a follow-up to deliver. + self._completion_event.clear() + await self._completion_event.wait() + return result + + async def wrap_run( + self, + ctx: RunContext[AgentDepsT], + *, + handler: Any, + ) -> Any: + try: + return await handler() + finally: + for task in self._tasks.values(): + task.cancel() + if self._tasks: + await asyncio.gather(*self._tasks.values(), return_exceptions=True) diff --git a/pyproject.toml b/pyproject.toml index f26a661..728706e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ lint = [ ] [tool.uv.sources] -pydantic-ai-slim = { git = 'https://github.com/pydantic/pydantic-ai.git', branch = 'main', subdirectory = 'pydantic_ai_slim' } +pydantic-ai-slim = { git = 'https://github.com/pydantic/pydantic-ai.git', branch = 'background-tools', subdirectory = 'pydantic_ai_slim' } [tool.hatch.version] source = 'uv-dynamic-versioning' diff --git a/tests/_background_tools/__init__.py b/tests/_background_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_background_tools/test_background_tools.py b/tests/_background_tools/test_background_tools.py new file mode 100644 index 0000000..820e98f --- /dev/null +++ b/tests/_background_tools/test_background_tools.py @@ -0,0 +1,252 @@ +"""Tests for the `BackgroundTools` capability.""" + +from __future__ import annotations + +import asyncio + +import pytest +from pydantic_ai import Agent +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.usage import RequestUsage + +from pydantic_ai_harness import BackgroundTools + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +def _ack_seen(messages: list[ModelMessage]) -> bool: + """True if any tool return in the history is a background-execution ack.""" + return any( + isinstance(part, ToolReturnPart) and 'running in background' in str(part.content) + for msg in messages + if isinstance(msg, ModelRequest) + for part in msg.parts + ) + + +def _follow_up_seen(messages: list[ModelMessage], needle: str) -> bool: + """True if any system prompt in the history contains *needle* (e.g. 'completed' / 'failed').""" + return any( + isinstance(part, SystemPromptPart) and needle in part.content + for msg in messages + if isinstance(msg, ModelRequest) + for part in msg.parts + ) + + +class TestBackgroundTools: + """Cover the metadata-default selector path: spawn, ack, deliver, error, cancel.""" + + async def test_metadata_marked_tool_runs_in_background(self) -> None: + """A tool with `metadata={'background': True}` returns an ack and delivers result as follow-up.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='slow_research', args='{"query": "topic"}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + # Agent saw the ack; produce a placeholder, drain holds it back. + return ModelResponse( + parts=[TextPart(content='waiting')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + # Third call: the follow-up has been delivered. + assert _follow_up_seen(messages, 'completed') + return ModelResponse( + parts=[TextPart(content='got result')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTools()]) + + @agent.tool_plain(metadata={'background': True}) + async def slow_research(query: str) -> str: # pyright: ignore[reportUnusedFunction] + await asyncio.sleep(0.01) + return f'researched {query}' + + result = await agent.run('research X') + assert result.output == 'got result' + assert _ack_seen(result.all_messages()) + + async def test_failure_delivered_as_follow_up(self) -> None: + """A background tool that raises produces a 'failed' follow-up message.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='broken', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + return ModelResponse( + parts=[TextPart(content='waiting')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + assert _follow_up_seen(messages, 'failed') + return ModelResponse( + parts=[TextPart(content='handled error')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTools()]) + + @agent.tool_plain(metadata={'background': True}) + async def broken() -> str: # pyright: ignore[reportUnusedFunction] + await asyncio.sleep(0.01) + raise RuntimeError('boom') + + result = await agent.run('go') + assert result.output == 'handled error' + + async def test_unmarked_tool_runs_synchronously(self) -> None: + """A tool without the metadata flag is executed normally; no ack, no follow-up.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + for msg in messages: + if isinstance(msg, ModelRequest): + for part in msg.parts: + if isinstance(part, ToolReturnPart) and part.content == 'sync result': + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='plain', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTools()]) + + @agent.tool_plain + def plain() -> str: # pyright: ignore[reportUnusedFunction] + return 'sync result' + + result = await agent.run('go') + assert result.output == 'done' + assert not _ack_seen(result.all_messages()) + + async def test_run_abort_cancels_live_tasks(self) -> None: + """When the surrounding run is cancelled (e.g. timeout), live background tasks are cancelled too.""" + cancel_seen = asyncio.Event() + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if _ack_seen(messages): + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[ToolCallPart(tool_name='slow', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTools()]) + + @agent.tool_plain(metadata={'background': True}) + async def slow() -> str: # pyright: ignore[reportUnusedFunction] + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_seen.set() + raise + return 'never' # pragma: no cover -- task is cancelled before completing + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(agent.run('go'), timeout=0.5) + + await asyncio.wait_for(cancel_seen.wait(), timeout=1) + + +class TestSelectors: + """Cover the non-default `tools=...` selectors: name list, predicate, custom dict.""" + + async def test_name_list_selector(self) -> None: + """`tools=['name']` selects without needing metadata.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='by_name', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + return ModelResponse( + parts=[TextPart(content='waiting')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent(FunctionModel(model_fn), capabilities=[BackgroundTools(tools=['by_name'])]) + + @agent.tool_plain + async def by_name() -> str: # pyright: ignore[reportUnusedFunction] + await asyncio.sleep(0.01) + return 'value' + + result = await agent.run('go') + assert result.output == 'done' + assert _ack_seen(result.all_messages()) + + async def test_custom_metadata_key_selector(self) -> None: + """`tools={'async': True}` matches any other metadata key.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse( + parts=[ToolCallPart(tool_name='custom', args='{}')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + if call_count == 2: + return ModelResponse( + parts=[TextPart(content='waiting')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + return ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + + agent = Agent( + FunctionModel(model_fn), + capabilities=[BackgroundTools(tools={'async': True})], + ) + + @agent.tool_plain(metadata={'async': True}) + async def custom() -> str: # pyright: ignore[reportUnusedFunction] + await asyncio.sleep(0.01) + return 'value' + + result = await agent.run('go') + assert result.output == 'done' + assert _ack_seen(result.all_messages()) diff --git a/uv.lock b/uv.lock index f585649..50dc4af 100644 --- a/uv.lock +++ b/uv.lock @@ -897,9 +897,9 @@ lint = [ [package.metadata] requires-dist = [ - { name = "pydantic-ai-slim", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=main" }, - { name = "pydantic-ai-slim", extras = ["dbos"], marker = "extra == 'dbos'", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=main" }, - { name = "pydantic-ai-slim", extras = ["temporal"], marker = "extra == 'temporal'", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=main" }, + { name = "pydantic-ai-slim", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=background-tools" }, + { name = "pydantic-ai-slim", extras = ["dbos"], marker = "extra == 'dbos'", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=background-tools" }, + { name = "pydantic-ai-slim", extras = ["temporal"], marker = "extra == 'temporal'", git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=background-tools" }, { name = "pydantic-monty", marker = "extra == 'code-mode'", specifier = ">=0.0.16" }, ] provides-extras = ["code-mode", "dbos", "temporal"] @@ -920,8 +920,8 @@ lint = [ [[package]] name = "pydantic-ai-slim" -version = "1.80.0" -source = { git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=main#4dcfb01e74867485fdb41f8eb3dc535e6e057a90" } +version = "1.86.2.dev12+cb91f8cf" +source = { git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_ai_slim&branch=background-tools#cb91f8cfba7046218fbef08aca1fd627dce69ff9" } dependencies = [ { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "genai-prices" }, @@ -1061,8 +1061,8 @@ wheels = [ [[package]] name = "pydantic-graph" -version = "1.80.0" -source = { git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_graph&branch=main#4dcfb01e74867485fdb41f8eb3dc535e6e057a90" } +version = "1.86.2.dev12+cb91f8cf" +source = { git = "https://github.com/pydantic/pydantic-ai.git?subdirectory=pydantic_graph&branch=background-tools#cb91f8cfba7046218fbef08aca1fd627dce69ff9" } dependencies = [ { name = "httpx" }, { name = "logfire-api" },