Skip to content
Merged
10 changes: 10 additions & 0 deletions backend/app/alembic/versions/003_added_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,17 @@ def upgrade() -> None:
op.create_index("idx_validator_project", "validator_config", ["project_id"])
op.create_index("idx_validator_type", "validator_config", ["type"])
op.create_index("idx_validator_stage", "validator_config", ["stage"])
op.create_index(
"idx_validator_on_fail_action", "validator_config", ["on_fail_action"]
)
op.create_index("idx_validator_is_enabled", "validator_config", ["is_enabled"])


def downgrade() -> None:
op.drop_index("idx_validator_is_enabled", table_name="validator_config")
op.drop_index("idx_validator_on_fail_action", table_name="validator_config")
op.drop_index("idx_validator_stage", table_name="validator_config")
op.drop_index("idx_validator_type", table_name="validator_config")
op.drop_index("idx_validator_project", table_name="validator_config")
op.drop_index("idx_validator_organization", table_name="validator_config")
op.drop_table("validator_config")
5 changes: 3 additions & 2 deletions backend/app/api/routes/validator_configs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter
from fastapi import APIRouter, Query

from app.api.deps import AuthDep, SessionDep
from app.core.enum import Stage, ValidatorType
Expand Down Expand Up @@ -40,11 +40,12 @@ def list_validators(
project_id: int,
session: SessionDep,
_: AuthDep,
ids: Optional[list[UUID]] = Query(None),
stage: Optional[Stage] = None,
type: Optional[ValidatorType] = None,
):
response_model = validator_config_crud.list(
session, organization_id, project_id, stage, type
session, organization_id, project_id, ids, stage, type
)
return APIResponse.success_response(data=response_model)

Expand Down
14 changes: 11 additions & 3 deletions backend/app/crud/validator_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import logging
from typing import List, Optional
from uuid import UUID

from fastapi import HTTPException
Expand All @@ -10,6 +11,8 @@
from app.schemas.validator_config import ValidatorCreate
from app.utils import now, split_validator_payload

logger = logging.getLogger(__name__)


class ValidatorConfigCrud:
def create(
Expand Down Expand Up @@ -48,14 +51,18 @@ def list(
session: Session,
organization_id: int,
project_id: int,
ids: Optional[list[UUID]] = None,
stage: Optional[Stage] = None,
type: Optional[ValidatorType] = None,
) -> list[dict]:
) -> List[dict]:
query = select(ValidatorConfig).where(
ValidatorConfig.organization_id == organization_id,
ValidatorConfig.project_id == project_id,
)

if ids:
query = query.where(ValidatorConfig.id.in_(ids))
Comment on lines +63 to +64
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

if ids: treats an empty list the same as None — potentially surprising.

if ids: is falsy for both None and []. If a caller passes an empty list of IDs (e.g., the config API has no validators configured), the filter is skipped and all validators for the org/project are returned instead of an empty set. Use an explicit None check:

Proposed fix
-        if ids:
+        if ids is not None:
             query = query.where(ValidatorConfig.id.in_(ids))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if ids:
query = query.where(ValidatorConfig.id.in_(ids))
if ids is not None:
query = query.where(ValidatorConfig.id.in_(ids))
🤖 Prompt for AI Agents
In `@backend/app/crud/validator_config.py` around lines 63 - 64, The current
conditional "if ids:" treats an empty list the same as None and skips applying
the filter, returning all validators; change the check to an explicit None test
(replace the "if ids:" guard with "if ids is not None:") so that when ids is an
empty list the code still calls ValidatorConfig.id.in_(ids) on the query
variable and returns an empty result set as intended.


if stage:
query = query.where(ValidatorConfig.stage == stage)

Expand Down Expand Up @@ -118,7 +125,8 @@ def delete(self, session: Session, obj: ValidatorConfig):

def flatten(self, row: ValidatorConfig) -> dict:
base = row.model_dump(exclude={"config"})
return {**base, **(row.config or {})}
config = row.config or {}
return {**base, **config}


validator_config_crud = ValidatorConfigCrud()
48 changes: 47 additions & 1 deletion backend/app/schemas/guardrail_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated, List, Optional, Union
from uuid import UUID

