diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..52fec09 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,57 @@ +# Plan: ToolOutputManagement capability + +Closes #82 + +## Context + +When tools return large outputs (file contents, search results, command output), they can consume most of the context window. A single `grep -r` or verbose test output can crowd out all useful conversation history. + +Issue #82 originally proposed this as harness infrastructure baked into the agent loop, but per maintainer feedback, this is better implemented as a configurable capability that uses the existing `after_tool_execute` hook on `AbstractCapability`. + +## Design + +A `ToolOutputManagement` capability (dataclass extending `AbstractCapability`) that: + +1. **Intercepts tool results** via `after_tool_execute` -- the standard hook that fires after every tool execution, before the result enters the model's context +2. **Measures output size** by converting the result to its string representation +3. **Truncates when over limit** using one of three strategies: + - `head` -- keep first N chars + - `tail` -- keep last N chars (good for build/test output) + - `head_tail` (default) -- keep first 60% + last 40% with middle elided +4. **Supports per-tool overrides** via `per_tool_limits` and `per_tool_strategies` dicts +5. **Supports custom summarization** via an optional `summarize_fn(tool_name, output) -> str` (sync or async), with truncation as a safety net if the summary still exceeds the limit + +### What it does NOT do (deliberately) + +- **Spill-to-file**: The issue proposed writing full output to files and returning paths. This requires filesystem access and assumptions about the execution environment. Better to leave this to a more specialized capability or to the `summarize_fn` hook. +- **Token-based limits**: Character limits are a simple, model-independent proxy. Token counting requires model-specific tokenizers and adds complexity. Can be added later. +- **Model-aware scaling**: Adjusting limits based on `ModelProfile.context_window` is a good idea but depends on #35 (ContextWindowTracker) which doesn't exist yet. + +### Key property: original preserved upstream + +The `after_tool_execute` hook modifies only what the model sees. The original full result is already captured in telemetry/trajectory before this hook fires, so no data is lost. + +## Files + +- `src/pydantic_harness/tool_output_management.py` -- the capability +- `src/pydantic_harness/__init__.py` -- re-exports `ToolOutputManagement` and `TruncationStrategy` +- `tests/test_tool_output_management.py` -- unit + integration tests (27 tests) + +## Usage + +```python +from pydantic_ai import Agent +from pydantic_harness import ToolOutputManagement, TruncationStrategy + +agent = Agent( + 'openai:gpt-4o', + capabilities=[ + ToolOutputManagement( + max_output_chars=8000, + strategy=TruncationStrategy.head_tail, + per_tool_limits={'bash': 4000}, + per_tool_strategies={'bash': TruncationStrategy.tail}, + ), + ], +) +``` diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..adce785 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,9 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.tool_output_management import ToolOutputManagement, TruncationStrategy + +__all__: list[str] = [ + 'ToolOutputManagement', + 'TruncationStrategy', +] diff --git a/src/pydantic_harness/tool_output_management.py b/src/pydantic_harness/tool_output_management.py new file mode 100644 index 0000000..acbc9c3 --- /dev/null +++ b/src/pydantic_harness/tool_output_management.py @@ -0,0 +1,331 @@ +"""Tool output management capability. + +Intercepts tool return values and truncates or summarizes large outputs +to prevent context window blowup. Uses the `after_tool_execute` hook so +that the original tool result is preserved in telemetry / trajectory logs, +while only the LLM sees the truncated version. +""" + +from __future__ import annotations + +import os +import re +import tempfile +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any + +from pydantic_ai.capabilities.abstract import AbstractCapability, ValidatedToolArgs +from pydantic_ai.messages import ToolCallPart +from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition + + +class TruncationStrategy(str, Enum): + """Strategy for truncating oversized tool output.""" + + head = 'head' + """Keep only the first characters.""" + + tail = 'tail' + """Keep only the last characters.""" + + head_tail = 'head_tail' + """Keep the first and last characters, eliding the middle.""" + + +SummarizeFn = Callable[[str, str], str | Awaitable[str]] +"""A function `(tool_name, output) -> summarized_output`. + +May be sync or async. +""" + + +# Regex matching ANSI escape sequences (CSI sequences, OSC sequences, and simple escapes). +# Terminal/bash tool output is full of color codes that waste tokens and confuse models. +# Both Mastra and Hermes strip ANSI before sending output to the model. +_ANSI_ESCAPE_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07|\x1b[^[\]()]') + + +def _strip_ansi(text: str) -> str: + """Remove ANSI escape sequences from *text*.""" + return _ANSI_ESCAPE_RE.sub('', text) + + +def _head_tail_default_split(limit: int) -> tuple[int, int]: + """Split a character limit into head and tail portions (40/60). + + The split is tail-heavy because for the most common large-output + scenarios (build logs, test output, command stderr) the actionable + information — errors, summaries, exit codes — tends to appear at the + end. This matches the convention used by Hermes (40/60) and Mastra + (10/90). Per-tool strategy overrides can still be used when the + beginning matters more (e.g. file reads). + """ + head = int(limit * 0.4) + tail = limit - head + return head, tail + + +def _truncate(text: str, limit: int, strategy: TruncationStrategy) -> str: + """Apply a truncation strategy to *text* that exceeds *limit* chars. + + Note: truncation is character-level and structure-unaware. If the tool + returned JSON, the truncated result will be invalid JSON. A future + improvement could detect structured formats and truncate more + intelligently (e.g. elide large array elements while preserving the + schema), but no framework we've surveyed does this today. + """ + total = len(text) + if total <= limit: + return text + + if strategy is TruncationStrategy.head: + kept = text[:limit] + return f'{kept}\n\n[Truncated: showing first {limit:,} of {total:,} chars]' + + if strategy is TruncationStrategy.tail: + kept = text[-limit:] + return f'[Truncated: showing last {limit:,} of {total:,} chars]\n\n{kept}' + + # head_tail + head_chars, tail_chars = _head_tail_default_split(limit) + head_part = text[:head_chars] + tail_part = text[-tail_chars:] + omitted = total - head_chars - tail_chars + return ( + f'{head_part}\n\n' + f'[Truncated: {omitted:,} chars omitted from middle; showing first {head_chars:,} + last {tail_chars:,} of {total:,} chars]\n\n' + f'{tail_part}' + ) + + +def _truncate_by_lines(text: str, limit: int, strategy: TruncationStrategy) -> str: + """Apply a truncation strategy to *text* that exceeds *limit* lines.""" + lines = text.splitlines(keepends=True) + total = len(lines) + if total <= limit: + return text + + if strategy is TruncationStrategy.head: + kept = ''.join(lines[:limit]) + return f'{kept}\n\n[Truncated: showing first {limit:,} of {total:,} lines]' + + if strategy is TruncationStrategy.tail: + kept = ''.join(lines[-limit:]) + return f'[Truncated: showing last {limit:,} of {total:,} lines]\n\n{kept}' + + # head_tail + head_lines, tail_lines = _head_tail_default_split(limit) + head_part = ''.join(lines[:head_lines]) + tail_part = ''.join(lines[-tail_lines:]) + omitted = total - head_lines - tail_lines + return ( + f'{head_part}\n\n' + f'[Truncated: {omitted:,} lines omitted from middle; showing first {head_lines:,} + last {tail_lines:,} of {total:,} lines]\n\n' + f'{tail_part}' + ) + + +def _is_binary(value: Any) -> bool: + """Return True if *value* is binary data that should not be truncated.""" + return isinstance(value, (bytes, bytearray, memoryview)) + + +def _stringify(value: Any) -> str: + """Convert an arbitrary tool return value to a string for size measurement.""" + if isinstance(value, str): + return value + return str(value) + + +@dataclass +class ToolOutputManagement(AbstractCapability[AgentDepsT]): + """Manage large tool outputs to prevent context window blowup. + + Intercepts tool return values via the `after_tool_execute` hook and + truncates or summarizes them when they exceed a configurable character + limit. The original (full) result is preserved upstream (telemetry, + `FunctionToolResultEvent.content`); only the value forwarded to the + model is modified. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_harness import ToolOutputManagement + + agent = Agent( + 'openai:gpt-4o', + capabilities=[ + ToolOutputManagement(max_output_chars=8000), + ], + ) + ``` + """ + + max_output_chars: int = 10_000 + """Default character limit for tool outputs. Outputs exceeding this + are truncated according to `strategy`. + + Note: character limits are a simple, model-independent proxy. A future + ``max_output_tokens`` option using model-specific tokenizers would give + more accurate budget control (characters are roughly a 4x overestimate + for English text). This depends on token-counting infrastructure that + does not yet exist in pydantic-ai (see ContextWindowTracker / #35). + """ + + max_output_lines: int | None = None + """Optional line-count limit for tool outputs. + + When set, output is also checked against this line limit. If both + `max_output_chars` and `max_output_lines` are set, the limit that + triggers first wins. + """ + + strategy: TruncationStrategy = TruncationStrategy.head_tail + """Default truncation strategy applied when output exceeds + `max_output_chars`.""" + + per_tool_limits: dict[str, int] = field(default_factory=lambda: {}) + """Per-tool character limits. Keys are tool names; values override + `max_output_chars` for that tool.""" + + per_tool_line_limits: dict[str, int] = field(default_factory=lambda: {}) + """Per-tool line-count limits. Keys are tool names; values override + `max_output_lines` for that tool.""" + + per_tool_strategies: dict[str, TruncationStrategy] = field(default_factory=lambda: {}) + """Per-tool truncation strategies. Keys are tool names; values + override `strategy` for that tool.""" + + summarize_fn: SummarizeFn | None = None + """Optional summarization function called *instead of* truncation. + + Receives `(tool_name, full_output_str)` and must return a + (potentially shorter) string. If the returned string still exceeds + the limit, it is truncated as a safety net. + + May be sync or async. + + Warning: if the callable wraps an LLM call, be aware that this + capability provides no timeout, retry, or cost guardrails — only a + size safety net. Callers are responsible for adding their own + timeout / error handling inside the function. Hermes, for example, + dedicates a cheap model (Gemini Flash) specifically for this purpose. + """ + + spill_to_file: bool = False + """When True, oversized output is written to a temporary file and + the model receives a pointer to that file plus a truncated preview. + + The file path is embedded in the returned string (e.g. + ``[Full output (N chars) saved to /tmp/...]``). Pi-mono takes an + alternative approach, returning structured metadata + (``details.truncation``) which is more machine-parseable if another + capability needs to act on it. A structured return would require + changes to the ``after_tool_execute`` contract, so for now we use + the simpler string-embedded pointer. + """ + + spill_dir: Path | None = None + """Directory for spill files. Defaults to the system temp directory + when `spill_to_file` is True and this is None. + """ + + strip_ansi: bool = True + """Strip ANSI escape sequences from tool output before measuring and + truncating. ANSI color/formatting codes from terminal output waste + tokens and can confuse models. Enabled by default. + """ + + def _exceeds_limits(self, text: str, char_limit: int, line_limit: int | None) -> bool: + """Return True if *text* exceeds either the char or line limit.""" + if len(text) > char_limit: + return True + if line_limit is not None and text.count('\n') + 1 > line_limit: + return True + return False + + def _apply_truncation( + self, text: str, char_limit: int, line_limit: int | None, strategy: TruncationStrategy + ) -> str: + """Truncate *text* by whichever limit fires first (lines or chars).""" + # Check which limit fires first + lines_exceed = line_limit is not None and text.count('\n') + 1 > line_limit + chars_exceed = len(text) > char_limit + + if lines_exceed and line_limit is not None: + # If both exceed, apply line truncation first, then char truncation + # if still needed; if only lines exceed, just truncate by lines. + truncated = _truncate_by_lines(text, line_limit, strategy) + if chars_exceed and len(truncated) > char_limit: + return _truncate(truncated, char_limit, strategy) + return truncated + + # Only chars exceed (or neither, but caller already checked) + return _truncate(text, char_limit, strategy) + + def _spill(self, text: str, char_limit: int, line_limit: int | None, strategy: TruncationStrategy) -> str: + """Write *text* to a temp file and return a pointer with a truncated preview.""" + total_chars = len(text) + dir_ = self.spill_dir or Path(tempfile.gettempdir()) + dir_.mkdir(parents=True, exist_ok=True) + + fd, path_str = tempfile.mkstemp(suffix='.txt', dir=str(dir_), prefix='tool_output_') + path = Path(path_str) + # Close the fd opened by mkstemp and write via Path + os.close(fd) + path.write_text(text, encoding='utf-8') + + preview = self._apply_truncation(text, char_limit, line_limit, strategy) + return f'[Full output ({total_chars:,} chars) saved to {path}]\n{preview}' + + async def after_tool_execute( + self, + ctx: RunContext[AgentDepsT], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: ValidatedToolArgs, + result: Any, + ) -> Any: + """Truncate or summarize the tool result if it exceeds the configured limit.""" + # Binary detection: skip truncation entirely + if _is_binary(result): + size = len(result) if isinstance(result, (bytes, bytearray)) else result.nbytes + return f'[Binary data, {size:,} bytes]' + + text = _stringify(result) + stripped = self.strip_ansi + if stripped: + text = _strip_ansi(text) + char_limit = self.per_tool_limits.get(call.tool_name, self.max_output_chars) + line_limit = self.per_tool_line_limits.get(call.tool_name, self.max_output_lines) + + if not self._exceeds_limits(text, char_limit, line_limit): + # If we stripped ANSI, return the cleaned text so the model + # never sees escape codes. Otherwise return the original value. + return text if stripped else result + + strategy = self.per_tool_strategies.get(call.tool_name, self.strategy) + + # Summarize path + if self.summarize_fn is not None: + summary = self.summarize_fn(call.tool_name, text) + if isinstance(summary, Awaitable): + summary = await summary + assert isinstance(summary, str) + # Safety net: if the summary itself is still too long, truncate it + if self._exceeds_limits(summary, char_limit, line_limit): + if self.spill_to_file: + return self._spill(summary, char_limit, line_limit, strategy) + return self._apply_truncation(summary, char_limit, line_limit, strategy) + return summary + + # Spill-to-file path + if self.spill_to_file: + return self._spill(text, char_limit, line_limit, strategy) + + # Truncation path + return self._apply_truncation(text, char_limit, line_limit, strategy) diff --git a/tests/test_tool_output_management.py b/tests/test_tool_output_management.py new file mode 100644 index 0000000..1cb5c55 --- /dev/null +++ b/tests/test_tool_output_management.py @@ -0,0 +1,733 @@ +"""Tests for ToolOutputManagement capability.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from pydantic_ai.messages import ToolCallPart +from pydantic_ai.tools import ToolDefinition + +from pydantic_harness.tool_output_management import ( + ToolOutputManagement, + TruncationStrategy, +) + +# Re-export private helpers for unit testing; pyright: ignore for private usage. +from pydantic_harness.tool_output_management import ( + _head_tail_default_split as head_tail_default_split, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.tool_output_management import ( + _is_binary as is_binary, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.tool_output_management import ( + _stringify as stringify, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.tool_output_management import ( + _strip_ansi as strip_ansi, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.tool_output_management import ( + _truncate as truncate, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.tool_output_management import ( + _truncate_by_lines as truncate_by_lines, # pyright: ignore[reportPrivateUsage] +) + +CALL = ToolCallPart(tool_name='my_tool', args={}) +TOOL_DEF = ToolDefinition(name='my_tool', description='test tool', parameters_json_schema={}) + + +# --------------------------------------------------------------------------- +# Unit tests: stringify +# --------------------------------------------------------------------------- + + +class TestStringify: + def test_string_passthrough(self) -> None: + assert stringify('hello') == 'hello' + + def test_int(self) -> None: + assert stringify(42) == '42' + + def test_dict(self) -> None: + result = stringify({'key': 'value'}) + assert 'key' in result + assert 'value' in result + + def test_list(self) -> None: + assert stringify([1, 2, 3]) == '[1, 2, 3]' + + def test_none(self) -> None: + assert stringify(None) == 'None' + + +# --------------------------------------------------------------------------- +# Unit tests: head_tail_default_split +# --------------------------------------------------------------------------- + + +class TestHeadTailSplit: + def test_split_100(self) -> None: + head, tail = head_tail_default_split(100) + assert head == 40 + assert tail == 60 + + def test_split_sums_to_limit(self) -> None: + for limit in (1, 10, 77, 1000, 9999): + head, tail = head_tail_default_split(limit) + assert head + tail == limit + + +# --------------------------------------------------------------------------- +# Unit tests: truncate +# --------------------------------------------------------------------------- + + +class TestTruncate: + def test_no_truncation_needed(self) -> None: + text = 'short' + assert truncate(text, 100, TruncationStrategy.head) == text + + def test_head_strategy(self) -> None: + text = 'a' * 200 + result = truncate(text, 50, TruncationStrategy.head) + assert result.startswith('a' * 50) + assert '[Truncated: showing first 50 of 200 chars]' in result + + def test_tail_strategy(self) -> None: + text = 'a' * 200 + result = truncate(text, 50, TruncationStrategy.tail) + assert result.endswith('a' * 50) + assert '[Truncated: showing last 50 of 200 chars]' in result + + def test_head_tail_strategy(self) -> None: + text = 'H' * 100 + 'M' * 800 + 'T' * 100 + result = truncate(text, 100, TruncationStrategy.head_tail) + # head=40 chars, tail=60 chars (tail-heavy split) + assert result.startswith('H' * 40) + assert result.endswith('T' * 60) + assert 'omitted from middle' in result + assert '900' in result # 1000 - 40 - 60 = 900 omitted + + def test_head_tail_exact_boundary(self) -> None: + text = 'x' * 100 + # Exactly at limit -> no truncation + assert truncate(text, 100, TruncationStrategy.head_tail) == text + + def test_comma_formatting(self) -> None: + text = 'a' * 100_000 + result = truncate(text, 1000, TruncationStrategy.head) + assert '100,000' in result + + +# --------------------------------------------------------------------------- +# Integration tests: ToolOutputManagement.after_tool_execute +# --------------------------------------------------------------------------- + + +class TestAfterToolExecute: + @pytest.mark.anyio + async def test_short_output_unchanged(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=100) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='short output', + ) + assert result == 'short output' + + @pytest.mark.anyio + async def test_long_string_truncated(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=50) + long_text = 'x' * 200 + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=long_text, + ) + assert isinstance(result, str) + assert len(result) < len(long_text) + assert 'Truncated' in result + + @pytest.mark.anyio + async def test_non_string_result_truncated(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=20) + big_dict = {'key': 'v' * 200} + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=big_dict, + ) + assert isinstance(result, str) + assert 'Truncated' in result + + @pytest.mark.anyio + async def test_per_tool_limit(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=10_000, + per_tool_limits={'special_tool': 20}, + ) + call = ToolCallPart(tool_name='special_tool', args={}) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=call, + tool_def=TOOL_DEF, + args={}, + result='a' * 100, + ) + assert 'Truncated' in result + + @pytest.mark.anyio + async def test_per_tool_limit_does_not_affect_others(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=10_000, + per_tool_limits={'other_tool': 5}, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='a' * 100, # under 10_000 + ) + assert result == 'a' * 100 + + @pytest.mark.anyio + async def test_per_tool_strategy(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + per_tool_strategies={'tail_tool': TruncationStrategy.tail}, + ) + call = ToolCallPart(tool_name='tail_tool', args={}) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=call, + tool_def=TOOL_DEF, + args={}, + result='a' * 200, + ) + assert 'showing last' in result + + @pytest.mark.anyio + async def test_head_strategy_via_config(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + strategy=TruncationStrategy.head, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='a' * 200, + ) + assert 'showing first' in result + + @pytest.mark.anyio + async def test_sync_summarize_fn(self) -> None: + def summarize(tool_name: str, output: str) -> str: + return f'Summary of {tool_name}: {len(output)} chars' + + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=100, + summarize_fn=summarize, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='x' * 200, + ) + assert result == 'Summary of my_tool: 200 chars' + + @pytest.mark.anyio + async def test_async_summarize_fn(self) -> None: + async def summarize(tool_name: str, output: str) -> str: + return f'Async summary: {len(output)} chars' + + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=100, + summarize_fn=summarize, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='x' * 200, + ) + assert result == 'Async summary: 200 chars' + + @pytest.mark.anyio + async def test_summarize_fn_safety_net(self) -> None: + """If summarize_fn returns something still too long, truncation kicks in.""" + + def bad_summarize(tool_name: str, output: str) -> str: + return output # returns full output, still too long + + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + summarize_fn=bad_summarize, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='x' * 200, + ) + assert 'Truncated' in result + + @pytest.mark.anyio + async def test_summarize_fn_not_called_under_limit(self) -> None: + calls: list[str] = [] + + def summarize(tool_name: str, output: str) -> str: # pragma: no cover + calls.append(tool_name) + return 'summarized' + + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=100, + summarize_fn=summarize, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='short', + ) + assert result == 'short' + assert calls == [] + + @pytest.mark.anyio + async def test_original_returned_when_exactly_at_limit(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=10) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='a' * 10, + ) + assert result == 'a' * 10 + + +# --------------------------------------------------------------------------- +# Test public exports +# --------------------------------------------------------------------------- + + +class TestExports: + def test_import_from_package(self) -> None: + from pydantic_harness import ToolOutputManagement, TruncationStrategy + + assert ToolOutputManagement is not None + assert TruncationStrategy is not None + + +# --------------------------------------------------------------------------- +# Unit tests: is_binary +# --------------------------------------------------------------------------- + + +class TestIsBinary: + def test_bytes(self) -> None: + assert is_binary(b'\x00\x01\x02') is True + + def test_bytearray(self) -> None: + assert is_binary(bytearray(b'\xff')) is True + + def test_memoryview(self) -> None: + assert is_binary(memoryview(b'hello')) is True + + def test_str_not_binary(self) -> None: + assert is_binary('hello') is False + + def test_int_not_binary(self) -> None: + assert is_binary(42) is False + + def test_none_not_binary(self) -> None: + assert is_binary(None) is False + + +# --------------------------------------------------------------------------- +# Unit tests: strip_ansi +# --------------------------------------------------------------------------- + + +class TestStripAnsi: + def test_plain_text_unchanged(self) -> None: + assert strip_ansi('hello world') == 'hello world' + + def test_strips_color_codes(self) -> None: + text = '\x1b[31mERROR\x1b[0m: something failed' + assert strip_ansi(text) == 'ERROR: something failed' + + def test_strips_bold_and_reset(self) -> None: + text = '\x1b[1mBold\x1b[0m Normal' + assert strip_ansi(text) == 'Bold Normal' + + def test_strips_multiple_sequences(self) -> None: + text = '\x1b[32m✓\x1b[0m test1\n\x1b[31m✗\x1b[0m test2' + assert strip_ansi(text) == '✓ test1\n✗ test2' + + def test_empty_string(self) -> None: + assert strip_ansi('') == '' + + +class TestStripAnsiIntegration: + @pytest.mark.anyio + async def test_ansi_stripped_before_measurement(self) -> None: + """ANSI codes should be stripped before size check so they don't count toward limit.""" + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=50) + # 30 visible chars wrapped in many ANSI codes pushes raw length over 50 + ansi_text = '\x1b[1m\x1b[31m\x1b[4m' + 'x' * 30 + '\x1b[0m\x1b[0m\x1b[0m' + assert len(ansi_text) > 50 # raw is over limit + assert len(strip_ansi(ansi_text)) == 30 # stripped is under + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=ansi_text, + ) + # The stripped text (30 chars) is under the limit, so no truncation. + # The cleaned (ANSI-free) text is returned so the model never sees escape codes. + assert result == 'x' * 30 + assert '\x1b[' not in result + + @pytest.mark.anyio + async def test_ansi_stripped_in_truncated_output(self) -> None: + """When output is truncated, ANSI codes should be stripped from the result.""" + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=20) + ansi_text = '\x1b[31m' + 'x' * 100 + '\x1b[0m' + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=ansi_text, + ) + assert '\x1b[' not in result + assert 'Truncated' in result + + @pytest.mark.anyio + async def test_strip_ansi_disabled(self) -> None: + """When strip_ansi=False, ANSI codes are preserved.""" + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=20, + strip_ansi=False, + ) + ansi_text = '\x1b[31m' + 'x' * 100 + '\x1b[0m' + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=ansi_text, + ) + assert 'Truncated' in result + # ANSI codes should still be present in the truncated portion + assert '\x1b[' in result + + +# --------------------------------------------------------------------------- +# Unit tests: truncate_by_lines +# --------------------------------------------------------------------------- + + +class TestTruncateByLines: + def test_no_truncation_needed(self) -> None: + text = 'line1\nline2\nline3' + assert truncate_by_lines(text, 5, TruncationStrategy.head) == text + + def test_head_strategy(self) -> None: + text = '\n'.join(f'line{i}' for i in range(20)) + result = truncate_by_lines(text, 5, TruncationStrategy.head) + assert result.startswith('line0\n') + assert '[Truncated: showing first 5 of 20 lines]' in result + + def test_tail_strategy(self) -> None: + text = '\n'.join(f'line{i}' for i in range(20)) + result = truncate_by_lines(text, 5, TruncationStrategy.tail) + assert 'line19' in result + assert '[Truncated: showing last 5 of 20 lines]' in result + + def test_head_tail_strategy(self) -> None: + text = '\n'.join(f'line{i}' for i in range(100)) + result = truncate_by_lines(text, 10, TruncationStrategy.head_tail) + # head=4 lines, tail=6 lines (tail-heavy split) + assert 'line0' in result + assert 'line99' in result + assert 'omitted from middle' in result + assert '90' in result # 100 - 4 - 6 = 90 omitted + + def test_exact_boundary(self) -> None: + text = 'line1\nline2\nline3' + assert truncate_by_lines(text, 3, TruncationStrategy.head) == text + + +# --------------------------------------------------------------------------- +# Integration tests: line-count limits +# --------------------------------------------------------------------------- + + +class TestLineCountLimits: + @pytest.mark.anyio + async def test_line_limit_triggers_before_char_limit(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=100_000, + max_output_lines=5, + strategy=TruncationStrategy.head, + ) + text = '\n'.join(f'line{i}' for i in range(20)) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + assert 'showing first 5 of 20 lines' in result + + @pytest.mark.anyio + async def test_char_limit_triggers_before_line_limit(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + max_output_lines=1000, + strategy=TruncationStrategy.head, + ) + text = 'x' * 200 + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + assert 'showing first 50 of 200 chars' in result + + @pytest.mark.anyio + async def test_under_both_limits(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=1000, + max_output_lines=10, + ) + text = 'short\ntext' + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + assert result == text + + @pytest.mark.anyio + async def test_per_tool_line_limit(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=100_000, + max_output_lines=100, + per_tool_line_limits={'my_tool': 3}, + strategy=TruncationStrategy.head, + ) + text = '\n'.join(f'line{i}' for i in range(10)) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + assert 'showing first 3 of 10 lines' in result + + @pytest.mark.anyio + async def test_line_limit_only(self) -> None: + """When max_output_lines is set but char limit is very high, line limit alone fires.""" + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=1_000_000, + max_output_lines=3, + strategy=TruncationStrategy.tail, + ) + text = '\n'.join(f'line{i}' for i in range(10)) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + assert 'showing last 3 of 10 lines' in result + + @pytest.mark.anyio + async def test_both_lines_and_chars_exceed_double_truncation(self) -> None: + """When both limits fire, line truncation is applied first, then char truncation.""" + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + max_output_lines=5, + strategy=TruncationStrategy.head, + ) + # 20 lines of 100 chars each — after line truncation to 5 lines the + # result is still well over 50 chars, so char truncation kicks in too. + text = '\n'.join('x' * 100 for _ in range(20)) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=text, + ) + # Char-level truncation marker should appear in the final output. + assert 'chars' in result + + +# --------------------------------------------------------------------------- +# Integration tests: spill-to-file +# --------------------------------------------------------------------------- + + +class TestSpillToFile: + @pytest.mark.anyio + async def test_spill_creates_file(self, tmp_path: Path) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + spill_to_file=True, + spill_dir=tmp_path, + ) + long_text = 'x' * 200 + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=long_text, + ) + assert isinstance(result, str) + assert '[Full output (200 chars) saved to' in result + assert 'Truncated' in result + + # Verify the file actually exists and contains the full output + spill_files = list(tmp_path.glob('tool_output_*.txt')) + assert len(spill_files) == 1 + assert spill_files[0].read_text(encoding='utf-8') == long_text + + @pytest.mark.anyio + async def test_spill_default_dir(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + spill_to_file=True, + ) + long_text = 'x' * 200 + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=long_text, + ) + assert '[Full output (200 chars) saved to' in result + + @pytest.mark.anyio + async def test_spill_not_triggered_under_limit(self, tmp_path: Path) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=1000, + spill_to_file=True, + spill_dir=tmp_path, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='short', + ) + assert result == 'short' + assert list(tmp_path.glob('tool_output_*.txt')) == [] + + @pytest.mark.anyio + async def test_spill_with_summarize_fn_safety_net(self, tmp_path: Path) -> None: + """When summarize_fn still produces oversized output, spill kicks in.""" + + def bad_summarize(tool_name: str, output: str) -> str: + return output # returns full output, still too long + + cap: ToolOutputManagement[None] = ToolOutputManagement( + max_output_chars=50, + summarize_fn=bad_summarize, + spill_to_file=True, + spill_dir=tmp_path, + ) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result='x' * 200, + ) + assert '[Full output' in result + assert len(list(tmp_path.glob('tool_output_*.txt'))) == 1 + + +# --------------------------------------------------------------------------- +# Integration tests: binary detection +# --------------------------------------------------------------------------- + + +class TestBinaryDetection: + @pytest.mark.anyio + async def test_bytes_result(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=50) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=b'\x89PNG\r\n\x1a\n' + b'\x00' * 1000, + ) + assert result == '[Binary data, 1,008 bytes]' + + @pytest.mark.anyio + async def test_bytearray_result(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=50) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=bytearray(b'\xff' * 256), + ) + assert result == '[Binary data, 256 bytes]' + + @pytest.mark.anyio + async def test_memoryview_result(self) -> None: + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=50) + data = b'hello world' + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=memoryview(data), + ) + assert result == '[Binary data, 11 bytes]' + + @pytest.mark.anyio + async def test_small_bytes_still_detected(self) -> None: + """Binary detection applies regardless of size -- even small bytes are replaced.""" + cap: ToolOutputManagement[None] = ToolOutputManagement(max_output_chars=10_000) + result = await cap.after_tool_execute( + None, # type: ignore[arg-type] + call=CALL, + tool_def=TOOL_DEF, + args={}, + result=b'hi', + ) + assert result == '[Binary data, 2 bytes]'