diff --git a/docs/ui/vercel-ai.md b/docs/ui/vercel-ai.md index 7d94aadbb9..f0da2110d3 100644 --- a/docs/ui/vercel-ai.md +++ b/docs/ui/vercel-ai.md @@ -1,6 +1,7 @@ # Vercel AI Data Stream Protocol -Pydantic AI natively supports the [Vercel AI Data Stream Protocol](https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol) to receive agent run input from, and stream events to, a [Vercel AI Elements](https://ai-sdk.dev/elements) frontend. +Pydantic AI natively supports the [Vercel AI Data Stream Protocol](https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol) to receive agent run input from, and stream events to, a frontend using [AI SDK UI](https://ai-sdk.dev/docs/ai-sdk-ui/overview) hooks like [`useChat`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/use-chat). You can optionally use [AI Elements](https://ai-sdk.dev/elements) for pre-built UI components. + ## Usage @@ -36,7 +37,7 @@ async def chat(request: Request) -> Response: If you're using a web framework not based on Starlette (e.g. Django or Flask) or need fine-grained control over the input or output, you can create a `VercelAIAdapter` instance and directly use its methods, which can be chained to accomplish the same thing as the `VercelAIAdapter.dispatch_request()` class method shown above: 1. The [`VercelAIAdapter.build_run_input()`][pydantic_ai.ui.vercel_ai.VercelAIAdapter.build_run_input] class method takes the request body as bytes and returns a Vercel AI [`RequestData`][pydantic_ai.ui.vercel_ai.request_types.RequestData] run input object, which you can then pass to the [`VercelAIAdapter()`][pydantic_ai.ui.vercel_ai.VercelAIAdapter] constructor along with the agent. - - You can also use the [`VercelAIAdapter.from_request()`][pydantic_ai.ui.UIAdapter.from_request] class method to build an adapter directly from a Starlette/FastAPI request. + - You can also use the [`VercelAIAdapter.from_request()`][pydantic_ai.ui.vercel_ai.VercelAIAdapter.from_request] class method to build an adapter directly from a Starlette/FastAPI request. 2. The [`VercelAIAdapter.run_stream()`][pydantic_ai.ui.UIAdapter.run_stream] method runs the agent and returns a stream of Vercel AI events. It supports the same optional arguments as [`Agent.run_stream_events()`](../agents.md#running-agents) and an optional `on_complete` callback function that receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional Vercel AI events. - You can also use [`VercelAIAdapter.run_stream_native()`][pydantic_ai.ui.UIAdapter.run_stream_native] to run the agent and return a stream of Pydantic AI events instead, which can then be transformed into Vercel AI events using [`VercelAIAdapter.transform_stream()`][pydantic_ai.ui.UIAdapter.transform_stream]. 3. The [`VercelAIAdapter.encode_stream()`][pydantic_ai.ui.UIAdapter.encode_stream] method encodes the stream of Vercel AI events as SSE (HTTP Server-Sent Events) strings, which you can then return as a streaming response. @@ -81,3 +82,27 @@ async def chat(request: Request) -> Response: sse_event_stream = adapter.encode_stream(event_stream) return StreamingResponse(sse_event_stream, media_type=accept) ``` + +## Tool Approval + +!!! note + Tool approval requires AI SDK UI v6 or later on the frontend. + +Pydantic AI supports human-in-the-loop tool approval workflows with AI SDK UI, allowing users to approve or deny tool executions before they run. See the [deferred tool calls documentation](../deferred-tools.md) for details on setting up tools that require approval. + +To enable tool approval streaming, pass `tool_approval=True` when creating the adapter: + +```py {test="skip" lint="skip"} +@app.post('/chat') +async def chat(request: Request) -> Response: + adapter = await VercelAIAdapter.from_request(request, agent=agent, tool_approval=True) + return adapter.streaming_response(adapter.run_stream()) +``` + +When `tool_approval=True`, the adapter will: + +1. Emit `tool-approval-request` chunks when tools with `requires_approval=True` are called +2. Automatically extract approval responses from follow-up requests +3. Emit `tool-output-denied` chunks for rejected tools + +On the frontend, AI SDK UI's [`useChat`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/use-chat) hook handles the approval flow. You can use the [`Confirmation`](https://ai-sdk.dev/elements/components/confirmation) component from AI Elements for a pre-built approval UI, or build your own using the hook's `addToolResult` function. diff --git a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py index 4ac609a07e..93310c34f3 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py @@ -125,6 +125,12 @@ class UIAdapter(ABC, Generic[RunInputT, MessageT, EventT, AgentDepsT, OutputData accept: str | None = None """The `Accept` header value of the request, used to determine how to encode the protocol-specific events for the streaming response.""" + tool_approval: bool = False + """Whether to enable tool approval streaming for human-in-the-loop workflows.""" + + deferred_tool_results: DeferredToolResults | None = None + """Deferred tool results extracted from the request, used for tool approval workflows.""" + @classmethod async def from_request( cls, request: Request, *, agent: AbstractAgent[AgentDepsT, OutputDataT] @@ -237,6 +243,10 @@ def run_stream_native( toolsets: Optional additional toolsets for this run. builtin_tools: Optional additional builtin tools to use for this run. """ + # Use instance field as fallback if not explicitly passed + if deferred_tool_results is None: + deferred_tool_results = self.deferred_tool_results + message_history = [*(message_history or []), *self.messages] toolset = self.toolset diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index 391cf06f2f..13b862147e 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -70,6 +70,9 @@ class UIEventStream(ABC, Generic[RunInputT, EventT, AgentDepsT, OutputDataT]): accept: str | None = None """The `Accept` header value of the request, used to determine how to encode the protocol-specific events for the streaming response.""" + tool_approval: bool = False + """Whether tool approval streaming is enabled for human-in-the-loop workflows.""" + message_id: str = field(default_factory=lambda: str(uuid4())) """The message ID to use for the next event.""" diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index fa82b9255b..c63efdf26d 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -35,7 +35,7 @@ VideoUrl, ) from ...output import OutputDataT -from ...tools import AgentDepsT +from ...tools import AgentDepsT, DeferredToolApprovalResult, DeferredToolResults from .. import MessagesBuilder, UIAdapter, UIEventStream from ._event_stream import VercelAIEventStream from .request_types import ( @@ -51,6 +51,7 @@ SourceUrlUIPart, StepStartUIPart, TextUIPart, + ToolApprovalResponded, ToolInputAvailablePart, ToolOutputAvailablePart, ToolOutputErrorPart, @@ -61,7 +62,9 @@ from .response_types import BaseChunk if TYPE_CHECKING: - pass + from starlette.requests import Request + + from ...agent import AbstractAgent __all__ = ['VercelAIAdapter'] @@ -78,9 +81,51 @@ def build_run_input(cls, body: bytes) -> RequestData: """Build a Vercel AI run input object from the request body.""" return request_data_ta.validate_json(body) + @classmethod + async def from_request( + cls, + request: Request, + *, + agent: AbstractAgent[AgentDepsT, OutputDataT], + tool_approval: bool = False, + ) -> VercelAIAdapter[AgentDepsT, OutputDataT]: + """Create a Vercel AI adapter from a request. + + Args: + request: The incoming Starlette/FastAPI request. + agent: The Pydantic AI agent to run. + tool_approval: Whether to enable tool approval streaming for human-in-the-loop workflows. + """ + run_input = cls.build_run_input(await request.body()) + + # Extract deferred tool results from messages when tool_approval is enabled + deferred_tool_results: DeferredToolResults | None = None + if tool_approval: + approvals: dict[str, bool | DeferredToolApprovalResult] = {} + for msg in run_input.messages: + if msg.role != 'assistant': + continue + for part in msg.parts: + if not isinstance(part, ToolUIPart | DynamicToolUIPart): + continue + approval = part.approval + if not isinstance(approval, ToolApprovalResponded): + continue + approvals[part.tool_call_id] = approval.approved + if approvals: + deferred_tool_results = DeferredToolResults(approvals=approvals) + + return cls( + agent=agent, + run_input=run_input, + accept=request.headers.get('accept'), + tool_approval=tool_approval, + deferred_tool_results=deferred_tool_results, + ) + def build_event_stream(self) -> UIEventStream[RequestData, BaseChunk, AgentDepsT, OutputDataT]: """Build a Vercel AI event stream transformer.""" - return VercelAIEventStream(self.run_input, accept=self.accept) + return VercelAIEventStream(self.run_input, accept=self.accept, tool_approval=self.tool_approval) @cached_property def messages(self) -> list[ModelMessage]: diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py index 070166df98..bb854e29c6 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py @@ -4,7 +4,9 @@ from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass +from functools import cached_property from typing import Any +from uuid import uuid4 from pydantic_core import to_json @@ -25,9 +27,14 @@ ) from ...output import OutputDataT from ...run import AgentRunResultEvent -from ...tools import AgentDepsT +from ...tools import AgentDepsT, DeferredToolRequests from .. import UIEventStream -from .request_types import RequestData +from .request_types import ( + DynamicToolUIPart, + RequestData, + ToolApprovalResponded, + ToolUIPart, +) from .response_types import ( BaseChunk, DoneChunk, @@ -44,10 +51,12 @@ TextDeltaChunk, TextEndChunk, TextStartChunk, + ToolApprovalRequestChunk, ToolInputAvailableChunk, ToolInputDeltaChunk, ToolInputStartChunk, ToolOutputAvailableChunk, + ToolOutputDeniedChunk, ToolOutputErrorChunk, ) @@ -78,6 +87,20 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp _step_started: bool = False _finish_reason: FinishReason = None + @cached_property + def _denied_tool_ids(self) -> set[str]: + """Get the set of tool_call_ids that were denied by the user.""" + denied_ids: set[str] = set() + for msg in self.run_input.messages: + if msg.role != 'assistant': + continue + for part in msg.parts: + if not isinstance(part, ToolUIPart | DynamicToolUIPart): + continue + if isinstance(part.approval, ToolApprovalResponded) and not part.approval.approved: + denied_ids.add(part.tool_call_id) + return denied_ids + @property def response_headers(self) -> Mapping[str, str] | None: return VERCEL_AI_DSP_HEADERS @@ -104,9 +127,16 @@ async def after_stream(self) -> AsyncIterator[BaseChunk]: async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]: pydantic_reason = event.result.response.finish_reason if pydantic_reason: - self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason) - return - yield + self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason, 'unknown') + + # Emit tool approval requests for deferred approvals (only when tool_approval is enabled) + output = event.result.output + if self.tool_approval and isinstance(output, DeferredToolRequests): + for tool_call in output.approvals: + yield ToolApprovalRequestChunk( + approval_id=str(uuid4()), + tool_call_id=tool_call.tool_call_id, + ) async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]: self._finish_reason = 'error' @@ -203,10 +233,15 @@ async def handle_file(self, part: FilePart) -> AsyncIterator[BaseChunk]: async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseChunk]: part = event.result - if isinstance(part, RetryPromptPart): - yield ToolOutputErrorChunk(tool_call_id=part.tool_call_id, error_text=part.model_response()) + tool_call_id = part.tool_call_id + + # Check if this tool was denied by the user (only when tool_approval is enabled) + if self.tool_approval and tool_call_id in self._denied_tool_ids: + yield ToolOutputDeniedChunk(tool_call_id=tool_call_id) + elif isinstance(part, RetryPromptPart): + yield ToolOutputErrorChunk(tool_call_id=tool_call_id, error_text=part.model_response()) else: - yield ToolOutputAvailableChunk(tool_call_id=part.tool_call_id, output=self._tool_return_output(part)) + yield ToolOutputAvailableChunk(tool_call_id=tool_call_id, output=self._tool_return_output(part)) # ToolCallResultEvent.content may hold user parts (e.g. text, images) that Vercel AI does not currently have events for diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py index 1fe9a593af..3203bb6162 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py @@ -1,7 +1,9 @@ """Vercel AI request types (UI messages). Converted to Python from: -https://github.com/vercel/ai/blob/ai%405.0.59/packages/ai/src/ui/ui-messages.ts +https://github.com/vercel/ai/blob/ai%406.0.0-beta.159/packages/ai/src/ui/ui-messages.ts + +Tool approval types (`ToolApprovalRequested`, `ToolApprovalResponded`) require AI SDK v6 or later. """ from abc import ABC @@ -110,6 +112,30 @@ class DataUIPart(BaseUIPart): data: Any +class ToolApprovalRequested(CamelBaseModel): + """Tool approval in requested state (awaiting user response).""" + + id: str + """The approval request ID.""" + + +class ToolApprovalResponded(CamelBaseModel): + """Tool approval in responded state (user has approved or denied).""" + + id: str + """The approval request ID.""" + + approved: bool + """Whether the user approved the tool call.""" + + reason: str | None = None + """Optional reason for the approval or denial.""" + + +ToolApproval = ToolApprovalRequested | ToolApprovalResponded +"""Union of tool approval states.""" + + # Tool part states as separate models class ToolInputStreamingPart(BaseUIPart): """Tool part in input-streaming state.""" @@ -119,6 +145,7 @@ class ToolInputStreamingPart(BaseUIPart): state: Literal['input-streaming'] = 'input-streaming' input: Any | None = None provider_executed: bool | None = None + approval: ToolApproval | None = None class ToolInputAvailablePart(BaseUIPart): @@ -130,6 +157,7 @@ class ToolInputAvailablePart(BaseUIPart): input: Any | None = None provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None class ToolOutputAvailablePart(BaseUIPart): @@ -143,6 +171,7 @@ class ToolOutputAvailablePart(BaseUIPart): provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None preliminary: bool | None = None + approval: ToolApproval | None = None class ToolOutputErrorPart(BaseUIPart): @@ -156,6 +185,7 @@ class ToolOutputErrorPart(BaseUIPart): error_text: str provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart @@ -171,6 +201,7 @@ class DynamicToolInputStreamingPart(BaseUIPart): tool_call_id: str state: Literal['input-streaming'] = 'input-streaming' input: Any | None = None + approval: ToolApproval | None = None class DynamicToolInputAvailablePart(BaseUIPart): @@ -182,6 +213,7 @@ class DynamicToolInputAvailablePart(BaseUIPart): state: Literal['input-available'] = 'input-available' input: Any call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None class DynamicToolOutputAvailablePart(BaseUIPart): @@ -195,6 +227,7 @@ class DynamicToolOutputAvailablePart(BaseUIPart): output: Any call_provider_metadata: ProviderMetadata | None = None preliminary: bool | None = None + approval: ToolApproval | None = None class DynamicToolOutputErrorPart(BaseUIPart): @@ -207,6 +240,7 @@ class DynamicToolOutputErrorPart(BaseUIPart): input: Any error_text: str call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None DynamicToolUIPart = ( diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py index 6a7b98a2dc..0757eb4848 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py @@ -1,7 +1,9 @@ """Vercel AI response types (SSE chunks). Converted to Python from: -https://github.com/vercel/ai/blob/ai%405.0.59/packages/ai/src/ui-message-stream/ui-message-chunks.ts +https://github.com/vercel/ai/blob/ai%406.0.0-beta.159/packages/ai/src/ui-message-stream/ui-message-chunks.ts + +Tool approval types (`ToolApprovalRequestChunk`, `ToolOutputDeniedChunk`) require AI SDK UI v6 or later. """ from abc import ABC @@ -149,7 +151,10 @@ class ToolOutputErrorChunk(BaseChunk): class ToolApprovalRequestChunk(BaseChunk): - """Tool approval request chunk for human-in-the-loop approval.""" + """Tool approval request chunk for human-in-the-loop approval. + + Requires AI SDK UI v6 or later. + """ type: Literal['tool-approval-request'] = 'tool-approval-request' approval_id: str @@ -157,7 +162,10 @@ class ToolApprovalRequestChunk(BaseChunk): class ToolOutputDeniedChunk(BaseChunk): - """Tool output denied chunk when user denies tool execution.""" + """Tool output denied chunk when user denies tool execution. + + Requires AI SDK UI v6 or later. + """ type: Literal['tool-output-denied'] = 'tool-output-denied' tool_call_id: str diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 0e191d445d..119f791dda 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -3,6 +3,7 @@ import json from collections.abc import AsyncIterator, MutableMapping from typing import Any, cast +from uuid import UUID import pytest from inline_snapshot import snapshot @@ -1651,6 +1652,359 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events +async def test_tool_approval_request_emission(): + """Test that ToolApprovalRequestChunk is emitted when tools require approval.""" + from pydantic_ai.tools import DeferredToolRequests + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + yield { + 0: DeltaToolCall( + name='delete_file', + json_args='{"path": "test.txt"}', + tool_call_id='delete_1', + ) + } + + agent: Agent[None, str | DeferredToolRequests] = Agent( + model=FunctionModel(stream_function=stream_function), output_type=[str, DeferredToolRequests] + ) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f'Deleted {path}' # pragma: no cover + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Delete test.txt')], + ), + ], + ) + + adapter = VercelAIAdapter(agent, request, tool_approval=True) + events: list[str | dict[str, Any]] = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream()) + ] + + # Verify tool-approval-request chunk is emitted with UUID approval_id + approval_event: dict[str, Any] | None = next( + (e for e in events if isinstance(e, dict) and e.get('type') == 'tool-approval-request'), + None, + ) + assert approval_event is not None + assert approval_event['toolCallId'] == 'delete_1' + # Validate approval_id is a valid UUID + approval_id = approval_event.get('approvalId') + assert approval_id is not None + UUID(approval_id) # Raises ValueError if not a valid UUID + + +async def test_tool_approval_false_does_not_emit_approval_chunks(): + """Test that ToolApprovalRequestChunk is NOT emitted when tool_approval=False.""" + from pydantic_ai.tools import DeferredToolRequests + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + yield { + 0: DeltaToolCall( + name='delete_file', + json_args='{"path": "test.txt"}', + tool_call_id='delete_1', + ) + } + + agent: Agent[None, str | DeferredToolRequests] = Agent( + model=FunctionModel(stream_function=stream_function), output_type=[str, DeferredToolRequests] + ) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f'Deleted {path}' # pragma: no cover + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Delete test.txt')], + ), + ], + ) + + adapter = VercelAIAdapter(agent, request, tool_approval=False) + events: list[str | dict[str, Any]] = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream()) + ] + + # Verify tool-approval-request chunk is NOT emitted when tool_approval=False + approval_events = [e for e in events if isinstance(e, dict) and e.get('type') == 'tool-approval-request'] + assert len(approval_events) == 0 + + +@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') +async def test_tool_output_denied_chunk_emission(): + """Test that ToolOutputDeniedChunk is emitted when a tool call is denied. + + This test verifies the full public interface: from_request() extracts approval + data from messages, and the adapter emits tool-output-denied chunks for denied tools. + """ + from unittest.mock import AsyncMock + + from starlette.requests import Request + + from pydantic_ai.tools import DeferredToolRequests + from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolInputAvailablePart, + ToolApprovalResponded, + ) + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + # Model acknowledges the denial + yield 'The file deletion was cancelled as requested.' + + agent = Agent(model=FunctionModel(stream_function=stream_function), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f'Deleted {path}' + + # Simulate a follow-up request where the user denied the tool + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='user-1', + role='user', + parts=[TextUIPart(text='Delete test.txt')], + ), + UIMessage( + id='assistant-1', + role='assistant', + parts=[ + TextUIPart(text='I will delete the file for you.'), + DynamicToolInputAvailablePart( + tool_name='delete_file', + tool_call_id='delete_approved', + input={'path': 'approved.txt'}, + approval=ToolApprovalResponded(id='approval-456', approved=True), + ), + DynamicToolInputAvailablePart( + tool_name='delete_file', + tool_call_id='delete_1', + input={'path': 'test.txt'}, + approval=ToolApprovalResponded( + id='approval-123', + approved=False, + reason='User cancelled the deletion', + ), + ), + ], + ), + ], + ) + + def mock_header_get(key: str) -> str | None: + return None + + request_body = request.model_dump_json().encode() + mock_request = AsyncMock(spec=Request) + mock_request.body = AsyncMock(return_value=request_body) + mock_request.headers.get = mock_header_get + + adapter = await VercelAIAdapter[None, str | DeferredToolRequests].from_request( + mock_request, agent=agent, tool_approval=True + ) + events: list[str | dict[str, Any]] = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream()) + ] + + # Verify tool-output-denied chunk is emitted + denied_event: dict[str, Any] | None = next( + (e for e in events if isinstance(e, dict) and e.get('type') == 'tool-output-denied'), + None, + ) + assert denied_event is not None + assert denied_event['toolCallId'] == 'delete_1' + + +@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') +async def test_tool_approval_extraction_with_edge_cases(): + """Test that approval extraction correctly skips non-tool parts and non-responded approvals.""" + from unittest.mock import AsyncMock + + from starlette.requests import Request + + from pydantic_ai.tools import DeferredToolRequests + from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolInputAvailablePart, + ToolApprovalRequested, + ToolApprovalResponded, + ) + + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain(requires_approval=True) + def some_tool(x: str) -> str: + return x # pragma: no cover + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage(id='user-1', role='user', parts=[TextUIPart(text='Test')]), + UIMessage( + id='assistant-1', + role='assistant', + parts=[ + TextUIPart(text='Here is my response.'), + DynamicToolInputAvailablePart( + tool_name='some_tool', + tool_call_id='pending_tool', + input={'x': 'pending'}, + approval=ToolApprovalRequested(id='pending-approval'), + ), + DynamicToolInputAvailablePart( + tool_name='some_tool', + tool_call_id='no_approval_tool', + input={'x': 'no_approval'}, + approval=None, + ), + DynamicToolInputAvailablePart( + tool_name='some_tool', + tool_call_id='approved_tool', + input={'x': 'approved'}, + approval=ToolApprovalResponded(id='approved-id', approved=True), + ), + ], + ), + ], + ) + + def mock_header_get(key: str) -> str | None: + return None + + request_body = request.model_dump_json().encode() + mock_request = AsyncMock(spec=Request) + mock_request.body = AsyncMock(return_value=request_body) + mock_request.headers.get = mock_header_get + + adapter = await VercelAIAdapter[None, str | DeferredToolRequests].from_request( + mock_request, agent=agent, tool_approval=True + ) + + # Verify that only the responded approval was extracted + assert adapter.deferred_tool_results is not None + assert adapter.deferred_tool_results.approvals == {'approved_tool': True} + + +@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') +async def test_tool_approval_no_approvals_extracted(): + """Test that deferred_tool_results is None when no approvals are responded.""" + from unittest.mock import AsyncMock + + from starlette.requests import Request + + from pydantic_ai.tools import DeferredToolRequests + from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolInputAvailablePart, + ToolApprovalRequested, + ) + + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain(requires_approval=True) + def some_tool(x: str) -> str: + return x # pragma: no cover + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage(id='user-1', role='user', parts=[TextUIPart(text='Test')]), + UIMessage( + id='assistant-1', + role='assistant', + parts=[ + DynamicToolInputAvailablePart( + tool_name='some_tool', + tool_call_id='pending_tool', + input={'x': 'pending'}, + approval=ToolApprovalRequested(id='pending-approval'), + ), + ], + ), + ], + ) + + def mock_header_get(key: str) -> str | None: + return None + + request_body = request.model_dump_json().encode() + mock_request = AsyncMock(spec=Request) + mock_request.body = AsyncMock(return_value=request_body) + mock_request.headers.get = mock_header_get + + adapter = await VercelAIAdapter[None, str | DeferredToolRequests].from_request( + mock_request, agent=agent, tool_approval=True + ) + + assert adapter.deferred_tool_results is None + + +@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') +async def test_run_stream_with_explicit_deferred_tool_results(): + """Test that run_stream accepts explicit deferred_tool_results parameter.""" + from unittest.mock import AsyncMock + + from starlette.requests import Request + + from pydantic_ai.tools import DeferredToolResults + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + yield 'Done' + + agent = Agent(model=FunctionModel(stream_function=stream_function)) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage(id='user-1', role='user', parts=[TextUIPart(text='Test')]), + ], + ) + + def mock_header_get(key: str) -> str | None: + return None + + request_body = request.model_dump_json().encode() + mock_request = AsyncMock(spec=Request) + mock_request.body = AsyncMock(return_value=request_body) + mock_request.headers.get = mock_header_get + + adapter = await VercelAIAdapter[None, str].from_request(mock_request, agent=agent) + + # Pass deferred_tool_results explicitly (even though it's empty, it covers the else branch) + explicit_results = DeferredToolResults() + events: list[str | dict[str, Any]] = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream(deferred_tool_results=explicit_results)) + ] + + # Verify stream completed successfully + assert events[-1] == '[DONE]' + + @pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') async def test_adapter_dispatch_request(): agent = Agent(model=TestModel()) @@ -2089,6 +2443,7 @@ async def test_adapter_dump_messages_with_tools(): 'output': '{"results":["result1","result2"]}', 'call_provider_metadata': None, 'preliminary': None, + 'approval': None, }, ], }, @@ -2151,6 +2506,7 @@ async def test_adapter_dump_messages_with_builtin_tools(): 'provider_executed': True, 'call_provider_metadata': {'pydantic_ai': {'provider_name': 'openai'}}, 'preliminary': None, + 'approval': None, } ], }, @@ -2197,6 +2553,7 @@ async def test_adapter_dump_messages_with_builtin_tool_without_return(): 'input': '{"query":"orphan query"}', 'provider_executed': True, 'call_provider_metadata': {'pydantic_ai': {'provider_name': 'openai'}}, + 'approval': None, } ], }, @@ -2363,6 +2720,7 @@ async def test_adapter_dump_messages_with_retry(): Fix the errors and try again.\ """, 'call_provider_metadata': None, + 'approval': None, } ], }, @@ -2494,6 +2852,7 @@ async def test_adapter_dump_messages_text_with_interruption(): 'provider_executed': True, 'call_provider_metadata': {'pydantic_ai': {'provider_name': 'test'}}, 'preliminary': None, + 'approval': None, }, { 'type': 'text', @@ -2612,6 +2971,7 @@ async def test_adapter_dump_messages_tool_call_without_return(): 'state': 'input-available', 'input': '{"city":"New York"}', 'call_provider_metadata': None, + 'approval': None, } ], } @@ -2646,6 +3006,7 @@ async def test_adapter_dump_messages_assistant_starts_with_tool(): 'state': 'input-available', 'input': '{}', 'call_provider_metadata': None, + 'approval': None, }, { 'type': 'text',