From 619b692983fd6ef7d23f6d5d84ecebec6b8be533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:05:28 +0200 Subject: [PATCH 1/9] feat: add input and output guardrails --- pydantic_ai_harness/__init__.py | 22 +- pydantic_ai_harness/guardrails/README.md | 112 ++++++ pydantic_ai_harness/guardrails/__init__.py | 21 ++ pydantic_ai_harness/guardrails/_capability.py | 215 +++++++++++ pydantic_ai_harness/guardrails/_exceptions.py | 20 + tests/_guardrails/__init__.py | 0 tests/_guardrails/test_input_guard.py | 347 ++++++++++++++++++ tests/_guardrails/test_output_guard.py | 72 ++++ 8 files changed, 808 insertions(+), 1 deletion(-) create mode 100644 pydantic_ai_harness/guardrails/README.md create mode 100644 pydantic_ai_harness/guardrails/__init__.py create mode 100644 pydantic_ai_harness/guardrails/_capability.py create mode 100644 pydantic_ai_harness/guardrails/_exceptions.py create mode 100644 tests/_guardrails/__init__.py create mode 100644 tests/_guardrails/test_input_guard.py create mode 100644 tests/_guardrails/test_output_guard.py diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..6222b5a 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -4,8 +4,24 @@ if TYPE_CHECKING: from .code_mode import CodeMode + from .guardrails import ( + GuardrailError, + GuardrailFunc, + InputBlocked, + InputGuard, + OutputBlocked, + OutputGuard, + ) -__all__ = ['CodeMode'] +__all__ = [ + 'CodeMode', + 'GuardrailError', + 'GuardrailFunc', + 'InputBlocked', + 'InputGuard', + 'OutputBlocked', + 'OutputGuard', +] def __getattr__(name: str) -> object: @@ -13,4 +29,8 @@ def __getattr__(name: str) -> object: from .code_mode import CodeMode return CodeMode + if name in {'GuardrailError', 'GuardrailFunc', 'InputBlocked', 'InputGuard', 'OutputBlocked', 'OutputGuard'}: + from . import guardrails + + return getattr(guardrails, name) raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_harness/guardrails/README.md b/pydantic_ai_harness/guardrails/README.md new file mode 100644 index 0000000..12025da --- /dev/null +++ b/pydantic_ai_harness/guardrails/README.md @@ -0,0 +1,112 @@ +# Guardrails + +Intercept unsafe user prompts before they reach the model, and unsafe model outputs before they reach the caller. + +## The problem + +Agents take unstructured input from users and return unstructured output to callers. Without a validation layer, a prompt injection attempt, PII-laden message, or off-topic question goes to the model as-is, and any output the model produces is returned verbatim. The framework does not reason about "this is unsafe to send" or "this is unsafe to show". + +## The solution + +Two capabilities, each backed by a callable you supply. + +| Capability | Checks | When a guard returns `False` | When a guard raises | +|---|---|---|---| +| `InputGuard` | The user prompt before each model request | `SkipModelRequest` — the model call is skipped and `block_message` becomes the response for that step | The exception propagates out of the run | +| `OutputGuard` | The final run output | `OutputBlocked` is raised | The exception propagates out of the run | + +The asymmetry is intentional. Blocking the input means no tokens are spent, so a graceful refusal is almost always what you want. Blocking the output means the model already generated a response you do not want exposed — raising forces the caller to decide what to do next. + +## Usage + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import InputGuard, OutputGuard + + +def no_secrets(prompt: str) -> bool: + return 'api_key' not in prompt.lower() + + +def no_pii(output: str) -> bool: + return 'SSN' not in output + + +agent = Agent( + 'openai:gpt-4.1', + capabilities=[ + InputGuard(guard=no_secrets), + OutputGuard(guard=no_pii), + ], +) +``` + +Both guards accept async callables too: + +```python +async def check_with_moderation_api(prompt: str) -> bool: + response = await client.moderations.create(input=prompt) + return not response.results[0].flagged + + +agent = Agent( + 'openai:gpt-4.1', + capabilities=[InputGuard(guard=check_with_moderation_api)], +) +``` + +## Parallel input guards + +When a guard is slow (an LLM-based classifier or a network call), running it in sequence before every model request adds latency to every turn. Set `parallel=True` to race the guard against the model call. The model call is cancelled immediately if the guard reports a violation. + +```python +InputGuard(guard=slow_async_classifier, parallel=True) +``` + +For fast local checks (regex, keyword lookup, a small classifier) sequential is usually fine — the overhead is measured in microseconds and the wiring is simpler. + +## Customising the block message + +```python +InputGuard( + guard=no_secrets, + block_message='This request looks like it contains credentials. Please rephrase.', +) +``` + +The text is returned as the model response for that step, so the caller sees a normal completion rather than an exception. Multi-turn agents can continue the conversation from there. + +## Hard-fail path + +Returning `False` from a guard is the graceful path. If you want the caller to see an exception instead, raise from the guard: + +```python +from pydantic_ai_harness import InputBlocked + + +def strict_guard(prompt: str) -> bool: + if contains_credentials(prompt): + raise InputBlocked('credentials detected') + return True +``` + +Any exception raised by the guard propagates as-is — you can use `InputBlocked` / `OutputBlocked` from this module or your own exception types. + +## API + +```python +InputGuard( + guard: Callable[[str], bool | Awaitable[bool]], + parallel: bool = False, + block_message: str = 'Request blocked by input guardrail.', +) + +OutputGuard( + guard: Callable[[str], bool | Awaitable[bool]], + block_message: str = 'Output blocked by output guardrail.', +) +``` + +## Relationship to `pydantic-ai-shields` + +`pydantic-ai-shields` provides opinionated implementations on top of these primitives (prompt-injection detectors, PII scrubbers, keyword blocklists, etc.). Use the guardrails here when you want to plug in your own validation logic; reach for shields when you need a batteries-included detector. diff --git a/pydantic_ai_harness/guardrails/__init__.py b/pydantic_ai_harness/guardrails/__init__.py new file mode 100644 index 0000000..a21d122 --- /dev/null +++ b/pydantic_ai_harness/guardrails/__init__.py @@ -0,0 +1,21 @@ +"""Input and output guardrails for Pydantic AI agents.""" + +from pydantic_ai_harness.guardrails._capability import ( + GuardrailFunc, + InputGuard, + OutputGuard, +) +from pydantic_ai_harness.guardrails._exceptions import ( + GuardrailError, + InputBlocked, + OutputBlocked, +) + +__all__ = [ + 'GuardrailError', + 'GuardrailFunc', + 'InputBlocked', + 'InputGuard', + 'OutputBlocked', + 'OutputGuard', +] diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py new file mode 100644 index 0000000..0b08fb7 --- /dev/null +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -0,0 +1,215 @@ +"""Input and output guardrail capabilities. + +`InputGuard` intercepts each model request and lets a user-supplied callable +decide whether the current user prompt is safe to send to the model. A guard +that returns `False` is treated as a graceful refusal: the LLM call is +skipped via [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] and +a canned message becomes the model response for that step. A guard that raises +propagates the exception so the caller can observe a hard failure. + +`OutputGuard` runs once the run completes and validates the final output. +A guard that returns `False` raises +[`OutputBlocked`][pydantic_ai_harness.guardrails.OutputBlocked]. + +Both guards accept sync or async callables. +""" + +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from pydantic_ai.capabilities import AbstractCapability, WrapModelRequestHandler +from pydantic_ai.exceptions import SkipModelRequest +from pydantic_ai.messages import ModelResponse, TextPart, UserPromptPart +from pydantic_ai.tools import AgentDepsT, RunContext + +from pydantic_ai_harness.guardrails._exceptions import OutputBlocked + +if TYPE_CHECKING: + from pydantic_ai.models import ModelRequestContext + from pydantic_ai.run import AgentRunResult + + +GuardrailFunc = Callable[[str], bool | Awaitable[bool]] +"""Signature of the callable passed to `InputGuard` / `OutputGuard`. + +The callable receives the text to validate and returns `True` when safe. +It may be sync or async. Raising an exception is treated as a hard failure +and propagates up to the caller. +""" + + +async def _evaluate(guard: GuardrailFunc, value: str) -> bool: + """Call `guard` and await it if it returned an awaitable.""" + result = guard(value) + if inspect.isawaitable(result): + return await result + return result + + +def _extract_prompt(ctx: RunContext[AgentDepsT], messages: list[Any]) -> str | None: + """Return the text of the most recent user prompt, or `None` if absent. + + Prefers `ctx.prompt` (set at run start) and falls back to scanning the + message history for the last [`UserPromptPart`][pydantic_ai.messages.UserPromptPart] + so that sub-agent calls or resumed runs without a fresh prompt still work. + """ + if ctx.prompt is not None: + return ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt) + for message in reversed(messages): + parts = getattr(message, 'parts', None) + if not parts: + continue + for part in reversed(parts): + if isinstance(part, UserPromptPart): + return part.content if isinstance(part.content, str) else str(part.content) + return None + + +@dataclass +class InputGuard(AbstractCapability[AgentDepsT]): + """Validate the user prompt before it reaches the model. + + The `guard` callable receives the prompt text and returns `True` when + the input is safe. Returning `False` triggers a graceful refusal: the + current model request is short-circuited via + [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] with + `block_message` as the response text, so the agent returns cleanly to + the caller. Raising an exception from the guard propagates it as-is. + + ```python + from pydantic_ai import Agent + from pydantic_ai_harness import InputGuard + + + def no_secrets(prompt: str) -> bool: + return 'api_key' not in prompt.lower() + + + agent = Agent('openai:gpt-4.1', capabilities=[InputGuard(guard=no_secrets)]) + ``` + + Set `parallel=True` to start the guard alongside the model call. The + handler is cancelled as soon as the guard reports a violation, which saves + tokens when the guard is slower than the provider round-trip. + """ + + guard: GuardrailFunc + """Callable that returns `True` when the prompt is safe to send to the model.""" + + parallel: bool = False + """Run the guard concurrently with the model request and cancel the model call on failure.""" + + block_message: str = 'Request blocked by input guardrail.' + """Text returned as the model response when the guard trips gracefully.""" + + def _blocked_response(self) -> ModelResponse: + return ModelResponse(parts=[TextPart(content=self.block_message)]) + + async def _run_guard(self, prompt: str) -> None: + """Evaluate the guard and raise `SkipModelRequest` on failure.""" + if not await _evaluate(self.guard, prompt): + raise SkipModelRequest(self._blocked_response()) + + async def before_model_request( + self, + ctx: RunContext[AgentDepsT], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Check the prompt before the model call in sequential mode.""" + if self.parallel: + return request_context + prompt = _extract_prompt(ctx, list(request_context.messages)) + if prompt is None: + return request_context + await self._run_guard(prompt) + return request_context + + async def wrap_model_request( + self, + ctx: RunContext[AgentDepsT], + *, + request_context: ModelRequestContext, + handler: WrapModelRequestHandler, + ) -> ModelResponse: + """Run the guard alongside the model call when `parallel=True`.""" + if not self.parallel: + return await handler(request_context) + prompt = _extract_prompt(ctx, list(request_context.messages)) + if prompt is None: + return await handler(request_context) + async def run_handler() -> ModelResponse: + return await handler(request_context) + + guard_task: asyncio.Task[None] = asyncio.create_task(self._run_guard(prompt)) + handler_task: asyncio.Task[ModelResponse] = asyncio.create_task(run_handler()) + try: + done, _ = await asyncio.wait( + [guard_task, handler_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if guard_task in done: + guard_exc = guard_task.exception() + if guard_exc is not None: + handler_task.cancel() + raise guard_exc + return await handler_task + # Handler finished first: if it raised, propagate and cancel the guard. + handler_exc = handler_task.exception() + if handler_exc is not None: + guard_task.cancel() + raise handler_exc + # Handler succeeded; still need the guard verdict before committing the response. + await guard_task + return handler_task.result() + finally: + if not guard_task.done(): + guard_task.cancel() + if not handler_task.done(): + handler_task.cancel() + + +@dataclass +class OutputGuard(AbstractCapability[AgentDepsT]): + """Validate the final agent output. + + The `guard` callable receives the stringified run output and returns + `True` when the output is safe to expose to the caller. Returning + `False` raises + [`OutputBlocked`][pydantic_ai_harness.guardrails.OutputBlocked] with + `block_message`. Raising an exception from the guard propagates it. + + ```python + from pydantic_ai import Agent + from pydantic_ai_harness import OutputGuard + + + def no_pii(output: str) -> bool: + return 'SSN' not in output + + + agent = Agent('openai:gpt-4.1', capabilities=[OutputGuard(guard=no_pii)]) + ``` + """ + + guard: GuardrailFunc + """Callable that returns `True` when the output is safe.""" + + block_message: str = 'Output blocked by output guardrail.' + """Message attached to `OutputBlocked` when the guard trips.""" + + async def after_run( + self, + ctx: RunContext[AgentDepsT], + *, + result: AgentRunResult[Any], + ) -> AgentRunResult[Any]: + """Validate `result.output` and raise `OutputBlocked` on failure.""" + output = str(result.output) + if not await _evaluate(self.guard, output): + raise OutputBlocked(self.block_message) + return result diff --git a/pydantic_ai_harness/guardrails/_exceptions.py b/pydantic_ai_harness/guardrails/_exceptions.py new file mode 100644 index 0000000..eacad6e --- /dev/null +++ b/pydantic_ai_harness/guardrails/_exceptions.py @@ -0,0 +1,20 @@ +"""Exceptions raised by the guardrail capabilities.""" + +from __future__ import annotations + + +class GuardrailError(Exception): + """Base exception for guardrail violations.""" + + +class InputBlocked(GuardrailError): + """Raised by a user-supplied input guard to hard-fail a run. + + Prefer returning ``False`` from the guard callable to trigger a graceful + refusal via [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]. + Raise this explicitly when the caller should have to handle the failure. + """ + + +class OutputBlocked(GuardrailError): + """Raised by [`OutputGuard`][pydantic_ai_harness.OutputGuard] when the final output fails validation.""" diff --git a/tests/_guardrails/__init__.py b/tests/_guardrails/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py new file mode 100644 index 0000000..9e2b15a --- /dev/null +++ b/tests/_guardrails/test_input_guard.py @@ -0,0 +1,347 @@ +"""Tests for the `InputGuard` capability.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from pydantic_ai import Agent +from pydantic_ai.exceptions import SkipModelRequest +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) +from pydantic_ai.models.test import TestModel + +from pydantic_ai_harness import InputBlocked, InputGuard +from pydantic_ai_harness.guardrails._capability import _extract_prompt # pyright: ignore[reportPrivateUsage] + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +async def test_guard_allows_when_safe(): + calls: list[str] = [] + + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + result = await agent.run('hello') + + assert result.output == 'ok' + assert calls == ['hello'] + + +async def test_guard_block_uses_block_message(): + agent = Agent( + TestModel(custom_output_text='would be model output'), + capabilities=[InputGuard[None](guard=lambda _: False, block_message='nope')], + ) + result = await agent.run('hello') + + assert result.output == 'nope' + + +async def test_async_guard_awaited(): + async def guard(prompt: str) -> bool: + await asyncio.sleep(0) + return 'safe' in prompt + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + + assert (await agent.run('safe message')).output == 'ok' + assert (await agent.run('bad message')).output == 'Request blocked by input guardrail.' + + +async def test_guard_raising_propagates(): + def guard(_: str) -> bool: + raise InputBlocked('policy violation') + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + with pytest.raises(InputBlocked, match='policy violation'): + await agent.run('anything') + + +def test_extract_prompt_from_messages(): + """Extraction falls back to the most recent `UserPromptPart`.""" + + class _Ctx: + prompt = None + + messages: list[Any] = [ + ModelRequest(parts=[UserPromptPart(content='first')]), + ModelResponse(parts=[TextPart(content='assistant')]), + ModelRequest(parts=[UserPromptPart(content='second')]), + ] + assert _extract_prompt(_Ctx(), messages) == 'second' # pyright: ignore[reportArgumentType] + + +def test_extract_prompt_stringifies_non_str_prompt(): + class _Ctx: + prompt = ['multimodal', 'content'] + + assert _extract_prompt(_Ctx(), []) == str(['multimodal', 'content']) # pyright: ignore[reportArgumentType] + + +def test_extract_prompt_stringifies_non_str_message_part(): + class _Ctx: + prompt = None + + messages: list[Any] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] + assert _extract_prompt(_Ctx(), messages) == str(['multi']) # pyright: ignore[reportArgumentType] + + +def test_extract_prompt_skips_messages_without_parts(): + class _Ctx: + prompt = None + + class _EmptyMessage: + pass + + messages: list[Any] = [_EmptyMessage(), ModelResponse(parts=[TextPart(content='only model parts here')])] + assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] + +async def _build_ctx_and_req() -> tuple[Any, Any]: + from pydantic_ai.models import ModelRequestContext, ModelRequestParameters + from pydantic_ai.models.test import TestModel as _TestModel + from pydantic_ai.tools import RunContext + from pydantic_ai.usage import RunUsage + + model = _TestModel() + messages: list[Any] = [ModelRequest(parts=[UserPromptPart(content='hello world')])] + req_ctx = ModelRequestContext( + model=model, + messages=messages, + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + run_ctx: RunContext[None] = RunContext( + deps=None, + model=model, + usage=RunUsage(), + prompt='hello world', + messages=messages, + ) + return run_ctx, req_ctx + + +async def test_parallel_guard_allows_handler_to_return(): + run_ctx, req_ctx = await _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) + + async def handler(_: Any) -> ModelResponse: + return sentinel + + guard = InputGuard[None](guard=lambda _: True, parallel=True) + out = await guard.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + + +async def test_parallel_guard_trips_and_cancels_handler(): + run_ctx, req_ctx = await _build_ctx_and_req() + handler_cancelled = asyncio.Event() + handler_started = asyncio.Event() + + async def slow_handler(_: Any) -> ModelResponse: + handler_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + handler_cancelled.set() + raise + return ModelResponse(parts=[TextPart(content='should never')]) # pragma: no cover + + async def guard(_: str) -> bool: + await handler_started.wait() + return False + + ig = InputGuard[None](guard=guard, parallel=True, block_message='blocked!') + with pytest.raises(SkipModelRequest) as exc_info: + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + + assert exc_info.value.response.parts[0] == TextPart(content='blocked!') + # Give the cancellation a chance to propagate. + await asyncio.sleep(0) + assert handler_cancelled.is_set() + + +async def test_parallel_guard_raises_propagates(): + run_ctx, req_ctx = await _build_ctx_and_req() + + async def slow_handler(_: Any) -> ModelResponse: + await asyncio.sleep(10) + return ModelResponse(parts=[TextPart(content='never')]) # pragma: no cover + + async def guard(_: str) -> bool: + raise InputBlocked('hard policy failure') + + ig = InputGuard[None](guard=guard, parallel=True) + with pytest.raises(InputBlocked, match='hard policy failure'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + + +async def test_parallel_handler_finishes_before_guard(): + """Handler completes first; guard still has to be awaited for a verdict.""" + run_ctx, req_ctx = await _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return sentinel + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return True + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True) + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + # Yield enough times for handler_task to complete while guard_task is still waiting on the event. + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + assert await task is sentinel + + +async def test_parallel_handler_finishes_then_guard_trips(): + """Handler returns first, then the guard trips — `SkipModelRequest` still wins.""" + run_ctx, req_ctx = await _build_ctx_and_req() + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='from handler')]) + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return False + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True, block_message='late trip') + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + with pytest.raises(SkipModelRequest) as exc_info: + await task + assert exc_info.value.response.parts[0] == TextPart(content='late trip') + + +async def test_parallel_handler_raises_while_guard_runs(): + """When the handler raises, `finally` cancels the still-running guard.""" + run_ctx, req_ctx = await _build_ctx_and_req() + guard_cancelled = asyncio.Event() + + async def failing_handler(_: Any) -> ModelResponse: + raise RuntimeError('model boom') + + async def slow_guard(_: str) -> bool: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + guard_cancelled.set() + raise + return True # pragma: no cover + + ig = InputGuard[None](guard=slow_guard, parallel=True) + with pytest.raises(RuntimeError, match='model boom'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) + await asyncio.sleep(0) + assert guard_cancelled.is_set() + + +async def test_parallel_skipped_when_prompt_missing(): + from pydantic_ai.models import ModelRequestContext, ModelRequestParameters + from pydantic_ai.models.test import TestModel as _TestModel + from pydantic_ai.tools import RunContext + from pydantic_ai.usage import RunUsage + + model = _TestModel() + req_ctx = ModelRequestContext( + model=model, + messages=[], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + run_ctx: RunContext[None] = RunContext(deps=None, model=model, usage=RunUsage(), prompt=None, messages=[]) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should never be called + called.append(prompt) + return False + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] + + +async def test_sequential_wrap_model_request_is_passthrough(): + run_ctx, req_ctx = await _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=lambda _: True, parallel=False) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + + +async def test_sequential_before_request_returns_context_when_prompt_missing(): + from pydantic_ai.models import ModelRequestContext, ModelRequestParameters + from pydantic_ai.models.test import TestModel as _TestModel + from pydantic_ai.tools import RunContext + from pydantic_ai.usage import RunUsage + + model = _TestModel() + req_ctx = ModelRequestContext( + model=model, + messages=[], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + run_ctx: RunContext[None] = RunContext(deps=None, model=model, usage=RunUsage(), prompt=None, messages=[]) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called + called.append(prompt) + return True + + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + + +async def test_parallel_mode_before_request_is_noop(): + run_ctx, req_ctx = await _build_ctx_and_req() + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called via before_model_request + called.append(prompt) + return False + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] diff --git a/tests/_guardrails/test_output_guard.py b/tests/_guardrails/test_output_guard.py new file mode 100644 index 0000000..d815eb2 --- /dev/null +++ b/tests/_guardrails/test_output_guard.py @@ -0,0 +1,72 @@ +"""Tests for the `OutputGuard` capability.""" + +from __future__ import annotations + +import asyncio + +import pytest +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + +from pydantic_ai_harness import OutputBlocked, OutputGuard +from pydantic_ai_harness.guardrails import GuardrailError + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +async def test_guard_allows_safe_output(): + agent = Agent( + TestModel(custom_output_text='harmless reply'), + capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in out)], + ) + result = await agent.run('hello') + assert result.output == 'harmless reply' + + +async def test_guard_blocks_unsafe_output(): + agent = Agent( + TestModel(custom_output_text='leaks SSN 123-45-6789'), + capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in out, block_message='contains SSN')], + ) + with pytest.raises(OutputBlocked, match='contains SSN'): + await agent.run('hello') + + +async def test_async_guard_awaited(): + async def guard(output: str) -> bool: + await asyncio.sleep(0) + return 'bad' not in output + + agent = Agent( + TestModel(custom_output_text='ok reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + assert (await agent.run('prompt')).output == 'ok reply' + + agent_bad = Agent( + TestModel(custom_output_text='bad reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('prompt') + + +async def test_guard_raising_propagates(): + def guard(_: str) -> bool: + raise RuntimeError('guard exploded') + + agent = Agent( + TestModel(custom_output_text='anything'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(RuntimeError, match='guard exploded'): + await agent.run('hello') + + +def test_output_blocked_is_guardrail_error(): + assert issubclass(OutputBlocked, GuardrailError) From fcbe3aee1d5cd563b7bca472d837c11b59731865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:17:37 +0200 Subject: [PATCH 2/9] chore: use single backticks --- pydantic_ai_harness/guardrails/_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_harness/guardrails/_exceptions.py b/pydantic_ai_harness/guardrails/_exceptions.py index eacad6e..8ae37dd 100644 --- a/pydantic_ai_harness/guardrails/_exceptions.py +++ b/pydantic_ai_harness/guardrails/_exceptions.py @@ -10,7 +10,7 @@ class GuardrailError(Exception): class InputBlocked(GuardrailError): """Raised by a user-supplied input guard to hard-fail a run. - Prefer returning ``False`` from the guard callable to trigger a graceful + Prefer returning `False` from the guard callable to trigger a graceful refusal via [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]. Raise this explicitly when the caller should have to handle the failure. """ From 22f49998b422953ea3f08fd400b33e44119e9a82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:23:50 +0200 Subject: [PATCH 3/9] fix: run InputGuard only on the first model request --- pydantic_ai_harness/guardrails/_capability.py | 10 ++- tests/_guardrails/test_input_guard.py | 65 ++++++++++++++++++- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index 0b08fb7..13c2efc 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -96,6 +96,12 @@ def no_secrets(prompt: str) -> bool: Set `parallel=True` to start the guard alongside the model call. The handler is cancelled as soon as the guard reports a violation, which saves tokens when the guard is slower than the provider round-trip. + + Scope: the guard runs exactly once per run — on the first model request — + and evaluates the original user prompt. Subsequent model requests in the + same run (e.g. after tool calls) are not re-checked, since the user input + has not changed. Validation of tool results or other mid-run content + belongs in a separate capability hooking `after_model_request`. """ guard: GuardrailFunc @@ -121,7 +127,7 @@ async def before_model_request( request_context: ModelRequestContext, ) -> ModelRequestContext: """Check the prompt before the model call in sequential mode.""" - if self.parallel: + if self.parallel or ctx.run_step > 1: return request_context prompt = _extract_prompt(ctx, list(request_context.messages)) if prompt is None: @@ -137,7 +143,7 @@ async def wrap_model_request( handler: WrapModelRequestHandler, ) -> ModelResponse: """Run the guard alongside the model call when `parallel=True`.""" - if not self.parallel: + if not self.parallel or ctx.run_step > 1: return await handler(request_context) prompt = _extract_prompt(ctx, list(request_context.messages)) if prompt is None: diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py index 9e2b15a..30171c6 100644 --- a/tests/_guardrails/test_input_guard.py +++ b/tests/_guardrails/test_input_guard.py @@ -110,7 +110,7 @@ class _EmptyMessage: messages: list[Any] = [_EmptyMessage(), ModelResponse(parts=[TextPart(content='only model parts here')])] assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] -async def _build_ctx_and_req() -> tuple[Any, Any]: +async def _build_ctx_and_req(run_step: int = 1) -> tuple[Any, Any]: from pydantic_ai.models import ModelRequestContext, ModelRequestParameters from pydantic_ai.models.test import TestModel as _TestModel from pydantic_ai.tools import RunContext @@ -130,6 +130,7 @@ async def _build_ctx_and_req() -> tuple[Any, Any]: usage=RunUsage(), prompt='hello world', messages=messages, + run_step=run_step, ) return run_ctx, req_ctx @@ -345,3 +346,65 @@ def guard(prompt: str) -> bool: # pragma: no cover — should not be called via out = await ig.before_model_request(run_ctx, req_ctx) assert out is req_ctx assert called == [] + + +# --------------------------------------------------------------------------- +# Re-entry protection — the guard must only fire on the first model request +# --------------------------------------------------------------------------- + + +async def test_sequential_skips_guard_on_subsequent_steps(): + """After the first model request, `before_model_request` must not re-run the guard.""" + run_ctx, req_ctx = await _build_ctx_and_req(run_step=2) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False + + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + + +async def test_parallel_skips_guard_on_subsequent_steps(): + """`wrap_model_request` must pass the handler through without running the guard past step 1.""" + run_ctx, req_ctx = await _build_ctx_and_req(run_step=2) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] + + +async def test_guard_runs_once_across_tool_loop(): + """End-to-end: guard fires once even when the model makes multiple tool calls.""" + calls: list[str] = [] + + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + + # TestModel(call_tools='all') calls each tool once, then returns a text response — two + # model requests total. + model = TestModel(call_tools='all', custom_output_text='done') + agent = Agent(model, capabilities=[InputGuard[None](guard=guard)]) + + @agent.tool_plain + def ping() -> str: + return 'pong' + + result = await agent.run('hello') + assert result.output == 'done' + assert calls == ['hello'] From b9b804016be4fe2f826c5140596ba00b36b80dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:30:44 +0200 Subject: [PATCH 4/9] fix: replace list[Any] with Sequence[ModelMessage] --- pydantic_ai_harness/guardrails/_capability.py | 15 ++++++--------- tests/_guardrails/test_input_guard.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index 13c2efc..399356b 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -18,13 +18,13 @@ import asyncio import inspect -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any from pydantic_ai.capabilities import AbstractCapability, WrapModelRequestHandler from pydantic_ai.exceptions import SkipModelRequest -from pydantic_ai.messages import ModelResponse, TextPart, UserPromptPart +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, UserPromptPart from pydantic_ai.tools import AgentDepsT, RunContext from pydantic_ai_harness.guardrails._exceptions import OutputBlocked @@ -51,7 +51,7 @@ async def _evaluate(guard: GuardrailFunc, value: str) -> bool: return result -def _extract_prompt(ctx: RunContext[AgentDepsT], messages: list[Any]) -> str | None: +def _extract_prompt(ctx: RunContext[AgentDepsT], messages: Sequence[ModelMessage]) -> str | None: """Return the text of the most recent user prompt, or `None` if absent. Prefers `ctx.prompt` (set at run start) and falls back to scanning the @@ -61,10 +61,7 @@ def _extract_prompt(ctx: RunContext[AgentDepsT], messages: list[Any]) -> str | N if ctx.prompt is not None: return ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt) for message in reversed(messages): - parts = getattr(message, 'parts', None) - if not parts: - continue - for part in reversed(parts): + for part in reversed(message.parts): if isinstance(part, UserPromptPart): return part.content if isinstance(part.content, str) else str(part.content) return None @@ -129,7 +126,7 @@ async def before_model_request( """Check the prompt before the model call in sequential mode.""" if self.parallel or ctx.run_step > 1: return request_context - prompt = _extract_prompt(ctx, list(request_context.messages)) + prompt = _extract_prompt(ctx, request_context.messages) if prompt is None: return request_context await self._run_guard(prompt) @@ -145,7 +142,7 @@ async def wrap_model_request( """Run the guard alongside the model call when `parallel=True`.""" if not self.parallel or ctx.run_step > 1: return await handler(request_context) - prompt = _extract_prompt(ctx, list(request_context.messages)) + prompt = _extract_prompt(ctx, request_context.messages) if prompt is None: return await handler(request_context) async def run_handler() -> ModelResponse: diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py index 30171c6..2d258f0 100644 --- a/tests/_guardrails/test_input_guard.py +++ b/tests/_guardrails/test_input_guard.py @@ -9,6 +9,7 @@ from pydantic_ai import Agent from pydantic_ai.exceptions import SkipModelRequest from pydantic_ai.messages import ( + ModelMessage, ModelRequest, ModelResponse, TextPart, @@ -77,7 +78,7 @@ def test_extract_prompt_from_messages(): class _Ctx: prompt = None - messages: list[Any] = [ + messages: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='first')]), ModelResponse(parts=[TextPart(content='assistant')]), ModelRequest(parts=[UserPromptPart(content='second')]), @@ -96,18 +97,17 @@ def test_extract_prompt_stringifies_non_str_message_part(): class _Ctx: prompt = None - messages: list[Any] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] assert _extract_prompt(_Ctx(), messages) == str(['multi']) # pyright: ignore[reportArgumentType] -def test_extract_prompt_skips_messages_without_parts(): +def test_extract_prompt_returns_none_when_no_user_prompt_part(): + """A history containing only model responses yields `None`.""" + class _Ctx: prompt = None - class _EmptyMessage: - pass - - messages: list[Any] = [_EmptyMessage(), ModelResponse(parts=[TextPart(content='only model parts here')])] + messages: list[ModelMessage] = [ModelResponse(parts=[TextPart(content='only model parts here')])] assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] async def _build_ctx_and_req(run_step: int = 1) -> tuple[Any, Any]: @@ -402,7 +402,7 @@ def guard(prompt: str) -> bool: agent = Agent(model, capabilities=[InputGuard[None](guard=guard)]) @agent.tool_plain - def ping() -> str: + def ping() -> str: # pyright: ignore[reportUnusedFunction] return 'pong' result = await agent.run('hello') From b27c6a330ee26adc4f05dd542022f8662d572092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:37:34 +0200 Subject: [PATCH 5/9] fix: pass raw output to OutputGuard, not str(result.output) --- pydantic_ai_harness/__init__.py | 18 +++++-- pydantic_ai_harness/guardrails/README.md | 27 +++++++++- pydantic_ai_harness/guardrails/__init__.py | 6 ++- pydantic_ai_harness/guardrails/_capability.py | 50 +++++++++++++------ tests/_guardrails/test_output_guard.py | 45 +++++++++++++++-- 5 files changed, 118 insertions(+), 28 deletions(-) diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 6222b5a..382297e 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -6,30 +6,42 @@ from .code_mode import CodeMode from .guardrails import ( GuardrailError, - GuardrailFunc, InputBlocked, InputGuard, + InputGuardFunc, OutputBlocked, OutputGuard, + OutputGuardFunc, ) __all__ = [ 'CodeMode', 'GuardrailError', - 'GuardrailFunc', 'InputBlocked', 'InputGuard', + 'InputGuardFunc', 'OutputBlocked', 'OutputGuard', + 'OutputGuardFunc', ] +_GUARDRAIL_EXPORTS = { + 'GuardrailError', + 'InputBlocked', + 'InputGuard', + 'InputGuardFunc', + 'OutputBlocked', + 'OutputGuard', + 'OutputGuardFunc', +} + def __getattr__(name: str) -> object: if name == 'CodeMode': from .code_mode import CodeMode return CodeMode - if name in {'GuardrailError', 'GuardrailFunc', 'InputBlocked', 'InputGuard', 'OutputBlocked', 'OutputGuard'}: + if name in _GUARDRAIL_EXPORTS: from . import guardrails return getattr(guardrails, name) diff --git a/pydantic_ai_harness/guardrails/README.md b/pydantic_ai_harness/guardrails/README.md index 12025da..b5b24c0 100644 --- a/pydantic_ai_harness/guardrails/README.md +++ b/pydantic_ai_harness/guardrails/README.md @@ -28,8 +28,8 @@ def no_secrets(prompt: str) -> bool: return 'api_key' not in prompt.lower() -def no_pii(output: str) -> bool: - return 'SSN' not in output +def no_pii(output: object) -> bool: + return 'SSN' not in str(output) agent = Agent( @@ -41,6 +41,29 @@ agent = Agent( ) ``` +`OutputGuard` receives `result.output` unchanged — no automatic stringification. For a string output the guard reads it directly; for a typed (Pydantic model) output the guard gets the model instance, so pick the serialization that fits the check: + +```python +from pydantic import BaseModel +from pydantic_ai_harness import OutputGuard + + +class Answer(BaseModel): + reply: str + sources: list[str] + + +def no_internal_urls(output: object) -> bool: + if isinstance(output, Answer): + return not any('internal.example.com' in url for url in output.sources) + return 'internal.example.com' not in str(output) + + +OutputGuard(guard=no_internal_urls) +``` + +This avoids the trap of `str(MyModel(...))` producing a `MyModel(field=...)` repr that hides field contents from regex-based checks. If you want JSON text, call `output.model_dump_json()` inside the guard. + Both guards accept async callables too: ```python diff --git a/pydantic_ai_harness/guardrails/__init__.py b/pydantic_ai_harness/guardrails/__init__.py index a21d122..6f4a17d 100644 --- a/pydantic_ai_harness/guardrails/__init__.py +++ b/pydantic_ai_harness/guardrails/__init__.py @@ -1,9 +1,10 @@ """Input and output guardrails for Pydantic AI agents.""" from pydantic_ai_harness.guardrails._capability import ( - GuardrailFunc, InputGuard, + InputGuardFunc, OutputGuard, + OutputGuardFunc, ) from pydantic_ai_harness.guardrails._exceptions import ( GuardrailError, @@ -13,9 +14,10 @@ __all__ = [ 'GuardrailError', - 'GuardrailFunc', 'InputBlocked', 'InputGuard', + 'InputGuardFunc', 'OutputBlocked', 'OutputGuard', + 'OutputGuardFunc', ] diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index 399356b..068ff24 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -34,18 +34,30 @@ from pydantic_ai.run import AgentRunResult -GuardrailFunc = Callable[[str], bool | Awaitable[bool]] -"""Signature of the callable passed to `InputGuard` / `OutputGuard`. +InputGuardFunc = Callable[[str], bool | Awaitable[bool]] +"""Signature of the callable passed to `InputGuard`. -The callable receives the text to validate and returns `True` when safe. -It may be sync or async. Raising an exception is treated as a hard failure -and propagates up to the caller. +The callable receives the user prompt and returns `True` when safe. It may +be sync or async. Raising an exception is treated as a hard failure and +propagates up to the caller. """ +OutputGuardFunc = Callable[[object], bool | Awaitable[bool]] +"""Signature of the callable passed to `OutputGuard`. -async def _evaluate(guard: GuardrailFunc, value: str) -> bool: +The callable receives `result.output` unchanged — for typed outputs this is +the Pydantic model (not a stringified form), so the guard can read fields +directly or serialize with `model_dump_json()` if it wants to match against +JSON text. Returning `True` marks the output safe. +""" + + +async def _evaluate( + guard: InputGuardFunc | OutputGuardFunc, + value: object, +) -> bool: """Call `guard` and await it if it returned an awaitable.""" - result = guard(value) + result = guard(value) # pyright: ignore[reportArgumentType] if inspect.isawaitable(result): return await result return result @@ -101,7 +113,7 @@ def no_secrets(prompt: str) -> bool: belongs in a separate capability hooking `after_model_request`. """ - guard: GuardrailFunc + guard: InputGuardFunc """Callable that returns `True` when the prompt is safe to send to the model.""" parallel: bool = False @@ -180,26 +192,33 @@ async def run_handler() -> ModelResponse: class OutputGuard(AbstractCapability[AgentDepsT]): """Validate the final agent output. - The `guard` callable receives the stringified run output and returns - `True` when the output is safe to expose to the caller. Returning - `False` raises + The `guard` callable receives `result.output` unchanged — no automatic + stringification — and returns `True` when the output is safe to expose. + Returning `False` raises [`OutputBlocked`][pydantic_ai_harness.guardrails.OutputBlocked] with `block_message`. Raising an exception from the guard propagates it. + For string outputs the guard works directly on the text. For typed + (Pydantic model) outputs the guard receives the model instance, so + choose the serialization that fits your check: read a field directly, + or call `model_dump_json()` to match against JSON text. Defaulting to + `str(model)` would produce a `MyModel(field=...)` repr rather than JSON + and easily hide fields from regex-based checks. + ```python from pydantic_ai import Agent from pydantic_ai_harness import OutputGuard - def no_pii(output: str) -> bool: - return 'SSN' not in output + def no_pii(output: object) -> bool: + return 'SSN' not in str(output) agent = Agent('openai:gpt-4.1', capabilities=[OutputGuard(guard=no_pii)]) ``` """ - guard: GuardrailFunc + guard: OutputGuardFunc """Callable that returns `True` when the output is safe.""" block_message: str = 'Output blocked by output guardrail.' @@ -212,7 +231,6 @@ async def after_run( result: AgentRunResult[Any], ) -> AgentRunResult[Any]: """Validate `result.output` and raise `OutputBlocked` on failure.""" - output = str(result.output) - if not await _evaluate(self.guard, output): + if not await _evaluate(self.guard, result.output): raise OutputBlocked(self.block_message) return result diff --git a/tests/_guardrails/test_output_guard.py b/tests/_guardrails/test_output_guard.py index d815eb2..60593b1 100644 --- a/tests/_guardrails/test_output_guard.py +++ b/tests/_guardrails/test_output_guard.py @@ -5,6 +5,7 @@ import asyncio import pytest +from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.models.test import TestModel @@ -22,7 +23,7 @@ def anyio_backend() -> str: async def test_guard_allows_safe_output(): agent = Agent( TestModel(custom_output_text='harmless reply'), - capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in out)], + capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in str(out))], ) result = await agent.run('hello') assert result.output == 'harmless reply' @@ -31,16 +32,18 @@ async def test_guard_allows_safe_output(): async def test_guard_blocks_unsafe_output(): agent = Agent( TestModel(custom_output_text='leaks SSN 123-45-6789'), - capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in out, block_message='contains SSN')], + capabilities=[ + OutputGuard[None](guard=lambda out: 'SSN' not in str(out), block_message='contains SSN'), + ], ) with pytest.raises(OutputBlocked, match='contains SSN'): await agent.run('hello') async def test_async_guard_awaited(): - async def guard(output: str) -> bool: + async def guard(output: object) -> bool: await asyncio.sleep(0) - return 'bad' not in output + return 'bad' not in str(output) agent = Agent( TestModel(custom_output_text='ok reply'), @@ -57,7 +60,7 @@ async def guard(output: str) -> bool: async def test_guard_raising_propagates(): - def guard(_: str) -> bool: + def guard(_: object) -> bool: raise RuntimeError('guard exploded') agent = Agent( @@ -68,5 +71,37 @@ def guard(_: str) -> bool: await agent.run('hello') +async def test_guard_receives_structured_output_unchanged(): + """For typed outputs the guard gets the model instance, not a stringified form.""" + + class Answer(BaseModel): + reply: str + internal_url: str + + seen: list[object] = [] + + def guard(output: object) -> bool: + seen.append(output) + assert isinstance(output, Answer) + return 'internal.example.com' not in output.internal_url + + agent = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://public.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + result = await agent.run('hello') + assert isinstance(result.output, Answer) + assert seen == [result.output] + + agent_bad = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://internal.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('hello') + + def test_output_blocked_is_guardrail_error(): assert issubclass(OutputBlocked, GuardrailError) From abf5259ef33f95f5b46b3a013d8961151361b229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Fri, 24 Apr 2026 12:41:42 +0200 Subject: [PATCH 6/9] refactor: organize tests into TestCapabilityName classes --- tests/_guardrails/test_input_guard.py | 560 ++++++++++++------------- tests/_guardrails/test_output_guard.py | 168 ++++---- 2 files changed, 345 insertions(+), 383 deletions(-) diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py index 2d258f0..8046425 100644 --- a/tests/_guardrails/test_input_guard.py +++ b/tests/_guardrails/test_input_guard.py @@ -15,7 +15,10 @@ TextPart, UserPromptPart, ) +from pydantic_ai.models import ModelRequestContext, ModelRequestParameters from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import RunContext +from pydantic_ai.usage import RunUsage from pydantic_ai_harness import InputBlocked, InputGuard from pydantic_ai_harness.guardrails._capability import _extract_prompt # pyright: ignore[reportPrivateUsage] @@ -28,96 +31,14 @@ def anyio_backend() -> str: return 'asyncio' -async def test_guard_allows_when_safe(): - calls: list[str] = [] - - def guard(prompt: str) -> bool: - calls.append(prompt) - return True - - agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) - result = await agent.run('hello') - - assert result.output == 'ok' - assert calls == ['hello'] - - -async def test_guard_block_uses_block_message(): - agent = Agent( - TestModel(custom_output_text='would be model output'), - capabilities=[InputGuard[None](guard=lambda _: False, block_message='nope')], +def _build_ctx_and_req( + run_step: int = 1, + prompt: str | None = 'hello world', +) -> tuple[RunContext[None], ModelRequestContext]: + model = TestModel() + messages: list[ModelMessage] = ( + [ModelRequest(parts=[UserPromptPart(content=prompt)])] if prompt is not None else [] ) - result = await agent.run('hello') - - assert result.output == 'nope' - - -async def test_async_guard_awaited(): - async def guard(prompt: str) -> bool: - await asyncio.sleep(0) - return 'safe' in prompt - - agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) - - assert (await agent.run('safe message')).output == 'ok' - assert (await agent.run('bad message')).output == 'Request blocked by input guardrail.' - - -async def test_guard_raising_propagates(): - def guard(_: str) -> bool: - raise InputBlocked('policy violation') - - agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) - with pytest.raises(InputBlocked, match='policy violation'): - await agent.run('anything') - - -def test_extract_prompt_from_messages(): - """Extraction falls back to the most recent `UserPromptPart`.""" - - class _Ctx: - prompt = None - - messages: list[ModelMessage] = [ - ModelRequest(parts=[UserPromptPart(content='first')]), - ModelResponse(parts=[TextPart(content='assistant')]), - ModelRequest(parts=[UserPromptPart(content='second')]), - ] - assert _extract_prompt(_Ctx(), messages) == 'second' # pyright: ignore[reportArgumentType] - - -def test_extract_prompt_stringifies_non_str_prompt(): - class _Ctx: - prompt = ['multimodal', 'content'] - - assert _extract_prompt(_Ctx(), []) == str(['multimodal', 'content']) # pyright: ignore[reportArgumentType] - - -def test_extract_prompt_stringifies_non_str_message_part(): - class _Ctx: - prompt = None - - messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] - assert _extract_prompt(_Ctx(), messages) == str(['multi']) # pyright: ignore[reportArgumentType] - - -def test_extract_prompt_returns_none_when_no_user_prompt_part(): - """A history containing only model responses yields `None`.""" - - class _Ctx: - prompt = None - - messages: list[ModelMessage] = [ModelResponse(parts=[TextPart(content='only model parts here')])] - assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] - -async def _build_ctx_and_req(run_step: int = 1) -> tuple[Any, Any]: - from pydantic_ai.models import ModelRequestContext, ModelRequestParameters - from pydantic_ai.models.test import TestModel as _TestModel - from pydantic_ai.tools import RunContext - from pydantic_ai.usage import RunUsage - - model = _TestModel() - messages: list[Any] = [ModelRequest(parts=[UserPromptPart(content='hello world')])] req_ctx = ModelRequestContext( model=model, messages=messages, @@ -128,283 +49,326 @@ async def _build_ctx_and_req(run_step: int = 1) -> tuple[Any, Any]: deps=None, model=model, usage=RunUsage(), - prompt='hello world', + prompt=prompt, messages=messages, run_step=run_step, ) return run_ctx, req_ctx -async def test_parallel_guard_allows_handler_to_return(): - run_ctx, req_ctx = await _build_ctx_and_req() - sentinel = ModelResponse(parts=[TextPart(content='from handler')]) +class TestInputGuard: + """Integration tests for the `InputGuard` capability driven through `Agent.run`.""" - async def handler(_: Any) -> ModelResponse: - return sentinel + async def test_allows_when_safe(self): + calls: list[str] = [] - guard = InputGuard[None](guard=lambda _: True, parallel=True) - out = await guard.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) - assert out is sentinel + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + result = await agent.run('hello') -async def test_parallel_guard_trips_and_cancels_handler(): - run_ctx, req_ctx = await _build_ctx_and_req() - handler_cancelled = asyncio.Event() - handler_started = asyncio.Event() + assert result.output == 'ok' + assert calls == ['hello'] - async def slow_handler(_: Any) -> ModelResponse: - handler_started.set() - try: - await asyncio.sleep(10) - except asyncio.CancelledError: - handler_cancelled.set() - raise - return ModelResponse(parts=[TextPart(content='should never')]) # pragma: no cover + async def test_block_uses_block_message(self): + agent = Agent( + TestModel(custom_output_text='would be model output'), + capabilities=[InputGuard[None](guard=lambda _: False, block_message='nope')], + ) + result = await agent.run('hello') - async def guard(_: str) -> bool: - await handler_started.wait() - return False + assert result.output == 'nope' - ig = InputGuard[None](guard=guard, parallel=True, block_message='blocked!') - with pytest.raises(SkipModelRequest) as exc_info: - await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + async def test_async_guard_awaited(self): + async def guard(prompt: str) -> bool: + await asyncio.sleep(0) + return 'safe' in prompt - assert exc_info.value.response.parts[0] == TextPart(content='blocked!') - # Give the cancellation a chance to propagate. - await asyncio.sleep(0) - assert handler_cancelled.is_set() + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + assert (await agent.run('safe message')).output == 'ok' + assert (await agent.run('bad message')).output == 'Request blocked by input guardrail.' -async def test_parallel_guard_raises_propagates(): - run_ctx, req_ctx = await _build_ctx_and_req() + async def test_raising_propagates(self): + def guard(_: str) -> bool: + raise InputBlocked('policy violation') - async def slow_handler(_: Any) -> ModelResponse: - await asyncio.sleep(10) - return ModelResponse(parts=[TextPart(content='never')]) # pragma: no cover + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + with pytest.raises(InputBlocked, match='policy violation'): + await agent.run('anything') - async def guard(_: str) -> bool: - raise InputBlocked('hard policy failure') + async def test_sequential_wrap_model_request_is_passthrough(self): + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='direct')]) - ig = InputGuard[None](guard=guard, parallel=True) - with pytest.raises(InputBlocked, match='hard policy failure'): - await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + async def handler(_: Any) -> ModelResponse: + return sentinel + ig = InputGuard[None](guard=lambda _: True, parallel=False) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel -async def test_parallel_handler_finishes_before_guard(): - """Handler completes first; guard still has to be awaited for a verdict.""" - run_ctx, req_ctx = await _build_ctx_and_req() - sentinel = ModelResponse(parts=[TextPart(content='from handler')]) - release_guard = asyncio.Event() + async def test_sequential_before_request_returns_context_when_prompt_missing(self): + run_ctx, req_ctx = _build_ctx_and_req(prompt=None) - async def fast_handler(_: Any) -> ModelResponse: - return sentinel + called: list[str] = [] - async def slow_guard(_: str) -> bool: - await release_guard.wait() - return True + def guard(prompt: str) -> bool: # pragma: no cover — should not be called + called.append(prompt) + return True - async def runner() -> ModelResponse: - ig = InputGuard[None](guard=slow_guard, parallel=True) - return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) - - task = asyncio.create_task(runner()) - # Yield enough times for handler_task to complete while guard_task is still waiting on the event. - for _ in range(3): - await asyncio.sleep(0) - release_guard.set() - assert await task is sentinel + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + async def test_parallel_before_request_is_noop(self): + run_ctx, req_ctx = _build_ctx_and_req() -async def test_parallel_handler_finishes_then_guard_trips(): - """Handler returns first, then the guard trips — `SkipModelRequest` still wins.""" - run_ctx, req_ctx = await _build_ctx_and_req() - release_guard = asyncio.Event() + called: list[str] = [] - async def fast_handler(_: Any) -> ModelResponse: - return ModelResponse(parts=[TextPart(content='from handler')]) + def guard(prompt: str) -> bool: # pragma: no cover — should not run via before_model_request + called.append(prompt) + return False - async def slow_guard(_: str) -> bool: - await release_guard.wait() - return False + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] - async def runner() -> ModelResponse: - ig = InputGuard[None](guard=slow_guard, parallel=True, block_message='late trip') - return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + async def test_runs_once_across_tool_loop(self): + """End-to-end: guard fires once even when the model makes multiple tool calls.""" + calls: list[str] = [] - task = asyncio.create_task(runner()) - for _ in range(3): - await asyncio.sleep(0) - release_guard.set() - with pytest.raises(SkipModelRequest) as exc_info: - await task - assert exc_info.value.response.parts[0] == TextPart(content='late trip') + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + # TestModel(call_tools='all') calls each tool once, then returns text — two model + # requests total. + model = TestModel(call_tools='all', custom_output_text='done') + agent = Agent(model, capabilities=[InputGuard[None](guard=guard)]) -async def test_parallel_handler_raises_while_guard_runs(): - """When the handler raises, `finally` cancels the still-running guard.""" - run_ctx, req_ctx = await _build_ctx_and_req() - guard_cancelled = asyncio.Event() + @agent.tool_plain + def ping() -> str: # pyright: ignore[reportUnusedFunction] + return 'pong' - async def failing_handler(_: Any) -> ModelResponse: - raise RuntimeError('model boom') + result = await agent.run('hello') + assert result.output == 'done' + assert calls == ['hello'] - async def slow_guard(_: str) -> bool: - try: - await asyncio.sleep(10) - except asyncio.CancelledError: - guard_cancelled.set() - raise - return True # pragma: no cover + async def test_sequential_skips_guard_on_subsequent_steps(self): + """After the first model request, `before_model_request` must not re-run the guard.""" + run_ctx, req_ctx = _build_ctx_and_req(run_step=2) - ig = InputGuard[None](guard=slow_guard, parallel=True) - with pytest.raises(RuntimeError, match='model boom'): - await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) - await asyncio.sleep(0) - assert guard_cancelled.is_set() + called: list[str] = [] + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False -async def test_parallel_skipped_when_prompt_missing(): - from pydantic_ai.models import ModelRequestContext, ModelRequestParameters - from pydantic_ai.models.test import TestModel as _TestModel - from pydantic_ai.tools import RunContext - from pydantic_ai.usage import RunUsage + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] - model = _TestModel() - req_ctx = ModelRequestContext( - model=model, - messages=[], - model_settings=None, - model_request_parameters=ModelRequestParameters(), - ) - run_ctx: RunContext[None] = RunContext(deps=None, model=model, usage=RunUsage(), prompt=None, messages=[]) - sentinel = ModelResponse(parts=[TextPart(content='direct')]) - called: list[str] = [] +class TestInputGuardParallel: + """Tests for `InputGuard(parallel=True)` exercising the race between guard and handler.""" - def guard(prompt: str) -> bool: # pragma: no cover — should never be called - called.append(prompt) - return False + async def test_allows_handler_to_return(self): + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) - async def handler(_: Any) -> ModelResponse: - return sentinel + async def handler(_: Any) -> ModelResponse: + return sentinel - ig = InputGuard[None](guard=guard, parallel=True) - out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) - assert out is sentinel - assert called == [] + guard = InputGuard[None](guard=lambda _: True, parallel=True) + out = await guard.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + async def test_trips_and_cancels_handler(self): + run_ctx, req_ctx = _build_ctx_and_req() + handler_cancelled = asyncio.Event() + handler_started = asyncio.Event() -async def test_sequential_wrap_model_request_is_passthrough(): - run_ctx, req_ctx = await _build_ctx_and_req() - sentinel = ModelResponse(parts=[TextPart(content='direct')]) + async def slow_handler(_: Any) -> ModelResponse: + handler_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + handler_cancelled.set() + raise + return ModelResponse(parts=[TextPart(content='should never')]) # pragma: no cover - async def handler(_: Any) -> ModelResponse: - return sentinel + async def guard(_: str) -> bool: + await handler_started.wait() + return False - ig = InputGuard[None](guard=lambda _: True, parallel=False) - out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) - assert out is sentinel + ig = InputGuard[None](guard=guard, parallel=True, block_message='blocked!') + with pytest.raises(SkipModelRequest) as exc_info: + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + assert exc_info.value.response.parts[0] == TextPart(content='blocked!') + await asyncio.sleep(0) + assert handler_cancelled.is_set() -async def test_sequential_before_request_returns_context_when_prompt_missing(): - from pydantic_ai.models import ModelRequestContext, ModelRequestParameters - from pydantic_ai.models.test import TestModel as _TestModel - from pydantic_ai.tools import RunContext - from pydantic_ai.usage import RunUsage - - model = _TestModel() - req_ctx = ModelRequestContext( - model=model, - messages=[], - model_settings=None, - model_request_parameters=ModelRequestParameters(), - ) - run_ctx: RunContext[None] = RunContext(deps=None, model=model, usage=RunUsage(), prompt=None, messages=[]) - - called: list[str] = [] - - def guard(prompt: str) -> bool: # pragma: no cover — should not be called - called.append(prompt) - return True - - ig = InputGuard[None](guard=guard, parallel=False) - out = await ig.before_model_request(run_ctx, req_ctx) - assert out is req_ctx - assert called == [] + async def test_guard_raises_propagates(self): + run_ctx, req_ctx = _build_ctx_and_req() + async def slow_handler(_: Any) -> ModelResponse: + await asyncio.sleep(10) + return ModelResponse(parts=[TextPart(content='never')]) # pragma: no cover + + async def guard(_: str) -> bool: + raise InputBlocked('hard policy failure') + + ig = InputGuard[None](guard=guard, parallel=True) + with pytest.raises(InputBlocked, match='hard policy failure'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + + async def test_handler_finishes_before_guard(self): + """Handler completes first; guard still has to be awaited for a verdict.""" + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return sentinel + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return True + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True) + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + assert await task is sentinel + + async def test_handler_finishes_then_guard_trips(self): + """Handler returns first, then the guard trips — `SkipModelRequest` still wins.""" + run_ctx, req_ctx = _build_ctx_and_req() + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='from handler')]) + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return False + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True, block_message='late trip') + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + with pytest.raises(SkipModelRequest) as exc_info: + await task + assert exc_info.value.response.parts[0] == TextPart(content='late trip') + + async def test_handler_raises_while_guard_runs(self): + """When the handler raises, `finally` cancels the still-running guard.""" + run_ctx, req_ctx = _build_ctx_and_req() + guard_cancelled = asyncio.Event() + + async def failing_handler(_: Any) -> ModelResponse: + raise RuntimeError('model boom') + + async def slow_guard(_: str) -> bool: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + guard_cancelled.set() + raise + return True # pragma: no cover -async def test_parallel_mode_before_request_is_noop(): - run_ctx, req_ctx = await _build_ctx_and_req() + ig = InputGuard[None](guard=slow_guard, parallel=True) + with pytest.raises(RuntimeError, match='model boom'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) + await asyncio.sleep(0) + assert guard_cancelled.is_set() - called: list[str] = [] + async def test_skipped_when_prompt_missing(self): + run_ctx, req_ctx = _build_ctx_and_req(prompt=None) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) - def guard(prompt: str) -> bool: # pragma: no cover — should not be called via before_model_request - called.append(prompt) - return False + called: list[str] = [] - ig = InputGuard[None](guard=guard, parallel=True) - out = await ig.before_model_request(run_ctx, req_ctx) - assert out is req_ctx - assert called == [] + def guard(prompt: str) -> bool: # pragma: no cover — should never be called + called.append(prompt) + return False + async def handler(_: Any) -> ModelResponse: + return sentinel -# --------------------------------------------------------------------------- -# Re-entry protection — the guard must only fire on the first model request -# --------------------------------------------------------------------------- + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] + async def test_skips_guard_on_subsequent_steps(self): + """`wrap_model_request` must pass the handler through without running the guard past step 1.""" + run_ctx, req_ctx = _build_ctx_and_req(run_step=2) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + called: list[str] = [] -async def test_sequential_skips_guard_on_subsequent_steps(): - """After the first model request, `before_model_request` must not re-run the guard.""" - run_ctx, req_ctx = await _build_ctx_and_req(run_step=2) + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False - called: list[str] = [] + async def handler(_: Any) -> ModelResponse: + return sentinel - def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 - called.append(prompt) - return False + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] - ig = InputGuard[None](guard=guard, parallel=False) - out = await ig.before_model_request(run_ctx, req_ctx) - assert out is req_ctx - assert called == [] +class TestExtractPrompt: + """Unit tests for the `_extract_prompt` helper.""" -async def test_parallel_skips_guard_on_subsequent_steps(): - """`wrap_model_request` must pass the handler through without running the guard past step 1.""" - run_ctx, req_ctx = await _build_ctx_and_req(run_step=2) - sentinel = ModelResponse(parts=[TextPart(content='direct')]) - called: list[str] = [] + def test_from_messages(self): + """Extraction falls back to the most recent `UserPromptPart`.""" - def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 - called.append(prompt) - return False + class _Ctx: + prompt = None - async def handler(_: Any) -> ModelResponse: - return sentinel + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='first')]), + ModelResponse(parts=[TextPart(content='assistant')]), + ModelRequest(parts=[UserPromptPart(content='second')]), + ] + assert _extract_prompt(_Ctx(), messages) == 'second' # pyright: ignore[reportArgumentType] - ig = InputGuard[None](guard=guard, parallel=True) - out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) - assert out is sentinel - assert called == [] + def test_stringifies_non_str_prompt(self): + class _Ctx: + prompt = ['multimodal', 'content'] + assert _extract_prompt(_Ctx(), []) == str(['multimodal', 'content']) # pyright: ignore[reportArgumentType] -async def test_guard_runs_once_across_tool_loop(): - """End-to-end: guard fires once even when the model makes multiple tool calls.""" - calls: list[str] = [] + def test_stringifies_non_str_message_part(self): + class _Ctx: + prompt = None - def guard(prompt: str) -> bool: - calls.append(prompt) - return True + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] + assert _extract_prompt(_Ctx(), messages) == str(['multi']) # pyright: ignore[reportArgumentType] - # TestModel(call_tools='all') calls each tool once, then returns a text response — two - # model requests total. - model = TestModel(call_tools='all', custom_output_text='done') - agent = Agent(model, capabilities=[InputGuard[None](guard=guard)]) + def test_returns_none_when_no_user_prompt_part(self): + """A history containing only model responses yields `None`.""" - @agent.tool_plain - def ping() -> str: # pyright: ignore[reportUnusedFunction] - return 'pong' + class _Ctx: + prompt = None - result = await agent.run('hello') - assert result.output == 'done' - assert calls == ['hello'] + messages: list[ModelMessage] = [ModelResponse(parts=[TextPart(content='only model parts here')])] + assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] diff --git a/tests/_guardrails/test_output_guard.py b/tests/_guardrails/test_output_guard.py index 60593b1..993fe82 100644 --- a/tests/_guardrails/test_output_guard.py +++ b/tests/_guardrails/test_output_guard.py @@ -20,88 +20,86 @@ def anyio_backend() -> str: return 'asyncio' -async def test_guard_allows_safe_output(): - agent = Agent( - TestModel(custom_output_text='harmless reply'), - capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in str(out))], - ) - result = await agent.run('hello') - assert result.output == 'harmless reply' - - -async def test_guard_blocks_unsafe_output(): - agent = Agent( - TestModel(custom_output_text='leaks SSN 123-45-6789'), - capabilities=[ - OutputGuard[None](guard=lambda out: 'SSN' not in str(out), block_message='contains SSN'), - ], - ) - with pytest.raises(OutputBlocked, match='contains SSN'): - await agent.run('hello') - - -async def test_async_guard_awaited(): - async def guard(output: object) -> bool: - await asyncio.sleep(0) - return 'bad' not in str(output) - - agent = Agent( - TestModel(custom_output_text='ok reply'), - capabilities=[OutputGuard[None](guard=guard)], - ) - assert (await agent.run('prompt')).output == 'ok reply' - - agent_bad = Agent( - TestModel(custom_output_text='bad reply'), - capabilities=[OutputGuard[None](guard=guard)], - ) - with pytest.raises(OutputBlocked): - await agent_bad.run('prompt') - - -async def test_guard_raising_propagates(): - def guard(_: object) -> bool: - raise RuntimeError('guard exploded') - - agent = Agent( - TestModel(custom_output_text='anything'), - capabilities=[OutputGuard[None](guard=guard)], - ) - with pytest.raises(RuntimeError, match='guard exploded'): - await agent.run('hello') - - -async def test_guard_receives_structured_output_unchanged(): - """For typed outputs the guard gets the model instance, not a stringified form.""" - - class Answer(BaseModel): - reply: str - internal_url: str - - seen: list[object] = [] - - def guard(output: object) -> bool: - seen.append(output) - assert isinstance(output, Answer) - return 'internal.example.com' not in output.internal_url - - agent = Agent( - TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://public.example.com/x'}), - output_type=Answer, - capabilities=[OutputGuard[None](guard=guard)], - ) - result = await agent.run('hello') - assert isinstance(result.output, Answer) - assert seen == [result.output] - - agent_bad = Agent( - TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://internal.example.com/x'}), - output_type=Answer, - capabilities=[OutputGuard[None](guard=guard)], - ) - with pytest.raises(OutputBlocked): - await agent_bad.run('hello') - - -def test_output_blocked_is_guardrail_error(): - assert issubclass(OutputBlocked, GuardrailError) +class TestOutputGuard: + """Integration tests for the `OutputGuard` capability driven through `Agent.run`.""" + + async def test_allows_safe_output(self): + agent = Agent( + TestModel(custom_output_text='harmless reply'), + capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in str(out))], + ) + result = await agent.run('hello') + assert result.output == 'harmless reply' + + async def test_blocks_unsafe_output(self): + agent = Agent( + TestModel(custom_output_text='leaks SSN 123-45-6789'), + capabilities=[ + OutputGuard[None](guard=lambda out: 'SSN' not in str(out), block_message='contains SSN'), + ], + ) + with pytest.raises(OutputBlocked, match='contains SSN'): + await agent.run('hello') + + async def test_async_guard_awaited(self): + async def guard(output: object) -> bool: + await asyncio.sleep(0) + return 'bad' not in str(output) + + agent = Agent( + TestModel(custom_output_text='ok reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + assert (await agent.run('prompt')).output == 'ok reply' + + agent_bad = Agent( + TestModel(custom_output_text='bad reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('prompt') + + async def test_raising_propagates(self): + def guard(_: object) -> bool: + raise RuntimeError('guard exploded') + + agent = Agent( + TestModel(custom_output_text='anything'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(RuntimeError, match='guard exploded'): + await agent.run('hello') + + async def test_receives_structured_output_unchanged(self): + """For typed outputs the guard gets the model instance, not a stringified form.""" + + class Answer(BaseModel): + reply: str + internal_url: str + + seen: list[object] = [] + + def guard(output: object) -> bool: + seen.append(output) + assert isinstance(output, Answer) + return 'internal.example.com' not in output.internal_url + + agent = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://public.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + result = await agent.run('hello') + assert isinstance(result.output, Answer) + assert seen == [result.output] + + agent_bad = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://internal.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('hello') + + def test_output_blocked_is_guardrail_error(self): + assert issubclass(OutputBlocked, GuardrailError) From dd4c1b8b9b1dde1870993a01e248dbc4a2d9edc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Sat, 25 Apr 2026 17:16:46 +0200 Subject: [PATCH 7/9] fix: drain cancelled tasks in InputGuard parallel finally --- pydantic_ai_harness/guardrails/_capability.py | 5 ++++ tests/_guardrails/test_input_guard.py | 27 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index 068ff24..c14c674 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import contextlib import inspect from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass @@ -187,6 +188,10 @@ async def run_handler() -> ModelResponse: if not handler_task.done(): handler_task.cancel() + for task in (guard_task, handler_task): + with contextlib.suppress(asyncio.CancelledError, Exception): + await task + @dataclass class OutputGuard(AbstractCapability[AgentDepsT]): diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py index 8046425..e7a7989 100644 --- a/tests/_guardrails/test_input_guard.py +++ b/tests/_guardrails/test_input_guard.py @@ -334,6 +334,33 @@ async def handler(_: Any) -> ModelResponse: assert out is sentinel assert called == [] + async def test_no_dangling_tasks_when_handler_raises(self): + """`finally` must drain cancelled tasks so they don't outlive the call. + + `task.cancel()` only schedules a `CancelledError` for the next loop tick — without + awaiting, the task stays in `asyncio.all_tasks()` after `wrap_model_request` + returns, leaks into the surrounding scope, and produces "Task exception was never + retrieved" warnings if it later raises. + """ + run_ctx, req_ctx = _build_ctx_and_req() + + async def failing_handler(_: Any) -> ModelResponse: + raise RuntimeError('handler boom') + + async def slow_guard(_: str) -> bool: + await asyncio.sleep(10) + return True # pragma: no cover + + current = asyncio.current_task() + before = {t for t in asyncio.all_tasks() if t is not current} + + ig = InputGuard[None](guard=slow_guard, parallel=True) + with pytest.raises(RuntimeError, match='handler boom'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) + + leftover = {t for t in asyncio.all_tasks() if t is not current} - before + assert leftover == set(), f'guard/handler tasks must be drained, got dangling: {leftover}' + class TestExtractPrompt: """Unit tests for the `_extract_prompt` helper.""" From 57d8e2cb3009a641c761b01f0f9efab60d557b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Sat, 25 Apr 2026 17:19:36 +0200 Subject: [PATCH 8/9] fix: re-raise task exceptions via await instead of .exception() --- pydantic_ai_harness/guardrails/_capability.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index c14c674..394cc0a 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -169,16 +169,14 @@ async def run_handler() -> ModelResponse: return_when=asyncio.FIRST_COMPLETED, ) if guard_task in done: - guard_exc = guard_task.exception() - if guard_exc is not None: + if guard_task.exception() is not None: handler_task.cancel() - raise guard_exc + await guard_task # re-raises the guard's exception return await handler_task - # Handler finished first: if it raised, propagate and cancel the guard. - handler_exc = handler_task.exception() - if handler_exc is not None: + # Handler finished first: if it raised, cancel the guard and propagate. + if handler_task.exception() is not None: guard_task.cancel() - raise handler_exc + await handler_task # re-raises the handler's exception # Handler succeeded; still need the guard verdict before committing the response. await guard_task return handler_task.result() From 38c31851614699cfb02ecfdbbf9f1691d20ec095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20W=C5=82odarczyk?= Date: Sat, 25 Apr 2026 17:36:26 +0200 Subject: [PATCH 9/9] refactor: consolidate parallel cancel/drain into single finally --- pydantic_ai_harness/guardrails/_capability.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py index 394cc0a..07f0d9b 100644 --- a/pydantic_ai_harness/guardrails/_capability.py +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -17,7 +17,6 @@ from __future__ import annotations import asyncio -import contextlib import inspect from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass @@ -165,30 +164,22 @@ async def run_handler() -> ModelResponse: handler_task: asyncio.Task[ModelResponse] = asyncio.create_task(run_handler()) try: done, _ = await asyncio.wait( - [guard_task, handler_task], + {guard_task, handler_task}, return_when=asyncio.FIRST_COMPLETED, ) if guard_task in done: - if guard_task.exception() is not None: - handler_task.cancel() - await guard_task # re-raises the guard's exception + await guard_task return await handler_task - # Handler finished first: if it raised, cancel the guard and propagate. - if handler_task.exception() is not None: - guard_task.cancel() - await handler_task # re-raises the handler's exception - # Handler succeeded; still need the guard verdict before committing the response. + + response = await handler_task await guard_task - return handler_task.result() + return response finally: - if not guard_task.done(): - guard_task.cancel() - if not handler_task.done(): - handler_task.cancel() - for task in (guard_task, handler_task): - with contextlib.suppress(asyncio.CancelledError, Exception): - await task + if not task.done(): + task.cancel() + + await asyncio.gather(guard_task, handler_task, return_exceptions=True) @dataclass