Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: Kaapi_Guardrail CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
checks:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:16
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: ai_platform_test
ports:
- 5432:5432
options: --health-cmd "pg_isready -U postgres" --health-interval 10s --health-timeout 5s --health-retries 5

strategy:
matrix:
python-version: ["3.12"]
redis-version: [6]

steps:
- uses: actions/checkout@v6

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

- name: Making env file
run: |
cp .env.test.example .env
cp .env.test.example .env.test

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.4.15"
enable-cache: true

- name: Install dependencies
run: uv sync
working-directory: backend

- name: Activate virtual environment and run Alembic migrations
run: |
source .venv/bin/activate
alembic upgrade head
working-directory: backend

- name: Run pre-commit
run: |
source .venv/bin/activate
uv run pre-commit run --all-files
working-directory: backend

- name: Run tests
run: uv run bash scripts/tests-start.sh "Coverage for ${{ github.sha }}"
working-directory: backend

- name: Upload coverage reports to codecov
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

- name: Check coverage percentage
run: |
source .venv/bin/activate
coverage report --fail-under=70
working-directory: backend
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-toml
- id: check-yaml
args:
- --unsafe
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
2 changes: 2 additions & 0 deletions backend/app/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
# my_important_option = config.get_main_option("my_important_option")
# ... etc.


def get_url():
return str(settings.SQLALCHEMY_DATABASE_URI)


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.

Expand Down
31 changes: 18 additions & 13 deletions backend/app/alembic/versions/001_added_request_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,32 @@


