diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py index ca94c5c186..070166df98 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py @@ -13,6 +13,7 @@ BuiltinToolCallPart, BuiltinToolReturnPart, FilePart, + FinishReason as PydanticFinishReason, FunctionToolResultEvent, RetryPromptPart, TextPart, @@ -23,6 +24,7 @@ ToolCallPartDelta, ) from ...output import OutputDataT +from ...run import AgentRunResultEvent from ...tools import AgentDepsT from .. import UIEventStream from .request_types import RequestData @@ -32,6 +34,7 @@ ErrorChunk, FileChunk, FinishChunk, + FinishReason, FinishStepChunk, ReasoningDeltaChunk, ReasoningEndChunk, @@ -48,6 +51,15 @@ ToolOutputErrorChunk, ) +# Map Pydantic AI finish reasons to Vercel AI format +_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = { + 'stop': 'stop', + 'length': 'length', + 'content_filter': 'content-filter', + 'tool_call': 'tool-calls', + 'error': 'error', +} + __all__ = ['VercelAIEventStream'] # See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol @@ -64,6 +76,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp """UI event stream transformer for the Vercel AI protocol.""" _step_started: bool = False + _finish_reason: FinishReason = None @property def response_headers(self) -> Mapping[str, str] | None: @@ -85,10 +98,18 @@ async def before_response(self) -> AsyncIterator[BaseChunk]: async def after_stream(self) -> AsyncIterator[BaseChunk]: yield FinishStepChunk() - yield FinishChunk() + yield FinishChunk(finish_reason=self._finish_reason) yield DoneChunk() + async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]: + pydantic_reason = event.result.response.finish_reason + if pydantic_reason: + self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason) + return + yield + async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]: + self._finish_reason = 'error' yield ErrorChunk(error_text=str(error)) async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]: diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py index 1255503107..6a7b98a2dc 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py @@ -16,6 +16,9 @@ ProviderMetadata = dict[str, dict[str, JSONValue]] """Provider metadata.""" +FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None +"""Reason why the model finished generating.""" + class BaseChunk(CamelBaseModel, ABC): """Abstract base class for response SSE events.""" @@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk): dynamic: bool | None = None +class ToolApprovalRequestChunk(BaseChunk): + """Tool approval request chunk for human-in-the-loop approval.""" + + type: Literal['tool-approval-request'] = 'tool-approval-request' + approval_id: str + tool_call_id: str + + +class ToolOutputDeniedChunk(BaseChunk): + """Tool output denied chunk when user denies tool execution.""" + + type: Literal['tool-output-denied'] = 'tool-output-denied' + tool_call_id: str + + class SourceUrlChunk(BaseChunk): """Source URL chunk.""" @@ -178,7 +196,9 @@ class DataChunk(BaseChunk): """Data chunk with dynamic type.""" type: Annotated[str, Field(pattern=r'^data-')] + id: str | None = None data: Any + transient: bool | None = None class StartStepChunk(BaseChunk): @@ -205,6 +225,7 @@ class FinishChunk(BaseChunk): """Finish chunk.""" type: Literal['finish'] = 'finish' + finish_reason: FinishReason = None message_metadata: Any | None = None diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 12a4ea3eaa..0e191d445d 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -1039,7 +1039,7 @@ def client_response\ {'type': 'text-delta', 'delta': ' bodies safely?', 'id': IsStr()}, {'type': 'text-end', 'id': IsStr()}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'stop'}, '[DONE]', ] ) @@ -1488,7 +1488,7 @@ async def stream_function( {'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}}, {'type': 'error', 'errorText': 'Exceeded maximum retries (1) for output validation'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1531,7 +1531,7 @@ async def tool(query: str) -> str: }, {'type': 'error', 'errorText': 'Unknown tool'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1572,7 +1572,7 @@ def raise_error(run_result: AgentRunResult[Any]) -> None: {'type': 'text-end', 'id': IsStr()}, {'type': 'error', 'errorText': 'Faulty on_complete'}, {'type': 'finish-step'}, - {'type': 'finish'}, + {'type': 'finish', 'finishReason': 'error'}, '[DONE]', ] ) @@ -1619,6 +1619,38 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun ) +async def test_data_chunk_with_id_and_transient(): + """Test DataChunk supports optional id and transient fields for AI SDK compatibility.""" + agent = Agent(model=TestModel()) + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='Hello')], + ), + ], + ) + + async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChunk]: + # Yield a data chunk with id for reconciliation + yield DataChunk(type='data-task', id='task-123', data={'status': 'complete'}) + # Yield a transient data chunk (not persisted to history) + yield DataChunk(type='data-progress', data={'percent': 100}, transient=True) + + adapter = VercelAIAdapter(agent, request) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream(on_complete=on_complete)) + ] + + # Verify the data chunks are present in the events with correct fields + assert {'type': 'data-task', 'id': 'task-123', 'data': {'status': 'complete'}} in events + assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events + + @pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed') async def test_adapter_dispatch_request(): agent = Agent(model=TestModel())