Skip to content
12 changes: 11 additions & 1 deletion backend/app/alembic/versions/003_added_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def upgrade() -> None:
op.create_table(
"validator_config",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=False),
sa.Column("type", sa.String(), nullable=False),
Expand Down Expand Up @@ -52,7 +52,17 @@ def upgrade() -> None:
op.create_index("idx_validator_project", "validator_config", ["project_id"])
op.create_index("idx_validator_type", "validator_config", ["type"])
op.create_index("idx_validator_stage", "validator_config", ["stage"])
op.create_index(
"idx_validator_on_fail_action", "validator_config", ["on_fail_action"]
)
op.create_index("idx_validator_is_enabled", "validator_config", ["is_enabled"])


def downgrade() -> None:
op.drop_index("idx_validator_is_enabled", table_name="validator_config")
op.drop_index("idx_validator_on_fail_action", table_name="validator_config")
op.drop_index("idx_validator_stage", table_name="validator_config")
op.drop_index("idx_validator_type", table_name="validator_config")
op.drop_index("idx_validator_project", table_name="validator_config")
op.drop_index("idx_validator_organization", table_name="validator_config")
op.drop_table("validator_config")
37 changes: 33 additions & 4 deletions backend/app/api/routes/validator_configs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter

from app.api.deps import AuthDep, SessionDep
from app.core.enum import Stage, ValidatorType
from app.schemas.validator_config import (
ValidatorBatchCreate,
ValidatorBatchFetchItem,
ValidatorCreate,
ValidatorResponse,
ValidatorUpdate,
Expand Down Expand Up @@ -34,6 +35,20 @@ def create_validator(
return APIResponse.success_response(data=response_model)


@router.post("/batch", response_model=APIResponse[list[ValidatorResponse]])
def create_validators_batch(
payload: ValidatorBatchCreate,
session: SessionDep,
organization_id: int,
project_id: int,
_: AuthDep,
):
response_model = validator_config_crud.create_many(
session, organization_id, project_id, payload
)
return APIResponse.success_response(data=response_model)


@router.get("/", response_model=APIResponse[list[ValidatorResponse]])
def list_validators(
organization_id: int,
Expand All @@ -49,9 +64,23 @@ def list_validators(
return APIResponse.success_response(data=response_model)


@router.post("/batch/fetch", response_model=APIResponse[list[ValidatorResponse]])
def fetch_validators_batch(
payload: list[ValidatorBatchFetchItem],
organization_id: int,
project_id: int,
session: SessionDep,
_: AuthDep,
):
response_model = validator_config_crud.list_by_batch_items(
session, organization_id, project_id, payload
)
return APIResponse.success_response(data=response_model)


@router.get("/{id}", response_model=APIResponse[ValidatorResponse])
def get_validator(
id: UUID,
id: int,
organization_id: int,
project_id: int,
session: SessionDep,
Expand All @@ -63,7 +92,7 @@ def get_validator(

@router.patch("/{id}", response_model=APIResponse[ValidatorResponse])
def update_validator(
id: UUID,
id: int,
organization_id: int,
project_id: int,
payload: ValidatorUpdate,
Expand All @@ -79,7 +108,7 @@ def update_validator(

@router.delete("/{id}", response_model=APIResponse[dict])
def delete_validator(
id: UUID,
id: int,
organization_id: int,
project_id: int,
session: SessionDep,
Expand Down
93 changes: 87 additions & 6 deletions backend/app/crud/validator_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from typing import Optional
from uuid import UUID
import logging
from typing import Any, List, Optional

from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select

from app.core.enum import Stage, ValidatorType
from app.models.config.validator_config import ValidatorConfig
from app.schemas.validator_config import ValidatorCreate
from app.schemas.validator_config import (
ValidatorBatchCreate,
ValidatorBatchFetchItem,
ValidatorCreate,
)
from app.utils import now, split_validator_payload

logger = logging.getLogger(__name__)


class ValidatorConfigCrud:
def create(
Expand Down Expand Up @@ -43,14 +49,57 @@ def create(
session.refresh(obj)
return self.flatten(obj)

def create_many(
self,
session: Session,
organization_id: int,
project_id: int,
payloads: ValidatorBatchCreate,
) -> list[dict]:
objs = []

try:
for payload in payloads.validators:
data = payload.model_dump()
model_fields, config_fields = split_validator_payload(data)
obj = ValidatorConfig(
organization_id=organization_id,
project_id=project_id,
config=config_fields,
**model_fields,
)
objs.append(obj)

session.add_all(objs)
except Exception:
session.rollback()
logger.exception("Failed to construct/add validator batch")
raise

try:
session.commit()
except IntegrityError:
session.rollback()
raise HTTPException(
400,
"Validator batch creation failed",
)
except Exception:
session.rollback()
raise

for obj in objs:
session.refresh(obj)
return [self.flatten(r) for r in objs]

def list(
self,
session: Session,
organization_id: int,
project_id: int,
stage: Optional[Stage] = None,
type: Optional[ValidatorType] = None,
) -> list[dict]:
) -> List[dict]:
query = select(ValidatorConfig).where(
ValidatorConfig.organization_id == organization_id,
ValidatorConfig.project_id == project_id,
Expand All @@ -65,10 +114,41 @@ def list(
rows = session.exec(query).all()
return [self.flatten(r) for r in rows]

def list_by_batch_items(
self,
session: Session,
organization_id: int,
project_id: int,
payload: List[ValidatorBatchFetchItem],
) -> List[dict]:
if not payload:
return []

ids = list({item.validator_config for item in payload})
query = select(ValidatorConfig).where(
ValidatorConfig.organization_id == organization_id,
ValidatorConfig.project_id == project_id,
ValidatorConfig.id.in_(ids),
)
rows = session.exec(query).all()

flattened_rows = {}
for row in rows:
row_type = row.type.value if hasattr(row.type, "value") else str(row.type)
flattened_rows[(row.id, row_type)] = self.flatten(row)

response: List[dict] = []
for item in payload:
maybe_row = flattened_rows.get((item.validator_config, item.validator_type))
if maybe_row:
response.append(maybe_row)

return response

def get(
self,
session: Session,
id: UUID,
id: int,
organization_id: int,
project_id: int,
) -> ValidatorConfig:
Expand Down Expand Up @@ -118,7 +198,8 @@ def delete(self, session: Session, obj: ValidatorConfig):

def flatten(self, row: ValidatorConfig) -> dict:
base = row.model_dump(exclude={"config"})
return {**base, **(row.config or {})}
config = row.config or {}
return {**base, **config}


validator_config_crud = ValidatorConfigCrud()
36 changes: 28 additions & 8 deletions backend/app/models/config/validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
class ValidatorConfig(SQLModel, table=True):
__tablename__ = "validator_config"

id: UUID = Field(
default_factory=uuid4,
id: int = Field(
primary_key=True,
sa_column_kwargs={
"comment": "Unique identifier for the validator configuration"
Expand All @@ -34,19 +33,40 @@ class ValidatorConfig(SQLModel, table=True):
)

type: ValidatorType = Field(
nullable=False,
sa_column_kwargs={"comment": "Type of the validator"},
sa_column=Column(
sa.Enum(
ValidatorType,
native_enum=False,
create_constraint=False,
),
nullable=False,
comment="Type of the validator",
),
)

stage: Stage = Field(
nullable=False,
sa_column_kwargs={"comment": "Stage at which the validator is applied"},
sa_column=Column(
sa.Enum(
Stage,
native_enum=False,
create_constraint=False,
),
nullable=False,
comment="Stage at which the validator is applied",
),
)

on_fail_action: GuardrailOnFail = Field(
default=GuardrailOnFail.Fix,
nullable=False,
sa_column_kwargs={"comment": "Action to take when the validator fails"},
sa_column=Column(
sa.Enum(
GuardrailOnFail,
native_enum=False,
create_constraint=False,
),
nullable=False,
comment="Action to take when the validator fails",
),
)

config: dict[str, Any] = SQLField(
Expand Down
48 changes: 47 additions & 1 deletion backend/app/schemas/guardrail_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated, List, Optional, Union
from uuid import UUID

from pydantic import ConfigDict
from pydantic import ConfigDict, model_validator
from sqlmodel import Field, SQLModel

# todo this could be improved by having some auto-discovery mechanism inside
Expand Down Expand Up @@ -37,6 +37,52 @@ class GuardrailRequest(SQLModel):
input: str
validators: List[ValidatorConfigItem]

@model_validator(mode="before")
@classmethod
def normalize_validators_from_config_api(cls, data):
"""
Accept validator payloads coming from validator-config endpoints and
map them into runtime validator-config shape expected by Guardrails.
"""
if not isinstance(data, dict):
return data

validators = data.get("validators")
if not isinstance(validators, list):
return data

normalized_payload = dict(data)
normalized_validators = []

drop_fields = {
"id",
"organization_id",
"project_id",
"stage",
"is_enabled",
"created_at",
"updated_at",
}

for validator in validators:
if not isinstance(validator, dict):
normalized_validators.append(validator)
continue

normalized_validator = {
key: value
for key, value in validator.items()
if key not in drop_fields and key != "on_fail_action"
}

if "on_fail" not in normalized_validator and "on_fail_action" in validator:
normalized_validator["on_fail"] = validator["on_fail_action"]

normalized_validators.append(normalized_validator)

normalized_payload["validators"] = normalized_validators
return normalized_payload


class GuardrailResponse(SQLModel):
response_id: UUID
Expand Down
Loading