diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..85f5f73 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,50 @@ +# ToolOrphanRepair Capability + +## Problem + +Multi-turn conversations with tools accumulate structurally invalid message history: + +- **Orphaned tool calls**: A `ToolCallPart` in a `ModelResponse` whose result was never + recorded (streaming timeout, deferred tool dropped). The next `ModelRequest` lacks a + matching `ToolReturnPart`. +- **Orphaned builtin tool calls**: A `BuiltinToolCallPart` without a matching + `BuiltinToolReturnPart` in the same `ModelResponse`. +- **Orphaned tool returns**: A `ToolReturnPart` or `RetryPromptPart` whose + `tool_call_id` does not match any call in the preceding `ModelResponse` + (frontend-generated IDs, mismatched call IDs from deferred tools). + +Providers (especially Anthropic) reject structurally invalid history with 400 errors. +Once a conversation is poisoned, every subsequent run fails on the same history. + +## Solution + +A `ToolOrphanRepair` capability that hooks into `before_model_request` to sanitize +`request_context.messages` before each model call. + +### Repair logic (single forward pass) + +For each `ModelResponse` paired with the `ModelRequest` that follows it: + +1. **Builtin call repair**: Inject synthetic `BuiltinToolReturnPart` for any + `BuiltinToolCallPart` without a matching return in the same response. +2. **Regular call matching**: Collect `tool_call_id` values from `ToolCallPart` parts. +3. **Orphaned return stripping**: Remove `ToolReturnPart` / `RetryPromptPart` from the + request whose `tool_call_id` is not in the call set. +4. **Orphaned call patching**: Inject synthetic `ToolReturnPart` for call IDs with no + matching return or retry in the request. +5. **Empty request guard**: If stripping leaves only `SystemPromptPart` parts, insert a + placeholder `UserPromptPart("Continue.")` to maintain alternation. + +For a trailing `ModelResponse` with no following request: +- If it contains only unmatched tool calls, drop it entirely. +- If it has other content (text, builtin results), keep it but strip the tool calls. + +### Configuration + +- `orphan_call_content: str` -- content for synthetic returns (default: `"Tool call was not completed."`) +- `warn: bool` -- emit a `UserWarning` when repairs are made (default: `True`) + +## References + +- pydantic-ai #4728: "Built-in HistoryProcessor for orphaned tool call/result repair" +- pydantic-harness #82: "Tool Output Management" diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..f4090b1 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,8 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.tool_orphan_repair import ToolOrphanRepair + +__all__: list[str] = [ + 'ToolOrphanRepair', +] diff --git a/src/pydantic_harness/tool_orphan_repair.py b/src/pydantic_harness/tool_orphan_repair.py new file mode 100644 index 0000000..62821e3 --- /dev/null +++ b/src/pydantic_harness/tool_orphan_repair.py @@ -0,0 +1,300 @@ +"""Capability that sanitizes message history to fix orphaned tool calls and results. + +Multi-turn conversations with tools can accumulate structurally invalid message +history -- tool calls without matching results, or results referencing calls that +no longer exist. Providers (especially Anthropic) reject such history with a 400, +and once a conversation is "poisoned" it stays broken for every subsequent run. + +This capability hooks into ``before_model_request`` to repair the history before +each model call, so conversations self-heal instead of permanently breaking. +""" + +from __future__ import annotations + +import logging +import warnings +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Any + +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.messages import ( + BuiltinToolCallPart, + BuiltinToolReturnPart, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) + +if TYPE_CHECKING: + from pydantic_ai.messages import ModelMessage, ModelRequestPart, ModelResponsePart + from pydantic_ai.models import ModelRequestContext + from pydantic_ai.tools import RunContext + +logger = logging.getLogger(__name__) + + +_ORPHAN_CALL_CONTENT = 'Tool call was not completed.' +"""Synthetic content injected for orphaned tool calls.""" + + +@dataclass +class ToolOrphanRepair(AbstractCapability[Any]): + """Sanitizes message history to fix orphaned tool calls and results. + + Repairs three classes of structural defects: + + 1. **Orphaned tool calls** -- a ``ToolCallPart`` in a ``ModelResponse`` + with no matching ``ToolReturnPart`` or ``RetryPromptPart`` in the + following ``ModelRequest``. A synthetic return is injected. + 2. **Orphaned builtin tool calls** -- a ``BuiltinToolCallPart`` in a + ``ModelResponse`` with no matching ``BuiltinToolReturnPart`` in the + same response. A synthetic return is appended to the response. + 3. **Orphaned tool returns** -- a ``ToolReturnPart`` or ``RetryPromptPart`` + in a ``ModelRequest`` whose ``tool_call_id`` does not match any call + in the preceding ``ModelResponse``. The orphaned part is stripped. + + Additionally, trailing ``ModelResponse`` messages whose *only* actionable + parts are unmatched tool calls (no text, no builtin results) are removed + entirely, since there is no following request to receive synthetic returns. + + When stripping parts empties a ``ModelRequest``, a placeholder + ``UserPromptPart`` is inserted to maintain user/assistant message + alternation. + + Usage:: + + from pydantic_harness import ToolOrphanRepair + + agent = Agent('anthropic:claude-sonnet', capabilities=[ToolOrphanRepair()]) + """ + + orphan_call_content: str = _ORPHAN_CALL_CONTENT + """Content used for synthetic tool return parts injected for orphaned calls.""" + + warn: bool = field(default=True, kw_only=True) + """Whether to emit a warning when orphans are detected and repaired.""" + + async def before_model_request( + self, + ctx: RunContext[Any], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Sanitize ``request_context.messages`` before each model request.""" + request_context.messages[:] = _repair_messages( + request_context.messages, + orphan_call_content=self.orphan_call_content, + warn=self.warn, + ) + return request_context + + +def _repair_messages( + messages: list[ModelMessage], + *, + orphan_call_content: str = _ORPHAN_CALL_CONTENT, + warn: bool = True, +) -> list[ModelMessage]: + """Return a repaired copy of *messages* with orphaned tool calls/results fixed. + + The algorithm makes a single forward pass pairing each ``ModelResponse`` + with the ``ModelRequest`` that follows it. Within each pair it: + + * collects the set of ``tool_call_id`` values from regular ``ToolCallPart`` + parts in the response, + * strips any ``ToolReturnPart`` / ``RetryPromptPart`` in the request whose + ``tool_call_id`` is not in that set, + * injects synthetic ``ToolReturnPart`` for any call id that has no matching + return or retry in the request, + * collects ``BuiltinToolCallPart`` ids from the response and injects + synthetic ``BuiltinToolReturnPart`` for any that lack a matching + ``BuiltinToolReturnPart`` in the same response. + + A trailing ``ModelResponse`` (no following request) that contains + unmatched regular tool calls is stripped. If stripping empties a + ``ModelRequest`` of meaningful parts, a placeholder ``UserPromptPart`` + is inserted. + """ + if not messages: + return messages + + repaired: list[ModelMessage] = [] + repairs_made = 0 + + i = 0 + while i < len(messages): + msg = messages[i] + + if isinstance(msg, ModelResponse): + next_request: ModelRequest | None = None + if i + 1 < len(messages): + next_msg = messages[i + 1] + if isinstance(next_msg, ModelRequest): + next_request = next_msg + + repaired_response, repaired_request, n_repairs = _repair_response_request_pair( + msg, + next_request, + orphan_call_content=orphan_call_content, + ) + repairs_made += n_repairs + + if repaired_response is not None: + repaired.append(repaired_response) + if repaired_request is not None: + repaired.append(repaired_request) + # Skip the original request since we already processed it. + i += 2 + continue + + i += 1 + else: + # ModelRequest not preceded by a ModelResponse -- pass through. + repaired.append(msg) + i += 1 + + if warn and repairs_made: + warnings.warn( + f'ToolOrphanRepair: repaired {repairs_made} orphaned tool call/result part(s) in message history.', + UserWarning, + stacklevel=2, + ) + + return repaired + + +def _repair_response_request_pair( + response: ModelResponse, + request: ModelRequest | None, + *, + orphan_call_content: str, +) -> tuple[ModelResponse | None, ModelRequest | None, int]: + """Repair a (response, request) pair, returning the repaired versions. + + Returns ``(repaired_response, repaired_request, repair_count)``. + Either element may be ``None`` if the message was dropped entirely. + """ + repairs = 0 + + # --- Phase 1: Repair orphaned builtin tool calls within the response --- + response, builtin_repairs = _repair_builtin_tool_calls(response, orphan_call_content) + repairs += builtin_repairs + + # --- Phase 2: Collect regular tool call ids from the response --- + call_ids: set[str] = set() + call_id_to_name: dict[str, str] = {} + for part in response.parts: + if isinstance(part, ToolCallPart): + call_ids.add(part.tool_call_id) + call_id_to_name[part.tool_call_id] = part.tool_name + + # If no regular tool calls, nothing else to repair. + if not call_ids: + return response, request, repairs + + # --- Phase 3: Handle trailing response with no following request --- + if request is None: + has_non_call_content = any(not isinstance(p, ToolCallPart) for p in response.parts) + if has_non_call_content: + # Keep the response but strip the dangling tool call parts. + new_resp_parts: list[ModelResponsePart] = [p for p in response.parts if not isinstance(p, ToolCallPart)] + for cid in sorted(call_ids): + logger.debug('Stripped orphaned tool call %r from trailing response (text content kept)', cid) + repairs += len(call_ids) + return replace(response, parts=new_resp_parts), None, repairs + else: + # Response is nothing but unmatched tool calls -- drop it entirely. + logger.debug( + 'Dropped trailing response containing only orphaned tool calls: %s', + ', '.join(sorted(call_ids)), + ) + repairs += len(call_ids) + return None, None, repairs + + # --- Phase 4: Strip orphaned returns from the request --- + matched_ids: set[str] = set() + kept_parts: list[ModelRequestPart] = [] + + for part in request.parts: + if isinstance(part, ToolReturnPart | RetryPromptPart): + if part.tool_call_id in call_ids: + matched_ids.add(part.tool_call_id) + kept_parts.append(part) + else: + part_type = 'RetryPromptPart' if isinstance(part, RetryPromptPart) else 'ToolReturnPart' + logger.debug( + 'Stripped orphaned %s for tool_call_id %r (no matching call in preceding response)', + part_type, + part.tool_call_id, + ) + repairs += 1 + else: + kept_parts.append(part) + + # --- Phase 5: Inject synthetic returns for orphaned calls --- + orphaned_call_ids = call_ids - matched_ids + for call_id in sorted(orphaned_call_ids): + logger.debug( + 'Injected synthetic ToolReturnPart for orphaned call %r (tool %r)', + call_id, + call_id_to_name[call_id], + ) + kept_parts.append( + ToolReturnPart( + tool_name=call_id_to_name[call_id], + content=orphan_call_content, + tool_call_id=call_id, + ) + ) + repairs += 1 + + # --- Phase 6: Ensure the request has non-system parts --- + non_system_parts = [p for p in kept_parts if not isinstance(p, SystemPromptPart)] + if not non_system_parts: # pragma: no cover – defensive; Phase 5 always injects non-system parts + logger.debug('Inserted placeholder UserPromptPart to maintain message alternation') + kept_parts.append(UserPromptPart(content='Continue.')) + repairs += 1 + + return response, replace(request, parts=kept_parts), repairs + + +def _repair_builtin_tool_calls( + response: ModelResponse, + orphan_call_content: str, +) -> tuple[ModelResponse, int]: + """Inject synthetic ``BuiltinToolReturnPart`` for orphaned ``BuiltinToolCallPart`` parts. + + Builtin tool calls and returns both live inside the same ``ModelResponse``. + """ + builtin_call_ids: dict[str, str] = {} # call_id -> tool_name + builtin_return_ids: set[str] = set() + + for part in response.parts: + if isinstance(part, BuiltinToolCallPart): + builtin_call_ids[part.tool_call_id] = part.tool_name + elif isinstance(part, BuiltinToolReturnPart): + builtin_return_ids.add(part.tool_call_id) + + orphaned = set(builtin_call_ids) - builtin_return_ids + if not orphaned: + return response, 0 + + new_parts: list[ModelResponsePart] = list(response.parts) + for call_id in sorted(orphaned): + logger.debug( + 'Injected synthetic BuiltinToolReturnPart for orphaned builtin call %r (tool %r)', + call_id, + builtin_call_ids[call_id], + ) + new_parts.append( + BuiltinToolReturnPart( + tool_name=builtin_call_ids[call_id], + content=orphan_call_content, + tool_call_id=call_id, + ) + ) + + return replace(response, parts=new_parts), len(orphaned) diff --git a/tests/test_tool_orphan_repair.py b/tests/test_tool_orphan_repair.py new file mode 100644 index 0000000..c9f118a --- /dev/null +++ b/tests/test_tool_orphan_repair.py @@ -0,0 +1,658 @@ +"""Tests for ToolOrphanRepair capability.""" + +from __future__ import annotations + +import logging +import warnings + +import pytest +from pydantic_ai.messages import ( + BuiltinToolCallPart, + BuiltinToolReturnPart, + ModelMessage, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) + +from pydantic_harness.tool_orphan_repair import ( + ToolOrphanRepair, + _repair_messages, # pyright: ignore[reportPrivateUsage] +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +ModelRequestPart = SystemPromptPart | UserPromptPart | ToolReturnPart | RetryPromptPart + + +def _user_request(text: str = 'hello') -> ModelRequest: + return ModelRequest(parts=[UserPromptPart(content=text)]) + + +def _tool_call_response(*calls: tuple[str, str]) -> ModelResponse: + """Create a response with ToolCallParts: (tool_name, tool_call_id).""" + return ModelResponse(parts=[ToolCallPart(tool_name=n, args='{}', tool_call_id=cid) for n, cid in calls]) + + +def _tool_return_request(*returns: tuple[str, str], extra_parts: list[ModelRequestPart] | None = None) -> ModelRequest: + """Create a request with ToolReturnParts: (tool_name, tool_call_id).""" + parts: list[ModelRequestPart] = [ToolReturnPart(tool_name=n, content='ok', tool_call_id=cid) for n, cid in returns] + if extra_parts: # pragma: no cover – convenience parameter unused so far + parts.extend(extra_parts) + return ModelRequest(parts=parts) + + +# --------------------------------------------------------------------------- +# No-op / passthrough +# --------------------------------------------------------------------------- + + +class TestNoRepairsNeeded: + def test_empty_messages(self) -> None: + result: list[ModelMessage] = _repair_messages([], warn=False) + assert result == [] + + def test_single_user_request(self) -> None: + msgs: list[ModelMessage] = [_user_request()] + assert _repair_messages(msgs, warn=False) == msgs + + def test_clean_tool_round_trip(self) -> None: + """Response with tool call followed by request with matching return -- no repair.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + _tool_return_request(('get_weather', 'call_1')), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 3 + # All messages pass through unchanged. + assert result[0] is msgs[0] + assert result[1] is msgs[1] + + def test_response_without_tool_calls(self) -> None: + """A plain text response needs no repair.""" + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse(parts=[TextPart(content='Sure!')]), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 2 + + def test_clean_builtin_tool_round_trip(self) -> None: + """BuiltinToolCallPart with matching BuiltinToolReturnPart in same response.""" + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='code_exec', args='{}', tool_call_id='bc_1'), + BuiltinToolReturnPart(tool_name='code_exec', content='result', tool_call_id='bc_1'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# Orphaned tool calls (call without matching return) +# --------------------------------------------------------------------------- + + +class TestOrphanedToolCalls: + def test_injects_synthetic_return(self) -> None: + """Orphaned ToolCallPart gets a synthetic ToolReturnPart in the next request.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + # Request has no return for call_1. + _user_request('what now?'), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 3 + + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'call_1' + assert return_parts[0].tool_name == 'get_weather' + assert return_parts[0].content == 'Tool call was not completed.' + + def test_injects_return_for_multiple_orphaned_calls(self) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('tool_a', 'ca'), ('tool_b', 'cb')), + _user_request('continue'), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 2 + assert {p.tool_call_id for p in return_parts} == {'ca', 'cb'} + + def test_partial_match_injects_only_missing(self) -> None: + """One call matched, one orphaned -- only the orphan gets a synthetic return.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('tool_a', 'ca'), ('tool_b', 'cb')), + _tool_return_request(('tool_a', 'ca')), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 2 + ids = {p.tool_call_id for p in return_parts} + assert ids == {'ca', 'cb'} + + def test_retry_prompt_counts_as_match(self) -> None: + """A RetryPromptPart with matching tool_call_id counts as a match.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + ModelRequest( + parts=[ + RetryPromptPart(content='bad args', tool_name='get_weather', tool_call_id='call_1'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 3 + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + # No synthetic return injected -- RetryPromptPart matched the call. + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 0 + + def test_custom_orphan_content(self) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + _user_request('next'), + ] + result = _repair_messages(msgs, orphan_call_content='timed out', warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert return_parts[0].content == 'timed out' + + +# --------------------------------------------------------------------------- +# Orphaned builtin tool calls +# --------------------------------------------------------------------------- + + +class TestOrphanedBuiltinToolCalls: + def test_injects_builtin_return_in_same_response(self) -> None: + """Orphaned BuiltinToolCallPart gets a BuiltinToolReturnPart in the same response.""" + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='code_exec', args='print(1)', tool_call_id='bc_1'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 2 + repaired_response = result[1] + assert isinstance(repaired_response, ModelResponse) + builtin_returns = [p for p in repaired_response.parts if isinstance(p, BuiltinToolReturnPart)] + assert len(builtin_returns) == 1 + assert builtin_returns[0].tool_call_id == 'bc_1' + assert builtin_returns[0].tool_name == 'code_exec' + assert builtin_returns[0].content == 'Tool call was not completed.' + + +# --------------------------------------------------------------------------- +# Orphaned tool returns (return without matching call) +# --------------------------------------------------------------------------- + + +class TestOrphanedToolReturns: + def test_strips_return_with_no_matching_call(self) -> None: + """ToolReturnPart whose call_id doesn't match any call is stripped.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='get_weather', content='ok', tool_call_id='call_1'), + ToolReturnPart(tool_name='ghost', content='orphaned', tool_call_id='no_match'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'call_1' + + def test_strips_retry_prompt_with_no_matching_call(self) -> None: + """RetryPromptPart whose call_id doesn't match any call is stripped.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='get_weather', content='ok', tool_call_id='call_1'), + RetryPromptPart(content='retry me', tool_name='phantom', tool_call_id='no_match'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + retry_parts = [p for p in repaired_request.parts if isinstance(p, RetryPromptPart)] + assert len(retry_parts) == 0 + + +# --------------------------------------------------------------------------- +# Trailing response with unmatched tool calls +# --------------------------------------------------------------------------- + + +class TestTrailingResponse: + def test_drops_trailing_response_with_only_tool_calls(self) -> None: + """A trailing response containing only unmatched tool calls is dropped entirely.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 1 + assert isinstance(result[0], ModelRequest) + + def test_keeps_trailing_response_with_text_strips_calls(self) -> None: + """Trailing response with text + tool calls: keep text, strip calls.""" + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + TextPart(content='Let me check...'), + ToolCallPart(tool_name='fetch', args='{}', tool_call_id='c1'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 2 + repaired_response = result[1] + assert isinstance(repaired_response, ModelResponse) + assert len(repaired_response.parts) == 1 + assert isinstance(repaired_response.parts[0], TextPart) + + +# --------------------------------------------------------------------------- +# Empty request after stripping +# --------------------------------------------------------------------------- + + +class TestEmptyRequestPlaceholder: + def test_orphaned_return_replaced_by_synthetic(self) -> None: + """Request with only an orphaned return: the orphan is stripped and a synthetic return injected.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + # Request has a return for the wrong id -- orphaned. + ModelRequest( + parts=[ + ToolReturnPart(tool_name='ghost', content='orphaned', tool_call_id='wrong_id'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + # The orphaned return was stripped, but a synthetic return for call_1 was injected. + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'call_1' + assert return_parts[0].content == 'Tool call was not completed.' + + def test_system_prompt_only_request_gets_synthetic_return(self) -> None: + """A request with SystemPromptPart + orphaned return: synthetic return injected, system prompt kept.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + ModelRequest( + parts=[ + SystemPromptPart(content='You are helpful.'), + ToolReturnPart(tool_name='ghost', content='orphaned', tool_call_id='wrong_id'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + # System prompt kept, orphaned return stripped, synthetic return injected. + system_parts = [p for p in repaired_request.parts if isinstance(p, SystemPromptPart)] + assert len(system_parts) == 1 + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'c1' + + +# --------------------------------------------------------------------------- +# Warning behavior +# --------------------------------------------------------------------------- + + +class TestWarnings: + def test_emits_warning_when_repairs_made(self) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + _user_request('next'), + ] + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + _repair_messages(msgs, warn=True) + + assert len(w) == 1 + assert 'ToolOrphanRepair' in str(w[0].message) + assert '1 orphaned' in str(w[0].message) + + def test_no_warning_when_clean(self) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse(parts=[TextPart(content='hi')]), + ] + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + _repair_messages(msgs, warn=True) + + assert len(w) == 0 + + def test_no_warning_when_disabled(self) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + _user_request('next'), + ] + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + _repair_messages(msgs, warn=False) + + assert len(w) == 0 + + +# --------------------------------------------------------------------------- +# Complex / multi-turn scenarios +# --------------------------------------------------------------------------- + + +class TestMultiTurnScenarios: + def test_multiple_response_request_pairs(self) -> None: + """Two consecutive tool round-trips, second one orphaned.""" + msgs: list[ModelMessage] = [ + _user_request(), + # First round-trip: clean. + _tool_call_response(('tool_a', 'ca')), + _tool_return_request(('tool_a', 'ca')), + # Second round-trip: orphaned. + _tool_call_response(('tool_b', 'cb')), + _user_request('done'), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 5 + + # The second request should have a synthetic return. + repaired = result[4] + assert isinstance(repaired, ModelRequest) + return_parts = [p for p in repaired.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'cb' + + def test_mixed_orphaned_and_clean_in_same_request(self) -> None: + """Request has one valid return and one orphaned return.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('tool_a', 'ca')), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='tool_a', content='ok', tool_call_id='ca'), + ToolReturnPart(tool_name='phantom', content='bad', tool_call_id='no_match'), + UserPromptPart(content='next step'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired = result[2] + assert isinstance(repaired, ModelRequest) + return_parts = [p for p in repaired.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'ca' + user_parts = [p for p in repaired.parts if isinstance(p, UserPromptPart)] + assert len(user_parts) == 1 + + def test_request_not_following_response_passes_through(self) -> None: + """A ModelRequest at position 0 (not preceded by a response) passes through.""" + msgs: list[ModelMessage] = [ + _user_request('first'), + ModelResponse(parts=[TextPart(content='hi')]), + _user_request('second'), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 3 + + def test_builtin_and_regular_orphans_in_same_response(self) -> None: + """Response has both an orphaned builtin call and an orphaned regular call.""" + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='code_exec', args='x=1', tool_call_id='bc_1'), + ToolCallPart(tool_name='get_weather', args='{}', tool_call_id='tc_1'), + ] + ), + _user_request('continue'), + ] + result = _repair_messages(msgs, warn=False) + assert len(result) == 3 + + # Response should have the builtin call + synthetic builtin return. + repaired_response = result[1] + assert isinstance(repaired_response, ModelResponse) + builtin_returns = [p for p in repaired_response.parts if isinstance(p, BuiltinToolReturnPart)] + assert len(builtin_returns) == 1 + assert builtin_returns[0].tool_call_id == 'bc_1' + + # Request should have a synthetic regular return. + repaired_request = result[2] + assert isinstance(repaired_request, ModelRequest) + return_parts = [p for p in repaired_request.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + assert return_parts[0].tool_call_id == 'tc_1' + + def test_preserves_existing_user_prompt_parts(self) -> None: + """Existing UserPromptPart in a request is preserved alongside injected returns.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + ModelRequest( + parts=[ + UserPromptPart(content='user text'), + ] + ), + ] + result = _repair_messages(msgs, warn=False) + repaired = result[2] + assert isinstance(repaired, ModelRequest) + user_parts = [p for p in repaired.parts if isinstance(p, UserPromptPart)] + assert len(user_parts) == 1 + assert user_parts[0].content == 'user text' + return_parts = [p for p in repaired.parts if isinstance(p, ToolReturnPart)] + assert len(return_parts) == 1 + + +# --------------------------------------------------------------------------- +# Debug logging +# --------------------------------------------------------------------------- + + +class TestDebugLogging: + def test_logs_injected_synthetic_return(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + _user_request('next'), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any('Injected synthetic ToolReturnPart' in r.message and 'call_1' in r.message for r in caplog.records) + + def test_logs_stripped_orphaned_return(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='get_weather', content='ok', tool_call_id='call_1'), + ToolReturnPart(tool_name='ghost', content='orphaned', tool_call_id='no_match'), + ] + ), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any('Stripped orphaned ToolReturnPart' in r.message and 'no_match' in r.message for r in caplog.records) + + def test_logs_stripped_orphaned_retry_prompt(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('get_weather', 'call_1')), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='get_weather', content='ok', tool_call_id='call_1'), + RetryPromptPart(content='retry', tool_name='phantom', tool_call_id='no_match'), + ] + ), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any('Stripped orphaned RetryPromptPart' in r.message and 'no_match' in r.message for r in caplog.records) + + def test_logs_dropped_trailing_response(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any('Dropped trailing response' in r.message and 'c1' in r.message for r in caplog.records) + + def test_logs_stripped_trailing_tool_calls(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + TextPart(content='Let me check...'), + ToolCallPart(tool_name='fetch', args='{}', tool_call_id='c1'), + ] + ), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any('Stripped orphaned tool call' in r.message and 'c1' in r.message for r in caplog.records) + + def test_logs_builtin_tool_call_repair(self, caplog: pytest.LogCaptureFixture) -> None: + msgs: list[ModelMessage] = [ + _user_request(), + ModelResponse( + parts=[ + BuiltinToolCallPart(tool_name='code_exec', args='print(1)', tool_call_id='bc_1'), + ] + ), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + assert any( + 'Injected synthetic BuiltinToolReturnPart' in r.message and 'bc_1' in r.message for r in caplog.records + ) + + def test_logs_placeholder_insertion(self, caplog: pytest.LogCaptureFixture) -> None: + """When all parts are stripped and only system prompt remains, a placeholder is logged.""" + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + ModelRequest( + parts=[ + SystemPromptPart(content='You are helpful.'), + ToolReturnPart(tool_name='ghost', content='orphaned', tool_call_id='wrong_id'), + ] + ), + ] + with caplog.at_level(logging.DEBUG, logger='pydantic_harness.tool_orphan_repair'): + _repair_messages(msgs, warn=False) + # The synthetic return for c1 provides a non-system part, so the placeholder + # is NOT needed here. Instead, verify the orphaned return stripping was logged. + assert any('Stripped orphaned ToolReturnPart' in r.message for r in caplog.records) + assert any('Injected synthetic ToolReturnPart' in r.message for r in caplog.records) + + +# --------------------------------------------------------------------------- +# before_model_request integration +# --------------------------------------------------------------------------- + + +class TestBeforeModelRequest: + @pytest.mark.anyio + async def test_before_model_request_repairs_messages(self) -> None: + """The capability's ``before_model_request`` hook delegates to ``_repair_messages``.""" + from unittest.mock import MagicMock + + from pydantic_ai.models import ModelRequestContext + + cap: ToolOrphanRepair = ToolOrphanRepair(warn=False) + + msgs: list[ModelMessage] = [ + _user_request(), + _tool_call_response(('fetch', 'c1')), + # Request is missing the return for c1 — should be injected. + _user_request('follow-up'), + ] + + mock_ctx = MagicMock() + request_context = MagicMock(spec=ModelRequestContext) + request_context.messages = list(msgs) + + result = await cap.before_model_request(mock_ctx, request_context) + assert result is request_context + # A synthetic ToolReturnPart for c1 should have been injected. + repaired = request_context.messages + assert any( + isinstance(p, ToolReturnPart) and p.tool_call_id == 'c1' + for msg in repaired + if isinstance(msg, ModelRequest) + for p in msg.parts + ) + + +# --------------------------------------------------------------------------- +# Consecutive ModelResponse messages (branch coverage for line 135->138) +# --------------------------------------------------------------------------- + + +class TestConsecutiveResponses: + def test_consecutive_model_responses(self) -> None: + """Two consecutive ModelResponses (no interleaved request) are handled.""" + msgs: list[ModelMessage] = [ + _user_request(), + # First response with an orphaned tool call. + _tool_call_response(('alpha', 'a1')), + # Second response immediately follows (no request in between). + ModelResponse(parts=[TextPart(content='some text')]), + _user_request('next'), + ] + repaired = _repair_messages(msgs, warn=False) + # The first response's orphaned call should be dropped since + # the next message is a ModelResponse, not a ModelRequest. + assert len(repaired) == 3 # user, text-response, user