|
8 | 8 | from pydantic import BaseModel |
9 | 9 |
|
10 | 10 | from ..._base_client import GuardrailsBaseClient |
| 11 | +from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier |
11 | 12 |
|
12 | 13 |
|
13 | 14 | class Responses: |
@@ -63,13 +64,20 @@ def create( |
63 | 64 |
|
64 | 65 | # Input guardrails and LLM call concurrently |
65 | 66 | with ThreadPoolExecutor(max_workers=1) as executor: |
| 67 | + # Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 68 | + llm_kwargs = { |
| 69 | + "input": modified_input, |
| 70 | + "model": model, |
| 71 | + "stream": stream, |
| 72 | + "tools": tools, |
| 73 | + **kwargs, |
| 74 | + } |
| 75 | + if supports_safety_identifier(self._client._resource_client): |
| 76 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
| 77 | + |
66 | 78 | llm_future = executor.submit( |
67 | 79 | self._client._resource_client.responses.create, |
68 | | - input=modified_input, # Use preflight-modified input |
69 | | - model=model, |
70 | | - stream=stream, |
71 | | - tools=tools, |
72 | | - **kwargs, |
| 80 | + **llm_kwargs, |
73 | 81 | ) |
74 | 82 | input_results = self._client._run_stage_guardrails( |
75 | 83 | "input", |
@@ -123,12 +131,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM |
123 | 131 |
|
124 | 132 | # Input guardrails and LLM call concurrently |
125 | 133 | with ThreadPoolExecutor(max_workers=1) as executor: |
| 134 | + # Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 135 | + llm_kwargs = { |
| 136 | + "input": modified_input, |
| 137 | + "model": model, |
| 138 | + "text_format": text_format, |
| 139 | + **kwargs, |
| 140 | + } |
| 141 | + if supports_safety_identifier(self._client._resource_client): |
| 142 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
| 143 | + |
126 | 144 | llm_future = executor.submit( |
127 | 145 | self._client._resource_client.responses.parse, |
128 | | - input=modified_input, # Use modified input with preflight changes |
129 | | - model=model, |
130 | | - text_format=text_format, |
131 | | - **kwargs, |
| 146 | + **llm_kwargs, |
132 | 147 | ) |
133 | 148 | input_results = self._client._run_stage_guardrails( |
134 | 149 | "input", |
@@ -218,13 +233,19 @@ async def create( |
218 | 233 | conversation_history=normalized_conversation, |
219 | 234 | suppress_tripwire=suppress_tripwire, |
220 | 235 | ) |
221 | | - llm_call = self._client._resource_client.responses.create( |
222 | | - input=modified_input, # Use preflight-modified input |
223 | | - model=model, |
224 | | - stream=stream, |
225 | | - tools=tools, |
| 236 | + |
| 237 | + # Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 238 | + llm_kwargs = { |
| 239 | + "input": modified_input, |
| 240 | + "model": model, |
| 241 | + "stream": stream, |
| 242 | + "tools": tools, |
226 | 243 | **kwargs, |
227 | | - ) |
| 244 | + } |
| 245 | + if supports_safety_identifier(self._client._resource_client): |
| 246 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
| 247 | + |
| 248 | + llm_call = self._client._resource_client.responses.create(**llm_kwargs) |
228 | 249 |
|
229 | 250 | input_results, llm_response = await asyncio.gather(input_check, llm_call) |
230 | 251 |
|
@@ -278,13 +299,19 @@ async def parse( |
278 | 299 | conversation_history=normalized_conversation, |
279 | 300 | suppress_tripwire=suppress_tripwire, |
280 | 301 | ) |
281 | | - llm_call = self._client._resource_client.responses.parse( |
282 | | - input=modified_input, # Use modified input with preflight changes |
283 | | - model=model, |
284 | | - text_format=text_format, |
285 | | - stream=stream, |
| 302 | + |
| 303 | + # Only include safety_identifier for OpenAI clients (not Azure or local models) |
| 304 | + llm_kwargs = { |
| 305 | + "input": modified_input, |
| 306 | + "model": model, |
| 307 | + "text_format": text_format, |
| 308 | + "stream": stream, |
286 | 309 | **kwargs, |
287 | | - ) |
| 310 | + } |
| 311 | + if supports_safety_identifier(self._client._resource_client): |
| 312 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
| 313 | + |
| 314 | + llm_call = self._client._resource_client.responses.parse(**llm_kwargs) |
288 | 315 |
|
289 | 316 | input_results, llm_response = await asyncio.gather(input_check, llm_call) |
290 | 317 |
|
|
0 commit comments