diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..382297e 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -4,8 +4,36 @@ if TYPE_CHECKING: from .code_mode import CodeMode + from .guardrails import ( + GuardrailError, + InputBlocked, + InputGuard, + InputGuardFunc, + OutputBlocked, + OutputGuard, + OutputGuardFunc, + ) -__all__ = ['CodeMode'] +__all__ = [ + 'CodeMode', + 'GuardrailError', + 'InputBlocked', + 'InputGuard', + 'InputGuardFunc', + 'OutputBlocked', + 'OutputGuard', + 'OutputGuardFunc', +] + +_GUARDRAIL_EXPORTS = { + 'GuardrailError', + 'InputBlocked', + 'InputGuard', + 'InputGuardFunc', + 'OutputBlocked', + 'OutputGuard', + 'OutputGuardFunc', +} def __getattr__(name: str) -> object: @@ -13,4 +41,8 @@ def __getattr__(name: str) -> object: from .code_mode import CodeMode return CodeMode + if name in _GUARDRAIL_EXPORTS: + from . import guardrails + + return getattr(guardrails, name) raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_harness/guardrails/README.md b/pydantic_ai_harness/guardrails/README.md new file mode 100644 index 0000000..b5b24c0 --- /dev/null +++ b/pydantic_ai_harness/guardrails/README.md @@ -0,0 +1,135 @@ +# Guardrails + +Intercept unsafe user prompts before they reach the model, and unsafe model outputs before they reach the caller. + +## The problem + +Agents take unstructured input from users and return unstructured output to callers. Without a validation layer, a prompt injection attempt, PII-laden message, or off-topic question goes to the model as-is, and any output the model produces is returned verbatim. The framework does not reason about "this is unsafe to send" or "this is unsafe to show". + +## The solution + +Two capabilities, each backed by a callable you supply. + +| Capability | Checks | When a guard returns `False` | When a guard raises | +|---|---|---|---| +| `InputGuard` | The user prompt before each model request | `SkipModelRequest` — the model call is skipped and `block_message` becomes the response for that step | The exception propagates out of the run | +| `OutputGuard` | The final run output | `OutputBlocked` is raised | The exception propagates out of the run | + +The asymmetry is intentional. Blocking the input means no tokens are spent, so a graceful refusal is almost always what you want. Blocking the output means the model already generated a response you do not want exposed — raising forces the caller to decide what to do next. + +## Usage + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import InputGuard, OutputGuard + + +def no_secrets(prompt: str) -> bool: + return 'api_key' not in prompt.lower() + + +def no_pii(output: object) -> bool: + return 'SSN' not in str(output) + + +agent = Agent( + 'openai:gpt-4.1', + capabilities=[ + InputGuard(guard=no_secrets), + OutputGuard(guard=no_pii), + ], +) +``` + +`OutputGuard` receives `result.output` unchanged — no automatic stringification. For a string output the guard reads it directly; for a typed (Pydantic model) output the guard gets the model instance, so pick the serialization that fits the check: + +```python +from pydantic import BaseModel +from pydantic_ai_harness import OutputGuard + + +class Answer(BaseModel): + reply: str + sources: list[str] + + +def no_internal_urls(output: object) -> bool: + if isinstance(output, Answer): + return not any('internal.example.com' in url for url in output.sources) + return 'internal.example.com' not in str(output) + + +OutputGuard(guard=no_internal_urls) +``` + +This avoids the trap of `str(MyModel(...))` producing a `MyModel(field=...)` repr that hides field contents from regex-based checks. If you want JSON text, call `output.model_dump_json()` inside the guard. + +Both guards accept async callables too: + +```python +async def check_with_moderation_api(prompt: str) -> bool: + response = await client.moderations.create(input=prompt) + return not response.results[0].flagged + + +agent = Agent( + 'openai:gpt-4.1', + capabilities=[InputGuard(guard=check_with_moderation_api)], +) +``` + +## Parallel input guards + +When a guard is slow (an LLM-based classifier or a network call), running it in sequence before every model request adds latency to every turn. Set `parallel=True` to race the guard against the model call. The model call is cancelled immediately if the guard reports a violation. + +```python +InputGuard(guard=slow_async_classifier, parallel=True) +``` + +For fast local checks (regex, keyword lookup, a small classifier) sequential is usually fine — the overhead is measured in microseconds and the wiring is simpler. + +## Customising the block message + +```python +InputGuard( + guard=no_secrets, + block_message='This request looks like it contains credentials. Please rephrase.', +) +``` + +The text is returned as the model response for that step, so the caller sees a normal completion rather than an exception. Multi-turn agents can continue the conversation from there. + +## Hard-fail path + +Returning `False` from a guard is the graceful path. If you want the caller to see an exception instead, raise from the guard: + +```python +from pydantic_ai_harness import InputBlocked + + +def strict_guard(prompt: str) -> bool: + if contains_credentials(prompt): + raise InputBlocked('credentials detected') + return True +``` + +Any exception raised by the guard propagates as-is — you can use `InputBlocked` / `OutputBlocked` from this module or your own exception types. + +## API + +```python +InputGuard( + guard: Callable[[str], bool | Awaitable[bool]], + parallel: bool = False, + block_message: str = 'Request blocked by input guardrail.', +) + +OutputGuard( + guard: Callable[[str], bool | Awaitable[bool]], + block_message: str = 'Output blocked by output guardrail.', +) +``` + +## Relationship to `pydantic-ai-shields` + +`pydantic-ai-shields` provides opinionated implementations on top of these primitives (prompt-injection detectors, PII scrubbers, keyword blocklists, etc.). Use the guardrails here when you want to plug in your own validation logic; reach for shields when you need a batteries-included detector. diff --git a/pydantic_ai_harness/guardrails/__init__.py b/pydantic_ai_harness/guardrails/__init__.py new file mode 100644 index 0000000..6f4a17d --- /dev/null +++ b/pydantic_ai_harness/guardrails/__init__.py @@ -0,0 +1,23 @@ +"""Input and output guardrails for Pydantic AI agents.""" + +from pydantic_ai_harness.guardrails._capability import ( + InputGuard, + InputGuardFunc, + OutputGuard, + OutputGuardFunc, +) +from pydantic_ai_harness.guardrails._exceptions import ( + GuardrailError, + InputBlocked, + OutputBlocked, +) + +__all__ = [ + 'GuardrailError', + 'InputBlocked', + 'InputGuard', + 'InputGuardFunc', + 'OutputBlocked', + 'OutputGuard', + 'OutputGuardFunc', +] diff --git a/pydantic_ai_harness/guardrails/_capability.py b/pydantic_ai_harness/guardrails/_capability.py new file mode 100644 index 0000000..07f0d9b --- /dev/null +++ b/pydantic_ai_harness/guardrails/_capability.py @@ -0,0 +1,230 @@ +"""Input and output guardrail capabilities. + +`InputGuard` intercepts each model request and lets a user-supplied callable +decide whether the current user prompt is safe to send to the model. A guard +that returns `False` is treated as a graceful refusal: the LLM call is +skipped via [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] and +a canned message becomes the model response for that step. A guard that raises +propagates the exception so the caller can observe a hard failure. + +`OutputGuard` runs once the run completes and validates the final output. +A guard that returns `False` raises +[`OutputBlocked`][pydantic_ai_harness.guardrails.OutputBlocked]. + +Both guards accept sync or async callables. +""" + +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from pydantic_ai.capabilities import AbstractCapability, WrapModelRequestHandler +from pydantic_ai.exceptions import SkipModelRequest +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, UserPromptPart +from pydantic_ai.tools import AgentDepsT, RunContext + +from pydantic_ai_harness.guardrails._exceptions import OutputBlocked + +if TYPE_CHECKING: + from pydantic_ai.models import ModelRequestContext + from pydantic_ai.run import AgentRunResult + + +InputGuardFunc = Callable[[str], bool | Awaitable[bool]] +"""Signature of the callable passed to `InputGuard`. + +The callable receives the user prompt and returns `True` when safe. It may +be sync or async. Raising an exception is treated as a hard failure and +propagates up to the caller. +""" + +OutputGuardFunc = Callable[[object], bool | Awaitable[bool]] +"""Signature of the callable passed to `OutputGuard`. + +The callable receives `result.output` unchanged — for typed outputs this is +the Pydantic model (not a stringified form), so the guard can read fields +directly or serialize with `model_dump_json()` if it wants to match against +JSON text. Returning `True` marks the output safe. +""" + + +async def _evaluate( + guard: InputGuardFunc | OutputGuardFunc, + value: object, +) -> bool: + """Call `guard` and await it if it returned an awaitable.""" + result = guard(value) # pyright: ignore[reportArgumentType] + if inspect.isawaitable(result): + return await result + return result + + +def _extract_prompt(ctx: RunContext[AgentDepsT], messages: Sequence[ModelMessage]) -> str | None: + """Return the text of the most recent user prompt, or `None` if absent. + + Prefers `ctx.prompt` (set at run start) and falls back to scanning the + message history for the last [`UserPromptPart`][pydantic_ai.messages.UserPromptPart] + so that sub-agent calls or resumed runs without a fresh prompt still work. + """ + if ctx.prompt is not None: + return ctx.prompt if isinstance(ctx.prompt, str) else str(ctx.prompt) + for message in reversed(messages): + for part in reversed(message.parts): + if isinstance(part, UserPromptPart): + return part.content if isinstance(part.content, str) else str(part.content) + return None + + +@dataclass +class InputGuard(AbstractCapability[AgentDepsT]): + """Validate the user prompt before it reaches the model. + + The `guard` callable receives the prompt text and returns `True` when + the input is safe. Returning `False` triggers a graceful refusal: the + current model request is short-circuited via + [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] with + `block_message` as the response text, so the agent returns cleanly to + the caller. Raising an exception from the guard propagates it as-is. + + ```python + from pydantic_ai import Agent + from pydantic_ai_harness import InputGuard + + + def no_secrets(prompt: str) -> bool: + return 'api_key' not in prompt.lower() + + + agent = Agent('openai:gpt-4.1', capabilities=[InputGuard(guard=no_secrets)]) + ``` + + Set `parallel=True` to start the guard alongside the model call. The + handler is cancelled as soon as the guard reports a violation, which saves + tokens when the guard is slower than the provider round-trip. + + Scope: the guard runs exactly once per run — on the first model request — + and evaluates the original user prompt. Subsequent model requests in the + same run (e.g. after tool calls) are not re-checked, since the user input + has not changed. Validation of tool results or other mid-run content + belongs in a separate capability hooking `after_model_request`. + """ + + guard: InputGuardFunc + """Callable that returns `True` when the prompt is safe to send to the model.""" + + parallel: bool = False + """Run the guard concurrently with the model request and cancel the model call on failure.""" + + block_message: str = 'Request blocked by input guardrail.' + """Text returned as the model response when the guard trips gracefully.""" + + def _blocked_response(self) -> ModelResponse: + return ModelResponse(parts=[TextPart(content=self.block_message)]) + + async def _run_guard(self, prompt: str) -> None: + """Evaluate the guard and raise `SkipModelRequest` on failure.""" + if not await _evaluate(self.guard, prompt): + raise SkipModelRequest(self._blocked_response()) + + async def before_model_request( + self, + ctx: RunContext[AgentDepsT], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Check the prompt before the model call in sequential mode.""" + if self.parallel or ctx.run_step > 1: + return request_context + prompt = _extract_prompt(ctx, request_context.messages) + if prompt is None: + return request_context + await self._run_guard(prompt) + return request_context + + async def wrap_model_request( + self, + ctx: RunContext[AgentDepsT], + *, + request_context: ModelRequestContext, + handler: WrapModelRequestHandler, + ) -> ModelResponse: + """Run the guard alongside the model call when `parallel=True`.""" + if not self.parallel or ctx.run_step > 1: + return await handler(request_context) + prompt = _extract_prompt(ctx, request_context.messages) + if prompt is None: + return await handler(request_context) + async def run_handler() -> ModelResponse: + return await handler(request_context) + + guard_task: asyncio.Task[None] = asyncio.create_task(self._run_guard(prompt)) + handler_task: asyncio.Task[ModelResponse] = asyncio.create_task(run_handler()) + try: + done, _ = await asyncio.wait( + {guard_task, handler_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + if guard_task in done: + await guard_task + return await handler_task + + response = await handler_task + await guard_task + return response + finally: + for task in (guard_task, handler_task): + if not task.done(): + task.cancel() + + await asyncio.gather(guard_task, handler_task, return_exceptions=True) + + +@dataclass +class OutputGuard(AbstractCapability[AgentDepsT]): + """Validate the final agent output. + + The `guard` callable receives `result.output` unchanged — no automatic + stringification — and returns `True` when the output is safe to expose. + Returning `False` raises + [`OutputBlocked`][pydantic_ai_harness.guardrails.OutputBlocked] with + `block_message`. Raising an exception from the guard propagates it. + + For string outputs the guard works directly on the text. For typed + (Pydantic model) outputs the guard receives the model instance, so + choose the serialization that fits your check: read a field directly, + or call `model_dump_json()` to match against JSON text. Defaulting to + `str(model)` would produce a `MyModel(field=...)` repr rather than JSON + and easily hide fields from regex-based checks. + + ```python + from pydantic_ai import Agent + from pydantic_ai_harness import OutputGuard + + + def no_pii(output: object) -> bool: + return 'SSN' not in str(output) + + + agent = Agent('openai:gpt-4.1', capabilities=[OutputGuard(guard=no_pii)]) + ``` + """ + + guard: OutputGuardFunc + """Callable that returns `True` when the output is safe.""" + + block_message: str = 'Output blocked by output guardrail.' + """Message attached to `OutputBlocked` when the guard trips.""" + + async def after_run( + self, + ctx: RunContext[AgentDepsT], + *, + result: AgentRunResult[Any], + ) -> AgentRunResult[Any]: + """Validate `result.output` and raise `OutputBlocked` on failure.""" + if not await _evaluate(self.guard, result.output): + raise OutputBlocked(self.block_message) + return result diff --git a/pydantic_ai_harness/guardrails/_exceptions.py b/pydantic_ai_harness/guardrails/_exceptions.py new file mode 100644 index 0000000..8ae37dd --- /dev/null +++ b/pydantic_ai_harness/guardrails/_exceptions.py @@ -0,0 +1,20 @@ +"""Exceptions raised by the guardrail capabilities.""" + +from __future__ import annotations + + +class GuardrailError(Exception): + """Base exception for guardrail violations.""" + + +class InputBlocked(GuardrailError): + """Raised by a user-supplied input guard to hard-fail a run. + + Prefer returning `False` from the guard callable to trigger a graceful + refusal via [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]. + Raise this explicitly when the caller should have to handle the failure. + """ + + +class OutputBlocked(GuardrailError): + """Raised by [`OutputGuard`][pydantic_ai_harness.OutputGuard] when the final output fails validation.""" diff --git a/tests/_guardrails/__init__.py b/tests/_guardrails/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_guardrails/test_input_guard.py b/tests/_guardrails/test_input_guard.py new file mode 100644 index 0000000..e7a7989 --- /dev/null +++ b/tests/_guardrails/test_input_guard.py @@ -0,0 +1,401 @@ +"""Tests for the `InputGuard` capability.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from pydantic_ai import Agent +from pydantic_ai.exceptions import SkipModelRequest +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestContext, ModelRequestParameters +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import RunContext +from pydantic_ai.usage import RunUsage + +from pydantic_ai_harness import InputBlocked, InputGuard +from pydantic_ai_harness.guardrails._capability import _extract_prompt # pyright: ignore[reportPrivateUsage] + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +def _build_ctx_and_req( + run_step: int = 1, + prompt: str | None = 'hello world', +) -> tuple[RunContext[None], ModelRequestContext]: + model = TestModel() + messages: list[ModelMessage] = ( + [ModelRequest(parts=[UserPromptPart(content=prompt)])] if prompt is not None else [] + ) + req_ctx = ModelRequestContext( + model=model, + messages=messages, + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + run_ctx: RunContext[None] = RunContext( + deps=None, + model=model, + usage=RunUsage(), + prompt=prompt, + messages=messages, + run_step=run_step, + ) + return run_ctx, req_ctx + + +class TestInputGuard: + """Integration tests for the `InputGuard` capability driven through `Agent.run`.""" + + async def test_allows_when_safe(self): + calls: list[str] = [] + + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + result = await agent.run('hello') + + assert result.output == 'ok' + assert calls == ['hello'] + + async def test_block_uses_block_message(self): + agent = Agent( + TestModel(custom_output_text='would be model output'), + capabilities=[InputGuard[None](guard=lambda _: False, block_message='nope')], + ) + result = await agent.run('hello') + + assert result.output == 'nope' + + async def test_async_guard_awaited(self): + async def guard(prompt: str) -> bool: + await asyncio.sleep(0) + return 'safe' in prompt + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + + assert (await agent.run('safe message')).output == 'ok' + assert (await agent.run('bad message')).output == 'Request blocked by input guardrail.' + + async def test_raising_propagates(self): + def guard(_: str) -> bool: + raise InputBlocked('policy violation') + + agent = Agent(TestModel(custom_output_text='ok'), capabilities=[InputGuard[None](guard=guard)]) + with pytest.raises(InputBlocked, match='policy violation'): + await agent.run('anything') + + async def test_sequential_wrap_model_request_is_passthrough(self): + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=lambda _: True, parallel=False) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + + async def test_sequential_before_request_returns_context_when_prompt_missing(self): + run_ctx, req_ctx = _build_ctx_and_req(prompt=None) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called + called.append(prompt) + return True + + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + + async def test_parallel_before_request_is_noop(self): + run_ctx, req_ctx = _build_ctx_and_req() + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not run via before_model_request + called.append(prompt) + return False + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + + async def test_runs_once_across_tool_loop(self): + """End-to-end: guard fires once even when the model makes multiple tool calls.""" + calls: list[str] = [] + + def guard(prompt: str) -> bool: + calls.append(prompt) + return True + + # TestModel(call_tools='all') calls each tool once, then returns text — two model + # requests total. + model = TestModel(call_tools='all', custom_output_text='done') + agent = Agent(model, capabilities=[InputGuard[None](guard=guard)]) + + @agent.tool_plain + def ping() -> str: # pyright: ignore[reportUnusedFunction] + return 'pong' + + result = await agent.run('hello') + assert result.output == 'done' + assert calls == ['hello'] + + async def test_sequential_skips_guard_on_subsequent_steps(self): + """After the first model request, `before_model_request` must not re-run the guard.""" + run_ctx, req_ctx = _build_ctx_and_req(run_step=2) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False + + ig = InputGuard[None](guard=guard, parallel=False) + out = await ig.before_model_request(run_ctx, req_ctx) + assert out is req_ctx + assert called == [] + + +class TestInputGuardParallel: + """Tests for `InputGuard(parallel=True)` exercising the race between guard and handler.""" + + async def test_allows_handler_to_return(self): + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) + + async def handler(_: Any) -> ModelResponse: + return sentinel + + guard = InputGuard[None](guard=lambda _: True, parallel=True) + out = await guard.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + + async def test_trips_and_cancels_handler(self): + run_ctx, req_ctx = _build_ctx_and_req() + handler_cancelled = asyncio.Event() + handler_started = asyncio.Event() + + async def slow_handler(_: Any) -> ModelResponse: + handler_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + handler_cancelled.set() + raise + return ModelResponse(parts=[TextPart(content='should never')]) # pragma: no cover + + async def guard(_: str) -> bool: + await handler_started.wait() + return False + + ig = InputGuard[None](guard=guard, parallel=True, block_message='blocked!') + with pytest.raises(SkipModelRequest) as exc_info: + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + + assert exc_info.value.response.parts[0] == TextPart(content='blocked!') + await asyncio.sleep(0) + assert handler_cancelled.is_set() + + async def test_guard_raises_propagates(self): + run_ctx, req_ctx = _build_ctx_and_req() + + async def slow_handler(_: Any) -> ModelResponse: + await asyncio.sleep(10) + return ModelResponse(parts=[TextPart(content='never')]) # pragma: no cover + + async def guard(_: str) -> bool: + raise InputBlocked('hard policy failure') + + ig = InputGuard[None](guard=guard, parallel=True) + with pytest.raises(InputBlocked, match='hard policy failure'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=slow_handler) + + async def test_handler_finishes_before_guard(self): + """Handler completes first; guard still has to be awaited for a verdict.""" + run_ctx, req_ctx = _build_ctx_and_req() + sentinel = ModelResponse(parts=[TextPart(content='from handler')]) + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return sentinel + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return True + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True) + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + assert await task is sentinel + + async def test_handler_finishes_then_guard_trips(self): + """Handler returns first, then the guard trips — `SkipModelRequest` still wins.""" + run_ctx, req_ctx = _build_ctx_and_req() + release_guard = asyncio.Event() + + async def fast_handler(_: Any) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='from handler')]) + + async def slow_guard(_: str) -> bool: + await release_guard.wait() + return False + + async def runner() -> ModelResponse: + ig = InputGuard[None](guard=slow_guard, parallel=True, block_message='late trip') + return await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=fast_handler) + + task = asyncio.create_task(runner()) + for _ in range(3): + await asyncio.sleep(0) + release_guard.set() + with pytest.raises(SkipModelRequest) as exc_info: + await task + assert exc_info.value.response.parts[0] == TextPart(content='late trip') + + async def test_handler_raises_while_guard_runs(self): + """When the handler raises, `finally` cancels the still-running guard.""" + run_ctx, req_ctx = _build_ctx_and_req() + guard_cancelled = asyncio.Event() + + async def failing_handler(_: Any) -> ModelResponse: + raise RuntimeError('model boom') + + async def slow_guard(_: str) -> bool: + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + guard_cancelled.set() + raise + return True # pragma: no cover + + ig = InputGuard[None](guard=slow_guard, parallel=True) + with pytest.raises(RuntimeError, match='model boom'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) + await asyncio.sleep(0) + assert guard_cancelled.is_set() + + async def test_skipped_when_prompt_missing(self): + run_ctx, req_ctx = _build_ctx_and_req(prompt=None) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should never be called + called.append(prompt) + return False + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] + + async def test_skips_guard_on_subsequent_steps(self): + """`wrap_model_request` must pass the handler through without running the guard past step 1.""" + run_ctx, req_ctx = _build_ctx_and_req(run_step=2) + sentinel = ModelResponse(parts=[TextPart(content='direct')]) + called: list[str] = [] + + def guard(prompt: str) -> bool: # pragma: no cover — should not be called after step 1 + called.append(prompt) + return False + + async def handler(_: Any) -> ModelResponse: + return sentinel + + ig = InputGuard[None](guard=guard, parallel=True) + out = await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=handler) + assert out is sentinel + assert called == [] + + async def test_no_dangling_tasks_when_handler_raises(self): + """`finally` must drain cancelled tasks so they don't outlive the call. + + `task.cancel()` only schedules a `CancelledError` for the next loop tick — without + awaiting, the task stays in `asyncio.all_tasks()` after `wrap_model_request` + returns, leaks into the surrounding scope, and produces "Task exception was never + retrieved" warnings if it later raises. + """ + run_ctx, req_ctx = _build_ctx_and_req() + + async def failing_handler(_: Any) -> ModelResponse: + raise RuntimeError('handler boom') + + async def slow_guard(_: str) -> bool: + await asyncio.sleep(10) + return True # pragma: no cover + + current = asyncio.current_task() + before = {t for t in asyncio.all_tasks() if t is not current} + + ig = InputGuard[None](guard=slow_guard, parallel=True) + with pytest.raises(RuntimeError, match='handler boom'): + await ig.wrap_model_request(run_ctx, request_context=req_ctx, handler=failing_handler) + + leftover = {t for t in asyncio.all_tasks() if t is not current} - before + assert leftover == set(), f'guard/handler tasks must be drained, got dangling: {leftover}' + + +class TestExtractPrompt: + """Unit tests for the `_extract_prompt` helper.""" + + def test_from_messages(self): + """Extraction falls back to the most recent `UserPromptPart`.""" + + class _Ctx: + prompt = None + + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='first')]), + ModelResponse(parts=[TextPart(content='assistant')]), + ModelRequest(parts=[UserPromptPart(content='second')]), + ] + assert _extract_prompt(_Ctx(), messages) == 'second' # pyright: ignore[reportArgumentType] + + def test_stringifies_non_str_prompt(self): + class _Ctx: + prompt = ['multimodal', 'content'] + + assert _extract_prompt(_Ctx(), []) == str(['multimodal', 'content']) # pyright: ignore[reportArgumentType] + + def test_stringifies_non_str_message_part(self): + class _Ctx: + prompt = None + + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=['multi'])])] + assert _extract_prompt(_Ctx(), messages) == str(['multi']) # pyright: ignore[reportArgumentType] + + def test_returns_none_when_no_user_prompt_part(self): + """A history containing only model responses yields `None`.""" + + class _Ctx: + prompt = None + + messages: list[ModelMessage] = [ModelResponse(parts=[TextPart(content='only model parts here')])] + assert _extract_prompt(_Ctx(), messages) is None # pyright: ignore[reportArgumentType] diff --git a/tests/_guardrails/test_output_guard.py b/tests/_guardrails/test_output_guard.py new file mode 100644 index 0000000..993fe82 --- /dev/null +++ b/tests/_guardrails/test_output_guard.py @@ -0,0 +1,105 @@ +"""Tests for the `OutputGuard` capability.""" + +from __future__ import annotations + +import asyncio + +import pytest +from pydantic import BaseModel +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + +from pydantic_ai_harness import OutputBlocked, OutputGuard +from pydantic_ai_harness.guardrails import GuardrailError + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +class TestOutputGuard: + """Integration tests for the `OutputGuard` capability driven through `Agent.run`.""" + + async def test_allows_safe_output(self): + agent = Agent( + TestModel(custom_output_text='harmless reply'), + capabilities=[OutputGuard[None](guard=lambda out: 'SSN' not in str(out))], + ) + result = await agent.run('hello') + assert result.output == 'harmless reply' + + async def test_blocks_unsafe_output(self): + agent = Agent( + TestModel(custom_output_text='leaks SSN 123-45-6789'), + capabilities=[ + OutputGuard[None](guard=lambda out: 'SSN' not in str(out), block_message='contains SSN'), + ], + ) + with pytest.raises(OutputBlocked, match='contains SSN'): + await agent.run('hello') + + async def test_async_guard_awaited(self): + async def guard(output: object) -> bool: + await asyncio.sleep(0) + return 'bad' not in str(output) + + agent = Agent( + TestModel(custom_output_text='ok reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + assert (await agent.run('prompt')).output == 'ok reply' + + agent_bad = Agent( + TestModel(custom_output_text='bad reply'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('prompt') + + async def test_raising_propagates(self): + def guard(_: object) -> bool: + raise RuntimeError('guard exploded') + + agent = Agent( + TestModel(custom_output_text='anything'), + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(RuntimeError, match='guard exploded'): + await agent.run('hello') + + async def test_receives_structured_output_unchanged(self): + """For typed outputs the guard gets the model instance, not a stringified form.""" + + class Answer(BaseModel): + reply: str + internal_url: str + + seen: list[object] = [] + + def guard(output: object) -> bool: + seen.append(output) + assert isinstance(output, Answer) + return 'internal.example.com' not in output.internal_url + + agent = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://public.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + result = await agent.run('hello') + assert isinstance(result.output, Answer) + assert seen == [result.output] + + agent_bad = Agent( + TestModel(custom_output_args={'reply': 'hi', 'internal_url': 'https://internal.example.com/x'}), + output_type=Answer, + capabilities=[OutputGuard[None](guard=guard)], + ) + with pytest.raises(OutputBlocked): + await agent_bad.run('hello') + + def test_output_blocked_is_guardrail_error(self): + assert issubclass(OutputBlocked, GuardrailError)