diff --git a/docs/output.md b/docs/output.md index a39556ead5..dbbb84c6e9 100644 --- a/docs/output.md +++ b/docs/output.md @@ -237,7 +237,11 @@ RouterFailure(explanation='I am not equipped to provide travel information, such #### Text output -If you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more [`ToolOutput`](#tool-output) marker classes (or unmarked types or functions) in a list provided to `output_type`. +If you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. + +If desired, this marker class can be used alongside one or more [`ToolOutput`](#tool-output) marker classes (or unmarked types or functions) in a list provided to `output_type`. + +Like other output functions, text output functions can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type). ```python {title="text_output_function.py"} from pydantic_ai import Agent, TextOutput diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index cd5e5865a6..f2135358ab 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -11,8 +11,9 @@ from . import _utils, exceptions from ._json_schema import InlineDefsJsonSchemaTransformer +from ._run_context import RunContext from .messages import ToolCallPart -from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition +from .tools import DeferredToolRequests, ObjectJsonSchema, ToolDefinition __all__ = ( # classes @@ -60,7 +61,7 @@ TextOutputFunc = TypeAliasType( 'TextOutputFunc', - Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], + Callable[[RunContext[Any], str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], type_params=(T_co,), ) """Definition of a function that will be called to process the model's plain text output. The function must take a single string argument. diff --git a/tests/typed_agent.py b/tests/typed_agent.py index d8ae40711e..7920698b64 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -196,6 +196,10 @@ def str_to_regex(text: str) -> re.Pattern[str]: return re.compile(text) +def str_to_regex_with_ctx(ctx: RunContext[int], text: str) -> re.Pattern[str]: + return re.compile(text) + + class MyClass: def my_method(self) -> bool: return True @@ -283,6 +287,16 @@ def my_method(self) -> bool: # since deps are not set, they default to `None`, so can't be `int` Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType,reportCallIssue] +# TextOutput with RunContext uses RunContext[Any], so deps_type is not checked. +# This is intentional: type checking deps in output functions isn't feasible because +# ToolOutput and plain output functions take arbitrary args, so the type checker +# treats RunContext as just another arg rather than enforcing deps_type compatibility. +text_output_with_ctx = TextOutput(str_to_regex_with_ctx) +assert_type(text_output_with_ctx, TextOutput[re.Pattern[str]]) +Agent('test', output_type=text_output_with_ctx, deps_type=int) +Agent('test', output_type=text_output_with_ctx, deps_type=str) +Agent('test', output_type=text_output_with_ctx) + # prepare example from docs: