Skip to content

Commit 2ba38cb

Browse files
committed
Remove Azure
1 parent fb5aee5 commit 2ba38cb

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

src/guardrails/checks/text/moderation.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,7 @@ def _get_moderation_client() -> AsyncOpenAI:
137137
return AsyncOpenAI()
138138

139139

140-
async def _call_moderation_api_async(
141-
client: AsyncOpenAI | AsyncAzureOpenAI,
142-
data: str, # type: ignore
143-
) -> Any:
140+
async def _call_moderation_api_async(client: Any, data: str) -> Any:
144141
"""Call the OpenAI moderation API asynchronously.
145142
146143
Args:
@@ -156,7 +153,7 @@ async def _call_moderation_api_async(
156153
)
157154

158155

159-
def _call_moderation_api_sync(client: OpenAI | AzureOpenAI, data: str) -> Any: # type: ignore
156+
def _call_moderation_api_sync(client: Any, data: str) -> Any:
160157
"""Call the OpenAI moderation API synchronously.
161158
162159
Args:
@@ -191,30 +188,30 @@ async def moderation(
191188
Returns:
192189
GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories.
193190
"""
194-
# Try the context client first, fall back if moderation endpoint doesn't exist
191+
# Try context client first (if provided), fall back on 404
195192
client = getattr(ctx, "guardrail_llm", None) if ctx is not None else None
196193

197194
if client is not None:
198-
# Determine if client is async or sync and call appropriately
199-
is_async_client = isinstance(client, AsyncOpenAI | AsyncAzureOpenAI)
195+
# Determine if client is async or sync
196+
is_async = isinstance(client, AsyncOpenAI)
200197

201198
try:
202-
if is_async_client:
199+
if is_async:
203200
resp = await _call_moderation_api_async(client, data)
204201
else:
205202
# Sync client - run in thread pool to avoid blocking event loop
206203
resp = await asyncio.to_thread(_call_moderation_api_sync, client, data)
207204
except NotFoundError as e:
208-
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
209-
# Fall back to the OpenAI client
205+
# Moderation endpoint doesn't exist (e.g., Azure, third-party)
206+
# Fall back to OpenAI client with OPENAI_API_KEY env var
210207
logger.debug(
211208
"Moderation endpoint not available on context client, falling back to OpenAI: %s",
212209
e,
213210
)
214211
client = _get_moderation_client()
215212
resp = await _call_moderation_api_async(client, data)
216213
else:
217-
# No context client, use fallback
214+
# No context client - use fallback OpenAI client
218215
client = _get_moderation_client()
219216
resp = await _call_moderation_api_async(client, data)
220217
results = resp.results or []

tests/unit/checks/test_moderation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,57 @@ def model_dump(self) -> dict[str, Any]:
166166
# Verify the sync context client was used (via asyncio.to_thread)
167167
assert sync_client_used is True # noqa: S101
168168
assert result.tripwire_triggered is False # noqa: S101
169+
170+
171+
@pytest.mark.asyncio
172+
async def test_moderation_falls_back_for_azure_clients(monkeypatch: pytest.MonkeyPatch) -> None:
173+
"""Moderation should fall back to OpenAI client for Azure clients (no moderation endpoint)."""
174+
try:
175+
from openai import AsyncAzureOpenAI, NotFoundError
176+
except ImportError:
177+
pytest.skip("Azure OpenAI not available")
178+
179+
# Track whether fallback was used
180+
fallback_used = False
181+
182+
async def track_fallback_create(**_: Any) -> Any:
183+
nonlocal fallback_used
184+
fallback_used = True
185+
186+
class _Result:
187+
def model_dump(self) -> dict[str, Any]:
188+
return {"categories": {"hate": False, "violence": False}}
189+
190+
return SimpleNamespace(results=[_Result()])
191+
192+
# Mock the fallback client
193+
fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create))
194+
monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client)
195+
196+
# Create a mock httpx.Response for NotFoundError
197+
mock_response = SimpleNamespace(
198+
status_code=404,
199+
headers={},
200+
text="404 page not found",
201+
json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}},
202+
)
203+
204+
# Create an Azure context client that raises NotFoundError for moderation
205+
async def raise_not_found(**_: Any) -> Any:
206+
raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type]
207+
208+
azure_client = AsyncAzureOpenAI(
209+
api_key="test-azure-key",
210+
api_version="2024-02-01",
211+
azure_endpoint="https://test.openai.azure.com",
212+
)
213+
azure_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment]
214+
215+
ctx = SimpleNamespace(guardrail_llm=azure_client)
216+
217+
cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
218+
result = await moderation(ctx, "test text", cfg)
219+
220+
# Verify the fallback client was used (not the Azure one)
221+
assert fallback_used is True # noqa: S101
222+
assert result.tripwire_triggered is False # noqa: S101

0 commit comments

Comments
 (0)