Skip to content

Commit 62b91cf

Browse files
committed
extract common logic
1 parent 514c197 commit 62b91cf

File tree

5 files changed

+97
-120
lines changed

5 files changed

+97
-120
lines changed

src/guardrails/checks/text/llm_base.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,36 +48,7 @@ class MyLLMOutput(LLMOutput):
4848
from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult
4949
from guardrails.utils.output import OutputSchema
5050

51-
# OpenAI safety identifier for tracking guardrails library usage
52-
# Only supported by official OpenAI API (not Azure or local/alternative providers)
53-
_SAFETY_IDENTIFIER = "oai_guardrails"
54-
55-
56-
def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI) -> bool:
57-
"""Check if the client supports the safety_identifier parameter.
58-
59-
Only the official OpenAI API supports this parameter.
60-
Azure OpenAI and local/alternative providers do not.
61-
62-
Args:
63-
client: The OpenAI client instance.
64-
65-
Returns:
66-
True if safety_identifier should be included, False otherwise.
67-
"""
68-
# Azure clients don't support it
69-
if isinstance(client, AsyncAzureOpenAI | AzureOpenAI):
70-
return False
71-
72-
# Check if using a custom base_url (local or alternative provider)
73-
base_url = getattr(client, "base_url", None)
74-
if base_url is not None:
75-
base_url_str = str(base_url)
76-
# Only official OpenAI API endpoints support safety_identifier
77-
return "api.openai.com" in base_url_str
78-
79-
# Default OpenAI client (no custom base_url) supports it
80-
return True
51+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
8152

8253
if TYPE_CHECKING:
8354
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
@@ -93,10 +64,10 @@ def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI
9364

9465
__all__ = [
9566
"LLMConfig",
96-
"LLMOutput",
9767
"LLMErrorOutput",
98-
"create_llm_check_fn",
68+
"LLMOutput",
9969
"create_error_result",
70+
"create_llm_check_fn",
10071
]
10172

10273

@@ -286,8 +257,8 @@ async def _request_chat_completion(
286257
}
287258

288259
# Only official OpenAI API supports safety_identifier (not Azure or local models)
289-
if _supports_safety_identifier(client):
290-
kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
260+
if supports_safety_identifier(client):
261+
kwargs["safety_identifier"] = SAFETY_IDENTIFIER
291262

292263
return await _invoke_openai_callable(client.chat.completions.create, **kwargs)
293264

src/guardrails/resources/chat/chat.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,7 @@
66
from typing import Any
77

88
from ..._base_client import GuardrailsBaseClient
9-
10-
# OpenAI safety identifier for tracking guardrails library usage
11-
# Only supported by official OpenAI API (not Azure or local/alternative providers)
12-
_SAFETY_IDENTIFIER = "oai_guardrails"
13-
14-
15-
def _supports_safety_identifier(client: Any) -> bool:
16-
"""Check if the client supports the safety_identifier parameter.
17-
18-
Only the official OpenAI API supports this parameter.
19-
Azure OpenAI and local/alternative providers do not.
20-
21-
Args:
22-
client: The OpenAI client instance.
23-
24-
Returns:
25-
True if safety_identifier should be included, False otherwise.
26-
"""
27-
# Azure clients don't support it
28-
client_type = type(client).__name__
29-
if "Azure" in client_type:
30-
return False
31-
32-
# Check if using a custom base_url (local or alternative provider)
33-
base_url = getattr(client, "base_url", None)
34-
if base_url is not None:
35-
base_url_str = str(base_url)
36-
# Only official OpenAI API endpoints support safety_identifier
37-
return "api.openai.com" in base_url_str
38-
39-
# Default OpenAI client (no custom base_url) supports it
40-
return True
9+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
4110

4211

4312
class Chat:
@@ -121,8 +90,8 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals
12190
"stream": stream,
12291
**kwargs,
12392
}
124-
if _supports_safety_identifier(self._client._resource_client):
125-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
93+
if supports_safety_identifier(self._client._resource_client):
94+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
12695

12796
llm_future = executor.submit(
12897
self._client._resource_client.chat.completions.create,
@@ -198,8 +167,8 @@ async def create(
198167
"stream": stream,
199168
**kwargs,
200169
}
201-
if _supports_safety_identifier(self._client._resource_client):
202-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
170+
if supports_safety_identifier(self._client._resource_client):
171+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
203172

204173
llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs)
205174

src/guardrails/resources/responses/responses.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,7 @@
88
from pydantic import BaseModel
99

1010
from ..._base_client import GuardrailsBaseClient
11-
12-
# OpenAI safety identifier for tracking guardrails library usage
13-
# Only supported by official OpenAI API (not Azure or local/alternative providers)
14-
_SAFETY_IDENTIFIER = "oai_guardrails"
15-
16-
17-
def _supports_safety_identifier(client: Any) -> bool:
18-
"""Check if the client supports the safety_identifier parameter.
19-
20-
Only the official OpenAI API supports this parameter.
21-
Azure OpenAI and local/alternative providers do not.
22-
23-
Args:
24-
client: The OpenAI client instance.
25-
26-
Returns:
27-
True if safety_identifier should be included, False otherwise.
28-
"""
29-
# Azure clients don't support it
30-
client_type = type(client).__name__
31-
if "Azure" in client_type:
32-
return False
33-
34-
# Check if using a custom base_url (local or alternative provider)
35-
base_url = getattr(client, "base_url", None)
36-
if base_url is not None:
37-
base_url_str = str(base_url)
38-
# Only official OpenAI API endpoints support safety_identifier
39-
return "api.openai.com" in base_url_str
40-
41-
# Default OpenAI client (no custom base_url) supports it
42-
return True
11+
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier
4312

4413

4514
class Responses:
@@ -103,8 +72,8 @@ def create(
10372
"tools": tools,
10473
**kwargs,
10574
}
106-
if _supports_safety_identifier(self._client._resource_client):
107-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
75+
if supports_safety_identifier(self._client._resource_client):
76+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
10877

10978
llm_future = executor.submit(
11079
self._client._resource_client.responses.create,
@@ -169,8 +138,8 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM
169138
"text_format": text_format,
170139
**kwargs,
171140
}
172-
if _supports_safety_identifier(self._client._resource_client):
173-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
141+
if supports_safety_identifier(self._client._resource_client):
142+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
174143

