Skip to content

Commit fb5aee5

Browse files
committed
Handle sync vs async and oai vs azure clients
1 parent 4960ae0 commit fb5aee5

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

src/guardrails/checks/text/moderation.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@
2727

2828
from __future__ import annotations
2929

30+
import asyncio
3031
import logging
3132
from enum import Enum
3233
from functools import cache
3334
from typing import Any
3435

35-
from openai import AsyncOpenAI, NotFoundError
36+
from openai import AsyncOpenAI, NotFoundError, OpenAI
3637
from pydantic import BaseModel, ConfigDict, Field
3738

39+
try:
40+
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore
41+
except Exception: # pragma: no cover - optional dependency
42+
AsyncAzureOpenAI = object # type: ignore
43+
AzureOpenAI = object # type: ignore
44+
3845
from guardrails.registry import default_spec_registry
3946
from guardrails.spec import GuardrailSpecMetadata
4047
from guardrails.types import GuardrailResult
@@ -130,11 +137,14 @@ def _get_moderation_client() -> AsyncOpenAI:
130137
return AsyncOpenAI()
131138

132139

133-
async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
134-
"""Call the OpenAI moderation API.
140+
async def _call_moderation_api_async(
141+
client: AsyncOpenAI | AsyncAzureOpenAI,
142+
data: str, # type: ignore
143+
) -> Any:
144+
"""Call the OpenAI moderation API asynchronously.
135145
136146
Args:
137-
client: The OpenAI client to use.
147+
client: The async OpenAI or Azure OpenAI client to use.
138148
data: The text to analyze.
139149
140150
Returns:
@@ -146,6 +156,22 @@ async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
146156
)
147157

148158

159+
def _call_moderation_api_sync(client: OpenAI | AzureOpenAI, data: str) -> Any: # type: ignore
160+
"""Call the OpenAI moderation API synchronously.
161+
162+
Args:
163+
client: The sync OpenAI or Azure OpenAI client to use.
164+
data: The text to analyze.
165+
166+
Returns:
167+
The moderation API response.
168+
"""
169+
return client.moderations.create(
170+
model="omni-moderation-latest",
171+
input=data,
172+
)
173+
174+
149175
async def moderation(
150176
ctx: Any,
151177
data: str,
@@ -169,8 +195,15 @@ async def moderation(
169195
client = getattr(ctx, "guardrail_llm", None) if ctx is not None else None
170196

171197
if client is not None:
198+
# Determine if client is async or sync and call appropriately
199+
is_async_client = isinstance(client, AsyncOpenAI | AsyncAzureOpenAI)
200+
172201
try:
173-
resp = await _call_moderation_api(client, data)
202+
if is_async_client:
203+
resp = await _call_moderation_api_async(client, data)
204+
else:
205+
# Sync client - run in thread pool to avoid blocking event loop
206+
resp = await asyncio.to_thread(_call_moderation_api_sync, client, data)
174207
except NotFoundError as e:
175208
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
176209
# Fall back to the OpenAI client
@@ -179,11 +212,11 @@ async def moderation(
179212
e,
180213
)
181214
client = _get_moderation_client()
182-
resp = await _call_moderation_api(client, data)
215+
resp = await _call_moderation_api_async(client, data)
183216
else:
184217
# No context client, use fallback
185218
client = _get_moderation_client()
186-
resp = await _call_moderation_api(client, data)
219+
resp = await _call_moderation_api_async(client, data)
187220
results = resp.results or []
188221
if not results:
189222
return GuardrailResult(

tests/unit/checks/test_moderation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,35 @@ async def raise_not_found(**_: Any) -> Any:
134134
# Verify the fallback client was used (not the third-party one)
135135
assert fallback_used is True # noqa: S101
136136
assert result.tripwire_triggered is False # noqa: S101
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_moderation_uses_sync_context_client() -> None:
141+
"""Moderation should support synchronous OpenAI clients from context."""
142+
from openai import OpenAI
143+
144+
# Track whether sync context client was used
145+
sync_client_used = False
146+
147+
def track_sync_create(**_: Any) -> Any:
148+
nonlocal sync_client_used
149+
sync_client_used = True
150+
151+
class _Result:
152+
def model_dump(self) -> dict[str, Any]:
153+
return {"categories": {"hate": False, "violence": False}}
154+
155+
return SimpleNamespace(results=[_Result()])
156+
157+
# Create a sync context client
158+
sync_client = OpenAI(api_key="test-sync-key", base_url="https://api.openai.com/v1")
159+
sync_client.moderations = SimpleNamespace(create=track_sync_create) # type: ignore[assignment]
160+
161+
ctx = SimpleNamespace(guardrail_llm=sync_client)
162+
163+
cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
164+
result = await moderation(ctx, "test text", cfg)
165+
166+
# Verify the sync context client was used (via asyncio.to_thread)
167+
assert sync_client_used is True # noqa: S101
168+
assert result.tripwire_triggered is False # noqa: S101

0 commit comments

Comments
 (0)