-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Allow typed RunContext[Deps] in TextOutput signature
#3732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
samuelcolvin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Docs Preview
|
| 'TextOutputFunc', | ||
| Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], | ||
| type_params=(T_co,), | ||
| Callable[[RunContext[AgentDepsT], str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #3319 and #3721; we shouldn't be reusing generic type vars across different files, but instead need a new one here.
And we should think about covariant vs contravariant etc.
The easiest way to ensure it works as expected is to add some cases that should be valid to typed_agent.py, which is automatically type checked
tests/test_agent.py
Outdated
| ) | ||
|
|
||
|
|
||
| def test_output_type_text_output_function_with_deps(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure we've already tested somewhere that it works, this PR is just for fixing the type checking, so let's remove the new tests and add some cases to typed_agent.py instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this test
| If you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more [`ToolOutput`](#tool-output) marker classes (or unmarked types or functions) in a list provided to `output_type`. | ||
| If you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. | ||
|
|
||
| ```python {title="text_output_function.py"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't mean to remove the example entirely; just the extra RunContext-based on you'd added 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmao I was like "okay..."
tests/test_examples.py
Outdated
| 'What is the capital of the UK?': 'The capital of the UK is London.', | ||
| 'What is the capital of Mexico?': 'The capital of Mexico is Mexico City.', | ||
| 'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.', | ||
| 'Hello world': 'Hello world', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed anymore
tests/typed_agent.py
Outdated
| Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType,reportCallIssue] | ||
|
|
||
| # TextOutput with RunContext | ||
| Agent('test', output_type=TextOutput(str_to_regex_with_ctx), deps_type=int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To verify that contravariant=True is the right option, we should also have an example where the deps_type is a subclass of the type on the function's ctx, e.g. bool. That should work, because a function that takes an int is also able to handle a bool (which is essentially 1 or 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move TextOutput(str_to_regex_with_ctx) to a variable and add an assert_type so we can verify its type is inferred correctly?
tests/typed_agent.py
Outdated
| # TextOutput with RunContext | ||
| Agent('test', output_type=TextOutput(str_to_regex_with_ctx), deps_type=int) | ||
| Agent('test', output_type=TextOutput(str_to_regex_with_ctx), deps_type=str) # pyright: ignore[reportArgumentType,reportCallIssue] | ||
| Agent('test', output_type=TextOutput(str_to_regex_with_ctx)) # pyright: ignore[reportArgumentType,reportCallIssue] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can then reuse the var here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _OutputSpecItem, we fill in the generic param as Any, so I'm curious if that could cause any issues here because maybe it'll accept any type instead of just same one as deps_type? It'd be interesting to see the typing errors you got here before you add the pyright: ignores. (just pasting them into a GH comment here is fine)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no type errors, I added a NOTE for it
docs/output.md
Outdated
|
|
||
| _(This example is complete, it can be run "as is")_ | ||
|
|
||
| Like other output functions, text output functions can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above the example please :) I prefer to explain, then show
tests/test_agent.py
Outdated
| ) | ||
|
|
||
|
|
||
| def test_output_type_text_output_function_with_deps(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this test
tests/typed_agent.py
Outdated
| Agent('test', output_type=text_output_with_ctx, deps_type=int) | ||
| Agent('test', output_type=text_output_with_ctx, deps_type=bool) # bool is subclass of int, works with contravariant | ||
| # NOTE: The following don't produce type errors because _OutputSpecItem uses TextOutput[T_co, Any] | ||
| # which erases the deps type constraint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can't make those give a typing error, then maybe it's not worth adding TextOutputAgentDepsT and we can just make it RunContext[Any].
Ideally, we would go a step further and add DepsT to _OutputSpecItem, OutputSpec etc, so that the user is warned if their ctx: RunContext[Foo] arg doesn't match the agent's deps_type, which would work for TextOutput(func), but unfortunately there's now way to make that work for output functions in general, e.g. output_type=func or output_type=ToolOutput(func), as those functions that arbitrary args, so if the first arg ctx: RunContext[Foo] while deps_type=Bar, the type checker will just treat it like an arbitrary arg, instead of complaining. (ping me if that doesn't make sense; typing is tricky)
So I think the best we can do is just make it RunContext[Any] and accept that this part isn't type checked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was able to enforce this
Co-authored-by: Douwe Maan <[email protected]>
| 'TextOutputFunc', | ||
| Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], | ||
| type_params=(T_co,), | ||
| Callable[[RunContext[TextOutputAgentDepsT], str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah looking at this now, since this will only work for TextOutput and not output functions in general, I don't think it's worth the extra (out of order) type var and would prefer to just have RunContext[Any] here and leave it at that
345bc95 to
9712a00
Compare
|
|
||
| @dataclass | ||
| class TextOutput(Generic[OutputDataT]): | ||
| class TextOutput(Generic[OutputDataT, TextOutputAgentDepsT]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's drop TextOutputAgentDepsT entirely and use RunContext[Any] in TextOutputFunc, as we don't actually use the generic param and it bothers me that it's out of order compared to everything else that takes Deps followed by Output 😄
tests/typed_agent.py
Outdated
| assert_type(text_output_with_ctx, TextOutput[re.Pattern[str], int]) | ||
| Agent('test', output_type=text_output_with_ctx, deps_type=int) | ||
| Agent('test', output_type=text_output_with_ctx, deps_type=bool) # bool is subclass of int, works with contravariant | ||
| # NOTE: The following don't produce type errors because _OutputSpecItem uses TextOutput[T_co, Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's worth having a comment like this, but this'll have to be rewritten if my comment above is addressed
TextOutputsupports passing in functions that take in the agent'sRunContext[Deps], but pyright complains when theDepstype is specified in the signature. This PR adds support for properly typing theRunContextContext (badum tss): https://pydantic.slack.com/archives/C081LUCJ4KX/p1765814128383649?thread_ts=1765812544.296369&cid=C081LUCJ4KX
Note: sorry about the change in the
uv.lock, that seems to have escaped thegoogle-genai bumpPR