Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/mcp/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def sampling_callback(
annotations=None,
meta=None,
),
meta=None,
)
]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class Step(BaseModel):
class Plan(BaseModel):
"""Represents a plan with multiple steps."""

steps: list[Step] = Field(default_factory=list, description='The steps in the plan')
steps: list[Step] = Field(
default_factory=list[Step], description='The steps in the plan'
)


class JSONPatchOp(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions examples/pydantic_ai_examples/ag_ui/api/shared_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ class Recipe(BaseModel):
description='The skill level required for the recipe',
)
special_preferences: list[SpecialPreferences] = Field(
default_factory=list,
default_factory=list[SpecialPreferences],
description='Any special preferences for the recipe',
)
cooking_time: CookingTime = Field(
default=CookingTime.FIVE_MIN, description='The cooking time of the recipe'
)
ingredients: list[Ingredient] = Field(
default_factory=list,
default_factory=list[Ingredient],
description='Ingredients for the recipe',
)
instructions: list[str] = Field(
Expand Down
6 changes: 3 additions & 3 deletions examples/pydantic_ai_examples/data_analyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@dataclass
class AnalystAgentDeps:
output: dict[str, pd.DataFrame] = field(default_factory=dict)
output: dict[str, pd.DataFrame] = field(default_factory=dict[str, pd.DataFrame])

def store(self, value: pd.DataFrame) -> str:
"""Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
Expand Down Expand Up @@ -47,7 +47,7 @@ def load_dataset(
"""
# begin load data from hf
builder = datasets.load_dataset_builder(path) # pyright: ignore[reportUnknownMemberType]
splits: dict[str, datasets.SplitInfo] = builder.info.splits or {} # pyright: ignore[reportUnknownMemberType]
splits: dict[str, datasets.SplitInfo] = builder.info.splits or {}
if split not in splits:
raise ModelRetry(
f'{split} is not valid for dataset {path}. Valid splits are {",".join(splits.keys())}'
Expand Down Expand Up @@ -87,7 +87,7 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
data = ctx.deps.get(dataset)
result = duckdb.query_df(df=data, virtual_table_name='dataset', sql_query=sql)
# pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
ref = ctx.deps.store(result.df()) # pyright: ignore[reportUnknownMemberType]
ref = ctx.deps.store(result.df())
return f'Executed SQL, result is `{ref}`'


Expand Down
6 changes: 4 additions & 2 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
ask_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])
evaluate_agent_messages: list[ModelMessage] = field(
default_factory=list[ModelMessage]
)


@dataclass
Expand Down
12 changes: 6 additions & 6 deletions examples/pydantic_ai_examples/slack_lead_qualifier/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
name='slack-lead-qualifier',
image=image,
secrets=[
modal.Secret.from_name('logfire'),
modal.Secret.from_name('openai'),
modal.Secret.from_name('slack'),
modal.Secret.from_name('logfire'), # pyright: ignore[reportUnknownMemberType]
modal.Secret.from_name('openai'), # pyright: ignore[reportUnknownMemberType]
modal.Secret.from_name('slack'), # pyright: ignore[reportUnknownMemberType]
],
) ### [/setup_modal]

Expand All @@ -31,7 +31,7 @@ def setup_logfire():


### [web_app]
@app.function(min_containers=1)
@app.function(min_containers=1) # pyright: ignore[reportUnknownMemberType]
@modal.asgi_app() # type: ignore
def web_app():
setup_logfire()
Expand All @@ -42,7 +42,7 @@ def web_app():


### [process_slack_member]
@app.function()
@app.function() # pyright: ignore[reportUnknownMemberType]
async def process_slack_member(profile_raw: dict[str, Any], logfire_ctx: Any):
setup_logfire()

Expand All @@ -57,7 +57,7 @@ async def process_slack_member(profile_raw: dict[str, Any], logfire_ctx: Any):


### [send_daily_summary]
@app.function(schedule=modal.Cron('0 8 * * *')) # Every day at 8am UTC
@app.function(schedule=modal.Cron('0 8 * * *')) # pyright: ignore[reportUnknownMemberType] # Every day at 8am UTC
async def send_daily_summary():
setup_logfire()

Expand Down
24 changes: 17 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
class GraphAgentState:
"""State kept across the execution of the agent graph."""

message_history: list[_messages.ModelMessage] = dataclasses.field(default_factory=list)
message_history: list[_messages.ModelMessage] = dataclasses.field(default_factory=list[_messages.ModelMessage])
usage: _usage.RunUsage = dataclasses.field(default_factory=_usage.RunUsage)
retries: int = 0
run_step: int = 0
Expand Down Expand Up @@ -186,12 +186,16 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
deferred_tool_results: DeferredToolResults | None = None

instructions: str | None = None
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=list[_system_prompt.SystemPromptRunner[DepsT]]
)

