Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def handle_text_delta(
vendor_part_id: VendorId | None,
content: str,
id: str | None = None,
provider_name: str | None = None,
provider_details: dict[str, Any] | None = None,
thinking_tags: tuple[str, str] | None = None,
ignore_leading_whitespace: bool = False,
Expand All @@ -92,6 +93,7 @@ def handle_text_delta(
a TextPart.
content: The text content to append to the appropriate TextPart.
id: An optional id for the text part.
provider_name: An optional provider name for the text part.
provider_details: An optional dictionary of provider-specific details for the text part.
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.
Expand Down Expand Up @@ -121,7 +123,7 @@ def handle_text_delta(
self._handle_embedded_thinking_end(vendor_part_id)
return
yield from self._handle_embedded_thinking_content(
existing_part, part_index, content, provider_details
existing_part, part_index, content, provider_name, provider_details
)
return
elif isinstance(existing_part, TextPart):
Expand All @@ -131,7 +133,7 @@ def handle_text_delta(

if thinking_tags and content == thinking_tags[0]:
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
yield from self._handle_embedded_thinking_start(vendor_part_id, provider_details)
yield from self._handle_embedded_thinking_start(vendor_part_id, provider_name, provider_details)
return

if existing_text_part_and_index is None:
Expand All @@ -141,13 +143,15 @@ def handle_text_delta(
return

# There is no existing text part that should be updated, so create a new one
part = TextPart(content=content, id=id, provider_details=provider_details)
part = TextPart(content=content, id=id, provider_name=provider_name, provider_details=provider_details)
new_part_index = self._append_part(part, vendor_part_id)
yield PartStartEvent(index=new_part_index, part=part)
else:
# Update the existing TextPart with the new content delta
existing_text_part, part_index = existing_text_part_and_index
part_delta = TextPartDelta(content_delta=content, provider_details=provider_details)
part_delta = TextPartDelta(
content_delta=content, provider_name=provider_name, provider_details=provider_details
)
self._parts[part_index] = part_delta.apply(existing_text_part)
yield PartDeltaEvent(index=part_index, delta=part_delta)

Expand Down Expand Up @@ -241,6 +245,7 @@ def handle_tool_call_delta(
tool_name: str | None = None,
args: str | dict[str, Any] | None = None,
tool_call_id: str | None = None,
provider_name: str | None = None,
provider_details: dict[str, Any] | None = None,
) -> ModelResponseStreamEvent | None:
"""Handle or update a tool call, creating or updating a `ToolCallPart`, `BuiltinToolCallPart`, or `ToolCallPartDelta`.
Expand All @@ -258,6 +263,7 @@ def handle_tool_call_delta(
a name match when `vendor_part_id` is None.
args: Arguments for the tool call, either as a string, a dictionary of key-value pairs, or None.
tool_call_id: An optional string representing an identifier for this tool call.
provider_name: An optional provider name for the tool call part.
provider_details: An optional dictionary of provider-specific details for the tool call part.

Returns:
Expand Down Expand Up @@ -293,7 +299,11 @@ def handle_tool_call_delta(
if existing_matching_part_and_index is None:
# No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
delta = ToolCallPartDelta(
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, provider_details=provider_details
tool_name_delta=tool_name,
args_delta=args,
tool_call_id=tool_call_id,
provider_name=provider_name,
provider_details=provider_details,
)
part = delta.as_part() or delta
new_part_index = self._append_part(part, vendor_part_id)
Expand All @@ -304,7 +314,11 @@ def handle_tool_call_delta(
# Update the existing part or delta with the new information
existing_part, part_index = existing_matching_part_and_index
delta = ToolCallPartDelta(
tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id, provider_details=provider_details
tool_name_delta=tool_name,
args_delta=args,
tool_call_id=tool_call_id,
provider_name=provider_name,
provider_details=provider_details,
)
updated_part = delta.apply(existing_part)
self._parts[part_index] = updated_part
Expand All @@ -326,6 +340,7 @@ def handle_tool_call_part(
args: str | dict[str, Any] | None,
tool_call_id: str | None = None,
id: str | None = None,
provider_name: str | None = None,
provider_details: dict[str, Any] | None = None,
) -> ModelResponseStreamEvent:
"""Immediately create or fully-overwrite a ToolCallPart with the given information.
Expand All @@ -339,6 +354,7 @@ def handle_tool_call_part(
args: The arguments for the tool call, either as a string, a dictionary, or None.
tool_call_id: An optional string identifier for this tool call.
id: An optional identifier for this tool call part.
provider_name: An optional provider name for the tool call part.
provider_details: An optional dictionary of provider-specific details for the tool call part.

Returns:
Expand All @@ -350,6 +366,7 @@ def handle_tool_call_part(
args=args,
tool_call_id=tool_call_id or _generate_tool_call_id(),
id=id,
provider_name=provider_name,
provider_details=provider_details,
)
if vendor_part_id is None:
Expand Down Expand Up @@ -420,19 +437,26 @@ def _latest_part_if_of_type(self, *part_types: type[PartT]) -> tuple[PartT, int]
return None

def _handle_embedded_thinking_start(
self, vendor_part_id: VendorId, provider_details: dict[str, Any] | None
self, vendor_part_id: VendorId, provider_name: str | None, provider_details: dict[str, Any] | None
) -> Iterator[ModelResponseStreamEvent]:
"""Handle <think> tag - create new ThinkingPart."""
self._stop_tracking_vendor_id(vendor_part_id)
part = ThinkingPart(content='', provider_details=provider_details)
part = ThinkingPart(content='', provider_name=provider_name, provider_details=provider_details)
new_index = self._append_part(part, vendor_part_id)
yield PartStartEvent(index=new_index, part=part)

def _handle_embedded_thinking_content(
self, existing_part: ThinkingPart, part_index: int, content: str, provider_details: dict[str, Any] | None
self,
existing_part: ThinkingPart,
part_index: int,
content: str,
provider_name: str | None,
provider_details: dict[str, Any] | None,
) -> Iterator[ModelResponseStreamEvent]:
"""Handle content inside <think>...</think>."""
part_delta = ThinkingPartDelta(content_delta=content, provider_details=provider_details)
part_delta = ThinkingPartDelta(
content_delta=content, provider_name=provider_name, provider_details=provider_details
)
self._parts[part_index] = part_delta.apply(existing_part)
yield PartDeltaEvent(index=part_index, delta=part_delta)

Expand Down
30 changes: 24 additions & 6 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,9 @@ class TextPart:
id: str | None = None
"""An optional identifier of the text part."""

provider_name: str | None = None
"""The name of the provider that generated the response."""

provider_details: dict[str, Any] | None = None
"""Additional data returned by the provider that can't be mapped to standard fields.

Expand Down Expand Up @@ -1148,6 +1151,12 @@ class BaseToolCallPart:

This is used by some APIs like OpenAI Responses."""

provider_name: str | None = None
"""The name of the provider that generated the response.

Tool calls are only sent back to the same provider.
"""

provider_details: dict[str, Any] | None = None
"""Additional data returned by the provider that can't be mapped to standard fields.

Expand Down Expand Up @@ -1205,12 +1214,6 @@ class BuiltinToolCallPart(BaseToolCallPart):

_: KW_ONLY

provider_name: str | None = None
"""The name of the provider that generated the response.

Built-in tool calls are only sent back to the same provider.
"""

part_kind: Literal['builtin-tool-call'] = 'builtin-tool-call'
"""Part type identifier, this is available on all parts as a discriminator."""

Expand Down Expand Up @@ -1496,6 +1499,9 @@ class TextPartDelta:

_: KW_ONLY

provider_name: str | None = None
"""The name of the provider that generated the response."""

provider_details: dict[str, Any] | None = None
"""Additional data returned by the provider that can't be mapped to standard fields.

Expand All @@ -1521,6 +1527,7 @@ def apply(self, part: ModelResponsePart) -> TextPart:
return replace(
part,
content=part.content + self.content_delta,
provider_name=self.provider_name or part.provider_name,
provider_details={**(part.provider_details or {}), **(self.provider_details or {})} or None,
)

Expand Down Expand Up @@ -1653,6 +1660,9 @@ class ToolCallPartDelta:
Note this is never treated as a delta — it can replace None, but otherwise if a
non-matching value is provided an error will be raised."""

provider_name: str | None = None
"""The name of the provider that generated the response."""

provider_details: dict[str, Any] | None = None
"""Additional data returned by the provider that can't be mapped to standard fields.

Expand All @@ -1674,6 +1684,7 @@ def as_part(self) -> ToolCallPart | None:
self.tool_name_delta,
self.args_delta,
self.tool_call_id or _generate_tool_call_id(),
provider_name=self.provider_name,
provider_details=self.provider_details,
)

Expand Down Expand Up @@ -1735,6 +1746,9 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | BuiltinToo
if self.tool_call_id:
delta = replace(delta, tool_call_id=self.tool_call_id)

if self.provider_name:
delta = replace(delta, provider_name=self.provider_name)

if self.provider_details:
merged_provider_details = {**(delta.provider_details or {}), **self.provider_details}
delta = replace(delta, provider_details=merged_provider_details)
Expand All @@ -1745,6 +1759,7 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | BuiltinToo
delta.tool_name_delta,
delta.args_delta,
delta.tool_call_id or _generate_tool_call_id(),
provider_name=delta.provider_name,
provider_details=delta.provider_details,
)

Expand All @@ -1771,6 +1786,9 @@ def _apply_to_part(self, part: ToolCallPart | BuiltinToolCallPart) -> ToolCallPa
if self.tool_call_id:
part = replace(part, tool_call_id=self.tool_call_id)

if self.provider_name:
part = replace(part, provider_name=self.provider_name)

if self.provider_details:
merged_provider_details = {**(part.provider_details or {}), **self.provider_details}
part = replace(part, provider_details=merged_provider_details)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
usage,
vendor_id=vendor_id,
vendor_details=vendor_details,
provider_name=self._provider.name,
provider_url=self.base_url,
)

Expand Down Expand Up @@ -719,6 +720,7 @@ def _process_response_from_parts(
model_name: GeminiModelName,
usage: usage.RequestUsage,
vendor_id: str | None,
provider_name: str,
provider_url: str,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
Expand All @@ -741,6 +743,7 @@ def _process_response_from_parts(
parts=items,
usage=usage,
model_name=model_name,
provider_name=provider_name,
provider_response_id=vendor_id,
provider_details=vendor_details,
provider_url=provider_url,
Expand Down
Loading
Loading