-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add docs and tests for RunContext.partial_output in output tools
#3726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6d19cf5
e68a9de
e2c0f23
d08191e
a8d51f2
4e69d4d
f368bfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,6 +258,90 @@ print(result.output) | |
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| #### Handling partial output in output functions | ||
|
|
||
| !!! warning "Output functions are called multiple times during streaming" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer for this to be top-level text rather than the entire docs section to be an indented warning block |
||
| When using streaming mode (`run_stream()`), output functions are called **multiple times** — once for each partial output received from the model, and once for the final complete output. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should say this in the "Handling partial output in output validators" section as well. Maybe the 2 sections can be merged as output validators and functions are very similar? |
||
|
|
||
| For output functions with **side effects** (e.g., sending notifications, logging, database updates), you should check the [`RunContext.partial_output`][pydantic_ai.tools.RunContext.partial_output] flag to avoid executing side effects on partial data. | ||
|
|
||
| **How `partial_output` works:** | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No "fake" subheadings please, I'd rather have paragraphs |
||
|
|
||
| - **In sync mode** (`run_sync()`): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I mentioned on the other PR, it's just specific to sync mode or
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather this be a paragraph than a list btw |
||
| - `partial_output=False` always (function called once) | ||
| - **In streaming mode** (`run_stream()`): | ||
| - `partial_output=True` for each partial call | ||
| - `partial_output=False` for the final complete call | ||
|
|
||
| **Example with side effects:** | ||
|
|
||
| ```python {title="output_function_with_side_effects.py"} | ||
| from pydantic import BaseModel | ||
|
|
||
| from pydantic_ai import Agent, RunContext | ||
|
|
||
|
|
||
| class DatabaseRecord(BaseModel): | ||
| name: str | ||
| value: int | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually not a great example because both of these fields are required, meaning the partial data wouldn't validate anyway until the output is complete, so there would not be any partial output in this case. |
||
|
|
||
|
|
||
| def save_to_database(ctx: RunContext, record: DatabaseRecord) -> DatabaseRecord: | ||
| """Output function with side effect - only save final output to database.""" | ||
| if ctx.partial_output: | ||
| # Skip side effects for partial outputs | ||
| return record | ||
|
|
||
| # Only execute side effect for the final output | ||
| print(f'Saving to database: {record.name} = {record.value}') | ||
| #> Saving to database: test = 42 | ||
| return record | ||
|
|
||
|
|
||
| agent = Agent('openai:gpt-5', output_type=save_to_database) | ||
|
|
||
| result = agent.run_sync('Create a record with name "test" and value 42') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not using This example needs some work to actually show the problem + use case |
||
| print(result.output) | ||
| #> name='test' value=42 | ||
| ``` | ||
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| **Example without side effects (transformation only):** | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need this example as it's the "base case" already covered above |
||
|
|
||
| ```python {title="output_function_transformation.py"} | ||
| from pydantic import BaseModel | ||
|
|
||
| from pydantic_ai import Agent | ||
|
|
||
|
|
||
| class UserData(BaseModel): | ||
| username: str | ||
| email: str | ||
|
|
||
|
|
||
| def normalize_user_data(user: UserData) -> UserData: | ||
| """Output function without side effects - safe to call multiple times.""" | ||
| # Pure transformation is safe for multiple calls | ||
| user.username = user.username.lower() | ||
| user.email = user.email.lower() | ||
| return user | ||
|
|
||
|
|
||
| agent = Agent('openai:gpt-5', output_type=normalize_user_data) | ||
|
|
||
| result = agent.run_sync('Create user with username "JohnDoe" and email "[email protected]"') | ||
| print(result.output) | ||
| #> username='johndoe' email='[email protected]' | ||
| ``` | ||
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| **Best practices:** | ||
|
|
||
| - If your output function **has** side effects (database writes, API calls, notifications) → use `ctx.partial_output` to guard them | ||
| - If your output function only **transforms** data (formatting, validation, normalization) → no need to check the flag | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is duplicative with |
||
|
|
||
| ### Output modes | ||
|
|
||
| Pydantic AI implements three different methods to get a model to output structured data: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -540,6 +540,16 @@ async def call_tool( | |
| 'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.", | ||
| 'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...', | ||
| 'What is the first sentence on https://ai.pydantic.dev?': 'Pydantic AI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI.', | ||
| 'Create a record with name "test" and value 42': ToolCallPart( | ||
| tool_name='final_result', | ||
| args={'name': 'test', 'value': 42}, | ||
| tool_call_id='pyd_ai_tool_call_id', | ||
| ), | ||
| 'Create user with username "JohnDoe" and email "[email protected]"': ToolCallPart( | ||
| tool_name='final_result', | ||
| args={'username': 'JohnDoe', 'email': '[email protected]'}, | ||
| tool_call_id='pyd_ai_tool_call_id', | ||
| ), | ||
| } | ||
|
|
||
| tool_responses: dict[tuple[str, str], str] = { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,11 @@ | |
| pytestmark = pytest.mark.anyio | ||
|
|
||
|
|
||
| class Foo(BaseModel): | ||
| a: int | ||
| b: str | ||
|
|
||
|
|
||
| async def test_streamed_text_response(): | ||
| m = TestModel() | ||
|
|
||
|
|
@@ -747,6 +752,100 @@ async def ret_a(x: str) -> str: # pragma: no cover | |
| ) | ||
|
|
||
|
|
||
| class TestPartialOutput: | ||
| """Tests for `ctx.partial_output` flag in output validators and output functions.""" | ||
|
|
||
| # NOTE: When changing these tests: | ||
| # 1. Follow the existing order | ||
| # 2. Update tests in `tests/test_agent.py::TestPartialOutput` as well | ||
|
|
||
| async def test_output_validator_text(self): | ||
| """Test that output validators receive correct value for `partial_output` with text output.""" | ||
| call_log: list[tuple[str, bool]] = [] | ||
|
|
||
| async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: | ||
| for chunk in ['Hello', ' ', 'world', '!']: | ||
| yield chunk | ||
|
|
||
| agent = Agent(FunctionModel(stream_function=sf)) | ||
|
|
||
| @agent.output_validator | ||
| def validate_output(ctx: RunContext[None], output: str) -> str: | ||
| call_log.append((output, ctx.partial_output)) | ||
| return output | ||
|
|
||
| async with agent.run_stream('test') as result: | ||
| text_parts = [text_part async for text_part in result.stream_text(debounce_by=None)] | ||
|
|
||
| assert text_parts[-1] == 'Hello world!' | ||
| assert call_log == snapshot( | ||
| [ | ||
| ('Hello', True), | ||
| ('Hello ', True), | ||
| ('Hello world', True), | ||
| ('Hello world!', True), | ||
| ('Hello world!', False), | ||
| ] | ||
| ) | ||
|
|
||
| async def test_output_validator_structured(self): | ||
| """Test that output validators receive correct value for `partial_output` with structured output.""" | ||
| call_log: list[tuple[Foo, bool]] = [] | ||
|
|
||
| async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: | ||
| assert info.output_tools is not None | ||
| yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')} | ||
| yield {0: DeltaToolCall(json_args=', "b": "f')} | ||
| yield {0: DeltaToolCall(json_args='oo"}')} | ||
|
|
||
| agent = Agent(FunctionModel(stream_function=sf), output_type=Foo) | ||
|
|
||
| @agent.output_validator | ||
| def validate_output(ctx: RunContext[None], output: Foo) -> Foo: | ||
| call_log.append((output, ctx.partial_output)) | ||
| return output | ||
|
|
||
| async with agent.run_stream('test') as result: | ||
| outputs = [output async for output in result.stream_output(debounce_by=None)] | ||
|
|
||
| assert outputs[-1] == Foo(a=42, b='foo') | ||
| assert call_log == snapshot( | ||
| [ | ||
| (Foo(a=42, b='f'), True), | ||
| (Foo(a=42, b='foo'), True), | ||
| (Foo(a=42, b='foo'), False), | ||
| ] | ||
| ) | ||
|
|
||
| async def test_output_function_structured(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add that |
||
| """Test that output functions receive correct value for `partial_output` with structured output.""" | ||
| call_log: list[tuple[Foo, bool]] = [] | ||
|
|
||
| def process_foo(ctx: RunContext[None], foo: Foo) -> Foo: | ||
| call_log.append((foo, ctx.partial_output)) | ||
| return Foo(a=foo.a * 2, b=foo.b.upper()) | ||
|
|
||
| async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: | ||
| assert info.output_tools is not None | ||
| yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 21')} | ||
| yield {0: DeltaToolCall(json_args=', "b": "f')} | ||
| yield {0: DeltaToolCall(json_args='oo"}')} | ||
|
|
||
| agent = Agent(FunctionModel(stream_function=sf), output_type=process_foo) | ||
|
|
||
| async with agent.run_stream('test') as result: | ||
| outputs = [output async for output in result.stream_output(debounce_by=None)] | ||
|
|
||
| assert outputs[-1] == Foo(a=42, b='FOO') | ||
| assert call_log == snapshot( | ||
| [ | ||
| (Foo(a=21, b='f'), True), | ||
| (Foo(a=21, b='foo'), True), | ||
| (Foo(a=21, b='foo'), False), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| class OutputType(BaseModel): | ||
| """Result type used by multiple tests.""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this
Handling partial outputand also dropin output validatorsfrom the other section as it's already under "Output validators"