Skip to content

Commit 02a5cfc

Browse files
committed
enforce TextOutput deps type checking via OutputSpec
1 parent 9712a00 commit 02a5cfc

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
133133
be merged with this value, with the runtime argument taking priority.
134134
"""
135135

136-
_output_type: OutputSpec[OutputDataT]
136+
_output_type: OutputSpec[OutputDataT, AgentDepsT]
137137

138138
instrument: InstrumentationSettings | bool | None
139139
"""Options to automatically instrument with OpenTelemetry."""
@@ -170,7 +170,7 @@ def __init__(
170170
self,
171171
model: models.Model | models.KnownModelName | str | None = None,
172172
*,
173-
output_type: OutputSpec[OutputDataT] = str,
173+
output_type: OutputSpec[OutputDataT, AgentDepsT] = str,
174174
instructions: Instructions[AgentDepsT] = None,
175175
system_prompt: str | Sequence[str] = (),
176176
deps_type: type[AgentDepsT] = NoneType,
@@ -198,7 +198,7 @@ def __init__(
198198
self,
199199
model: models.Model | models.KnownModelName | str | None = None,
200200
*,
201-
output_type: OutputSpec[OutputDataT] = str,
201+
output_type: OutputSpec[OutputDataT, AgentDepsT] = str,
202202
instructions: Instructions[AgentDepsT] = None,
203203
system_prompt: str | Sequence[str] = (),
204204
deps_type: type[AgentDepsT] = NoneType,
@@ -224,7 +224,7 @@ def __init__(
224224
self,
225225
model: models.Model | models.KnownModelName | str | None = None,
226226
*,
227-
output_type: OutputSpec[OutputDataT] = str,
227+
output_type: OutputSpec[OutputDataT, AgentDepsT] = str,
228228
instructions: Instructions[AgentDepsT] = None,
229229
system_prompt: str | Sequence[str] = (),
230230
deps_type: type[AgentDepsT] = NoneType,
@@ -419,7 +419,7 @@ def deps_type(self) -> type:
419419
return self._deps_type
420420

421421
@property
422-
def output_type(self) -> OutputSpec[OutputDataT]:
422+
def output_type(self) -> OutputSpec[OutputDataT, AgentDepsT]:
423423
"""The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`."""
424424
return self._output_type
425425

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
T = TypeVar('T')
3636
T_co = TypeVar('T_co', covariant=True)
3737
TextOutputAgentDepsT = TypeVar('TextOutputAgentDepsT', default=None, contravariant=True)
38+
# default=Any for backward compat
39+
OutputSpecDepsT = TypeVar('OutputSpecDepsT', default=Any)
3840

3941
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
4042
"""Covariant type variable for the output data type of a run."""
@@ -356,14 +358,18 @@ def __get_pydantic_json_schema__(
356358

357359
_OutputSpecItem = TypeAliasType(
358360
'_OutputSpecItem',
359-
OutputTypeOrFunction[T_co] | ToolOutput[T_co] | NativeOutput[T_co] | PromptedOutput[T_co] | TextOutput[T_co, Any],
360-
type_params=(T_co,),
361+
OutputTypeOrFunction[T_co]
362+
| ToolOutput[T_co]
363+
| NativeOutput[T_co]
364+
| PromptedOutput[T_co]
365+
| TextOutput[T_co, OutputSpecDepsT],
366+
type_params=(T_co, OutputSpecDepsT),
361367
)
362368

363369
OutputSpec = TypeAliasType(
364370
'OutputSpec',
365-
_OutputSpecItem[T_co] | Sequence['OutputSpec[T_co]'],
366-
type_params=(T_co,),
371+
_OutputSpecItem[T_co, OutputSpecDepsT] | Sequence['OutputSpec[T_co, OutputSpecDepsT]'],
372+
type_params=(T_co, OutputSpecDepsT),
367373
)
368374
"""Specification of the agent's output data.
369375

tests/typed_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,8 @@ def my_method(self) -> bool:
292292
assert_type(text_output_with_ctx, TextOutput[re.Pattern[str], int])
293293
Agent('test', output_type=text_output_with_ctx, deps_type=int)
294294
Agent('test', output_type=text_output_with_ctx, deps_type=bool) # bool is subclass of int, works with contravariant
295-
# NOTE: The following don't produce type errors because _OutputSpecItem uses TextOutput[T_co, Any]
296-
# which erases the deps type constraint.
297-
Agent('test', output_type=text_output_with_ctx, deps_type=str)
298-
Agent('test', output_type=text_output_with_ctx)
295+
Agent('test', output_type=text_output_with_ctx, deps_type=str) # pyright: ignore[reportArgumentType,reportCallIssue]
296+
Agent('test', output_type=text_output_with_ctx) # pyright: ignore[reportArgumentType,reportCallIssue]
299297

300298
# prepare example from docs:
301299

0 commit comments

Comments
 (0)