|
8 | 8 | from pydantic import BaseModel |
9 | 9 |
|
10 | 10 | from ..._base_client import GuardrailsBaseClient |
11 | | - |
12 | | -# OpenAI safety identifier for tracking guardrails library usage |
13 | | -# Only supported by official OpenAI API (not Azure or local/alternative providers) |
14 | | -_SAFETY_IDENTIFIER = "oai_guardrails" |
15 | | - |
16 | | - |
17 | | -def _supports_safety_identifier(client: Any) -> bool: |
18 | | - """Check if the client supports the safety_identifier parameter. |
19 | | -
|
20 | | - Only the official OpenAI API supports this parameter. |
21 | | - Azure OpenAI and local/alternative providers do not. |
22 | | -
|
23 | | - Args: |
24 | | - client: The OpenAI client instance. |
25 | | -
|
26 | | - Returns: |
27 | | - True if safety_identifier should be included, False otherwise. |
28 | | - """ |
29 | | - # Azure clients don't support it |
30 | | - client_type = type(client).__name__ |
31 | | - if "Azure" in client_type: |
32 | | - return False |
33 | | - |
34 | | - # Check if using a custom base_url (local or alternative provider) |
35 | | - base_url = getattr(client, "base_url", None) |
36 | | - if base_url is not None: |
37 | | - base_url_str = str(base_url) |
38 | | - # Only official OpenAI API endpoints support safety_identifier |
39 | | - return "api.openai.com" in base_url_str |
40 | | - |
41 | | - # Default OpenAI client (no custom base_url) supports it |
42 | | - return True |
| 11 | +from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier |
43 | 12 |
|
44 | 13 |
|
45 | 14 | class Responses: |
@@ -103,8 +72,8 @@ def create( |
103 | 72 | "tools": tools, |
104 | 73 | **kwargs, |
105 | 74 | } |
106 | | - if _supports_safety_identifier(self._client._resource_client): |
107 | | - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 75 | + if supports_safety_identifier(self._client._resource_client): |
| 76 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
108 | 77 |
|
109 | 78 | llm_future = executor.submit( |
110 | 79 | self._client._resource_client.responses.create, |
@@ -169,8 +138,8 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM |
169 | 138 | "text_format": text_format, |
170 | 139 | **kwargs, |
171 | 140 | } |
172 | | - if _supports_safety_identifier(self._client._resource_client): |
173 | | - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 141 | + if supports_safety_identifier(self._client._resource_client): |
| 142 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
174 | 143 |
|
175 | 144 | llm_future = executor.submit( |
176 | 145 | self._client._resource_client.responses.parse, |
@@ -273,8 +242,8 @@ async def create( |
273 | 242 | "tools": tools, |
274 | 243 | **kwargs, |
275 | 244 | } |
276 | | - if _supports_safety_identifier(self._client._resource_client): |
277 | | - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 245 | + if supports_safety_identifier(self._client._resource_client): |
| 246 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
278 | 247 |
|
279 | 248 | llm_call = self._client._resource_client.responses.create(**llm_kwargs) |
280 | 249 |
|
@@ -339,8 +308,8 @@ async def parse( |
339 | 308 | "stream": stream, |
340 | 309 | **kwargs, |
341 | 310 | } |
342 | | - if _supports_safety_identifier(self._client._resource_client): |
343 | | - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER |
| 311 | + if supports_safety_identifier(self._client._resource_client): |
| 312 | + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER |
344 | 313 |
|
345 | 314 | llm_call = self._client._resource_client.responses.parse(**llm_kwargs) |
346 | 315 |
|
|
0 commit comments