Skip to content

Commit 514c197

Browse files
committed
fixing safety identifier
1 parent 45e958c commit 514c197

File tree

13 files changed

+259
-122
lines changed

13 files changed

+259
-122
lines changed

src/guardrails/_openai_utils.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/guardrails/agents.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pathlib import Path
1919
from typing import Any
2020

21-
from ._openai_utils import prepare_openai_kwargs
2221
from .utils.conversation import merge_conversation_with_items, normalize_conversation
2322

2423
logger = logging.getLogger(__name__)
@@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any:
167166
class DefaultContext:
168167
guardrail_llm: AsyncOpenAI
169168

170-
return DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
169+
return DefaultContext(guardrail_llm=AsyncOpenAI())
171170

172171

173172
def _create_conversation_context(
@@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config(
393392
class DefaultContext:
394393
guardrail_llm: AsyncOpenAI
395394

396-
context = DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
395+
context = DefaultContext(guardrail_llm=AsyncOpenAI())
397396

398397
def _create_stage_guardrail(stage_name: str):
399398
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:

src/guardrails/checks/text/llm_base.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,37 @@ class MyLLMOutput(LLMOutput):
4848
from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult
4949
from guardrails.utils.output import OutputSchema
5050

51+
# OpenAI safety identifier for tracking guardrails library usage
52+
# Only supported by official OpenAI API (not Azure or local/alternative providers)
53+
_SAFETY_IDENTIFIER = "oai_guardrails"
54+
55+
56+
def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI) -> bool:
57+
"""Check if the client supports the safety_identifier parameter.
58+
59+
Only the official OpenAI API supports this parameter.
60+
Azure OpenAI and local/alternative providers do not.
61+
62+
Args:
63+
client: The OpenAI client instance.
64+
65+
Returns:
66+
True if safety_identifier should be included, False otherwise.
67+
"""
68+
# Azure clients don't support it
69+
if isinstance(client, AsyncAzureOpenAI | AzureOpenAI):
70+
return False
71+
72+
# Check if using a custom base_url (local or alternative provider)
73+
base_url = getattr(client, "base_url", None)
74+
if base_url is not None:
75+
base_url_str = str(base_url)
76+
# Only official OpenAI API endpoints support safety_identifier
77+
return "api.openai.com" in base_url_str
78+
79+
# Default OpenAI client (no custom base_url) supports it
80+
return True
81+
5182
if TYPE_CHECKING:
5283
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
5384
else:
@@ -247,12 +278,18 @@ async def _request_chat_completion(
247278
response_format: dict[str, Any],
248279
) -> Any:
249280
"""Invoke chat.completions.create on sync or async OpenAI clients."""
250-
return await _invoke_openai_callable(
251-
client.chat.completions.create,
252-
messages=messages,
253-
model=model,
254-
response_format=response_format,
255-
)
281+
# Only include safety_identifier for official OpenAI API
282+
kwargs: dict[str, Any] = {
283+
"messages": messages,
284+
"model": model,
285+
"response_format": response_format,
286+
}
287+
288+
# Only official OpenAI API supports safety_identifier (not Azure or local models)
289+
if _supports_safety_identifier(client):
290+
kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
291+
292+
return await _invoke_openai_callable(client.chat.completions.create, **kwargs)
256293

257294

258295
async def run_llm(

src/guardrails/checks/text/moderation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from guardrails.spec import GuardrailSpecMetadata
4040
from guardrails.types import GuardrailResult
4141

42-
from ..._openai_utils import prepare_openai_kwargs
43-
4442
logger = logging.getLogger(__name__)
4543

4644
__all__ = ["moderation", "Category", "ModerationCfg"]
@@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI:
129127
Returns:
130128
AsyncOpenAI: Cached OpenAI API client for moderation checks.
131129
"""
132-
return AsyncOpenAI(**prepare_openai_kwargs({}))
130+
return AsyncOpenAI()
133131

134132

135133
async def moderation(

src/guardrails/client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
GuardrailsResponse,
2727
OpenAIResponseType,
2828
)
29-
from ._openai_utils import prepare_openai_kwargs
3029
from ._streaming import StreamingMixin
3130
from .exceptions import GuardrailTripwireTriggered
3231
from .runtime import run_guardrails
@@ -167,7 +166,6 @@ def __init__(
167166
by this parameter.
168167
**openai_kwargs: Additional arguments passed to AsyncOpenAI constructor.
169168
"""
170-
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
171169
# Initialize OpenAI client first
172170
super().__init__(**openai_kwargs)
173171

@@ -205,7 +203,7 @@ class DefaultContext:
205203
default_headers = getattr(self, "default_headers", None)
206204
if default_headers is not None:
207205
guardrail_kwargs["default_headers"] = default_headers
208-
guardrail_client = AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs))
206+
guardrail_client = AsyncOpenAI(**guardrail_kwargs)
209207

210208
return DefaultContext(guardrail_llm=guardrail_client)
211209

