From ce047971fb4ca2831fe7e61f9fa730233f0f821e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 05:11:11 +0000 Subject: [PATCH 1/2] Add AdaptiveReasoning capability for per-step thinking effort selection Implements a capability that dynamically adjusts model thinking effort based on task complexity signals via `get_model_settings()` returning a callable. Built-in heuristic uses high effort on first step and after tool errors, low effort on simple follow-ups. Supports custom effort functions for domain-specific logic. Closes #84 Co-Authored-By: Claude Opus 4.6 (1M context) --- PLAN.md | 59 +++++++ src/pydantic_harness/__init__.py | 6 +- src/pydantic_harness/adaptive_reasoning.py | 92 +++++++++++ tests/test_adaptive_reasoning.py | 178 +++++++++++++++++++++ 4 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 PLAN.md create mode 100644 src/pydantic_harness/adaptive_reasoning.py create mode 100644 tests/test_adaptive_reasoning.py diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..e3a9a9e --- /dev/null +++ b/PLAN.md @@ -0,0 +1,59 @@ +# AdaptiveReasoning Capability + +Closes #84 + +## Summary + +A capability that dynamically adjusts model thinking effort per agent step, +using the `get_model_settings()` callable mechanism from Pydantic AI's +capabilities abstraction. This reduces token usage on simple steps (file reads, +straightforward follow-ups) while preserving deep reasoning for complex +decisions (first step task understanding, error recovery). + +## Design + +### Approach: capability with dynamic `get_model_settings` + +As noted by @DouweM in #84, this is cleanly implemented as a capability whose +`get_model_settings()` returns a callable receiving `RunContext`. The callable +inspects `ctx.run_step` and `ctx.messages` to select an effort level, then +returns `ModelSettings(thinking=...)`. + +This leverages two existing Pydantic AI primitives: +1. **Dynamic model settings** (callable `get_model_settings`) -- resolved per + model request with the current `RunContext` +2. **Unified `thinking` setting** -- maps to provider-specific parameters + (Claude thinking budget, OpenAI reasoning_effort, etc.) + +### Effort levels + +Three coarse levels (`'low'`, `'medium'`, `'high'`) mapped directly to +`ThinkingEffort` values. These are a subset of the full `ThinkingEffort` scale +(`'minimal'`/`'low'`/`'medium'`/`'high'`/`'xhigh'`) chosen to match the +research literature (Ares uses three tiers) and keep the API simple. + +### Built-in heuristic (`default_effort_fn`) + +Rules evaluated in order: +1. First step (`run_step <= 1`): `'high'` -- understand the task +2. After tool errors (retry prompts in latest request): `'high'` -- reason about failures +3. Later steps without errors: `'low'` -- simple follow-ups incorporating tool results + +### Custom effort function + +Users can supply `effort_fn: Callable[[RunContext], Literal['low', 'medium', 'high']]` +to override the built-in heuristic with domain-specific logic. + +## Files + +| File | Purpose | +|------|---------| +| `src/pydantic_harness/adaptive_reasoning.py` | `AdaptiveReasoning` capability, `default_effort_fn`, `EffortLevel` type alias | +| `src/pydantic_harness/__init__.py` | Re-export `AdaptiveReasoning` | +| `tests/test_adaptive_reasoning.py` | 18 tests covering helper, heuristic, capability, and custom fn | + +## Not included (future work) + +- **`ModelRoutedEffort`**: Small model (Haiku-class) predicting effort from history (the Ares approach). This is a natural follow-up but requires model call infrastructure. +- **`PhaseBasedEffort`**: High for planning, medium for execution, high for verification. Requires a phase detection mechanism. +- **Provider-specific token budgets**: Mapping effort levels to concrete `budget_tokens` values per provider. The current implementation uses the unified `thinking` setting which handles this portably. diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..196d809 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,8 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.adaptive_reasoning import AdaptiveReasoning + +__all__: list[str] = [ + 'AdaptiveReasoning', +] diff --git a/src/pydantic_harness/adaptive_reasoning.py b/src/pydantic_harness/adaptive_reasoning.py new file mode 100644 index 0000000..53fb285 --- /dev/null +++ b/src/pydantic_harness/adaptive_reasoning.py @@ -0,0 +1,92 @@ +"""Adaptive reasoning effort capability. + +Dynamically adjusts the model's thinking effort level per step based on +task complexity signals, reducing token usage on simple steps while +preserving deep reasoning for complex decisions. +""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Any, Literal, TypeAlias + +from pydantic_ai._run_context import RunContext +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.messages import ModelMessage, ModelRequest, RetryPromptPart +from pydantic_ai.settings import ModelSettings, ThinkingEffort + +EffortLevel: TypeAlias = Literal['low', 'medium', 'high'] +"""The coarse effort levels used by adaptive reasoning. + +Mapped to the full ``ThinkingEffort`` scale when applied to model settings. +""" + +_EFFORT_TO_THINKING: dict[str, ThinkingEffort] = { + 'low': 'low', + 'medium': 'medium', + 'high': 'high', +} + + +def _has_tool_errors(messages: Sequence[ModelMessage]) -> bool: + """Check whether the most recent request message contains retry prompts (tool errors).""" + for msg in reversed(messages): + if isinstance(msg, ModelRequest): + return any(isinstance(part, RetryPromptPart) for part in msg.parts) + return False + + +def default_effort_fn(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: + """Built-in heuristic effort selector. + + Rules (evaluated in order): + 1. First step (``run_step == 1``): ``'high'`` -- understand the task. + 2. After tool errors (retry prompts in the latest request): ``'high'`` -- reason about failures. + 3. Steps 2+ with no errors: ``'low'`` -- simple follow-ups incorporating tool results. + 4. Default: ``'medium'``. + """ + if ctx.run_step <= 1: + return 'high' + + if _has_tool_errors(ctx.messages): + return 'high' + + # Later steps without errors are typically straightforward follow-ups. + if ctx.run_step >= 2: + return 'low' + + return 'medium' # pragma: no cover + + +@dataclass +class AdaptiveReasoning(AbstractCapability[Any]): + """Dynamically adjusts model thinking effort per step. + + By default a built-in heuristic is used: + + * **First step** -> ``'high'`` (understand the task) + * **After tool errors** -> ``'high'`` (reason about what went wrong) + * **Simple follow-ups** -> ``'low'`` (just incorporating tool results) + + Supply a custom ``effort_fn`` to override these rules:: + + def my_effort(ctx: RunContext[MyDeps]) -> Literal['low', 'medium', 'high']: + if ctx.run_step > 5: + return 'high' # wrap-up needs careful thought + return 'medium' + + agent = Agent(..., capabilities=[AdaptiveReasoning(effort_fn=my_effort)]) + """ + + effort_fn: Callable[[RunContext[Any]], Literal['low', 'medium', 'high']] = field(default=default_effort_fn) + """Callable that receives the current ``RunContext`` and returns an effort level.""" + + def get_model_settings(self) -> Callable[[RunContext[Any]], ModelSettings]: + """Return a dynamic model-settings callable that sets ``thinking`` per step.""" + + def _resolve(ctx: RunContext[Any]) -> ModelSettings: + effort = self.effort_fn(ctx) + return ModelSettings(thinking=_EFFORT_TO_THINKING[effort]) + + return _resolve diff --git a/tests/test_adaptive_reasoning.py b/tests/test_adaptive_reasoning.py new file mode 100644 index 0000000..757850b --- /dev/null +++ b/tests/test_adaptive_reasoning.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +# pyright: reportPrivateUsage=false +from typing import Any, Literal +from unittest.mock import MagicMock + +from pydantic_ai._run_context import RunContext +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + RetryPromptPart, + TextPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.settings import ModelSettings + +from pydantic_harness import AdaptiveReasoning +from pydantic_harness.adaptive_reasoning import _has_tool_errors, default_effort_fn + + +def _make_ctx( + *, + run_step: int = 0, + messages: list[Any] | None = None, +) -> RunContext[None]: + """Build a minimal RunContext for testing.""" + model = MagicMock() + model.system = 'test' + ctx = RunContext[None]( + deps=None, + model=model, + usage=MagicMock(), + messages=messages or [], + run_step=run_step, + ) + return ctx + + +# --- _has_tool_errors --- + + +class TestHasToolErrors: + def test_no_messages(self) -> None: + assert _has_tool_errors([]) is False + + def test_no_retry_parts(self) -> None: + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content='hello')])] + assert _has_tool_errors(messages) is False + + def test_with_retry_part(self) -> None: + messages: list[ModelMessage] = [ + ModelRequest( + parts=[ + RetryPromptPart(content='validation failed', tool_name='my_tool'), + ] + ), + ] + assert _has_tool_errors(messages) is True + + def test_checks_latest_request(self) -> None: + """Only the most recent ModelRequest is inspected.""" + old_request = ModelRequest(parts=[RetryPromptPart(content='old error', tool_name='my_tool')]) + new_request = ModelRequest(parts=[ToolReturnPart(tool_name='my_tool', content='ok')]) + messages: list[ModelMessage] = [old_request, new_request] + # Most recent message is the one without errors. + assert _has_tool_errors(messages) is False + + def test_skips_model_responses(self) -> None: + """ModelResponse objects are skipped when searching for the latest request.""" + request_with_error = ModelRequest(parts=[RetryPromptPart(content='error', tool_name='my_tool')]) + response = ModelResponse(parts=[TextPart(content='ok')]) + messages: list[ModelMessage] = [request_with_error, response] + assert _has_tool_errors(messages) is True + + +# --- default_effort_fn --- + + +class TestDefaultEffortFn: + def test_first_step_high(self) -> None: + ctx = _make_ctx(run_step=1) + assert default_effort_fn(ctx) == 'high' + + def test_step_zero_high(self) -> None: + ctx = _make_ctx(run_step=0) + assert default_effort_fn(ctx) == 'high' + + def test_after_tool_error_high(self) -> None: + messages = [ + ModelRequest(parts=[RetryPromptPart(content='bad args', tool_name='t')]), + ] + ctx = _make_ctx(run_step=3, messages=messages) + assert default_effort_fn(ctx) == 'high' + + def test_simple_followup_low(self) -> None: + messages = [ + ModelRequest(parts=[ToolReturnPart(tool_name='t', content='result')]), + ] + ctx = _make_ctx(run_step=2, messages=messages) + assert default_effort_fn(ctx) == 'low' + + def test_later_step_no_errors_low(self) -> None: + ctx = _make_ctx(run_step=5, messages=[]) + assert default_effort_fn(ctx) == 'low' + + +# --- AdaptiveReasoning capability --- + + +class TestAdaptiveReasoning: + def test_default_construction(self) -> None: + cap = AdaptiveReasoning() + assert cap.effort_fn is default_effort_fn + + def test_get_model_settings_returns_callable(self) -> None: + cap = AdaptiveReasoning() + settings_fn = cap.get_model_settings() + assert callable(settings_fn) + + def test_dynamic_settings_first_step(self) -> None: + cap = AdaptiveReasoning() + settings_fn = cap.get_model_settings() + ctx = _make_ctx(run_step=1) + result = settings_fn(ctx) + assert result == ModelSettings(thinking='high') + + def test_dynamic_settings_followup(self) -> None: + cap = AdaptiveReasoning() + settings_fn = cap.get_model_settings() + messages = [ + ModelRequest(parts=[ToolReturnPart(tool_name='t', content='ok')]), + ] + ctx = _make_ctx(run_step=3, messages=messages) + result = settings_fn(ctx) + assert result == ModelSettings(thinking='low') + + def test_dynamic_settings_after_error(self) -> None: + cap = AdaptiveReasoning() + settings_fn = cap.get_model_settings() + messages = [ + ModelRequest(parts=[RetryPromptPart(content='err', tool_name='t')]), + ] + ctx = _make_ctx(run_step=4, messages=messages) + result = settings_fn(ctx) + assert result == ModelSettings(thinking='high') + + def test_custom_effort_fn(self) -> None: + def always_medium(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: + return 'medium' + + cap = AdaptiveReasoning(effort_fn=always_medium) + settings_fn = cap.get_model_settings() + ctx = _make_ctx(run_step=1) + result = settings_fn(ctx) + assert result == ModelSettings(thinking='medium') + + def test_custom_effort_fn_context_aware(self) -> None: + def step_based(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: + if ctx.run_step > 10: + return 'high' + return 'low' + + cap = AdaptiveReasoning(effort_fn=step_based) + settings_fn = cap.get_model_settings() + + ctx_early = _make_ctx(run_step=2) + assert settings_fn(ctx_early) == ModelSettings(thinking='low') + + ctx_late = _make_ctx(run_step=11) + assert settings_fn(ctx_late) == ModelSettings(thinking='high') + + def test_is_abstract_capability(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + cap = AdaptiveReasoning() + assert isinstance(cap, AbstractCapability) From 72baecfad0d31a636d2c2f4aed60152a936fb1fb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 05:51:40 +0000 Subject: [PATCH 2/2] Fix effort heuristic and add tool-count signal - Step 2 now returns medium (was incorrectly going straight to low) - Step 3+ returns low as intended - Removed unreachable default branch - Added many-tool-calls signal: if the last ModelResponse had 3+ ToolCallParts, use medium effort regardless of step number Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/adaptive_reasoning.py | 35 ++++-- tests/test_adaptive_reasoning.py | 120 ++++++++++++++++++++- 2 files changed, 144 insertions(+), 11 deletions(-) diff --git a/src/pydantic_harness/adaptive_reasoning.py b/src/pydantic_harness/adaptive_reasoning.py index 53fb285..771b0b7 100644 --- a/src/pydantic_harness/adaptive_reasoning.py +++ b/src/pydantic_harness/adaptive_reasoning.py @@ -13,7 +13,7 @@ from pydantic_ai._run_context import RunContext from pydantic_ai.capabilities.abstract import AbstractCapability -from pydantic_ai.messages import ModelMessage, ModelRequest, RetryPromptPart +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, RetryPromptPart, ToolCallPart from pydantic_ai.settings import ModelSettings, ThinkingEffort EffortLevel: TypeAlias = Literal['low', 'medium', 'high'] @@ -29,6 +29,10 @@ } +_MANY_TOOL_CALLS_THRESHOLD = 3 +"""Number of ``ToolCallPart`` parts in a single response that signals medium effort.""" + + def _has_tool_errors(messages: Sequence[ModelMessage]) -> bool: """Check whether the most recent request message contains retry prompts (tool errors).""" for msg in reversed(messages): @@ -37,14 +41,29 @@ def _has_tool_errors(messages: Sequence[ModelMessage]) -> bool: return False +def _last_response_had_many_tool_calls(messages: Sequence[ModelMessage]) -> bool: + """Check whether the most recent ``ModelResponse`` had many tool call parts. + + Returns ``True`` if the latest response contained at least + :data:`_MANY_TOOL_CALLS_THRESHOLD` ``ToolCallPart`` instances, signalling + a complex multi-tool orchestration step that deserves medium effort. + """ + for msg in reversed(messages): + if isinstance(msg, ModelResponse): + tool_call_count = sum(1 for part in msg.parts if isinstance(part, ToolCallPart)) + return tool_call_count >= _MANY_TOOL_CALLS_THRESHOLD + return False + + def default_effort_fn(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: """Built-in heuristic effort selector. Rules (evaluated in order): 1. First step (``run_step == 1``): ``'high'`` -- understand the task. 2. After tool errors (retry prompts in the latest request): ``'high'`` -- reason about failures. - 3. Steps 2+ with no errors: ``'low'`` -- simple follow-ups incorporating tool results. - 4. Default: ``'medium'``. + 3. Many tool calls (3+ ``ToolCallPart`` in last response): ``'medium'`` -- complex orchestration. + 4. Second step (``run_step == 2``): ``'medium'`` -- still building context. + 5. Later steps (``run_step >= 3``): ``'low'`` -- simple follow-ups. """ if ctx.run_step <= 1: return 'high' @@ -52,11 +71,13 @@ def default_effort_fn(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: if _has_tool_errors(ctx.messages): return 'high' - # Later steps without errors are typically straightforward follow-ups. - if ctx.run_step >= 2: - return 'low' + if _last_response_had_many_tool_calls(ctx.messages): + return 'medium' + + if ctx.run_step == 2: + return 'medium' - return 'medium' # pragma: no cover + return 'low' @dataclass diff --git a/tests/test_adaptive_reasoning.py b/tests/test_adaptive_reasoning.py index 757850b..221864b 100644 --- a/tests/test_adaptive_reasoning.py +++ b/tests/test_adaptive_reasoning.py @@ -11,13 +11,18 @@ ModelResponse, RetryPromptPart, TextPart, + ToolCallPart, ToolReturnPart, UserPromptPart, ) from pydantic_ai.settings import ModelSettings from pydantic_harness import AdaptiveReasoning -from pydantic_harness.adaptive_reasoning import _has_tool_errors, default_effort_fn +from pydantic_harness.adaptive_reasoning import ( + _has_tool_errors, + _last_response_had_many_tool_calls, + default_effort_fn, +) def _make_ctx( @@ -75,6 +80,67 @@ def test_skips_model_responses(self) -> None: assert _has_tool_errors(messages) is True +# --- _last_response_had_many_tool_calls --- + + +class TestLastResponseHadManyToolCalls: + def test_no_messages(self) -> None: + assert _last_response_had_many_tool_calls([]) is False + + def test_no_response(self) -> None: + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content='hi')])] + assert _last_response_had_many_tool_calls(messages) is False + + def test_below_threshold(self) -> None: + messages: list[ModelMessage] = [ + ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ] + ), + ] + assert _last_response_had_many_tool_calls(messages) is False + + def test_at_threshold(self) -> None: + messages: list[ModelMessage] = [ + ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ToolCallPart(tool_name='t3', args='{}'), + ] + ), + ] + assert _last_response_had_many_tool_calls(messages) is True + + def test_checks_latest_response(self) -> None: + """Only the most recent ModelResponse is inspected.""" + old_response = ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ToolCallPart(tool_name='t3', args='{}'), + ] + ) + new_response = ModelResponse(parts=[TextPart(content='done')]) + messages: list[ModelMessage] = [old_response, new_response] + assert _last_response_had_many_tool_calls(messages) is False + + def test_skips_requests(self) -> None: + """ModelRequest objects are skipped when searching.""" + response = ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ToolCallPart(tool_name='t3', args='{}'), + ] + ) + request = ModelRequest(parts=[ToolReturnPart(tool_name='t1', content='ok')]) + messages: list[ModelMessage] = [response, request] + assert _last_response_had_many_tool_calls(messages) is True + + # --- default_effort_fn --- @@ -94,17 +160,53 @@ def test_after_tool_error_high(self) -> None: ctx = _make_ctx(run_step=3, messages=messages) assert default_effort_fn(ctx) == 'high' - def test_simple_followup_low(self) -> None: + def test_step_two_medium(self) -> None: messages = [ ModelRequest(parts=[ToolReturnPart(tool_name='t', content='result')]), ] ctx = _make_ctx(run_step=2, messages=messages) - assert default_effort_fn(ctx) == 'low' + assert default_effort_fn(ctx) == 'medium' def test_later_step_no_errors_low(self) -> None: ctx = _make_ctx(run_step=5, messages=[]) assert default_effort_fn(ctx) == 'low' + def test_step_three_low(self) -> None: + messages = [ + ModelRequest(parts=[ToolReturnPart(tool_name='t', content='result')]), + ] + ctx = _make_ctx(run_step=3, messages=messages) + assert default_effort_fn(ctx) == 'low' + + def test_many_tool_calls_medium(self) -> None: + """A response with 3+ ToolCallParts should trigger medium effort.""" + messages: list[Any] = [ + ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ToolCallPart(tool_name='t3', args='{}'), + ] + ), + ModelRequest(parts=[ToolReturnPart(tool_name='t1', content='ok')]), + ] + ctx = _make_ctx(run_step=5, messages=messages) + assert default_effort_fn(ctx) == 'medium' + + def test_few_tool_calls_not_medium(self) -> None: + """A response with <3 ToolCallParts should not trigger medium effort at step 3+.""" + messages: list[Any] = [ + ModelResponse( + parts=[ + ToolCallPart(tool_name='t1', args='{}'), + ToolCallPart(tool_name='t2', args='{}'), + ] + ), + ModelRequest(parts=[ToolReturnPart(tool_name='t1', content='ok')]), + ] + ctx = _make_ctx(run_step=5, messages=messages) + assert default_effort_fn(ctx) == 'low' + # --- AdaptiveReasoning capability --- @@ -126,7 +228,17 @@ def test_dynamic_settings_first_step(self) -> None: result = settings_fn(ctx) assert result == ModelSettings(thinking='high') - def test_dynamic_settings_followup(self) -> None: + def test_dynamic_settings_step_two(self) -> None: + cap = AdaptiveReasoning() + settings_fn = cap.get_model_settings() + messages = [ + ModelRequest(parts=[ToolReturnPart(tool_name='t', content='ok')]), + ] + ctx = _make_ctx(run_step=2, messages=messages) + result = settings_fn(ctx) + assert result == ModelSettings(thinking='medium') + + def test_dynamic_settings_later_step_low(self) -> None: cap = AdaptiveReasoning() settings_fn = cap.get_model_settings() messages = [