|
7 | 7 |
|
8 | 8 | from ..._base_client import GuardrailsBaseClient |
9 | 9 |
|
| 10 | +# OpenAI safety identifier for tracking guardrails library usage |
| 11 | +# Only supported by official OpenAI API (not Azure or local/alternative providers) |
| 12 | +_SAFETY_IDENTIFIER = "oai_guardrails" |
| 13 | + |
| 14 | + |
| 15 | +def _supports_safety_identifier(client: Any) -> bool: |
| 16 | + """Check if the client supports the safety_identifier parameter. |
| 17 | +
|
| 18 | + Only the official OpenAI API supports this parameter. |
| 19 | + Azure OpenAI and local/alternative providers do not. |
| 20 | +
|
| 21 | + Args: |
| 22 | + client: The OpenAI client instance. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + True if safety_identifier should be included, False otherwise. |
| 26 | + """ |
| 27 | + # Azure clients don't support it |
| 28 | + client_type = type(client).__name__ |
| 29 | + if "Azure" in client_type: |
| 30 | + return False |
| 31 | + |
| 32 | + # Check if using a custom base_url (local or alternative provider) |
| 33 | + base_url = getattr(client, "base_url", None) |
| 34 | + if base_url is not None: |
| 35 | + base_url_str = str(base_url) |
| 36 | + # Only official OpenAI API endpoints support safety_identifier |
| 37 | + return "api.openai.com" in base_url_str |
| 38 | + |
| 39 | + # Default OpenAI client (no custom base_url) supports it |
| 40 | + return True |
| 41 | + |
10 | 42 |
|
11 | 43 | class Chat: |
12 | 44 | """Chat completions with guardrails (sync).""" |
@@ -82,12 +114,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals |
82 | 114 |
|
83 | 115 | # Run input guardrails and LLM call concurrently using a thread for the LLM |
84 | 116 | with ThreadPoolExecutor(max_workers=1) as executor: |
| 117 | + # Only include safety_identifier for OpenAI clients (not Azure) |
| 118 | + llm_kwargs = { |
| 119 | + "messages": modified_messages, |
| 120 | + "model": model, |
| 121 | + "stream": stream, |
| 122 | + **kwargs, |
| 123 | + } |
| 124 | + if _supports_safety_identifier(self._client._resource_client): |
| 125 | + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 126 | + |
85 | 127 | llm_future = executor.submit( |
86 | 128 | self._client._resource_client.chat.completions.create, |
87 | | - messages=modified_messages, # Use messages with any preflight modifications |
88 | | - model=model, |
89 | | - stream=stream, |
90 | | - **kwargs, |
| 129 | + **llm_kwargs, |
91 | 130 | ) |
92 | 131 | input_results = self._client._run_stage_guardrails( |
93 | 132 | "input", |
@@ -152,12 +191,17 @@ async def create( |
152 | 191 | conversation_history=normalized_conversation, |
153 | 192 | suppress_tripwire=suppress_tripwire, |
154 | 193 | ) |
155 | | - llm_call = self._client._resource_client.chat.completions.create( |
156 | | - messages=modified_messages, # Use messages with any preflight modifications |
157 | | - model=model, |
158 | | - stream=stream, |
| 194 | + # Only include safety_identifier for OpenAI clients (not Azure) |
| 195 | + llm_kwargs = { |
| 196 | + "messages": modified_messages, |
| 197 | + "model": model, |
| 198 | + "stream": stream, |
159 | 199 | **kwargs, |
160 | | - ) |
| 200 | + } |
| 201 | + if _supports_safety_identifier(self._client._resource_client): |
| 202 | + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 203 | + |
| 204 | + llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs) |
161 | 205 |
|
162 | 206 | input_results, llm_response = await asyncio.gather(input_check, llm_call) |
163 | 207 |
|
|
0 commit comments