Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
VideoUrl,
)
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...tools import AgentDepsT, DeferredToolResults, ToolApproved, ToolDenied
from .. import MessagesBuilder, UIAdapter, UIEventStream
from ._event_stream import VercelAIEventStream
from .request_types import (
Expand All @@ -51,6 +51,7 @@
SourceUrlUIPart,
StepStartUIPart,
TextUIPart,
ToolApprovalResponded,
ToolInputAvailablePart,
ToolOutputAvailablePart,
ToolOutputErrorPart,
Expand Down Expand Up @@ -87,6 +88,56 @@ def messages(self) -> list[ModelMessage]:
"""Pydantic AI messages from the Vercel AI run input."""
return self.load_messages(self.run_input.messages)

@cached_property
def deferred_tool_results(self) -> DeferredToolResults | None:
"""Extract deferred tool results from tool parts with approval responses.

When the Vercel AI SDK client responds to a tool-approval-request, it sends
the approval decision in the tool part's `approval` field. This method extracts
those responses and converts them to Pydantic AI's `DeferredToolResults` format.

Returns:
DeferredToolResults if any tool parts have approval responses, None otherwise.
"""
return self.extract_deferred_tool_results(self.run_input.messages)

@classmethod
def extract_deferred_tool_results(cls, messages: Sequence[UIMessage]) -> DeferredToolResults | None:
"""Extract deferred tool results from UI messages.

Args:
messages: The UI messages to scan for approval responses.

Returns:
DeferredToolResults if any tool parts have approval responses, None otherwise.
"""
approvals: dict[str, bool | ToolApproved | ToolDenied] = {}

for msg in messages:
if msg.role != 'assistant':
continue

for part in msg.parts:
if not isinstance(part, ToolUIPart | DynamicToolUIPart):
continue

approval = part.approval
if approval is None or not isinstance(approval, ToolApprovalResponded):
continue

tool_call_id = part.tool_call_id
if approval.approved:
approvals[tool_call_id] = ToolApproved()
else:
approvals[tool_call_id] = ToolDenied(
message=approval.reason or 'The tool call was denied.'
)

if not approvals:
return None

return DeferredToolResults(approvals=approvals)

@classmethod
def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # noqa: C901
"""Transform Vercel AI messages into Pydantic AI messages."""
Expand Down
34 changes: 32 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import AsyncIterator, Mapping
from dataclasses import dataclass
from typing import Any
from uuid import uuid4

from pydantic_core import to_json

Expand All @@ -13,6 +14,7 @@
BuiltinToolCallPart,
BuiltinToolReturnPart,
FilePart,
FinishReason as PydanticFinishReason,
FunctionToolResultEvent,
RetryPromptPart,
TextPart,
Expand All @@ -23,7 +25,8 @@
ToolCallPartDelta,
)
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...run import AgentRunResultEvent
from ...tools import AgentDepsT, DeferredToolRequests
from .. import UIEventStream
from .request_types import RequestData
from .response_types import (
Expand All @@ -32,6 +35,7 @@
ErrorChunk,
FileChunk,
FinishChunk,
FinishReason,
FinishStepChunk,
ReasoningDeltaChunk,
ReasoningEndChunk,
Expand All @@ -41,13 +45,23 @@
TextDeltaChunk,
TextEndChunk,
TextStartChunk,
ToolApprovalRequestChunk,
ToolInputAvailableChunk,
ToolInputDeltaChunk,
ToolInputStartChunk,
ToolOutputAvailableChunk,
ToolOutputErrorChunk,
)

# Map Pydantic AI finish reasons to Vercel AI format
_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = {
'stop': 'stop',
'length': 'length',
'content_filter': 'content-filter',
'tool_call': 'tool-calls',
'error': 'error',
}

__all__ = ['VercelAIEventStream']

# See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
Expand All @@ -64,6 +78,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp
"""UI event stream transformer for the Vercel AI protocol."""

_step_started: bool = False
_finish_reason: FinishReason = None

@property
def response_headers(self) -> Mapping[str, str] | None:
Expand All @@ -85,10 +100,25 @@ async def before_response(self) -> AsyncIterator[BaseChunk]:
async def after_stream(self) -> AsyncIterator[BaseChunk]:
yield FinishStepChunk()

yield FinishChunk()
yield FinishChunk(finish_reason=self._finish_reason)
yield DoneChunk()

async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]:
pydantic_reason = event.result.response.finish_reason
if pydantic_reason:
self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason)

