Skip to content

Commit 3bd5480

Browse files
committed
Merge branch 'main' of github.com:openai/openai-guardrails-python into dev/steven/safety_header
2 parents 1da2ecc + b2d7a81 commit 3bd5480

22 files changed

+3805
-3
lines changed

pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ dev = [
5858
"pymdown-extensions>=10.0.0",
5959
"coverage>=7.8.0",
6060
"hypothesis>=6.131.20",
61+
"pytest-cov>=6.3.0",
6162
]
6263

6364
[tool.uv.workspace]
@@ -103,8 +104,24 @@ convention = "google"
103104
[tool.ruff.format]
104105
docstring-code-format = true
105106

107+
[tool.coverage.run]
108+
source = ["guardrails"]
109+
omit = [
110+
"src/guardrails/evals/*",
111+
]
112+
106113
[tool.mypy]
107114
strict = true
108115
disallow_incomplete_defs = false
109116
disallow_untyped_defs = false
110117
disallow_untyped_calls = false
118+
exclude = [
119+
"examples",
120+
"src/guardrails/evals",
121+
]
122+
123+
[tool.pyright]
124+
ignore = [
125+
"examples",
126+
"src/guardrails/evals",
127+
]

tests/conftest.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Shared pytest fixtures for guardrails tests.
2+
3+
These fixtures provide deterministic test environments by stubbing the OpenAI
4+
client library, seeding environment variables, and preventing accidental live
5+
network activity during the suite.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import logging
11+
import sys
12+
import types
13+
from collections.abc import Iterator
14+
from dataclasses import dataclass
15+
from types import SimpleNamespace
16+
from typing import Any
17+
18+
import pytest
19+
20+
21+
class _StubOpenAIBase:
22+
"""Base stub with attribute bag behaviour for OpenAI client classes."""
23+
24+
def __init__(self, **kwargs: Any) -> None:
25+
self._client_kwargs = kwargs
26+
self.chat = SimpleNamespace()
27+
self.responses = SimpleNamespace()
28+
self.api_key = kwargs.get("api_key", "test-key")
29+
self.base_url = kwargs.get("base_url")
30+
self.organization = kwargs.get("organization")
31+
self.timeout = kwargs.get("timeout")
32+
self.max_retries = kwargs.get("max_retries")
33+
34+
def __getattr__(self, item: str) -> Any:
35+
"""Return None for unknown attributes to emulate real client laziness."""
36+
return None
37+
38+
39+
class _StubAsyncOpenAI(_StubOpenAIBase):
40+
"""Stub asynchronous OpenAI client."""
41+
42+
43+
class _StubSyncOpenAI(_StubOpenAIBase):
44+
"""Stub synchronous OpenAI client."""
45+
46+
47+
@dataclass(frozen=True, slots=True)
48+
class _DummyResponse:
49+
"""Minimal response type with choices and output."""
50+
51+
choices: list[Any] | None = None
52+
output: list[Any] | None = None
53+
output_text: str | None = None
54+
type: str | None = None
55+
delta: str | None = None
56+
57+
58+
_STUB_OPENAI_MODULE = types.ModuleType("openai")
59+
_STUB_OPENAI_MODULE.AsyncOpenAI = _StubAsyncOpenAI
60+
_STUB_OPENAI_MODULE.OpenAI = _StubSyncOpenAI
61+
_STUB_OPENAI_MODULE.AsyncAzureOpenAI = _StubAsyncOpenAI
62+
_STUB_OPENAI_MODULE.AzureOpenAI = _StubSyncOpenAI
63+
_STUB_OPENAI_MODULE.NOT_GIVEN = object()
64+
65+
66+
class APITimeoutError(Exception):
67+
"""Stub API timeout error."""
68+
69+
70+
_STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError
71+
72+
_OPENAI_TYPES_MODULE = types.ModuleType("openai.types")
73+
_OPENAI_TYPES_MODULE.Completion = _DummyResponse
74+
_OPENAI_TYPES_MODULE.Response = _DummyResponse
75+
76+
_OPENAI_CHAT_MODULE = types.ModuleType("openai.types.chat")
77+
_OPENAI_CHAT_MODULE.ChatCompletion = _DummyResponse
78+
_OPENAI_CHAT_MODULE.ChatCompletionChunk = _DummyResponse
79+
80+
_OPENAI_RESPONSES_MODULE = types.ModuleType("openai.types.responses")
81+
_OPENAI_RESPONSES_MODULE.Response = _DummyResponse
82+
_OPENAI_RESPONSES_MODULE.ResponseInputItemParam = dict # type: ignore[attr-defined]
83+
_OPENAI_RESPONSES_MODULE.ResponseOutputItem = dict # type: ignore[attr-defined]
84+
_OPENAI_RESPONSES_MODULE.ResponseStreamEvent = dict # type: ignore[attr-defined]
85+
86+
87+
_OPENAI_RESPONSES_RESPONSE_MODULE = types.ModuleType("openai.types.responses.response")
88+
_OPENAI_RESPONSES_RESPONSE_MODULE.Response = _DummyResponse
89+
90+
91+
class _ResponseTextConfigParam(dict):
92+
"""Stub config param used for response formatting."""
93+
94+
95+
_OPENAI_RESPONSES_MODULE.ResponseTextConfigParam = _ResponseTextConfigParam
96+
97+
sys.modules["openai"] = _STUB_OPENAI_MODULE
98+
sys.modules["openai.types"] = _OPENAI_TYPES_MODULE
99+
sys.modules["openai.types.chat"] = _OPENAI_CHAT_MODULE
100+
sys.modules["openai.types.responses"] = _OPENAI_RESPONSES_MODULE
101+
sys.modules["openai.types.responses.response"] = _OPENAI_RESPONSES_RESPONSE_MODULE
102+
103+
104+
@pytest.fixture(autouse=True)
105+
def stub_openai_module(monkeypatch: pytest.MonkeyPatch) -> Iterator[types.ModuleType]:
106+
"""Provide stub OpenAI module so tests avoid real network-bound clients."""
107+
# Patch imported symbols in guardrails modules
108+
from guardrails import _base_client, client, types as guardrail_types # type: ignore
109+
110+
monkeypatch.setattr(_base_client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
111+
monkeypatch.setattr(_base_client, "OpenAI", _StubSyncOpenAI, raising=False)
112+
monkeypatch.setattr(client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
113+
monkeypatch.setattr(client, "OpenAI", _StubSyncOpenAI, raising=False)
114+
monkeypatch.setattr(client, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False)
115+
monkeypatch.setattr(client, "AzureOpenAI", _StubSyncOpenAI, raising=False)
116+
monkeypatch.setattr(guardrail_types, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
117+
monkeypatch.setattr(guardrail_types, "OpenAI", _StubSyncOpenAI, raising=False)
118+
monkeypatch.setattr(guardrail_types, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False)
119+
monkeypatch.setattr(guardrail_types, "AzureOpenAI", _StubSyncOpenAI, raising=False)
120+
121+
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
122+
123+
yield _STUB_OPENAI_MODULE
124+
125+
126+
@pytest.fixture(autouse=True)
127+
def configure_logging() -> None:
128+
"""Ensure logging defaults to DEBUG for deterministic assertions."""
129+
logging.basicConfig(level=logging.DEBUG)

tests/unit/checks/test_keywords.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Tests for keyword-based guardrail helpers."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
from pydantic import ValidationError
7+
8+
from guardrails.checks.text.competitors import CompetitorCfg, competitors
9+
from guardrails.checks.text.keywords import KeywordCfg, keywords, match_keywords
10+
from guardrails.types import GuardrailResult
11+
12+
13+
def test_match_keywords_sanitizes_trailing_punctuation() -> None:
14+
"""Ensure keyword sanitization strips trailing punctuation before matching."""
15+
config = KeywordCfg(keywords=["token.", "secret!", "KEY?"])
16+
result = match_keywords("Leaked token appears here.", config, guardrail_name="Test Guardrail")
17+
18+
assert result.tripwire_triggered is True # noqa: S101
19+
assert result.info["sanitized_keywords"] == ["token", "secret", "KEY"] # noqa: S101
20+
assert result.info["matched"] == ["token"] # noqa: S101
21+
assert result.info["guardrail_name"] == "Test Guardrail" # noqa: S101
22+
assert result.info["checked_text"] == "Leaked token appears here." # noqa: S101
23+
24+
25+
def test_match_keywords_deduplicates_case_insensitive_matches() -> None:
26+
"""Repeated matches differing by case should be deduplicated."""
27+
config = KeywordCfg(keywords=["Alert"])
28+
result = match_keywords("alert ALERT Alert", config, guardrail_name="Keyword Filter")
29+
30+
assert result.tripwire_triggered is True # noqa: S101
31+
assert result.info["matched"] == ["alert"] # noqa: S101
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_keywords_guardrail_wraps_match_keywords() -> None:
36+
"""Async guardrail should mirror match_keywords behaviour."""
37+
config = KeywordCfg(keywords=["breach"])
38+
result = await keywords(ctx=None, data="Potential breach detected", config=config)
39+
40+
assert isinstance(result, GuardrailResult) # noqa: S101
41+
assert result.tripwire_triggered is True # noqa: S101
42+
assert result.info["guardrail_name"] == "Keyword Filter" # noqa: S101
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_competitors_uses_keyword_matching() -> None:
47+
"""Competitors guardrail delegates to keyword matching with distinct name."""
48+
config = CompetitorCfg(keywords=["ACME Corp"])
49+
result = await competitors(ctx=None, data="Comparing against ACME Corp today", config=config)
50+
51+
assert result.tripwire_triggered is True # noqa: S101
52+
assert result.info["guardrail_name"] == "Competitors" # noqa: S101
53+
assert result.info["matched"] == ["ACME Corp"] # noqa: S101
54+
55+
56+
def test_keyword_cfg_requires_non_empty_keywords() -> None:
57+
"""KeywordCfg should enforce at least one keyword."""
58+
with pytest.raises(ValidationError):
59+
KeywordCfg(keywords=[])
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_keywords_does_not_trigger_on_benign_text() -> None:
64+
"""Guardrail should not trigger when no keywords are present."""
65+
config = KeywordCfg(keywords=["restricted"])
66+
result = await keywords(ctx=None, data="Safe content", config=config)
67+
68+
assert result.tripwire_triggered is False # noqa: S101

tests/unit/checks/test_llm_base.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Tests for LLM-based guardrail helpers."""
2+
3+
from __future__ import annotations
4+
5+
from types import SimpleNamespace
6+
from typing import Any
7+
8+
import pytest
9+
10+
from guardrails.checks.text import llm_base
11+
from guardrails.checks.text.llm_base import (
12+
LLMConfig,
13+
LLMErrorOutput,
14+
LLMOutput,
15+
_build_full_prompt,
16+
_strip_json_code_fence,
17+
create_llm_check_fn,
18+
run_llm,
19+
)
20+
from guardrails.types import GuardrailResult
21+
22+
23+
class _FakeCompletions:
24+
def __init__(self, content: str | None) -> None:
25+
self._content = content
26+
27+
async def create(self, **kwargs: Any) -> Any:
28+
_ = kwargs
29+
return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))])
30+
31+
32+
class _FakeAsyncClient:
33+
def __init__(self, content: str | None) -> None:
34+
self.chat = SimpleNamespace(completions=_FakeCompletions(content))
35+
36+
37+
def test_strip_json_code_fence_removes_wrapping() -> None:
38+
"""Valid JSON code fences should be removed."""
39+
fenced = """```json
40+
{"flagged": false, "confidence": 0.2}
41+
```"""
42+
assert _strip_json_code_fence(fenced) == '{"flagged": false, "confidence": 0.2}' # noqa: S101
43+
44+
45+
def test_build_full_prompt_includes_instructions() -> None:
46+
"""Generated prompt should embed system instructions and schema guidance."""
47+
prompt = _build_full_prompt("Analyze text")
48+
assert "Analyze text" in prompt # noqa: S101
49+
assert "Respond with a json object" in prompt # noqa: S101
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_run_llm_returns_valid_output() -> None:
54+
"""run_llm should parse the JSON response into the provided output model."""
55+
client = _FakeAsyncClient('{"flagged": true, "confidence": 0.9}')
56+
result = await run_llm(
57+
text="Sensitive text",
58+
system_prompt="Detect problems.",
59+
client=client, # type: ignore[arg-type]
60+
model="gpt-test",
61+
output_model=LLMOutput,
62+
)
63+
assert isinstance(result, LLMOutput) # noqa: S101
64+
assert result.flagged is True and result.confidence == 0.9 # noqa: S101
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) -> None:
69+
"""Content filter errors should return LLMErrorOutput with flagged=True."""
70+
71+
class _FailingClient:
72+
class _Chat:
73+
class _Completions:
74+
async def create(self, **kwargs: Any) -> Any:
75+
raise RuntimeError("content_filter triggered by provider")
76+
77+
completions = _Completions()
78+
79+
chat = _Chat()
80+
81+
result = await run_llm(
82+
text="Sensitive",
83+
system_prompt="Detect.",
84+
client=_FailingClient(), # type: ignore[arg-type]
85+
model="gpt-test",
86+
output_model=LLMOutput,
87+
)
88+
89+
assert isinstance(result, LLMErrorOutput) # noqa: S101
90+
assert result.flagged is True # noqa: S101
91+
assert result.info["third_party_filter"] is True # noqa: S101
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_create_llm_check_fn_triggers_on_confident_flag(monkeypatch: pytest.MonkeyPatch) -> None:
96+
"""Generated guardrail function should trip when confidence exceeds the threshold."""
97+
98+
async def fake_run_llm(
99+
text: str,
100+
system_prompt: str,
101+
client: Any,
102+
model: str,
103+
output_model: type[LLMOutput],
104+
) -> LLMOutput:
105+
assert system_prompt == "Check with details" # noqa: S101
106+
return LLMOutput(flagged=True, confidence=0.95)
107+
108+
monkeypatch.setattr(llm_base, "run_llm", fake_run_llm)
109+
110+
class DetailedConfig(LLMConfig):
111+
system_prompt_details: str = "details"
112+
113+
guardrail_fn = create_llm_check_fn(
114+
name="HighConfidence",
115+
description="Test guardrail",
116+
system_prompt="Check with {system_prompt_details}",
117+
output_model=LLMOutput,
118+
config_model=DetailedConfig,
119+
)
120+
121+
config = DetailedConfig(model="gpt-test", confidence_threshold=0.9)
122+
context = SimpleNamespace(guardrail_llm="fake-client")
123+
124+
result = await guardrail_fn(context, "content", config)
125+
126+
assert isinstance(result, GuardrailResult) # noqa: S101
127+
assert result.tripwire_triggered is True # noqa: S101
128+
assert result.info["threshold"] == 0.9 # noqa: S101
129+
130+
131+
@pytest.mark.asyncio
132+
async def test_create_llm_check_fn_handles_llm_error(monkeypatch: pytest.MonkeyPatch) -> None:
133+
"""LLM error results should mark execution_failed without triggering tripwire."""
134+
135+
async def fake_run_llm(
136+
text: str,
137+
system_prompt: str,
138+
client: Any,
139+
model: str,
140+
output_model: type[LLMOutput],
141+
) -> LLMErrorOutput:
142+
return LLMErrorOutput(flagged=False, confidence=0.0, info={"error_message": "timeout"})
143+
144+
monkeypatch.setattr(llm_base, "run_llm", fake_run_llm)
145+
146+
guardrail_fn = create_llm_check_fn(
147+
name="Resilient",
148+
description="Test guardrail",
149+
system_prompt="Prompt",
150+
)
151+
152+
config = LLMConfig(model="gpt-test", confidence_threshold=0.5)
153+
context = SimpleNamespace(guardrail_llm="fake-client")
154+
result = await guardrail_fn(context, "text", config)
155+
156+
assert result.tripwire_triggered is False # noqa: S101
157+
assert result.execution_failed is True # noqa: S101
158+
assert "timeout" in str(result.original_exception) # noqa: S101

0 commit comments

Comments
 (0)