175144
llm_future = executor.submit(
176145
self._client._resource_client.responses.parse,
@@ -273,8 +242,8 @@ async def create(
273242
"tools": tools,
274243
**kwargs,
275244
}
276-
if _supports_safety_identifier(self._client._resource_client):
277-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
245+
if supports_safety_identifier(self._client._resource_client):
246+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
278247

279248
llm_call = self._client._resource_client.responses.create(**llm_kwargs)
280249

@@ -339,8 +308,8 @@ async def parse(
339308
"stream": stream,
340309
**kwargs,
341310
}
342-
if _supports_safety_identifier(self._client._resource_client):
343-
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER
311+
if supports_safety_identifier(self._client._resource_client):
312+
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER
344313

345314
llm_call = self._client._resource_client.responses.parse(**llm_kwargs)
346315

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""OpenAI safety identifier utilities.
2+
3+
This module provides utilities for handling the OpenAI safety_identifier parameter,
4+
which is used to track guardrails library usage for monitoring and abuse detection.
5+
6+
The safety identifier is only supported by the official OpenAI API and should not
7+
be sent to Azure OpenAI or other OpenAI-compatible providers.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import TYPE_CHECKING, Any
13+
14+
if TYPE_CHECKING:
15+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
16+
else:
17+
try:
18+
from openai import AsyncAzureOpenAI, AzureOpenAI
19+
except ImportError:
20+
AsyncAzureOpenAI = None # type: ignore[assignment, misc]
21+
AzureOpenAI = None # type: ignore[assignment, misc]
22+
23+
__all__ = ["SAFETY_IDENTIFIER", "supports_safety_identifier"]
24+
25+
# OpenAI safety identifier for tracking guardrails library usage
26+
SAFETY_IDENTIFIER = "oai_guardrails"
27+
28+
29+
def supports_safety_identifier(
30+
client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI | Any,
31+
) -> bool:
32+
"""Check if the client supports the safety_identifier parameter.
33+
34+
Only the official OpenAI API supports this parameter.
35+
Azure OpenAI and local/alternative providers (Ollama, vLLM, etc.) do not.
36+
37+
Args:
38+
client: The OpenAI client instance to check.
39+
40+
Returns:
41+
True if safety_identifier should be included in API calls, False otherwise.
42+
43+
Examples:
44+
>>> from openai import AsyncOpenAI
45+
>>> client = AsyncOpenAI()
46+
>>> supports_safety_identifier(client)
47+
True
48+
49+
>>> from openai import AsyncOpenAI
50+
>>> local_client = AsyncOpenAI(base_url="http://localhost:11434")
51+
>>> supports_safety_identifier(local_client)
52+
False
53+
"""
54+
# Azure clients don't support it
55+
if AsyncAzureOpenAI is not None and AzureOpenAI is not None:
56+
if isinstance(client, AsyncAzureOpenAI | AzureOpenAI):
57+
return False
58+
59+
# Check if using a custom base_url (local or alternative provider)
60+
base_url = getattr(client, "base_url", None)
61+
if base_url is not None:
62+
base_url_str = str(base_url)
63+
# Only official OpenAI API endpoints support safety_identifier
64+
return "api.openai.com" in base_url_str
65+
66+
# Default OpenAI client (no custom base_url) supports it
67+
return True
68+

tests/unit/test_safety_identifier.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,29 @@
77

88
def test_supports_safety_identifier_for_openai_client() -> None:
99
"""Official OpenAI client with default base_url should support safety_identifier."""
10-
from guardrails.checks.text.llm_base import _supports_safety_identifier
10+
from guardrails.utils.safety_identifier import supports_safety_identifier
1111

1212
mock_client = Mock()
1313
mock_client.base_url = None
1414
mock_client.__class__.__name__ = "AsyncOpenAI"
1515

16-
assert _supports_safety_identifier(mock_client) is True # noqa: S101
16+
assert supports_safety_identifier(mock_client) is True # noqa: S101
1717

1818

1919
def test_supports_safety_identifier_for_openai_with_official_url() -> None:
2020
"""OpenAI client with explicit api.openai.com base_url should support safety_identifier."""
21-
from guardrails.checks.text.llm_base import _supports_safety_identifier
21+
from guardrails.utils.safety_identifier import supports_safety_identifier
2222

2323
mock_client = Mock()
2424
mock_client.base_url = "https://api.openai.com/v1"
2525
mock_client.__class__.__name__ = "AsyncOpenAI"
2626

27-
assert _supports_safety_identifier(mock_client) is True # noqa: S101
27+
assert supports_safety_identifier(mock_client) is True # noqa: S101
2828

2929

3030
def test_does_not_support_safety_identifier_for_azure() -> None:
3131
"""Azure OpenAI client should not support safety_identifier."""
32-
from guardrails.checks.text.llm_base import _supports_safety_identifier
32+
from guardrails.utils.safety_identifier import supports_safety_identifier
3333

3434
mock_client = Mock()
3535
mock_client.base_url = "https://example.openai.azure.com/v1"
@@ -44,30 +44,30 @@ def test_does_not_support_safety_identifier_for_azure() -> None:
4444
azure_endpoint="https://example.openai.azure.com",
4545
api_version="2024-02-01",
4646
)
47-
assert _supports_safety_identifier(azure_client) is False # noqa: S101
47+
assert supports_safety_identifier(azure_client) is False # noqa: S101
4848
except Exception:
4949
# If we can't create a real Azure client in tests, that's okay
5050
pytest.skip("Could not create Azure client for testing")
5151

5252

5353
def test_does_not_support_safety_identifier_for_local_model() -> None:
5454
"""Local model with custom base_url should not support safety_identifier."""
55-
from guardrails.checks.text.llm_base import _supports_safety_identifier
55+
from guardrails.utils.safety_identifier import supports_safety_identifier
5656

5757
mock_client = Mock()
5858
mock_client.base_url = "http://localhost:11434/v1" # Ollama
5959
mock_client.__class__.__name__ = "AsyncOpenAI"
6060

61-
assert _supports_safety_identifier(mock_client) is False # noqa: S101
61+
assert supports_safety_identifier(mock_client) is False # noqa: S101
6262

6363

6464
def test_does_not_support_safety_identifier_for_alternative_provider() -> None:
6565
"""Alternative OpenAI-compatible provider should not support safety_identifier."""
66-
from guardrails.checks.text.llm_base import _supports_safety_identifier
66+
from guardrails.utils.safety_identifier import supports_safety_identifier
6767

6868
mock_client = Mock()
6969
mock_client.base_url = "https://api.together.xyz/v1"
7070
mock_client.__class__.__name__ = "AsyncOpenAI"
7171

72-
assert _supports_safety_identifier(mock_client) is False # noqa: S101
72+
assert supports_safety_identifier(mock_client) is False # noqa: S101
7373

0 commit comments

Comments
 (0)