Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pydantic_harness/code_mode/_capability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import Any

from pydantic_ai import AbstractToolset
from pydantic_ai import AbstractToolset, DeferredToolRequests, RunContext
from pydantic_ai.capabilities import AbstractCapability, CapabilityOrdering
from pydantic_ai.capabilities._tool_search import ToolSearch as _ToolSearch
from pydantic_ai.run import AgentRunResult
from pydantic_ai.tools import AgentDepsT, ToolSelector

from pydantic_harness.code_mode._toolset import CodeModeToolset

from ._toolset import _RUN_CODE_TOOL_NAME


@dataclass
class CodeMode(AbstractCapability[AgentDepsT]):
Expand Down Expand Up @@ -55,3 +59,19 @@ def get_ordering(self) -> CapabilityOrdering:
def get_wrapper_toolset(self, toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT] | None:
"""Wrap the agent's assembled toolset, splitting it into native + sandboxed subsets if needed."""
return CodeModeToolset(wrapped=toolset, tool_selector=self.tools, max_retries=self.max_retries)

async def after_run(self, ctx: RunContext[AgentDepsT], *, result: AgentRunResult[Any]) -> AgentRunResult[Any]:
output = result.output
if not isinstance(output, DeferredToolRequests):
return result

for i, part in enumerate(output.approvals):
if part.tool_name != _RUN_CODE_TOOL_NAME:
continue
metadata = result.output.metadata.get(part.tool_call_id, {})
tool_name = metadata.get('tool_name')
kwargs = metadata.get('kwargs')
if isinstance(tool_name, str) and isinstance(kwargs, dict):
output.approvals[i] = replace(part, tool_name=tool_name, args=kwargs)