@@ -335,7 +333,6 @@ def __init__(
335333
by this parameter.
336334
**openai_kwargs: Additional arguments passed to OpenAI constructor.
337335
"""
338-
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
339336
# Initialize OpenAI client first
340337
super().__init__(**openai_kwargs)
341338

@@ -373,7 +370,7 @@ class DefaultContext:
373370
default_headers = getattr(self, "default_headers", None)
374371
if default_headers is not None:
375372
guardrail_kwargs["default_headers"] = default_headers
376-
guardrail_client = OpenAI(**prepare_openai_kwargs(guardrail_kwargs))
373+
guardrail_client = OpenAI(**guardrail_kwargs)
377374

378375
return DefaultContext(guardrail_llm=guardrail_client)
379376

@@ -516,7 +513,6 @@ def __init__(
516513
by this parameter.
517514
**azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor.
518515
"""
519-
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
520516
# Initialize Azure client first
521517
super().__init__(**azure_kwargs)
522518

@@ -671,7 +667,6 @@ def __init__(
671667
by this parameter.
672668
**azure_kwargs: Additional arguments passed to AzureOpenAI constructor.
673669
"""
674-
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
675670
super().__init__(**azure_kwargs)
676671

677672
# Store the error handling preference

src/guardrails/evals/guardrail_evals.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424

2525
from guardrails import instantiate_guardrails, load_pipeline_bundles
26-
from guardrails._openai_utils import prepare_openai_kwargs
2726
from guardrails.evals.core import (
2827
AsyncRunEngine,
2928
BenchmarkMetricsCalculator,
@@ -281,7 +280,7 @@ def _create_context(self) -> Context:
281280
if self.api_key:
282281
azure_kwargs["api_key"] = self.api_key
283282

284-
guardrail_llm = AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs))
283+
guardrail_llm = AsyncAzureOpenAI(**azure_kwargs)
285284
logger.info("Created Azure OpenAI client for endpoint: %s", self.azure_endpoint)
286285
# OpenAI or OpenAI-compatible API
287286
else:
@@ -292,7 +291,7 @@ def _create_context(self) -> Context:
292291
openai_kwargs["base_url"] = self.base_url
293292
logger.info("Created OpenAI-compatible client for base_url: %s", self.base_url)
294293

295-
guardrail_llm = AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs))
294+
guardrail_llm = AsyncOpenAI(**openai_kwargs)
296295

297296
return Context(guardrail_llm=guardrail_llm)
298297

src/guardrails/resources/chat/chat.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77

88
from ..._base_client import GuardrailsBaseClient
99

10+
# OpenAI safety identifier for tracking guardrails library usage
11+
# Only supported by official OpenAI API (not Azure or local/alternative providers)
12+
_SAFETY_IDENTIFIER = "oai_guardrails"
13+
14+
15+
def _supports_safety_identifier(client: Any) -> bool:
16+
"""Check if the client supports the safety_identifier parameter.
17+
18+
Only the official OpenAI API supports this parameter.
19+
Azure OpenAI and local/alternative providers do not.
20+
21+
Args:
22+
client: The OpenAI client instance.
23+
24+
Returns:
25+
True if safety_identifier should be included, False otherwise.
26+
"""
27+
# Azure clients don't support it
28+
client_type = type(client).__name__
29+
if "Azure" in client_type:
30+
return False
31+
32+
# Check if using a custom base_url (local or alternative provider)
33+
base_url = getattr(client, "base_url", None)
34+
if base_url is not None:
35+
base_url_str = str(base_url)
36+
# Only official OpenAI API endpoints support safety_identifier
37+
return "api.openai.com" in base_url_str
38+
39+
# Default OpenAI client (no custom base_url) supports it
40+
return True
41+
1042

1143
class Chat:
1244
"""Chat completions with guardrails (sync)."""
@@ -82,12 +114,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
82114

83115
# Run input guardrails and LLM call concurrently using a thread for the LLM
84116
with ThreadPoolExecutor(max_workers=1) as executor:
117+
# Only include safety_identifier for OpenAI clients (not Azure)
118+
llm_kwargs = {
119+
"messages": modified_messages,
120+
"model": model,
121+
"stream": stream,
122+
**kwargs,
123+
}
124+
if _supports_safety_identifier(self._client._resource_client):
125+
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
126+
85127
llm_future = executor.submit(
86128
self._client._resource_client.chat.completions.create,
87-
messages=modified_messages, # Use messages with any preflight modifications
88-
model=model,
89-
stream=stream,
90-
**kwargs,
129+
**llm_kwargs,
91130
)
92131
input_results = self._client._run_stage_guardrails(
93132
"input",
@@ -152,12 +191,17 @@ async def create(
152191
conversation_history=normalized_conversation,
153192
suppress_tripwire=suppress_tripwire,
154193
)
155-
llm_call = self._client._resource_client.chat.completions.create(
156-
messages=modified_messages, # Use messages with any preflight modifications
157-
model=model,
158-
stream=stream,
194+
# Only include safety_identifier for OpenAI clients (not Azure)
195+
llm_kwargs = {
196+
"messages": modified_messages,
197+
"model": model,
198+
"stream": stream,
159199
**kwargs,
160-
)
200+
}
201+
if _supports_safety_identifier(self._client._resource_client):
202+
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
203+
204+
llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs)
161205

162206
input_results, llm_response = await asyncio.gather(input_check, llm_call)
163207

0 commit comments

Comments
 (0)