diff --git a/pydantic_ai_slim/pydantic_ai/history_processors.py b/pydantic_ai_slim/pydantic_ai/history_processors.py new file mode 100644 index 0000000000..abb8a3253a --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/history_processors.py @@ -0,0 +1,136 @@ +"""Built-in history processor functions for common message history repair tasks. + +These functions can be passed directly to `Agent(history_processors=[...])` or +used with `capabilities.HistoryProcessor(processor=...)`. +""" + +from __future__ import annotations + +import logging +from dataclasses import replace + +from pydantic_ai import messages as _messages + +__all__ = ('repair_orphaned_tool_parts',) + +logger = logging.getLogger(__name__) + + +def repair_orphaned_tool_parts( + messages: list[_messages.ModelMessage], +) -> list[_messages.ModelMessage]: + """Remove orphaned tool call/return parts from message history. + + Multi-turn agent conversations can accumulate structurally invalid history + when tool calls and their corresponding results become mismatched. Common + causes include streaming timeouts, deferred tool result drops, and history + trimming by other processors. + + Providers like Anthropic strictly enforce that every `ToolCallPart` has a + matching `ToolReturnPart` (or `RetryPromptPart`) and vice versa; orphaned + entries cause 400 errors. + + This processor performs a two-pass repair: + + 1. **Orphaned returns/retries**: `ToolReturnPart` or `RetryPromptPart` whose + `tool_call_id` does not match any preceding `ToolCallPart` are removed. + 2. **Orphaned calls**: `ToolCallPart` whose `tool_call_id` does not match + any following `ToolReturnPart` or `RetryPromptPart` are removed. + + Empty messages (all parts removed) are dropped entirely. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_ai.history_processors import repair_orphaned_tool_parts + + agent = Agent('openai:gpt-5.2', history_processors=[repair_orphaned_tool_parts]) + ``` + """ + call_ids = _collect_tool_call_ids(messages) + return_ids = _collect_tool_return_ids(messages) + + repaired: list[_messages.ModelMessage] = [] + for message in messages: + if isinstance(message, _messages.ModelRequest): + result = _repair_request(message, call_ids) + else: + result = _repair_response(message, return_ids) + if result is not None: + repaired.append(result) + + return repaired + + +def _collect_tool_call_ids(messages: list[_messages.ModelMessage]) -> set[str]: + """Collect all tool_call_ids from ToolCallPart in ModelResponse messages.""" + ids: set[str] = set() + for message in messages: + if isinstance(message, _messages.ModelResponse): + for part in message.parts: + if isinstance(part, _messages.ToolCallPart) and part.tool_call_id: + ids.add(part.tool_call_id) + return ids + + +def _collect_tool_return_ids(messages: list[_messages.ModelMessage]) -> set[str]: + """Collect all tool_call_ids from ToolReturnPart/RetryPromptPart in ModelRequest messages.""" + ids: set[str] = set() + for message in messages: + if isinstance(message, _messages.ModelRequest): + for part in message.parts: + if isinstance(part, (_messages.ToolReturnPart, _messages.RetryPromptPart)) and part.tool_call_id: + ids.add(part.tool_call_id) + return ids + + +def _is_orphaned_request_part(part: _messages.ModelRequestPart, call_ids: set[str]) -> bool: + """Check if a request part is orphaned (no matching tool call).""" + if isinstance(part, _messages.ToolReturnPart): + return part.tool_call_id not in call_ids + if isinstance(part, _messages.RetryPromptPart): + return part.tool_name is not None and part.tool_call_id not in call_ids + return False + + +def _repair_request( + message: _messages.ModelRequest, + call_ids: set[str], +) -> _messages.ModelRequest | None: + """Remove orphaned ToolReturnPart/RetryPromptPart from a ModelRequest.""" + kept: list[_messages.ModelRequestPart] = [] + for part in message.parts: + if _is_orphaned_request_part(part, call_ids): + logger.debug( + 'Removing orphaned %s with tool_call_id=%r (no matching ToolCallPart)', + type(part).__name__, + getattr(part, 'tool_call_id', None), + ) + continue + kept.append(part) + if not kept: + return None + if len(kept) != len(message.parts): + return replace(message, parts=kept) + return message + + +def _repair_response( + message: _messages.ModelResponse, + return_ids: set[str], +) -> _messages.ModelResponse | None: + """Remove orphaned ToolCallPart from a ModelResponse.""" + kept: list[_messages.ModelResponsePart] = [] + for part in message.parts: + if isinstance(part, _messages.ToolCallPart) and part.tool_call_id not in return_ids: + logger.debug( + 'Removing orphaned ToolCallPart with tool_call_id=%r (no matching return)', + part.tool_call_id, + ) + continue + kept.append(part) + if not kept: + return None + if len(kept) != len(message.parts): + return replace(message, parts=kept) + return message diff --git a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py index 0b6fed34e9..f5158ae14d 100644 --- a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py +++ b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py @@ -27,26 +27,37 @@ class GradingOutput(BaseModel, populate_by_name=True): """The output of a grading operation.""" - reason: str + reason: str = Field( + description='A concise 1-2 sentence explanation of why the output passed or failed.', + ) pass_: bool = Field(validation_alias='pass', serialization_alias='pass') score: float +_JUDGE_REASON_INSTRUCTION = ( + 'The "reason" field must be a concise 1-2 sentence summary of your verdict. ' + 'Do not include your reasoning process, self-corrections, or re-checking in the reason. ' + 'State only the final conclusion.' +) + + _judge_output_agent = Agent( name='judge_output', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Hello world Content contains a greeting - {"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0} + {{"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0}} Avast ye swabs, repel the invaders! Does not speak like a pirate - {"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0} + {{"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -73,20 +84,22 @@ async def judge_output( _judge_input_output_agent = Agent( name='judge_input_output', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Hello world Hello Content contains a greeting word which is present in the input - {"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0} + {{"reason": "the content contains the word 'Hello'", "pass": true, "score": 1.0}} Pirate Avast ye swabs, repel the invaders! Does not speak in the style described by the input - {"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0} + {{"reason": "'avast ye' is a common pirate term", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -115,8 +128,10 @@ async def judge_input_output( _judge_input_output_expected_agent = Agent( name='judge_input_output_expected', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input, expected output, and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided input, expected output, and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: @@ -124,13 +139,13 @@ async def judge_input_output( Blue Cerulean The output is consistent with the expected output but doesn't have to match exactly - {"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0} + {{"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0}} How many legs does a spider have? 8 Six The output is factually consistent with the expected output - {"reason": "Spiders have 8 legs", "pass": false, "score": 0.0} + {{"reason": "Spiders have 8 legs", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, @@ -162,20 +177,22 @@ async def judge_input_output_expected( _judge_output_expected_agent = Agent( name='judge_output_expected', system_prompt=dedent( - """ - You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided expected output and output, then the output passes the test. You respond with a JSON object with this structure: {reason: string, pass: boolean, score: number} + f""" + You are grading output according to a user-specified rubric. If the statement in the rubric is true for the provided expected output and output, then the output passes the test. You respond with a JSON object with this structure: {{reason: string, pass: boolean, score: number}} + + {_JUDGE_REASON_INSTRUCTION} Examples: Blue Cerulean The output should be a shade of the expected output color - {"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0} + {{"reason": "'Cerulean' is a shade of blue", "pass": true, "score": 1.0}} 8 Six The output should be a number written in words which matches the number written in digits in the expected output - {"reason": "The output is 'Six' which is a different number than 8", "pass": false, "score": 0.0} + {{"reason": "The output is 'Six' which is a different number than 8", "pass": false, "score": 0.0}} """ ), output_type=GradingOutput, diff --git a/tests/test_history_processors.py b/tests/test_history_processors.py new file mode 100644 index 0000000000..6dcdd2515b --- /dev/null +++ b/tests/test_history_processors.py @@ -0,0 +1,199 @@ +"""Tests for built-in history processor functions.""" + +from __future__ import annotations + +from pydantic_ai.history_processors import repair_orphaned_tool_parts +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + RetryPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) + + +def test_no_changes_needed(): + """Matched pairs pass through untouched.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[ToolCallPart(tool_name='get_data', tool_call_id='call_1')]), + ModelRequest(parts=[ToolReturnPart(tool_name='get_data', content='result', tool_call_id='call_1')]), + ModelResponse(parts=[TextPart(content='done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result[0] == messages[0] + assert result[1] == messages[1] + assert result[2] == messages[2] + assert result[3] == messages[3] + + +def test_orphaned_tool_return_removed(): + """ToolReturnPart with no matching ToolCallPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='unknown', content='orphan', tool_call_id='call_missing'), + UserPromptPart(content='continue'), + ] + ), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 3 + assert len(result[1].parts) == 1 + assert isinstance(result[1].parts[0], UserPromptPart) + + +def test_orphaned_retry_prompt_removed(): + """RetryPromptPart with no matching ToolCallPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelRequest( + parts=[ + RetryPromptPart(content='try again', tool_name='missing_tool', tool_call_id='call_gone'), + ] + ), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 2 + + +def test_orphaned_tool_call_removed(): + """ToolCallPart with no matching ToolReturnPart or RetryPromptPart is removed.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse( + parts=[ + TextPart(content='Let me call a tool'), + ToolCallPart(tool_name='timed_out', tool_call_id='call_orphan'), + ] + ), + ModelRequest(parts=[UserPromptPart(content='what happened?')]), + ModelResponse(parts=[TextPart(content='sorry about that')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + response = result[1] + assert isinstance(response, ModelResponse) + assert len(response.parts) == 1 + assert isinstance(response.parts[0], TextPart) + + +def test_empty_message_removed(): + """Messages with all parts removed are dropped entirely.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[ToolCallPart(tool_name='lost', tool_call_id='call_lost')]), + ModelRequest(parts=[ToolReturnPart(tool_name='ghost', content='data', tool_call_id='call_ghost')]), + ModelResponse(parts=[TextPart(content='end')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 2 + assert isinstance(result[0].parts[0], UserPromptPart) + assert isinstance(result[1], ModelResponse) + assert isinstance(result[1].parts[0], TextPart) + + +def test_multiple_matched_pairs(): + """Multiple valid tool call/return pairs are preserved.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='do work')]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='a', tool_call_id='id_a'), + ToolCallPart(tool_name='b', tool_call_id='id_b'), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='a', content='result_a', tool_call_id='id_a'), + ToolReturnPart(tool_name='b', content='result_b', tool_call_id='id_b'), + ] + ), + ModelResponse(parts=[TextPart(content='all done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages + + +def test_mixed_orphans_and_valid(): + """Only orphaned parts are removed; valid pairs remain.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='go')]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='valid', tool_call_id='id_ok'), + ToolCallPart(tool_name='orphan_call', tool_call_id='id_orphan'), + ] + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='valid', content='good', tool_call_id='id_ok'), + ToolReturnPart(tool_name='orphan_return', content='bad', tool_call_id='id_no_call'), + ] + ), + ModelResponse(parts=[TextPart(content='done')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + + response = result[1] + assert isinstance(response, ModelResponse) + assert len(response.parts) == 1 + assert isinstance(response.parts[0], ToolCallPart) + assert response.parts[0].tool_call_id == 'id_ok' + + request = result[2] + assert isinstance(request, ModelRequest) + assert len(request.parts) == 1 + assert isinstance(request.parts[0], ToolReturnPart) + assert request.parts[0].tool_call_id == 'id_ok' + + +def test_retry_prompt_matches_call(): + """RetryPromptPart with matching ToolCallPart is preserved.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='try')]), + ModelResponse(parts=[ToolCallPart(tool_name='flaky', tool_call_id='id_retry')]), + ModelRequest(parts=[RetryPromptPart(content='bad args', tool_name='flaky', tool_call_id='id_retry')]), + ModelResponse(parts=[TextPart(content='ok')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages + + +def test_empty_history(): + """Empty input returns empty output.""" + assert repair_orphaned_tool_parts([]) == [] + + +def test_text_only_history(): + """History with no tool parts passes through unchanged.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='hello')]), + ModelResponse(parts=[TextPart(content='hi')]), + ModelRequest(parts=[UserPromptPart(content='bye')]), + ModelResponse(parts=[TextPart(content='goodbye')]), + ] + result = repair_orphaned_tool_parts(messages) + assert result == messages + + +def test_retry_prompt_without_tool_name_preserved(): + """RetryPromptPart without tool_name (output validation retry) is kept.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content='generate')]), + ModelResponse(parts=[TextPart(content='bad output')]), + ModelRequest(parts=[RetryPromptPart(content='validation failed')]), + ModelResponse(parts=[TextPart(content='better output')]), + ] + result = repair_orphaned_tool_parts(messages) + assert len(result) == 4 + assert result == messages