diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index fa82b9255b..2745ccff1d 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -35,7 +35,7 @@ VideoUrl, ) from ...output import OutputDataT -from ...tools import AgentDepsT +from ...tools import AgentDepsT, DeferredToolResults, ToolApproved, ToolDenied from .. import MessagesBuilder, UIAdapter, UIEventStream from ._event_stream import VercelAIEventStream from .request_types import ( @@ -51,6 +51,7 @@ SourceUrlUIPart, StepStartUIPart, TextUIPart, + ToolApprovalResponded, ToolInputAvailablePart, ToolOutputAvailablePart, ToolOutputErrorPart, @@ -87,6 +88,56 @@ def messages(self) -> list[ModelMessage]: """Pydantic AI messages from the Vercel AI run input.""" return self.load_messages(self.run_input.messages) + @cached_property + def deferred_tool_results(self) -> DeferredToolResults | None: + """Extract deferred tool results from tool parts with approval responses. + + When the Vercel AI SDK client responds to a tool-approval-request, it sends + the approval decision in the tool part's `approval` field. This method extracts + those responses and converts them to Pydantic AI's `DeferredToolResults` format. + + Returns: + DeferredToolResults if any tool parts have approval responses, None otherwise. + """ + return self.extract_deferred_tool_results(self.run_input.messages) + + @classmethod + def extract_deferred_tool_results(cls, messages: Sequence[UIMessage]) -> DeferredToolResults | None: + """Extract deferred tool results from UI messages. + + Args: + messages: The UI messages to scan for approval responses. + + Returns: + DeferredToolResults if any tool parts have approval responses, None otherwise. + """ + approvals: dict[str, bool | ToolApproved | ToolDenied] = {} + + for msg in messages: + if msg.role != 'assistant': + continue + + for part in msg.parts: + if not isinstance(part, ToolUIPart | DynamicToolUIPart): + continue + + approval = part.approval + if approval is None or not isinstance(approval, ToolApprovalResponded): + continue + + tool_call_id = part.tool_call_id + if approval.approved: + approvals[tool_call_id] = ToolApproved() + else: + approvals[tool_call_id] = ToolDenied( + message=approval.reason or 'The tool call was denied.' + ) + + if not approvals: + return None + + return DeferredToolResults(approvals=approvals) + @classmethod def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # noqa: C901 """Transform Vercel AI messages into Pydantic AI messages.""" diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py index ca94c5c186..ec89b515bc 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass from typing import Any +from uuid import uuid4 from pydantic_core import to_json @@ -13,6 +14,7 @@ BuiltinToolCallPart, BuiltinToolReturnPart, FilePart, + FinishReason as PydanticFinishReason, FunctionToolResultEvent, RetryPromptPart, TextPart, @@ -23,7 +25,8 @@ ToolCallPartDelta, ) from ...output import OutputDataT -from ...tools import AgentDepsT +from ...run import AgentRunResultEvent +from ...tools import AgentDepsT, DeferredToolRequests from .. import UIEventStream from .request_types import RequestData from .response_types import ( @@ -32,6 +35,7 @@ ErrorChunk, FileChunk, FinishChunk, + FinishReason, FinishStepChunk, ReasoningDeltaChunk, ReasoningEndChunk, @@ -41,6 +45,7 @@ TextDeltaChunk, TextEndChunk, TextStartChunk, + ToolApprovalRequestChunk, ToolInputAvailableChunk, ToolInputDeltaChunk, ToolInputStartChunk, @@ -48,6 +53,15 @@ ToolOutputErrorChunk, ) +# Map Pydantic AI finish reasons to Vercel AI format +_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = { + 'stop': 'stop', + 'length': 'length', + 'content_filter': 'content-filter', + 'tool_call': 'tool-calls', + 'error': 'error', +} + __all__ = ['VercelAIEventStream'] # See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol @@ -64,6 +78,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp """UI event stream transformer for the Vercel AI protocol.""" _step_started: bool = False + _finish_reason: FinishReason = None @property def response_headers(self) -> Mapping[str, str] | None: @@ -85,10 +100,25 @@ async def before_response(self) -> AsyncIterator[BaseChunk]: async def after_stream(self) -> AsyncIterator[BaseChunk]: yield FinishStepChunk() - yield FinishChunk() + yield FinishChunk(finish_reason=self._finish_reason) yield DoneChunk() + async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]: + pydantic_reason = event.result.response.finish_reason + if pydantic_reason: + self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason) + + # Emit tool approval requests for deferred approvals + output = event.result.output + if isinstance(output, DeferredToolRequests): + for tool_call in output.approvals: + yield ToolApprovalRequestChunk( + approval_id=str(uuid4()), + tool_call_id=tool_call.tool_call_id, + ) + async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]: + self._finish_reason = 'error' yield ErrorChunk(error_text=str(error)) async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]: diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py index 1fe9a593af..7d9a8e8e54 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py @@ -110,6 +110,30 @@ class DataUIPart(BaseUIPart): data: Any +class ToolApprovalRequested(CamelBaseModel): + """Tool approval in requested state (awaiting user response).""" + + id: str + """The approval request ID.""" + + +class ToolApprovalResponded(CamelBaseModel): + """Tool approval in responded state (user has approved or denied).""" + + id: str + """The approval request ID.""" + + approved: bool + """Whether the user approved the tool call.""" + + reason: str | None = None + """Optional reason for the approval or denial.""" + + +ToolApproval = ToolApprovalRequested | ToolApprovalResponded +"""Union of tool approval states.""" + + # Tool part states as separate models class ToolInputStreamingPart(BaseUIPart): """Tool part in input-streaming state.""" @@ -119,6 +143,7 @@ class ToolInputStreamingPart(BaseUIPart): state: Literal['input-streaming'] = 'input-streaming' input: Any | None = None provider_executed: bool | None = None + approval: ToolApproval | None = None class ToolInputAvailablePart(BaseUIPart): @@ -130,6 +155,7 @@ class ToolInputAvailablePart(BaseUIPart): input: Any | None = None provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None class ToolOutputAvailablePart(BaseUIPart): @@ -143,6 +169,7 @@ class ToolOutputAvailablePart(BaseUIPart): provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None preliminary: bool | None = None + approval: ToolApproval | None = None class ToolOutputErrorPart(BaseUIPart): @@ -156,6 +183,7 @@ class ToolOutputErrorPart(BaseUIPart): error_text: str provider_executed: bool | None = None call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart @@ -171,6 +199,7 @@ class DynamicToolInputStreamingPart(BaseUIPart): tool_call_id: str state: Literal['input-streaming'] = 'input-streaming' input: Any | None = None + approval: ToolApproval | None = None class DynamicToolInputAvailablePart(BaseUIPart): @@ -182,6 +211,7 @@ class DynamicToolInputAvailablePart(BaseUIPart): state: Literal['input-available'] = 'input-available' input: Any call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None class DynamicToolOutputAvailablePart(BaseUIPart): @@ -195,6 +225,7 @@ class DynamicToolOutputAvailablePart(BaseUIPart): output: Any call_provider_metadata: ProviderMetadata | None = None preliminary: bool | None = None + approval: ToolApproval | None = None class DynamicToolOutputErrorPart(BaseUIPart): @@ -207,6 +238,7 @@ class DynamicToolOutputErrorPart(BaseUIPart): input: Any error_text: str call_provider_metadata: ProviderMetadata | None = None + approval: ToolApproval | None = None DynamicToolUIPart = ( diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py index 1255503107..6a7b98a2dc 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py @@ -16,6 +16,9 @@ ProviderMetadata = dict[str, dict[str, JSONValue]] """Provider metadata.""" +FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None +"""Reason why the model finished generating.""" + class BaseChunk(CamelBaseModel, ABC): """Abstract base class for response SSE events.""" @@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk): dynamic: bool | None = None +class ToolApprovalRequestChunk(BaseChunk): + """Tool approval request chunk for human-in-the-loop approval.""" + + type: Literal['tool-approval-request'] = 'tool-approval-request' + approval_id: str + tool_call_id: str + + +class ToolOutputDeniedChunk(BaseChunk): + """Tool output denied chunk when user denies tool execution.""" + + type: Literal['tool-output-denied'] = 'tool-output-denied' + tool_call_id: str + + class SourceUrlChunk(BaseChunk): """Source URL chunk.""" @@ -178,7 +196,9 @@ class DataChunk(BaseChunk): """Data chunk with dynamic type.""" type: Annotated[str, Field(pattern=r'^data-')] + id: str | None = None data: Any + transient: bool | None = None class StartStepChunk(BaseChunk): @@ -205,6 +225,7 @@ class FinishChunk(BaseChunk): """Finish chunk.""" type: Literal['finish'] = 'finish' + finish_reason: FinishReason = None message_metadata: Any | None = None diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 12a4ea3eaa..dad1e9c91e 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -1039,7 +1039,7 @@ def client_response\ {'type': 'text-delta', 'delta': ' bodies safely?', 'id': IsStr()}, {'type': 'text-end', 'id': IsStr()}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'stop'}, '[DONE]', ] ) @@ -1488,7 +1488,7 @@ async def stream_function( {'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}}, {'type': 'error', 'errorText': 'Exceeded maximum retries (1) for output validation'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1531,7 +1531,7 @@ async def tool(query: str) -> str: }, {'type': 'error', 'errorText': 'Unknown tool'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1572,7 +1572,7 @@ def raise_error(run_result: AgentRunResult[Any]) -> None: {'type': 'text-end', 'id': IsStr()}, {'type': 'error', 'errorText': 'Faulty on_complete'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1619,6 +1619,166 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun ) +async def test_data_chunk_with_id_and_transient(): + """Test DataChunk supports optional id and transient fields for AI SDK compatibility.""" + agent = Agent(model=TestModel()) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Hello')], + ), + ], + ) + + async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChunk]: + # Yield a data chunk with id for reconciliation + yield DataChunk(type='data-task', id='task-123', data={'status': 'complete'}) + # Yield a transient data chunk (not persisted to history) + yield DataChunk(type='data-progress', data={'percent': 100}, transient=True) + + adapter = VercelAIAdapter(agent, request) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream(on_complete=on_complete)) + ] + + # Verify the data chunks are present in the events with correct fields + assert {'type': 'data-task', 'id': 'task-123', 'data': {'status': 'complete'}} in events + assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events + + +async def test_tool_approval_request_emission(): + """Test that ToolApprovalRequestChunk is emitted when tools require approval.""" + from pydantic_ai.tools import DeferredToolRequests + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + yield { + 0: DeltaToolCall( + name='delete_file', + json_args='{"path": "test.txt"}', + tool_call_id='delete_1', + ) + } + + agent: Agent[None, str | DeferredToolRequests] = Agent( + model=FunctionModel(stream_function=stream_function), output_type=[str, DeferredToolRequests] + ) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f'Deleted {path}' + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Delete test.txt')], + ), + ], + ) + + adapter = VercelAIAdapter(agent, request) + events: list[str | dict[str, Any]] = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream()) + ] + + # Verify tool-approval-request chunk is emitted with UUID approval_id + approval_event: dict[str, Any] | None = next( + (e for e in events if isinstance(e, dict) and e.get('type') == 'tool-approval-request'), + None, + ) + assert approval_event is not None + assert approval_event['toolCallId'] == 'delete_1' + assert 'approvalId' in approval_event + + +def test_extract_deferred_tool_results_approved(): + """Test that approved tool calls are correctly extracted from UI messages.""" + from pydantic_ai.tools import ToolApproved + + from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolInputAvailablePart, + ToolApprovalResponded, + ) + + messages = [ + UIMessage( + id='msg-1', + role='assistant', + parts=[ + DynamicToolInputAvailablePart( + tool_name='delete_file', + tool_call_id='delete_1', + input={'path': 'test.txt'}, + approval=ToolApprovalResponded(id='approval-123', approved=True), + ), + ], + ), + ] + + result = VercelAIAdapter.extract_deferred_tool_results(messages) + assert result is not None + assert 'delete_1' in result.approvals + assert isinstance(result.approvals['delete_1'], ToolApproved) + + +def test_extract_deferred_tool_results_denied(): + """Test that denied tool calls are correctly extracted from UI messages.""" + from pydantic_ai.tools import ToolDenied + + from pydantic_ai.ui.vercel_ai.request_types import ( + DynamicToolInputAvailablePart, + ToolApprovalResponded, + ) + + messages = [ + UIMessage( + id='msg-1', + role='assistant', + parts=[ + DynamicToolInputAvailablePart( + tool_name='delete_file', + tool_call_id='delete_1', + input={'path': 'test.txt'}, + approval=ToolApprovalResponded( + id='approval-123', approved=False, reason='User rejected deletion' + ), + ), + ], + ), + ] + + result = VercelAIAdapter.extract_deferred_tool_results(messages) + assert result is not None + assert 'delete_1' in result.approvals + denial = result.approvals['delete_1'] + assert isinstance(denial, ToolDenied) + assert denial.message == 'User rejected deletion' + + +def test_extract_deferred_tool_results_no_approvals(): + """Test that None is returned when no approval responses exist.""" + messages = [ + UIMessage( + id='msg-1', + role='user', + parts=[TextUIPart(text='Hello')], + ), + ] + + result = VercelAIAdapter.extract_deferred_tool_results(messages) + assert result is None + + @pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') async def test_adapter_dispatch_request(): agent = Agent(model=TestModel())