# revision identifiers, used by Alembic.
revision: str = '001'
revision: str = "001"
down_revision: str | None = None
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table('request_log',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('request_id', sa.Uuid(), nullable=False),
sa.Column('response_id', sa.Uuid(), nullable=True),
sa.Column('status', sa.Enum('PROCESSING','SUCCESS', 'ERROR', 'WARNING', name='requeststatus'), nullable=False, default='PROCESSING'),
sa.Column('request_text', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('response_text', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"request_log",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("request_id", sa.Uuid(), nullable=False),
sa.Column("response_id", sa.Uuid(), nullable=True),
sa.Column(
"status",
sa.Enum("PROCESSING", "SUCCESS", "ERROR", "WARNING", name="requeststatus"),
nullable=False,
default="PROCESSING",
),
sa.Column("request_text", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("response_text", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("inserted_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
op.drop_table('request_log')
op.drop_table("request_log")
# todo : drop requeststatus enum type

37 changes: 22 additions & 15 deletions backend/app/alembic/versions/002_added_validator_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,33 @@


# revision identifiers, used by Alembic.
revision: str = '002'
down_revision: str = '001'
revision: str = "002"
down_revision: str = "001"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table('validator_log',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('request_id', sa.Uuid(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('input', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('output', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('outcome', sa.Enum('PASS', 'FAIL', name='validatoroutcome'), nullable=False),
sa.Column('inserted_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['request_id'], ['request_log.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_table(
"validator_log",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("request_id", sa.Uuid(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("input", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("output", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column(
"outcome", sa.Enum("PASS", "FAIL", name="validatoroutcome"), nullable=False
),
sa.Column("inserted_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["request_id"],
["request_log.id"],
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
op.drop_table('validator_log')
op.drop_table("validator_log")
42 changes: 25 additions & 17 deletions backend/app/alembic/versions/003_added_validator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,47 @@
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision: str = '003'
down_revision: str = '002'
revision: str = "003"
down_revision: str = "002"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table('validator_config',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('organization_id', sa.Integer(), nullable=False),
sa.Column('project_id', sa.Integer(), nullable=False),
sa.Column('type', sa.String(), nullable=False),
sa.Column('stage', sa.String(), nullable=False),
sa.Column('on_fail_action', sa.String(), nullable=False),
op.create_table(
"validator_config",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("project_id", sa.Integer(), nullable=False),
sa.Column("type", sa.String(), nullable=False),
sa.Column("stage", sa.String(), nullable=False),
sa.Column("on_fail_action", sa.String(), nullable=False),
sa.Column(
"config",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default=sa.true()),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),

sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('organization_id', 'project_id', 'type', 'stage', name='uq_validator_identity')
sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default=sa.true()),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"organization_id",
"project_id",
"type",
"stage",
name="uq_validator_identity",
),
)

op.create_index("idx_validator_organization", "validator_config", ["organization_id"])
op.create_index(
"idx_validator_organization", "validator_config", ["organization_id"]
)
op.create_index("idx_validator_project", "validator_config", ["project_id"])
op.create_index("idx_validator_type", "validator_config", ["type"])
op.create_index("idx_validator_stage", "validator_config", ["stage"])


def downgrade() -> None:
op.drop_table('validator_config')
op.drop_table("validator_config")
4 changes: 4 additions & 0 deletions backend/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from app.core.config import settings
from app.core.db import engine


def get_db() -> Generator[Session, None, None]:
with Session(engine) as session:
yield session


SessionDep = Annotated[Session, Depends(get_db)]
security = HTTPBearer(auto_error=False)


def verify_bearer_token(
credentials: Annotated[
HTTPAuthorizationCredentials | None,
Expand All @@ -37,4 +40,5 @@ def verify_bearer_token(

return True


AuthDep = Annotated[bool, Depends(verify_bearer_token)]
43 changes: 27 additions & 16 deletions backend/app/api/routes/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from app.crud.request_log import RequestLogCrud
from app.crud.validator_log import ValidatorLogCrud
from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse
from app.models.logging.request_log import RequestLogUpdate, RequestStatus
from app.models.logging.request_log import RequestLogUpdate, RequestStatus
from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome
from app.utils import APIResponse

router = APIRouter(prefix="/guardrails", tags=["guardrails"])


@router.post(
"/",
response_model=APIResponse[GuardrailResponse],
response_model_exclude_none=True)
"/", response_model=APIResponse[GuardrailResponse], response_model_exclude_none=True
)
async def run_guardrails(
payload: GuardrailRequest,
session: SessionDep,
Expand All @@ -35,16 +35,17 @@ async def run_guardrails(
except ValueError:
return APIResponse.failure_response(error="Invalid request_id")

request_log = request_log_crud.create(request_id, input_text=payload.input)
request_log = request_log_crud.create(request_id, input_text=payload.input)
return await _validate_with_guard(
payload.input,
payload.validators,
request_log_crud,
request_log.id,
validator_log_crud,
suppress_pass_logs
suppress_pass_logs,
)


@router.get("/")
async def list_validators(_: AuthDep):
"""
Expand All @@ -57,10 +58,12 @@ async def list_validators(_: AuthDep):
try:
schema = model.model_json_schema()
validator_type = schema["properties"]["type"]["const"]
validators.append({
"type": validator_type,
"config": schema,
})
validators.append(
{
"type": validator_type,
"config": schema,
}
)

except (KeyError, TypeError) as e:
return APIResponse.failure_response(
Expand All @@ -69,6 +72,7 @@ async def list_validators(_: AuthDep):

return {"validators": validators}


async def _validate_with_guard(
data: str,
validators: list,
Expand All @@ -84,7 +88,7 @@ async def _validate_with_guard(
This function treats validation failures as first-class outcomes (not exceptions),
while still safely handling unexpected runtime errors.
"""
response_id = uuid.uuid4()
response_id = uuid.uuid4()
guard: Guard | None = None

def _finalize(
Expand Down Expand Up @@ -115,11 +119,12 @@ def _finalize(
)

if guard is not None:
add_validator_logs(guard, request_log_id, validator_log_crud, suppress_pass_logs)
add_validator_logs(
guard, request_log_id, validator_log_crud, suppress_pass_logs
)

rephrase_needed = (
validated_output is not None
and validated_output.startswith(REPHRASE_ON_FAIL_PREFIX)
rephrase_needed = validated_output is not None and validated_output.startswith(
REPHRASE_ON_FAIL_PREFIX
)

response_model = GuardrailResponse(
Expand Down Expand Up @@ -160,7 +165,13 @@ def _finalize(
error_message=str(exc),
)

def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, suppress_pass_logs: bool = False):

def add_validator_logs(
guard: Guard,
request_log_id: UUID,
validator_log_crud: ValidatorLogCrud,
suppress_pass_logs: bool = False,
):
history = getattr(guard, "history", None)
if not history:
return
Expand Down
Loading