return result
60 changes: 37 additions & 23 deletions pydantic_harness/code_mode/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MontySyntaxError,
MontyTypingError,
NameLookupSnapshot,
load_repl_snapshot,
)
except ImportError as _import_error: # pragma: no cover
raise ImportError(
Expand All @@ -46,7 +47,7 @@
from typing_extensions import NotRequired, TypedDict

# Type alias for the dispatch callback passed to _execution_loop.
_DispatchFn = Callable[[str, dict[str, Any]], Coroutine[Any, Any, Any]]
_DispatchFn = Callable[[str, dict[str, Any], int], Coroutine[Any, Any, Any]]


class _RunCodeArguments(TypedDict):
Expand Down Expand Up @@ -312,20 +313,17 @@ async def call_tool(
# can be attached as metadata on the run_code ToolReturnPart.
nested_calls: dict[str, ToolCallPart] = {}
nested_returns: dict[str, ToolReturnPart] = {}
call_counter = 0

async def dispatch_tool_call(original_name: str, kwargs: dict[str, Any]) -> Any:
async def dispatch_tool_call(original_name: str, kwargs: dict[str, Any], call_id: int) -> Any:
"""Dispatch a single tool call from inside the sandbox.

Returns the serialized tool result on success. On failure, the
exception propagates — the execution loop passes it back into
Monty via `ExternalException` so the sandbox sees it at the
`await` site.
"""
nonlocal call_counter
call_counter += 1
parent_id = ctx.tool_call_id or 'pyd_ai_code_mode'
tool_call_id = f'{parent_id}__{call_counter}'
tool_call_id = f'{parent_id}__{call_id}'
call_part = ToolCallPart(tool_name=original_name, args=kwargs, tool_call_id=tool_call_id)
nested_calls[tool_call_id] = call_part

Expand Down Expand Up @@ -374,16 +372,24 @@ async def dispatch_tool_call(original_name: str, kwargs: dict[str, Any]) -> Any:
assert self._repl is not None

capture = _PrintCapture()
approved_tool: tuple[str, Any] | None = None

try:
monty_state = self._repl.feed_start(code, print_callback=capture)
if ctx.tool_call_approved:
metadata = ctx.tool_call_metadata
snapshot = metadata.get('snapshot')
approved_tool = (metadata.get('tool_name'), metadata.get('kwargs'))
monty_state, self._repl = load_repl_snapshot(data=snapshot)
else:
monty_state = self._repl.feed_start(code, print_callback=capture)
completed = await _execution_loop(
monty_state,
dispatch=dispatch_tool_call,
callable_defs=callable_defs,
sanitized_to_original=sanitized_to_original,
sequential_tools=sequential_tools,
global_sequential=global_sequential,
approved_tool=approved_tool,
)
except MontySyntaxError as e:
raise ModelRetry(f'Syntax error in code:\n{_prepend_prints(e.display(), capture)}') from e
Expand All @@ -400,6 +406,10 @@ async def dispatch_tool_call(original_name: str, kwargs: dict[str, Any]) -> Any:
# (ModelRetry → MontyRuntimeError → ModelRetry), but the retry
# semantics are the same — the model gets another chance.
raise ModelRetry(f'Runtime error:\n{_prepend_prints(e.display(), capture)}') from e
except ValueError:
raise UserError('Snapshot is corrupted')
except (CallDeferred, ApprovalRequired):
raise

result = completed.output
printed = capture.joined
Expand Down Expand Up @@ -453,19 +463,6 @@ def _partition_callable_tools(
native_fallbacks: set[str] = set()
for name, tool in wrapped_tools.items():
td = tool.tool_def
if td.defer:
if name not in self._warned_deferred:
self._warned_deferred.add(name)
warnings.warn(
f'CodeMode: tool {name!r} requires deferred execution '
f'(kind={td.kind!r}) and cannot be called from inside the '
f'sandbox; it will be exposed as a native tool instead.',
UserWarning,
stacklevel=2,
)
native_fallbacks.add(name)
continue

safe_name = _sanitize_tool_name(name)
if safe_name == _RUN_CODE_TOOL_NAME:
raise UserError(
Expand Down Expand Up @@ -576,6 +573,7 @@ async def _execution_loop(
sanitized_to_original: dict[str, str],
sequential_tools: set[str],
global_sequential: bool,
approved_tool: tuple[str, Any] | None,
) -> MontyComplete:
"""Drive the Monty REPL via the synchronous snapshot API until completion.

Expand All @@ -600,6 +598,7 @@ async def _execution_loop(
# barrier) but whose FutureSnapshot hasn't been reached yet.
pre_resolved: dict[int, ExternalResult] = {}
try:
# I would ideally want to collect those here instead of raising them one by one and causing multiple flows of that
while not isinstance(monty_state, MontyComplete):
if isinstance(monty_state, NameLookupSnapshot):
monty_state = monty_state.resume()
Expand All @@ -613,6 +612,7 @@ async def _execution_loop(
global_sequential=global_sequential,
pending=pending,
pre_resolved=pre_resolved,
approved_tool=approved_tool,
)
else:
monty_state = await _resolve_future_snapshot(
Expand Down Expand Up @@ -641,6 +641,7 @@ async def _handle_function_snapshot(
global_sequential: bool,
pending: dict[int, asyncio.Task[Any] | Coroutine[Any, Any, Any]],
pre_resolved: dict[int, ExternalResult],
approved_tool: tuple[str, Any] | None,
) -> FunctionSnapshot | FutureSnapshot | NameLookupSnapshot | MontyComplete:
"""Handle a single FunctionSnapshot from the Monty execution loop."""
fn_name = snapshot.function_name
Expand All @@ -655,24 +656,37 @@ async def _handle_function_snapshot(

original_name = sanitized_to_original.get(fn_name, fn_name)

td = callable_defs[fn_name]

approved = approved_tool and (original_name == approved_tool[0] and snapshot.kwargs == approved_tool[1])
if not approved:
if td.kind == 'unapproved':
raise ApprovalRequired(
metadata={'snapshot': snapshot.dump(), 'tool_name': original_name, 'kwargs': snapshot.kwargs}
)
elif td.kind == 'external':
raise CallDeferred(
metadata={'snapshot': snapshot.dump(), 'tool_name': original_name, 'kwargs': snapshot.kwargs}
)

if fn_name in sequential_tools:
# Per-tool sequential: rendered as `def` (sync), so must resolve inline —
# the sandbox code doesn't `await` the result. Await pending parallel
# tasks first (barrier) to maintain ordering.
for cid in list(pending):
pre_resolved[cid] = await _resolve_coro(pending.pop(cid))
outcome = await _resolve_coro(dispatch(original_name, snapshot.kwargs))
outcome = await _resolve_coro(dispatch(original_name, snapshot.kwargs, snapshot.call_id))
if 'return_value' in outcome:
return snapshot.resume(return_value=outcome['return_value'])
return snapshot.resume(exception=outcome['exception'])

# Deferred execution — store for later resolution at FutureSnapshot.
if global_sequential:
# Bare coroutine — don't schedule on the event loop yet.
pending[snapshot.call_id] = dispatch(original_name, snapshot.kwargs)
pending[snapshot.call_id] = dispatch(original_name, snapshot.kwargs, snapshot.call_id)
else:
# Eagerly schedule as a Task for concurrent execution.
pending[snapshot.call_id] = asyncio.ensure_future(dispatch(original_name, snapshot.kwargs))
pending[snapshot.call_id] = asyncio.ensure_future(dispatch(original_name, snapshot.kwargs, snapshot.call_id))
return snapshot.resume(future=...)


Expand Down
Loading
Loading