from pydantic import ConfigDict
from pydantic import ConfigDict, model_validator
from sqlmodel import Field, SQLModel

# todo this could be improved by having some auto-discovery mechanism inside
Expand Down Expand Up @@ -37,6 +37,52 @@ class GuardrailRequest(SQLModel):
input: str
validators: List[ValidatorConfigItem]

@model_validator(mode="before")
@classmethod
def normalize_validators_from_config_api(cls, data):
"""
Accept validator payloads coming from validator-config endpoints and
map them into runtime validator-config shape expected by Guardrails.
"""
if not isinstance(data, dict):
return data

validators = data.get("validators")
if not isinstance(validators, list):
return data

normalized_payload = dict(data)
normalized_validators = []

drop_fields = {
"id",
"organization_id",
"project_id",
"stage",
"is_enabled",
"created_at",
"updated_at",
}

for validator in validators:
if not isinstance(validator, dict):
normalized_validators.append(validator)
continue

normalized_validator = {
key: value
for key, value in validator.items()
if key not in drop_fields and key != "on_fail_action"
}

if "on_fail" not in normalized_validator and "on_fail_action" in validator:
normalized_validator["on_fail"] = validator["on_fail_action"]

normalized_validators.append(normalized_validator)

normalized_payload["validators"] = normalized_validators
return normalized_payload


class GuardrailResponse(SQLModel):
response_id: UUID
Expand Down
2 changes: 1 addition & 1 deletion backend/app/schemas/validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ValidatorBase(SQLModel):

type: ValidatorType
stage: Stage
on_fail_action: GuardrailOnFail
on_fail_action: GuardrailOnFail = GuardrailOnFail.Fix
is_enabled: bool = True


Expand Down
56 changes: 56 additions & 0 deletions backend/app/tests/test_validator_configs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,52 @@ def test_list_validators_filter_by_type(self, integration_client, clear_database
assert len(data) == 1
assert data[0]["type"] == "pii_remover"

def test_list_validators_filter_by_ids(self, integration_client, clear_database):
"""Test filtering validators by ids query parameter."""
first = self.create_validator(integration_client, "lexical_slur")
second = self.create_validator(integration_client, "pii_remover_input")
first_id = first.json()["data"]["id"]
second_id = second.json()["data"]["id"]

response = integration_client.get(
f"{BASE_URL}{DEFAULT_QUERY_PARAMS}&ids={first_id}",
)

assert response.status_code == 200
data = response.json()["data"]
assert len(data) == 1
assert data[0]["id"] == first_id
assert data[0]["id"] != second_id

def test_list_validators_filter_by_multiple_ids(
self, integration_client, clear_database
):
"""Test filtering validators by multiple ids query parameters."""
first = self.create_validator(integration_client, "lexical_slur")
second = self.create_validator(integration_client, "pii_remover_input")
first_id = first.json()["data"]["id"]
second_id = second.json()["data"]["id"]

response = integration_client.get(
f"{BASE_URL}{DEFAULT_QUERY_PARAMS}&ids={first_id}&ids={second_id}",
)

assert response.status_code == 200
data = response.json()["data"]
assert len(data) == 2
returned_ids = {item["id"] for item in data}
assert returned_ids == {first_id, second_id}

def test_list_validators_invalid_ids_query_returns_422(
self, integration_client, clear_database
):
"""Test invalid UUID in ids query returns validation error."""
response = integration_client.get(
f"{BASE_URL}{DEFAULT_QUERY_PARAMS}&ids=not-a-uuid",
)

assert response.status_code == 422

def test_list_validators_empty(self, integration_client, clear_database):
"""Test listing validators when none exist."""
response = integration_client.get(
Expand Down Expand Up @@ -203,6 +249,16 @@ def test_get_validator_not_found(self, integration_client, clear_database):

assert response.status_code == 404

def test_get_validator_invalid_id_returns_422(
self, integration_client, clear_database
):
"""Test invalid UUID path param returns validation error."""
response = integration_client.get(
f"{BASE_URL}not-a-uuid/{DEFAULT_QUERY_PARAMS}",
)

assert response.status_code == 422

def test_get_validator_wrong_org(self, integration_client, clear_database):
"""Test that accessing validator from different org returns 404."""
# Create a validator for org 1
Expand Down