Skip to content

Commit 64a6fe4

Browse files
authored
Add AI SDK data chunk ID and tool approval types (#3760)
1 parent 4a5da9b commit 64a6fe4

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BuiltinToolCallPart,
1414
BuiltinToolReturnPart,
1515
FilePart,
16+
FinishReason as PydanticFinishReason,
1617
FunctionToolResultEvent,
1718
RetryPromptPart,
1819
TextPart,
@@ -23,6 +24,7 @@
2324
ToolCallPartDelta,
2425
)
2526
from ...output import OutputDataT
27+
from ...run import AgentRunResultEvent
2628
from ...tools import AgentDepsT
2729
from .. import UIEventStream
2830
from .request_types import RequestData
@@ -32,6 +34,7 @@
3234
ErrorChunk,
3335
FileChunk,
3436
FinishChunk,
37+
FinishReason,
3538
FinishStepChunk,
3639
ReasoningDeltaChunk,
3740
ReasoningEndChunk,
@@ -48,6 +51,15 @@
4851
ToolOutputErrorChunk,
4952
)
5053

54+
# Map Pydantic AI finish reasons to Vercel AI format
55+
_FINISH_REASON_MAP: dict[PydanticFinishReason, FinishReason] = {
56+
'stop': 'stop',
57+
'length': 'length',
58+
'content_filter': 'content-filter',
59+
'tool_call': 'tool-calls',
60+
'error': 'error',
61+
}
62+
5163
__all__ = ['VercelAIEventStream']
5264

5365
# See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
@@ -64,6 +76,7 @@ class VercelAIEventStream(UIEventStream[RequestData, BaseChunk, AgentDepsT, Outp
6476
"""UI event stream transformer for the Vercel AI protocol."""
6577

6678
_step_started: bool = False
79+
_finish_reason: FinishReason = None
6780

6881
@property
6982
def response_headers(self) -> Mapping[str, str] | None:
@@ -85,10 +98,18 @@ async def before_response(self) -> AsyncIterator[BaseChunk]:
8598
async def after_stream(self) -> AsyncIterator[BaseChunk]:
8699
yield FinishStepChunk()
87100

88-
yield FinishChunk()
101+
yield FinishChunk(finish_reason=self._finish_reason)
89102
yield DoneChunk()
90103

104+
async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[BaseChunk]:
105+
pydantic_reason = event.result.response.finish_reason
106+
if pydantic_reason:
107+
self._finish_reason = _FINISH_REASON_MAP.get(pydantic_reason)
108+
return
109+
yield
110+
91111
async def on_error(self, error: Exception) -> AsyncIterator[BaseChunk]:
112+
self._finish_reason = 'error'
92113
yield ErrorChunk(error_text=str(error))
93114

94115
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[BaseChunk]:

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/response_types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
ProviderMetadata = dict[str, dict[str, JSONValue]]
1717
"""Provider metadata."""
1818

19+
FinishReason = Literal['stop', 'length', 'content-filter', 'tool-calls', 'error', 'other', 'unknown'] | None
20+
"""Reason why the model finished generating."""
21+
1922

2023
class BaseChunk(CamelBaseModel, ABC):
2124
"""Abstract base class for response SSE events."""
@@ -145,6 +148,21 @@ class ToolOutputErrorChunk(BaseChunk):
145148
dynamic: bool | None = None
146149

147150

151+
class ToolApprovalRequestChunk(BaseChunk):
152+
"""Tool approval request chunk for human-in-the-loop approval."""
153+
154+
type: Literal['tool-approval-request'] = 'tool-approval-request'
155+
approval_id: str
156+
tool_call_id: str
157+
158+
159+
class ToolOutputDeniedChunk(BaseChunk):
160+
"""Tool output denied chunk when user denies tool execution."""
161+
162+
type: Literal['tool-output-denied'] = 'tool-output-denied'
163+
tool_call_id: str
164+
165+
148166
class SourceUrlChunk(BaseChunk):
149167
"""Source URL chunk."""
150168

@@ -178,7 +196,9 @@ class DataChunk(BaseChunk):
178196
"""Data chunk with dynamic type."""
179197

180198
type: Annotated[str, Field(pattern=r'^data-')]
199+
id: str | None = None
181200
data: Any
201+
transient: bool | None = None
182202

183203

184204
class StartStepChunk(BaseChunk):
@@ -205,6 +225,7 @@ class FinishChunk(BaseChunk):
205225
"""Finish chunk."""
206226

207227
type: Literal['finish'] = 'finish'
228+
finish_reason: FinishReason = None
208229
message_metadata: Any | None = None
209230

210231

tests/test_vercel_ai.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ def client_response\
10391039
{'type': 'text-delta', 'delta': ' bodies safely?', 'id': IsStr()},
10401040
{'type': 'text-end', 'id': IsStr()},
10411041
{'type': 'finish-step'},
1042-
{'type': 'finish'},
1042+
{'type': 'finish', 'finishReason': 'stop'},
10431043
'[DONE]',
10441044
]
10451045
)
@@ -1488,7 +1488,7 @@ async def stream_function(
14881488
{'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}},
14891489
{'type': 'error', 'errorText': 'Exceeded maximum retries (1) for output validation'},
14901490
{'type': 'finish-step'},
1491-
{'type': 'finish'},
1491+
{'type': 'finish', 'finishReason': 'error'},
14921492
'[DONE]',
14931493
]
14941494
)
@@ -1531,7 +1531,7 @@ async def tool(query: str) -> str:
15311531
},
15321532
{'type': 'error', 'errorText': 'Unknown tool'},
15331533
{'type': 'finish-step'},
1534-
{'type': 'finish'},
1534+
{'type': 'finish', 'finishReason': 'error'},
15351535
'[DONE]',
15361536
]
15371537
)
@@ -1572,7 +1572,7 @@ def raise_error(run_result: AgentRunResult[Any]) -> None:
15721572
{'type': 'text-end', 'id': IsStr()},
15731573
{'type': 'error', 'errorText': 'Faulty on_complete'},
15741574
{'type': 'finish-step'},
1575-
{'type': 'finish'},
1575+
{'type': 'finish', 'finishReason': 'error'},
15761576
'[DONE]',
15771577
]
15781578
)
@@ -1619,6 +1619,38 @@ async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChun
16191619
)
16201620

