Skip to content

Commit e86f555

Browse files
committed
Add GuardrailAgent prompt param support
1 parent 12c4add commit e86f555

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

src/guardrails/agents.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,34 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
451451
return guardrail_functions
452452

453453

454+
def _resolve_agent_instructions(instructions: str | None, prompt: Any | None) -> str | None:
455+
"""Derive instructions from explicit input or prompt.
456+
457+
Args:
458+
instructions: Explicit instructions provided by the caller.
459+
prompt: Optional prompt object or string supplied to the agent.
460+
461+
Returns:
462+
A string containing the agent instructions when available, otherwise ``None``.
463+
"""
464+
465+
if instructions is not None:
466+
return instructions
467+
468+
if prompt is None:
469+
return None
470+
471+
if isinstance(prompt, str):
472+
return prompt
473+
474+
for attr_name in ("instructions", "text", "content"):
475+
candidate = getattr(prompt, attr_name, None)
476+
if isinstance(candidate, str):
477+
return candidate
478+
479+
return None
480+
481+
454482
class GuardrailAgent:
455483
"""Drop-in replacement for Agents SDK Agent with automatic guardrails integration.
456484
@@ -492,7 +520,7 @@ def __new__(
492520
cls,
493521
config: str | Path | dict[str, Any],
494522
name: str,
495-
instructions: str,
523+
instructions: str | None = None,
496524
raise_guardrail_errors: bool = False,
497525
block_on_tool_violations: bool = False,
498526
**agent_kwargs: Any,
@@ -511,7 +539,7 @@ def __new__(
511539
Args:
512540
config: Pipeline configuration (file path, dict, or JSON string)
513541
name: Agent name
514-
instructions: Agent instructions
542+
instructions: Agent instructions. When omitted, a ``prompt`` argument must be provided.
515543
raise_guardrail_errors: If True, raise exceptions when guardrails fail to execute.
516544
If False (default), treat guardrail errors as safe and continue execution.
517545
block_on_tool_violations: If True, tool guardrail violations raise exceptions (halt execution).
@@ -614,5 +642,19 @@ def __new__(
614642
)
615643
_attach_guardrail_to_tools(tools, tool_output_gr, "output")
616644

645+
prompt_arg: Any | None = agent_kwargs.get("prompt")
646+
resolved_instructions = _resolve_agent_instructions(instructions, prompt_arg)
647+
648+
if resolved_instructions is None and prompt_arg is None:
649+
raise ValueError(
650+
"GuardrailAgent requires either 'instructions' or 'prompt' to initialize the underlying Agent."
651+
)
652+
617653
# Create and return a regular Agent instance with guardrails configured
618-
return Agent(name=name, instructions=instructions, input_guardrails=input_guardrails, output_guardrails=output_guardrails, **agent_kwargs)
654+
return Agent(
655+
name=name,
656+
instructions=resolved_instructions,
657+
input_guardrails=input_guardrails,
658+
output_guardrails=output_guardrails,
659+
**agent_kwargs,
660+
)

tests/unit/test_agents.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
import types
77
from collections.abc import Callable
88
from dataclasses import dataclass
9+
from pathlib import Path
910
from types import SimpleNamespace
1011
from typing import Any
1112

1213
import pytest
1314

15+
guardrails_pkg = types.ModuleType("guardrails")
16+
guardrails_pkg.__path__ = [str(Path(__file__).resolve().parents[2] / "src" / "guardrails")]
17+
sys.modules.setdefault("guardrails", guardrails_pkg)
18+
1419
from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE
1520
from guardrails.types import GuardrailResult
1621

@@ -94,7 +99,8 @@ class Agent:
9499
"""Trivial Agent stub storing initialization args for assertions."""
95100

96101
name: str
97-
instructions: str
102+
instructions: str | None = None
103+
prompt: Any | None = None
98104
input_guardrails: list[Callable] | None = None
99105
output_guardrails: list[Callable] | None = None
100106
tools: list[Any] | None = None
@@ -597,3 +603,42 @@ def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None:
597603
agent_instance = agents.GuardrailAgent(config={}, name="NoTools", instructions="None")
598604

599605
assert getattr(agent_instance, "input_guardrails", []) == [] # noqa: S101
606+
607+
608+
def test_guardrail_agent_allows_prompt_without_instructions(monkeypatch: pytest.MonkeyPatch) -> None:
609+
"""Prompt attribute text becomes derived instructions."""
610+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
611+
612+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
613+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
614+
615+
prompt = SimpleNamespace(text="Serve customers helpfully.")
616+
617+
agent_instance = agents.GuardrailAgent(config={}, name="PromptOnly", prompt=prompt)
618+
619+
assert agent_instance.prompt is prompt # noqa: S101
620+
assert agent_instance.instructions == "Serve customers helpfully." # noqa: S101
621+
622+
623+
def test_guardrail_agent_accepts_string_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
624+
"""String prompts populate missing instructions automatically."""
625+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
626+
627+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
628+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
629+
630+
agent_instance = agents.GuardrailAgent(config={}, name="PromptStr", prompt="Be concise.")
631+
632+
assert agent_instance.prompt == "Be concise." # noqa: S101
633+
assert agent_instance.instructions == "Be concise." # noqa: S101
634+
635+
636+
def test_guardrail_agent_requires_instructions_or_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
637+
"""GuardrailAgent requires instructions or prompt for construction."""
638+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
639+
640+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
641+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
642+
643+
with pytest.raises(ValueError):
644+
agents.GuardrailAgent(config={}, name="Missing")

0 commit comments

Comments
 (0)