system_prompts: tuple[str, ...] = dataclasses.field(default_factory=tuple)
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=list[_system_prompt.SystemPromptRunner[DepsT]]
)
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=dict
default_factory=dict[str, _system_prompt.SystemPromptRunner[DepsT]]
)

async def run( # noqa: C901
Expand Down Expand Up @@ -1101,7 +1105,14 @@ async def handle_call_or_result(
for call in tool_calls
]

pending = tasks
pending: set[
asyncio.Task[
tuple[
_messages.ToolReturnPart | _messages.RetryPromptPart,
str | Sequence[_messages.UserContent] | None,
]
]
] = set(tasks)
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
Expand Down Expand Up @@ -1319,8 +1330,7 @@ async def _process_message_history(
if takes_ctx:
messages = await processor(run_context, messages)
else:
async_processor = cast(_HistoryProcessorAsync, processor)
messages = await async_processor(messages)
messages = await processor(messages)
else:
if takes_ctx:
sync_processor_with_ctx = cast(_HistoryProcessorSyncWithCtx[DepsT], processor)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FunctionSchema:
takes_ctx: bool
is_async: bool
single_arg_name: str | None = None
positional_fields: list[str] = field(default_factory=list)
positional_fields: list[str] = field(default_factory=list[str])
var_positional_field: str | None = None

async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
Expand Down
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def map_from_mcp_params(params: mcp_types.CreateMessageRequestParams) -> list[me
# TODO(Marcelo): We can reuse the `_map_tool_result_part` from the mcp module here.
if isinstance(content, mcp_types.TextContent):
user_part_content: str | Sequence[messages.UserContent] = content.text
else:
# image content
elif isinstance(content, mcp_types.ImageContent):
user_part_content = [
messages.BinaryContent(data=base64.b64decode(content.data), media_type=content.mimeType)
]
else:
raise NotImplementedError(f'Unsupported user content type: {type(content).__name__}')

request_parts.append(messages.UserPromptPart(content=user_part_content))
else:
Expand All @@ -47,7 +48,10 @@ def map_from_mcp_params(params: mcp_types.CreateMessageRequestParams) -> list[me
pai_messages.append(messages.ModelRequest(parts=request_parts))
request_parts = []

response_parts.append(map_from_sampling_content(content))
if isinstance(content, (mcp_types.TextContent, mcp_types.ImageContent, mcp_types.AudioContent)):
response_parts.append(map_from_sampling_content(content))
else:
raise NotImplementedError(f'Unsupported assistant content type: {type(content).__name__}')

if response_parts:
pai_messages.append(messages.ModelResponse(parts=response_parts))
Expand Down
25 changes: 15 additions & 10 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def build( # noqa: C901
if allows_image:
outputs = [output for output in outputs if output is not _messages.BinaryImage]

if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
if output := next((cast(NativeOutput[OutputDataT], o) for o in outputs if isinstance(o, NativeOutput)), None):
if len(outputs) > 1:
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover

Expand All @@ -275,7 +275,9 @@ def build( # noqa: C901
allows_deferred_tools=allows_deferred_tools,
allows_image=allows_image,
)
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
elif output := next(
(cast(PromptedOutput[OutputDataT], o) for o in outputs if isinstance(o, PromptedOutput)), None
):
if len(outputs) > 1:
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover

Expand Down Expand Up @@ -308,9 +310,9 @@ def build( # noqa: C901
if output is str:
text_outputs.append(cast(type[str], output))
elif isinstance(output, TextOutput):
text_outputs.append(output)
text_outputs.append(cast(TextOutput[OutputDataT], output))
elif isinstance(output, ToolOutput):
tool_outputs.append(output)
tool_outputs.append(cast(ToolOutput[OutputDataT], output))
elif isinstance(output, NativeOutput):
# We can never get here because this is checked for above.
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
Expand Down Expand Up @@ -897,7 +899,7 @@ def build(
description = output.description
strict = output.strict

output = output.output
output = cast(OutputTypeOrFunction[OutputDataT], output.output)

description = description or default_description
if strict is None:
Expand Down Expand Up @@ -991,7 +993,7 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
outputs: Sequence[OutputSpec[T]]
if isinstance(output_spec, Sequence):
outputs = output_spec
outputs = cast(Sequence[OutputSpec[T]], output_spec)
else:
outputs = (output_spec,)

Expand All @@ -1009,20 +1011,23 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem
def types_from_output_spec(output_spec: OutputSpec[T]) -> Sequence[T | type[str]]:
outputs: Sequence[OutputSpec[T]]
if isinstance(output_spec, Sequence):
outputs = output_spec
outputs = cast(Sequence[OutputSpec[T]], output_spec)
else:
outputs = (output_spec,)

outputs_flat: list[T | type[str]] = []
for output in outputs:
if isinstance(output, NativeOutput):
outputs_flat.extend(types_from_output_spec(output.outputs))
native_outputs = cast(OutputSpec[T], output.outputs) # pyright: ignore[reportUnknownMemberType]
outputs_flat.extend(types_from_output_spec(native_outputs))
elif isinstance(output, PromptedOutput):
outputs_flat.extend(types_from_output_spec(output.outputs))
prompted_outputs = cast(OutputSpec[T], output.outputs) # pyright: ignore[reportUnknownMemberType]
outputs_flat.extend(types_from_output_spec(prompted_outputs))
elif isinstance(output, TextOutput):
outputs_flat.append(str)
elif isinstance(output, ToolOutput):
outputs_flat.extend(types_from_output_spec(output.output))
tool_output = cast(OutputSpec[T], output.output) # pyright: ignore[reportUnknownMemberType]
outputs_flat.extend(types_from_output_spec(tool_output))
elif union_types := _utils.get_union_args(output):
outputs_flat.extend(union_types)
elif inspect.isfunction(output) or inspect.ismethod(output):
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class ModelResponsePartsManager:
Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
"""

_parts: list[ManagedPart] = field(default_factory=list, init=False)
_parts: list[ManagedPart] = field(default_factory=list[ManagedPart], init=False)
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict[VendorId, int], init=False)
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""

def get_parts(self) -> list[ModelResponsePart]:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RunContext(Generic[RunContextAgentDepsT]):
"""LLM usage associated with the run."""
prompt: str | Sequence[_messages.UserContent] | None = None
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
messages: list[_messages.ModelMessage] = field(default_factory=list[_messages.ModelMessage])
"""Messages exchanged in the conversation so far."""
validation_context: Any = None
"""Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for tool args and run outputs."""
Expand All @@ -48,7 +48,7 @@ class RunContext(Generic[RunContextAgentDepsT]):
"""Whether to include the content of the messages in the trace."""
instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION
"""Instrumentation settings version, if instrumentation is enabled."""
retries: dict[str, int] = field(default_factory=dict)
retries: dict[str, int] = field(default_factory=dict[str, int])
"""Number of retries for each tool so far."""
tool_call_id: str | None = None
"""The ID of the tool call."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ToolManager(Generic[AgentDepsT]):
"""The agent run context for a specific run step."""
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
failed_tools: set[str] = field(default_factory=set[str])
"""Names of tools that failed in this run step."""

@classmethod
Expand Down
7 changes: 4 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeAlias,
TypeGuard,
TypeVar,
cast,
get_args,
get_origin,
overload,
Expand Down Expand Up @@ -220,7 +221,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
if task is None:
# anext(aiter) returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
# so far, this doesn't seem to be a problem
task = asyncio.create_task(anext(aiterator)) # pyright: ignore[reportArgumentType]
task = asyncio.create_task(anext(aiterator)) # pyright: ignore[reportArgumentType, reportUnknownVariableType]

# we use asyncio.wait to avoid cancelling the coroutine if it's not done
done, _ = await asyncio.wait((task,), timeout=wait_time)
Expand Down Expand Up @@ -411,7 +412,7 @@ def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func

return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))


def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
Expand All @@ -431,7 +432,7 @@ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, s

# Handle arrays
if 'items' in s and isinstance(s['items'], dict):
items: dict[str, Any] = s['items']
items = cast(dict[str, Any], s['items'])
_update_mapped_json_schema_refs(items, name_mapping)
if 'prefixItems' in s:
prefix_items: list[dict[str, Any]] = s['prefixItems']
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class StreamedResponseSync:

_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
_queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field(
default_factory=queue.Queue, init=False
default_factory=queue.Queue[messages.ModelResponseStreamEvent | Exception | None], init=False
)
_thread: threading.Thread | None = field(default=None, init=False)
_stream_response: StreamedResponse | None = field(default=None, init=False)
Expand Down
8 changes: 5 additions & 3 deletions pydantic_ai_slim/pydantic_ai/format_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class _ToXml:
include_field_info: Literal['once'] | bool
# a map of Pydantic and dataclasses Field paths to their metadata:
# a field unique string representation and its class
_fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(default_factory=dict)
_fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(
default_factory=dict[str, tuple[str, FieldInfo | ComputedFieldInfo]]
)
# keep track of fields we have extracted attributes from
_included_fields: set[str] = field(default_factory=set)
_included_fields: set[str] = field(default_factory=set[str])
# keep track of class names for dataclasses and Pydantic models, that occur in lists
_element_names: dict[str, str] = field(default_factory=dict)
_element_names: dict[str, str] = field(default_factory=dict[str, str])
# flag for parsing dataclasses and Pydantic models once
_is_info_extracted: bool = False
_FIELD_ATTRIBUTES = ('title', 'description')
Expand Down
Loading
Loading