-
Notifications
You must be signed in to change notification settings - Fork 14
Add AdaptiveReasoning capability #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # AdaptiveReasoning Capability | ||
|
|
||
| Closes #84 | ||
|
|
||
| ## Summary | ||
|
|
||
| A capability that dynamically adjusts model thinking effort per agent step, | ||
| using the `get_model_settings()` callable mechanism from Pydantic AI's | ||
| capabilities abstraction. This reduces token usage on simple steps (file reads, | ||
| straightforward follow-ups) while preserving deep reasoning for complex | ||
| decisions (first step task understanding, error recovery). | ||
|
|
||
| ## Design | ||
|
|
||
| ### Approach: capability with dynamic `get_model_settings` | ||
|
|
||
| As noted by @DouweM in #84, this is cleanly implemented as a capability whose | ||
| `get_model_settings()` returns a callable receiving `RunContext`. The callable | ||
| inspects `ctx.run_step` and `ctx.messages` to select an effort level, then | ||
| returns `ModelSettings(thinking=...)`. | ||
|
|
||
| This leverages two existing Pydantic AI primitives: | ||
| 1. **Dynamic model settings** (callable `get_model_settings`) -- resolved per | ||
| model request with the current `RunContext` | ||
| 2. **Unified `thinking` setting** -- maps to provider-specific parameters | ||
| (Claude thinking budget, OpenAI reasoning_effort, etc.) | ||
|
|
||
| ### Effort levels | ||
|
|
||
| Three coarse levels (`'low'`, `'medium'`, `'high'`) mapped directly to | ||
| `ThinkingEffort` values. These are a subset of the full `ThinkingEffort` scale | ||
| (`'minimal'`/`'low'`/`'medium'`/`'high'`/`'xhigh'`) chosen to match the | ||
| research literature (Ares uses three tiers) and keep the API simple. | ||
|
|
||
| ### Built-in heuristic (`default_effort_fn`) | ||
|
|
||
| Rules evaluated in order: | ||
| 1. First step (`run_step <= 1`): `'high'` -- understand the task | ||
| 2. After tool errors (retry prompts in latest request): `'high'` -- reason about failures | ||
| 3. Later steps without errors: `'low'` -- simple follow-ups incorporating tool results | ||
|
|
||
| ### Custom effort function | ||
|
|
||
| Users can supply `effort_fn: Callable[[RunContext], Literal['low', 'medium', 'high']]` | ||
| to override the built-in heuristic with domain-specific logic. | ||
|
|
||
| ## Files | ||
|
|
||
| | File | Purpose | | ||
| |------|---------| | ||
| | `src/pydantic_harness/adaptive_reasoning.py` | `AdaptiveReasoning` capability, `default_effort_fn`, `EffortLevel` type alias | | ||
| | `src/pydantic_harness/__init__.py` | Re-export `AdaptiveReasoning` | | ||
| | `tests/test_adaptive_reasoning.py` | 18 tests covering helper, heuristic, capability, and custom fn | | ||
|
|
||
| ## Not included (future work) | ||
|
|
||
| - **`ModelRoutedEffort`**: Small model (Haiku-class) predicting effort from history (the Ares approach). This is a natural follow-up but requires model call infrastructure. | ||
| - **`PhaseBasedEffort`**: High for planning, medium for execution, high for verification. Requires a phase detection mechanism. | ||
| - **Provider-specific token budgets**: Mapping effort levels to concrete `budget_tokens` values per provider. The current implementation uses the unified `thinking` setting which handles this portably. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| """Adaptive reasoning effort capability. | ||
|
|
||
| Dynamically adjusts the model's thinking effort level per step based on | ||
| task complexity signals, reducing token usage on simple steps while | ||
| preserving deep reasoning for complex decisions. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable, Sequence | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Literal, TypeAlias | ||
|
|
||
| from pydantic_ai._run_context import RunContext | ||
| from pydantic_ai.capabilities.abstract import AbstractCapability | ||
| from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, RetryPromptPart, ToolCallPart | ||
| from pydantic_ai.settings import ModelSettings, ThinkingEffort | ||
|
|
||
| EffortLevel: TypeAlias = Literal['low', 'medium', 'high'] | ||
| """The coarse effort levels used by adaptive reasoning. | ||
|
|
||
| Mapped to the full ``ThinkingEffort`` scale when applied to model settings. | ||
| """ | ||
|
|
||
| _EFFORT_TO_THINKING: dict[str, ThinkingEffort] = { | ||
| 'low': 'low', | ||
| 'medium': 'medium', | ||
| 'high': 'high', | ||
| } | ||
|
|
||
|
|
||
| _MANY_TOOL_CALLS_THRESHOLD = 3 | ||
| """Number of ``ToolCallPart`` parts in a single response that signals medium effort.""" | ||
|
|
||
|
|
||
| def _has_tool_errors(messages: Sequence[ModelMessage]) -> bool: | ||
| """Check whether the most recent request message contains retry prompts (tool errors).""" | ||
| for msg in reversed(messages): | ||
| if isinstance(msg, ModelRequest): | ||
| return any(isinstance(part, RetryPromptPart) for part in msg.parts) | ||
| return False | ||
|
|
||
|
|
||
| def _last_response_had_many_tool_calls(messages: Sequence[ModelMessage]) -> bool: | ||
| """Check whether the most recent ``ModelResponse`` had many tool call parts. | ||
|
|
||
| Returns ``True`` if the latest response contained at least | ||
| :data:`_MANY_TOOL_CALLS_THRESHOLD` ``ToolCallPart`` instances, signalling | ||
| a complex multi-tool orchestration step that deserves medium effort. | ||
| """ | ||
| for msg in reversed(messages): | ||
| if isinstance(msg, ModelResponse): | ||
| tool_call_count = sum(1 for part in msg.parts if isinstance(part, ToolCallPart)) | ||
| return tool_call_count >= _MANY_TOOL_CALLS_THRESHOLD | ||
| return False | ||
|
|
||
|
|
||
| def default_effort_fn(ctx: RunContext[Any]) -> Literal['low', 'medium', 'high']: | ||
| """Built-in heuristic effort selector. | ||
|
|
||
| Rules (evaluated in order): | ||
| 1. First step (``run_step == 1``): ``'high'`` -- understand the task. | ||
| 2. After tool errors (retry prompts in the latest request): ``'high'`` -- reason about failures. | ||
| 3. Many tool calls (3+ ``ToolCallPart`` in last response): ``'medium'`` -- complex orchestration. | ||
| 4. Second step (``run_step == 2``): ``'medium'`` -- still building context. | ||
| 5. Later steps (``run_step >= 3``): ``'low'`` -- simple follow-ups. | ||
| """ | ||
| if ctx.run_step <= 1: | ||
| return 'high' | ||
|
|
||
| if _has_tool_errors(ctx.messages): | ||
| return 'high' | ||
|
|
||
| if _last_response_had_many_tool_calls(ctx.messages): | ||
| return 'medium' | ||
|
|
||
| if ctx.run_step == 2: | ||
| return 'medium' | ||
|
|
||
| return 'low' | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdaptiveReasoning(AbstractCapability[Any]): | ||
| """Dynamically adjusts model thinking effort per step. | ||
|
|
||
| By default a built-in heuristic is used: | ||
|
|
||
| * **First step** -> ``'high'`` (understand the task) | ||
| * **After tool errors** -> ``'high'`` (reason about what went wrong) | ||
| * **Simple follow-ups** -> ``'low'`` (just incorporating tool results) | ||
|
|
||
| Supply a custom ``effort_fn`` to override these rules:: | ||
|
|
||
| def my_effort(ctx: RunContext[MyDeps]) -> Literal['low', 'medium', 'high']: | ||
| if ctx.run_step > 5: | ||
| return 'high' # wrap-up needs careful thought | ||
| return 'medium' | ||
|
|
||
| agent = Agent(..., capabilities=[AdaptiveReasoning(effort_fn=my_effort)]) | ||
| """ | ||
|
|
||
| effort_fn: Callable[[RunContext[Any]], Literal['low', 'medium', 'high']] = field(default=default_effort_fn) | ||
| """Callable that receives the current ``RunContext`` and returns an effort level.""" | ||
|
|
||
| def get_model_settings(self) -> Callable[[RunContext[Any]], ModelSettings]: | ||
| """Return a dynamic model-settings callable that sets ``thinking`` per step.""" | ||
|
|
||
| def _resolve(ctx: RunContext[Any]) -> ModelSettings: | ||
| effort = self.effort_fn(ctx) | ||
| return ModelSettings(thinking=_EFFORT_TO_THINKING[effort]) | ||
|
|
||
| return _resolve | ||
|
Comment on lines
+83
to
+113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Capability only overrides The Was this helpful? React with 👍 or 👎 to provide feedback. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Docstring says
run_step == 1but code usesrun_step <= 1The docstring for
default_effort_fndocuments rule 1 as "First step (run_step == 1)" but the actual implementation at line 68 usesctx.run_step <= 1, which also matchesrun_step == 0. The testtest_step_zero_high(tests/test_adaptive_reasoning.py:152-154) confirms thatrun_step=0returns'high', matching the code but not the docstring. ThePLAN.md:38correctly documents this asrun_step <= 1.Was this helpful? React with 👍 or 👎 to provide feedback.