16211621

1622+
async def test_data_chunk_with_id_and_transient():
1623+
"""Test DataChunk supports optional id and transient fields for AI SDK compatibility."""
1624+
agent = Agent(model=TestModel())
1625+
1626+
request = SubmitMessage(
1627+
id='foo',
1628+
messages=[
1629+
UIMessage(
1630+
id='bar',
1631+
role='user',
1632+
parts=[TextUIPart(text='Hello')],
1633+
),
1634+
],
1635+
)
1636+
1637+
async def on_complete(run_result: AgentRunResult[Any]) -> AsyncIterator[BaseChunk]:
1638+
# Yield a data chunk with id for reconciliation
1639+
yield DataChunk(type='data-task', id='task-123', data={'status': 'complete'})
1640+
# Yield a transient data chunk (not persisted to history)
1641+
yield DataChunk(type='data-progress', data={'percent': 100}, transient=True)
1642+
1643+
adapter = VercelAIAdapter(agent, request)
1644+
events = [
1645+
'[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: '))
1646+
async for event in adapter.encode_stream(adapter.run_stream(on_complete=on_complete))
1647+
]
1648+
1649+
# Verify the data chunks are present in the events with correct fields
1650+
assert {'type': 'data-task', 'id': 'task-123', 'data': {'status': 'complete'}} in events
1651+
assert {'type': 'data-progress', 'data': {'percent': 100}, 'transient': True} in events
1652+
1653+
16221654
@pytest.mark.skipif(not starlette_import_successful, reason='Starlette is not installed')
16231655
async def test_adapter_dispatch_request():
16241656
agent = Agent(model=TestModel())

0 commit comments

Comments
 (0)