# Emit tool approval requests for deferred approvals
output = event.result.output
if isinstance(output, DeferredToolRequests):
for tool_call in output.approvals:
yield ToolApprovalRequestChunk(
approval_id=str(uuid4()),
tool_call_id=tool_call.tool_call_id,
)

async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]:
self._finish_reason = 'error'
yield ErrorChunk(error_text=str(error))

async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]:
Expand Down
32 changes: 32 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ class DataUIPart(BaseUIPart):
data: Any


class ToolApprovalRequested(CamelBaseModel):
"""Tool approval in requested state (awaiting user response)."""

id: str
"""The approval request ID."""


class ToolApprovalResponded(CamelBaseModel):
"""Tool approval in responded state (user has approved or denied)."""

id: str
"""The approval request ID."""

approved: bool
"""Whether the user approved the tool call."""

reason: str | None = None
"""Optional reason for the approval or denial."""


ToolApproval = ToolApprovalRequested | ToolApprovalResponded
"""Union of tool approval states."""


# Tool part states as separate models
class ToolInputStreamingPart(BaseUIPart):
"""Tool part in input-streaming state."""
Expand All @@ -119,6 +143,7 @@ class ToolInputStreamingPart(BaseUIPart):
state: Literal['input-streaming'] = 'input-streaming'
input: Any | None = None
provider_executed: bool | None = None
approval: ToolApproval | None = None


class ToolInputAvailablePart(BaseUIPart):
Expand All @@ -130,6 +155,7 @@ class ToolInputAvailablePart(BaseUIPart):
input: Any | None = None
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class ToolOutputAvailablePart(BaseUIPart):
Expand All @@ -143,6 +169,7 @@ class ToolOutputAvailablePart(BaseUIPart):
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
preliminary: bool | None = None
approval: ToolApproval | None = None


class ToolOutputErrorPart(BaseUIPart):
Expand All @@ -156,6 +183,7 @@ class ToolOutputErrorPart(BaseUIPart):
error_text: str
provider_executed: bool | None = None
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


ToolUIPart = ToolInputStreamingPart | ToolInputAvailablePart | ToolOutputAvailablePart | ToolOutputErrorPart
Expand All @@ -171,6 +199,7 @@ class DynamicToolInputStreamingPart(BaseUIPart):
tool_call_id: str
state: Literal['input-streaming'] = 'input-streaming'
input: Any | None = None
approval: ToolApproval | None = None


class DynamicToolInputAvailablePart(BaseUIPart):
Expand All @@ -182,6 +211,7 @@ class DynamicToolInputAvailablePart(BaseUIPart):
state: Literal['input-available'] = 'input-available'
input: Any
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


class DynamicToolOutputAvailablePart(BaseUIPart):
Expand All @@ -195,6 +225,7 @@ class DynamicToolOutputAvailablePart(BaseUIPart):
output: Any
call_provider_metadata: ProviderMetadata | None = None
preliminary: bool | None = None
approval: ToolApproval | None = None


class DynamicToolOutputErrorPart(BaseUIPart):
Expand All @@ -207,6 +238,7 @@ class DynamicToolOutputErrorPart(BaseUIPart):
input: Any
error_text: str
call_provider_metadata: ProviderMetadata | None = None
approval: ToolApproval | None = None


DynamicToolUIPart = (
Expand Down
21 changes: 21 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
ProviderMetadata = dict[str, dict[str, JSONValue]]
"""Provider metadata."""

FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None
"""Reason why the model finished generating."""


class BaseChunk(CamelBaseModel, ABC):
"""Abstract base class for response SSE events."""
Expand Down Expand Up @@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk):
dynamic: bool | None = None


class ToolApprovalRequestChunk(BaseChunk):
"""Tool approval request chunk for human-in-the-loop approval."""

type: Literal['tool-approval-request'] = 'tool-approval-request'
approval_id: str
tool_call_id: str


class ToolOutputDeniedChunk(BaseChunk):
"""Tool output denied chunk when user denies tool execution."""

type: Literal['tool-output-denied'] = 'tool-output-denied'
tool_call_id: str


class SourceUrlChunk(BaseChunk):
"""Source URL chunk."""

Expand Down Expand Up @@ -178,7 +196,9 @@ class DataChunk(BaseChunk):
"""Data chunk with dynamic type."""

type: Annotated[str, Field(pattern=r'^data-')]
id: str | None = None
data: Any
transient: bool | None = None


class StartStepChunk(BaseChunk):
Expand All @@ -205,6 +225,7 @@ class FinishChunk(BaseChunk):
"""Finish chunk."""

type: Literal['finish'] = 'finish'
finish_reason: FinishReason = None
message_metadata: Any | None = None


Expand Down
Loading
Loading