diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..2dcd617 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,55 @@ +# StuckLoopDetection Capability + +Closes #71. + +## Summary + +A `StuckLoopDetection` capability that monitors agent tool-call patterns via +capability hooks and detects when the agent is stuck in a repetitive loop. + +## Detection scenarios + +1. **Repeated calls** -- the same tool is called with the same arguments N times + consecutively (tracked in `after_model_request`). +2. **Alternating calls** -- two distinct tool+args pairs alternate A-B-A-B for N + full cycles (tracked in `after_model_request`). +3. **No-op calls** -- the same tool returns the same result N times consecutively, + even if the arguments differ (tracked in `after_tool_execute`). + +N is configurable via `max_repeated_calls` (default 3). + +## Recovery actions + +| `action` | Behavior | +|----------|----------| +| `'warn'` (default) | Raises `ModelRetry` with a descriptive message so the model receives a retry prompt asking it to change approach. | +| `'error'` | Raises `StuckLoopError` to abort the run. | + +## Per-run state + +Uses `for_run()` to return a fresh instance with empty history lists, ensuring +concurrent runs don't interfere. + +## API + +```python +from pydantic_ai import Agent +from pydantic_harness.stuck_loop_detection import StuckLoopDetection + +agent = Agent( + 'openai:gpt-4o', + capabilities=[ + StuckLoopDetection( + max_repeated_calls=3, + action='warn', + warning_message='You appear to be stuck. Try something else.', + ), + ], +) +``` + +## Files + +- `src/pydantic_harness/stuck_loop_detection.py` -- capability implementation +- `src/pydantic_harness/__init__.py` -- re-exports `StuckLoopDetection` and `StuckLoopError` +- `tests/test_stuck_loop_detection.py` -- 32 tests, 100% coverage diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..f190c13 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,9 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.stuck_loop_detection import StuckLoopDetection, StuckLoopError + +__all__: list[str] = [ + 'StuckLoopDetection', + 'StuckLoopError', +] diff --git a/src/pydantic_harness/stuck_loop_detection.py b/src/pydantic_harness/stuck_loop_detection.py new file mode 100644 index 0000000..83f901a --- /dev/null +++ b/src/pydantic_harness/stuck_loop_detection.py @@ -0,0 +1,241 @@ +"""Stuck loop detection capability for PydanticAI agents. + +Detects when an agent is stuck repeating the same actions and either warns the +model via a retry prompt or raises an error to abort the run. + +Detection scenarios: + 1. **Repeated calls**: The same tool is called with the same arguments + `max_repeated_calls` times consecutively. + 2. **Alternating calls**: Two distinct tool calls alternate back and forth + for `max_repeated_calls` full cycles (i.e. `max_repeated_calls * 2` + consecutive tool calls forming an A-B-A-B pattern). + 3. **No-op calls**: The same tool returns the same result + `max_repeated_calls` times consecutively, regardless of whether the + arguments differ. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.messages import ModelResponse, ToolCallPart + +if TYPE_CHECKING: + from pydantic_ai.models import ModelRequestContext + from pydantic_ai.tools import RunContext + + +class StuckLoopError(Exception): + """Raised when the agent is detected to be stuck in a loop. + + Attributes: + reason: A human-readable description of why the loop was detected. + """ + + reason: str + + def __init__(self, reason: str) -> None: + """Initialize with a human-readable description of the detected loop.""" + self.reason = reason + super().__init__(reason) + + +def _normalize_args(args: str | dict[str, Any] | None) -> str: + """Produce a stable string representation of tool call arguments for comparison.""" + if args is None: + return '' + if isinstance(args, str): + # Try to parse and re-serialize for canonical ordering. + try: + return json.dumps(json.loads(args), sort_keys=True) + except (json.JSONDecodeError, ValueError): + return args + return json.dumps(args, sort_keys=True) + + +def _tool_call_key(part: ToolCallPart) -> str: + """Return a hashable key representing the tool name + normalized arguments.""" + return f'{part.tool_name}::{_normalize_args(part.args)}' + + +def _detect_repeated(history: list[str], threshold: int) -> str | None: + """Detect if the last *threshold* entries are all identical.""" + if len(history) < threshold: + return None + tail = history[-threshold:] + if len(set(tail)) == 1: + return tail[0] + return None + + +def _detect_alternating(history: list[str], threshold: int) -> tuple[str, str] | None: + """Detect an A-B-A-B pattern in the tail of *history*. + + Returns the two alternating keys if found, otherwise ``None``. + A full "cycle" is A-B, so we need ``threshold * 2`` entries. + """ + needed = threshold * 2 + if len(history) < needed: + return None + tail = history[-needed:] + a, b = tail[0], tail[1] + if a == b: + return None + for i, key in enumerate(tail): + expected = a if i % 2 == 0 else b + if key != expected: + return None + return (a, b) + + +DEFAULT_WARNING_MESSAGE = 'You appear to be stuck in a loop, repeating the same action(s). Try a different approach.' + + +@dataclass +class StuckLoopDetection(AbstractCapability[Any]): + """Detects when an agent is stuck repeating the same tool calls. + + Monitors model responses for repetitive tool-call patterns and either + sends a retry prompt asking the model to change strategy (``action='warn'``) + or raises :class:`StuckLoopError` to abort the run (``action='error'``). + + Example:: + + from pydantic_ai import Agent + from pydantic_harness.stuck_loop_detection import StuckLoopDetection + + agent = Agent( + 'openai:gpt-4o', + capabilities=[StuckLoopDetection(max_repeated_calls=3)], + ) + """ + + max_repeated_calls: int = 3 + """Number of consecutive repetitions before detection triggers.""" + + action: Literal['warn', 'error'] = 'warn' + """What to do when a loop is detected. + + - ``'warn'``: Raise :class:`~pydantic_ai.exceptions.ModelRetry` so the model + receives a retry prompt asking it to try a different approach. + - ``'error'``: Raise :class:`StuckLoopError` to abort the run. + """ + + warning_message: str = DEFAULT_WARNING_MESSAGE + """The message sent to the model (or included in the error) when a loop is detected.""" + + max_history_length: int = 50 + """Maximum number of entries to keep in the call and result history lists. + + Older entries are discarded (from the left) when this limit is exceeded, + preventing unbounded memory growth during long agent runs. + """ + + # --- Per-run state (populated by ``for_run``) --- + + _call_history: list[str] = field(default_factory=lambda: list[str](), repr=False) + """Keys of recent tool calls (tool_name::normalized_args).""" + + _result_history: list[tuple[str, str]] = field(default_factory=lambda: list[tuple[str, str]](), repr=False) + """Pairs of (tool_name, repr(result)) for no-op detection.""" + + async def for_run(self, ctx: RunContext[Any]) -> StuckLoopDetection: + """Return a fresh instance with empty history for each agent run.""" + return StuckLoopDetection( + max_repeated_calls=self.max_repeated_calls, + action=self.action, + warning_message=self.warning_message, + max_history_length=self.max_history_length, + ) + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + """Track tool calls from the model response and check for loops.""" + tool_calls = [p for p in response.parts if isinstance(p, ToolCallPart)] + if not tool_calls: + return response + + for tc in tool_calls: + self._call_history.append(_tool_call_key(tc)) + + self._trim_history(self._call_history) + + # --- Check for repeated identical calls --- + reason = self._check_repeated() + if reason is None: + reason = self._check_alternating() + + if reason is not None: + self._trigger(reason) + + return response + + async def after_tool_execute( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: Any, + args: dict[str, Any], + result: Any, + ) -> Any: + """Track tool results for no-op detection.""" + result_repr = repr(result) + self._result_history.append((call.tool_name, result_repr)) + self._trim_history(self._result_history) + + reason = self._check_noop() + if reason is not None: + self._trigger(reason) + + return result + + # --- History management --- + + def _trim_history(self, history: list[Any]) -> None: + """Remove oldest entries when *history* exceeds :attr:`max_history_length`.""" + while len(history) > self.max_history_length: + history.pop(0) + + # --- Detection helpers --- + + def _check_repeated(self) -> str | None: + match = _detect_repeated(self._call_history, self.max_repeated_calls) + if match is not None: + name = match.split('::')[0] + return f'Tool `{name}` called {self.max_repeated_calls} times with identical arguments.' + return None + + def _check_alternating(self) -> str | None: + match = _detect_alternating(self._call_history, self.max_repeated_calls) + if match is not None: + a_name = match[0].split('::')[0] + b_name = match[1].split('::')[0] + return f'Alternating between `{a_name}` and `{b_name}` for {self.max_repeated_calls} cycles.' + return None + + def _check_noop(self) -> str | None: + if len(self._result_history) < self.max_repeated_calls: + return None + tail = self._result_history[-self.max_repeated_calls :] + names = {t[0] for t in tail} + results = {t[1] for t in tail} + if len(names) == 1 and len(results) == 1: + return f'Tool `{next(iter(names))}` returned the same result {self.max_repeated_calls} times.' + return None + + def _trigger(self, reason: str) -> None: + """Trigger the configured action.""" + message = f'{self.warning_message}\n\nDetected: {reason}' + if self.action == 'error': + raise StuckLoopError(message) + raise ModelRetry(message) diff --git a/tests/test_stuck_loop_detection.py b/tests/test_stuck_loop_detection.py new file mode 100644 index 0000000..12a4c58 --- /dev/null +++ b/tests/test_stuck_loop_detection.py @@ -0,0 +1,406 @@ +"""Tests for the StuckLoopDetection capability.""" +# pyright: reportPrivateUsage=false, reportArgumentType=false + +from __future__ import annotations + +import pytest +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.messages import ModelResponse, TextPart, ToolCallPart + +from pydantic_harness.stuck_loop_detection import ( + DEFAULT_WARNING_MESSAGE, + StuckLoopDetection, + StuckLoopError, + _detect_alternating, + _detect_repeated, + _normalize_args, + _tool_call_key, +) + +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestNormalizeArgs: + def test_none(self): + assert _normalize_args(None) == '' + + def test_dict(self): + assert _normalize_args({'b': 2, 'a': 1}) == '{"a": 1, "b": 2}' + + def test_json_string(self): + assert _normalize_args('{"b": 2, "a": 1}') == '{"a": 1, "b": 2}' + + def test_non_json_string(self): + assert _normalize_args('not json') == 'not json' + + def test_empty_dict(self): + assert _normalize_args({}) == '{}' + + +class TestToolCallKey: + def test_basic(self): + part = ToolCallPart(tool_name='read_file', args={'path': '/foo'}) + assert _tool_call_key(part) == 'read_file::{"path": "/foo"}' + + def test_no_args(self): + part = ToolCallPart(tool_name='get_time', args=None) + assert _tool_call_key(part) == 'get_time::' + + +class TestDetectRepeated: + def test_below_threshold(self): + assert _detect_repeated(['a', 'a'], 3) is None + + def test_at_threshold(self): + assert _detect_repeated(['a', 'a', 'a'], 3) == 'a' + + def test_above_threshold(self): + assert _detect_repeated(['b', 'a', 'a', 'a'], 3) == 'a' + + def test_no_repeat(self): + assert _detect_repeated(['a', 'b', 'a'], 3) is None + + def test_mixed_then_repeat(self): + assert _detect_repeated(['x', 'y', 'z', 'z', 'z'], 3) == 'z' + + +class TestDetectAlternating: + def test_below_threshold(self): + assert _detect_alternating(['a', 'b', 'a'], 2) is None + + def test_at_threshold(self): + assert _detect_alternating(['a', 'b', 'a', 'b'], 2) == ('a', 'b') + + def test_not_alternating(self): + assert _detect_alternating(['a', 'b', 'c', 'b'], 2) is None + + def test_same_keys(self): + # a == b should return None (that's a repeat, not alternation) + assert _detect_alternating(['a', 'a', 'a', 'a'], 2) is None + + def test_longer_pattern(self): + assert _detect_alternating(['a', 'b', 'a', 'b', 'a', 'b'], 3) == ('a', 'b') + + +# --------------------------------------------------------------------------- +# Integration-style tests for the capability hooks +# --------------------------------------------------------------------------- + + +def _make_response(*tool_calls: ToolCallPart) -> ModelResponse: + return ModelResponse(parts=list(tool_calls)) + + +def _make_text_response(text: str = 'hello') -> ModelResponse: + return ModelResponse(parts=[TextPart(content=text)]) + + +def _make_tc(name: str, args: dict[str, object] | None = None) -> ToolCallPart: + return ToolCallPart(tool_name=name, args=args) + + +class _FakeCtx: + """Minimal stand-in for RunContext — the capability only receives it, never inspects it.""" + + +class _FakeRequestContext: + """Minimal stand-in for ModelRequestContext.""" + + +class _FakeToolDef: + """Minimal stand-in for ToolDefinition.""" + + +@pytest.fixture() +def cap_warn() -> StuckLoopDetection: + """A fresh warn-mode capability with threshold 3.""" + return StuckLoopDetection(max_repeated_calls=3, action='warn') + + +@pytest.fixture() +def cap_error() -> StuckLoopDetection: + """A fresh error-mode capability with threshold 3.""" + return StuckLoopDetection(max_repeated_calls=3, action='error') + + +# --- for_run isolation --- + + +@pytest.mark.anyio() +async def test_for_run_returns_fresh_instance(cap_warn: StuckLoopDetection): + cap_warn._call_history.append('something') + fresh = await cap_warn.for_run(_FakeCtx()) # type: ignore[arg-type] + assert fresh is not cap_warn + assert fresh._call_history == [] + assert fresh.max_repeated_calls == cap_warn.max_repeated_calls + assert fresh.action == cap_warn.action + assert fresh.warning_message == cap_warn.warning_message + + +# --- Repeated call detection --- + + +@pytest.mark.anyio() +async def test_repeated_calls_warn(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc = _make_tc('read_file', {'path': '/a'}) + resp = _make_response(tc) + + # First two calls are fine. + for _ in range(2): + result = await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + assert result is resp + + # Third triggers ModelRetry. + with pytest.raises(ModelRetry, match='read_file.*identical arguments'): + await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +@pytest.mark.anyio() +async def test_repeated_calls_error(cap_error: StuckLoopDetection): + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc = _make_tc('bash', {'cmd': 'ls'}) + resp = _make_response(tc) + + for _ in range(2): + await cap_error.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + with pytest.raises(StuckLoopError, match='bash.*identical arguments'): + await cap_error.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +@pytest.mark.anyio() +async def test_different_calls_do_not_trigger(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + + for i in range(5): + tc = _make_tc('read_file', {'path': f'/file_{i}'}) + resp = _make_response(tc) + result = await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + assert result is resp + + +# --- Alternating call detection --- + + +@pytest.mark.anyio() +async def test_alternating_calls_warn(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc_a = _make_tc('read_file', {'path': '/a'}) + tc_b = _make_tc('write_file', {'path': '/b'}) + + calls = [tc_a, tc_b, tc_a, tc_b, tc_a, tc_b] + # First 5 are fine (need 6 = 3*2 for alternating detection). + for tc in calls[:5]: + resp = _make_response(tc) + await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + with pytest.raises(ModelRetry, match='Alternating'): + resp = _make_response(calls[5]) + await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +# --- No-op call detection --- + + +@pytest.mark.anyio() +async def test_noop_detection_warn(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + tc = _make_tc('search', {'query': 'foo'}) + td: object = _FakeToolDef() + + for _ in range(2): + result = await cap_warn.after_tool_execute( + ctx, call=tc, tool_def=td, args={'query': 'foo'}, result='same result' + ) # type: ignore[arg-type] + assert result == 'same result' + + with pytest.raises(ModelRetry, match='same result'): + await cap_warn.after_tool_execute(ctx, call=tc, tool_def=td, args={'query': 'foo'}, result='same result') # type: ignore[arg-type] + + +@pytest.mark.anyio() +async def test_noop_different_results_do_not_trigger(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + tc = _make_tc('search', {'query': 'foo'}) + td: object = _FakeToolDef() + + for i in range(5): + result = await cap_warn.after_tool_execute( + ctx, call=tc, tool_def=td, args={'query': 'foo'}, result=f'result_{i}' + ) # type: ignore[arg-type] + assert result == f'result_{i}' + + +@pytest.mark.anyio() +async def test_noop_different_tools_do_not_trigger(cap_warn: StuckLoopDetection): + """Even with the same result, different tool names should not trigger no-op detection.""" + ctx: object = _FakeCtx() + td: object = _FakeToolDef() + + for i in range(5): + tc = _make_tc(f'tool_{i}', {'x': 1}) + result = await cap_warn.after_tool_execute(ctx, call=tc, tool_def=td, args={'x': 1}, result='same') # type: ignore[arg-type] + assert result == 'same' + + +# --- Text-only responses are ignored --- + + +@pytest.mark.anyio() +async def test_text_response_ignored(cap_warn: StuckLoopDetection): + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + resp = _make_text_response('hello') + + for _ in range(5): + result = await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + assert result is resp + + +# --- Custom warning message --- + + +@pytest.mark.anyio() +async def test_custom_warning_message(): + cap = StuckLoopDetection(max_repeated_calls=2, action='warn', warning_message='Stop looping!') + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc = _make_tc('x') + resp = _make_response(tc) + + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + with pytest.raises(ModelRetry, match='Stop looping!'): + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +# --- Default values --- + + +def test_defaults(): + cap = StuckLoopDetection() + assert cap.max_repeated_calls == 3 + assert cap.action == 'warn' + assert cap.warning_message == DEFAULT_WARNING_MESSAGE + assert cap._call_history == [] + assert cap._result_history == [] + + +# --- StuckLoopError attributes --- + + +def test_stuck_loop_error(): + err = StuckLoopError('test reason') + assert err.reason == 'test reason' + assert str(err) == 'test reason' + + +# --- Multiple tool calls in a single response --- + + +@pytest.mark.anyio() +async def test_multiple_tool_calls_per_response(cap_warn: StuckLoopDetection): + """When a model response contains multiple tool calls, all are tracked.""" + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc = _make_tc('do_thing', {'a': 1}) + # Response with 3 identical tool calls should trigger immediately. + resp = _make_response(tc, tc, tc) + + with pytest.raises(ModelRetry, match='do_thing.*identical arguments'): + await cap_warn.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +# --- Threshold of 1 --- + + +@pytest.mark.anyio() +async def test_threshold_one(): + """With max_repeated_calls=1, the very first call triggers detection.""" + cap = StuckLoopDetection(max_repeated_calls=1, action='error') + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + tc = _make_tc('any_tool') + resp = _make_response(tc) + + with pytest.raises(StuckLoopError): + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + +# --- History length capping --- + + +def test_max_history_length_default(): + cap = StuckLoopDetection() + assert cap.max_history_length == 50 + + +@pytest.mark.anyio() +async def test_call_history_capped(): + """_call_history should never exceed max_history_length.""" + cap = StuckLoopDetection(max_repeated_calls=100, max_history_length=5) + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + + for i in range(10): + tc = _make_tc('tool', {'i': i}) + resp = _make_response(tc) + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + assert len(cap._call_history) == 5 + # The most recent entries are kept (oldest popped from left). + assert cap._call_history[0] == _tool_call_key(_make_tc('tool', {'i': 5})) + assert cap._call_history[-1] == _tool_call_key(_make_tc('tool', {'i': 9})) + + +@pytest.mark.anyio() +async def test_result_history_capped(): + """_result_history should never exceed max_history_length.""" + cap = StuckLoopDetection(max_repeated_calls=100, max_history_length=5) + ctx: object = _FakeCtx() + td: object = _FakeToolDef() + + for i in range(10): + tc = _make_tc('tool', {'i': i}) + await cap.after_tool_execute(ctx, call=tc, tool_def=td, args={'i': i}, result=f'r{i}') # type: ignore[arg-type] + + assert len(cap._result_history) == 5 + assert cap._result_history[0] == ('tool', repr('r5')) + assert cap._result_history[-1] == ('tool', repr('r9')) + + +@pytest.mark.anyio() +async def test_for_run_preserves_max_history_length(): + """for_run should carry over the max_history_length setting.""" + cap = StuckLoopDetection(max_history_length=10) + fresh = await cap.for_run(_FakeCtx()) # type: ignore[arg-type] + assert fresh.max_history_length == 10 + + +@pytest.mark.anyio() +async def test_detection_still_works_with_capped_history(): + """Detection should still trigger even when history is being capped.""" + cap = StuckLoopDetection(max_repeated_calls=3, max_history_length=5) + ctx: object = _FakeCtx() + rctx: object = _FakeRequestContext() + + # Fill history with unique calls first. + for i in range(3): + tc = _make_tc('other', {'i': i}) + resp = _make_response(tc) + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + # Now repeat the same call 3 times — should still trigger. + tc = _make_tc('stuck', {'x': 1}) + resp = _make_response(tc) + for _ in range(2): + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type] + + with pytest.raises(ModelRetry, match='stuck.*identical arguments'): + await cap.after_model_request(ctx, request_context=rctx, response=resp) # type: ignore[arg-type]