Skip to content

Commit 06fa983

Browse files
authored
Set safety_identifier to openai-guardrails-python (#37)
* extract common logic * change id value
1 parent 1bfd82b commit 06fa983

File tree

14 files changed

+238
-124
lines changed

14 files changed

+238
-124
lines changed

src/guardrails/_openai_utils.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/guardrails/agents.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pathlib import Path
1919
from typing import Any
2020

21-
from ._openai_utils import prepare_openai_kwargs
2221
from .utils.conversation import merge_conversation_with_items, normalize_conversation
2322

2423
logger = logging.getLogger(__name__)
@@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any:
167166
class DefaultContext:
168167
guardrail_llm: AsyncOpenAI
169168

170-
return DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
169+
return DefaultContext(guardrail_llm=AsyncOpenAI())
171170

172171

173172
def _create_conversation_context(
@@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config(
393392
class DefaultContext:
394393
guardrail_llm: AsyncOpenAI
395394

396-
context = DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
395+
context = DefaultContext(guardrail_llm=AsyncOpenAI())
397396

398397
def _create_stage_guardrail(stage_name: str):
399398
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:

src/guardrails/checks/text/llm_base.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class MyLLMOutput(LLMOutput):
4848
from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult
4949
from guardrails.utils.output import OutputSchema
5050

51+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
52+
5153
if TYPE_CHECKING:
5254
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
5355
else:
@@ -62,10 +64,10 @@ class MyLLMOutput(LLMOutput):
6264

6365
__all__ = [
6466
"LLMConfig",
65-
"LLMOutput",
6667
"LLMErrorOutput",
67-
"create_llm_check_fn",
68+
"LLMOutput",
6869
"create_error_result",
70+
"create_llm_check_fn",
6971
]
7072

7173

@@ -247,12 +249,18 @@ async def _request_chat_completion(
247249
response_format: dict[str, Any],
248250
) -> Any:
249251
"""Invoke chat.completions.create on sync or async OpenAI clients."""
250-
return await _invoke_openai_callable(
251-
client.chat.completions.create,
252-
messages=messages,
253-
model=model,
254-
response_format=response_format,
255-
)
252+
# Only include safety_identifier for official OpenAI API
253+
kwargs: dict[str, Any] = {
254+
"messages": messages,
255+
"model": model,
256+
"response_format": response_format,
257+
}
258+
259+
# Only official OpenAI API supports safety_identifier (not Azure or local models)
260+
if supports_safety_identifier(client):
261+
kwargs["safety_identifier"] = SAFETY_IDENTIFIER
262+
263+
return await _invoke_openai_callable(client.chat.completions.create, **kwargs)
256264

257265

258266
async def run_llm(

src/guardrails/checks/text/moderation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from guardrails.spec import GuardrailSpecMetadata
4040
from guardrails.types import GuardrailResult
4141

42-
from ..._openai_utils import prepare_openai_kwargs
43-
4442
logger = logging.getLogger(__name__)
4543

4644
__all__ = ["moderation", "Category", "ModerationCfg"]
@@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI:
129127
Returns:
130128
AsyncOpenAI: Cached OpenAI API client for moderation checks.
131129
"""
132-
return AsyncOpenAI(**prepare_openai_kwargs({}))
130+
return AsyncOpenAI()
133131

134132

135133
async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:

src/guardrails/client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
GuardrailsResponse,
2727
OpenAIResponseType,
2828
)
29-
from ._openai_utils import prepare_openai_kwargs
3029
from ._streaming import StreamingMixin
3130
from .exceptions import GuardrailTripwireTriggered
3231
from .runtime import run_guardrails
@@ -167,7 +166,6 @@ def __init__(
167166
by this parameter.
168167
**openai_kwargs: Additional arguments passed to AsyncOpenAI constructor.
169168
"""
170-
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
171169
# Initialize OpenAI client first
172170
super().__init__(**openai_kwargs)
173171

@@ -205,7 +203,7 @@ class DefaultContext:
205203
default_headers = getattr(self, "default_headers", None)
206204
if default_headers is not None:
207205
guardrail_kwargs["default_headers"] = default_headers
208-
guardrail_client = AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs))
206+
guardrail_client = AsyncOpenAI(**guardrail_kwargs)
209207

210208
return DefaultContext(guardrail_llm=guardrail_client)
211209

@@ -335,7 +333,6 @@ def __init__(
335333
by this parameter.
336334
**openai_kwargs: Additional arguments passed to OpenAI constructor.
337335
"""
338-
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
339336
# Initialize OpenAI client first
340337
super().__init__(**openai_kwargs)
341338

@@ -373,7 +370,7 @@ class DefaultContext:
373370
default_headers = getattr(self, "default_headers", None)
374371
if default_headers is not None:
375372
guardrail_kwargs["default_headers"] = default_headers
376-
guardrail_client = OpenAI(**prepare_openai_kwargs(guardrail_kwargs))
373+
guardrail_client = OpenAI(**guardrail_kwargs)
377374

378375
return DefaultContext(guardrail_llm=guardrail_client)
379376

@@ -516,7 +513,6 @@ def __init__(
516513
by this parameter.
517514
**azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor.
518515
"""
519-
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
520516
# Initialize Azure client first
521517
super().__init__(**azure_kwargs)
522518

@@ -671,7 +667,6 @@ def __init__(
671667
by this parameter.
672668
**azure_kwargs: Additional arguments passed to AzureOpenAI constructor.
673669
"""
674-
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
675670
super().__init__(**azure_kwargs)
676671

677672
# Store the error handling preference

src/guardrails/evals/guardrail_evals.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424

2525
from guardrails import instantiate_guardrails, load_pipeline_bundles
26-
from guardrails._openai_utils import prepare_openai_kwargs
2726
from guardrails.evals.core import (
2827
AsyncRunEngine,
2928
BenchmarkMetricsCalculator,
@@ -281,7 +280,7 @@ def _create_context(self) -> Context:
281280
if self.api_key:
282281
azure_kwargs["api_key"] = self.api_key
283282

284-
guardrail_llm = AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs))
283+
guardrail_llm = AsyncAzureOpenAI(**azure_kwargs)
285284
logger.info("Created Azure OpenAI client for endpoint: %s", self.azure_endpoint)
286285
# OpenAI or OpenAI-compatible API
287286
else:
@@ -292,7 +291,7 @@ def _create_context(self) -> Context:
292291
openai_kwargs["base_url"] = self.base_url
293292
logger.info("Created OpenAI-compatible client for base_url: %s", self.base_url)
294293

295-
guardrail_llm = AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs))
294+
guardrail_llm = AsyncOpenAI(**openai_kwargs)
296295

297296
return Context(guardrail_llm=guardrail_llm)
298297

src/guardrails/resources/chat/chat.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
from ..._base_client import GuardrailsBaseClient
9+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
910

1011

1112
class Chat:
@@ -82,12 +83,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
8283

8384
# Run input guardrails and LLM call concurrently using a thread for the LLM
8485
with ThreadPoolExecutor(max_workers=1) as executor:
86+
# Only include safety_identifier for OpenAI clients (not Azure)
87+
llm_kwargs = {
88+
"messages": modified_messages,
89+
"model": model,
90+
"stream": stream,
91+
**kwargs,
92+
}
93+
if supports_safety_identifier(self._client._resource_client):
94+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
95+
8596
llm_future = executor.submit(
8697
self._client._resource_client.chat.completions.create,
87-
messages=modified_messages, # Use messages with any preflight modifications
88-
model=model,
89-
stream=stream,
90-
**kwargs,
98+
**llm_kwargs,
9199
)
92100
input_results = self._client._run_stage_guardrails(
93101
"input",
@@ -152,12 +160,17 @@ async def create(
152160
conversation_history=normalized_conversation,
153161
suppress_tripwire=suppress_tripwire,
154162
)
155-
llm_call = self._client._resource_client.chat.completions.create(
156-
messages=modified_messages, # Use messages with any preflight modifications
157-
model=model,
158-
stream=stream,
163+
# Only include safety_identifier for OpenAI clients (not Azure)
164+
llm_kwargs = {
165+
"messages": modified_messages,
166+
"model": model,
167+
"stream": stream,
159168
**kwargs,
160-
)
169+
}
170+
if supports_safety_identifier(self._client._resource_client):
171+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
172+
173+
llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs)
161174

162175
input_results, llm_response = await asyncio.gather(input_check, llm_call)
163176

src/guardrails/resources/responses/responses.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import BaseModel
99

1010
from ..._base_client import GuardrailsBaseClient
11+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
1112

1213

1314
class Responses:
@@ -63,13 +64,20 @@ def create(
6364

6465
# Input guardrails and LLM call concurrently
6566
with ThreadPoolExecutor(max_workers=1) as executor:
67+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
68+
llm_kwargs = {
69+
"input": modified_input,
70+
"model": model,
71+
"stream": stream,
72+
"tools": tools,
73+
**kwargs,
74+
}
75+
if supports_safety_identifier(self._client._resource_client):
76+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
77+
6678
llm_future = executor.submit(
6779
self._client._resource_client.responses.create,
68-
input=modified_input, # Use preflight-modified input
69-
model=model,
70-
stream=stream,
71-
tools=tools,
72-
**kwargs,
80+
**llm_kwargs,
7381
)
7482
input_results = self._client._run_stage_guardrails(
7583
"input",
@@ -123,12 +131,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM
123131

124132
# Input guardrails and LLM call concurrently
125133
with ThreadPoolExecutor(max_workers=1) as executor:
134+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
135+
llm_kwargs = {
136+
"input": modified_input,
137+
"model": model,
138+
"text_format": text_format,
139+
**kwargs,
140+
}
141+
if supports_safety_identifier(self._client._resource_client):
142+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
143+
126144
llm_future = executor.submit(
127145
self._client._resource_client.responses.parse,
128-
input=modified_input, # Use modified input with preflight changes
129-
model=model,
130-
text_format=text_format,
131-
**kwargs,
146+
**llm_kwargs,
132147
)
133148
input_results = self._client._run_stage_guardrails(
134149
"input",
@@ -218,13 +233,19 @@ async def create(
218233
conversation_history=normalized_conversation,
219234
suppress_tripwire=suppress_tripwire,
220235
)
221-
llm_call = self._client._resource_client.responses.create(
222-
input=modified_input, # Use preflight-modified input
223-
model=model,
224-
stream=stream,
225-
tools=tools,
236+
237+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
238+
llm_kwargs = {
239+
"input": modified_input,
240+
"model": model,
241+
"stream": stream,
242+
"tools": tools,
226243
**kwargs,
227-
)
244+
}
245+
if supports_safety_identifier(self._client._resource_client):
246+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
247+
248+
llm_call = self._client._resource_client.responses.create(**llm_kwargs)
228249

229250
input_results, llm_response = await asyncio.gather(input_check, llm_call)
230251

@@ -278,13 +299,19 @@ async def parse(
278299
conversation_history=normalized_conversation,
279300
suppress_tripwire=suppress_tripwire,
280301
)
281-
llm_call = self._client._resource_client.responses.parse(
282-
input=modified_input, # Use modified input with preflight changes
283-
model=model,
284-
text_format=text_format,
285-
stream=stream,
302+
303+
# Only include safety_identifier for OpenAI clients (not Azure or local models)
304+
llm_kwargs = {
305+
"input": modified_input,
306+
"model": model,
307+
"text_format": text_format,
308+
"stream": stream,
286309
**kwargs,
287-
)
310+
}
311+
if supports_safety_identifier(self._client._resource_client):
312+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
313+
314+
llm_call = self._client._resource_client.responses.parse(**llm_kwargs)
288315

289316
input_results, llm_response = await asyncio.gather(input_check, llm_call)
290317

src/guardrails/runtime.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from openai import AsyncOpenAI
2222
from pydantic import BaseModel, ConfigDict
2323

24-
from ._openai_utils import prepare_openai_kwargs
2524
from .exceptions import ConfigError, GuardrailTripwireTriggered
2625
from .registry import GuardrailRegistry, default_spec_registry
2726
from .spec import GuardrailSpec
@@ -495,7 +494,7 @@ def _get_default_ctx():
495494
class DefaultCtx:
496495
guardrail_llm: AsyncOpenAI
497496

498-
return DefaultCtx(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
497+
return DefaultCtx(guardrail_llm=AsyncOpenAI())
499498

500499

501500
async def check_plain_text(

0 commit comments

Comments
 (0)