Skip to content

feat: Initial migration for Workspaces and pipeline step #600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions migrations/versions/5c2f3eee5f90_introduce_workspaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""introduce workspaces

Revision ID: 5c2f3eee5f90
Revises: 30d0144e1a50
Create Date: 2025-01-15 19:27:08.230296

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "5c2f3eee5f90"
down_revision: Union[str, None] = "30d0144e1a50"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Workspaces table
op.execute(
"""
CREATE TABLE workspaces (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
name TEXT NOT NULL,
UNIQUE (name)
);
"""
)
op.execute("INSERT INTO workspaces (id, name) VALUES ('1', 'default');")
# Sessions table
op.execute(
"""
CREATE TABLE sessions (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
active_workspace_id TEXT NOT NULL,
last_update DATETIME NOT NULL,
FOREIGN KEY (active_workspace_id) REFERENCES workspaces(id)
);
"""
)
# Alter table prompts
op.execute("ALTER TABLE prompts ADD COLUMN workspace_id TEXT REFERENCES workspaces(id);")
op.execute("UPDATE prompts SET workspace_id = '1';")
# Create index for workspace_id
op.execute("CREATE INDEX idx_prompts_workspace_id ON prompts (workspace_id);")
# Create index for session_id
op.execute("CREATE INDEX idx_sessions_workspace_id ON sessions (active_workspace_id);")


def downgrade() -> None:
# Drop the index for workspace_id
op.execute("DROP INDEX IF EXISTS idx_prompts_workspace_id;")
op.execute("DROP INDEX IF EXISTS idx_sessions_workspace_id;")
# Remove the workspace_id column from prompts table
op.execute("ALTER TABLE prompts DROP COLUMN workspace_id;")
# Drop the sessions table
op.execute("DROP TABLE IF EXISTS sessions;")
# Drop the workspaces table
op.execute("DROP TABLE IF EXISTS workspaces;")
2 changes: 2 additions & 0 deletions src/codegate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_VERSION = "dev"
_DESC = "CodeGate - A Generative AI security gateway."


def __get_version_and_description() -> tuple[str, str]:
try:
version = metadata.version("codegate")
Expand All @@ -19,6 +20,7 @@ def __get_version_and_description() -> tuple[str, str]:
description = _DESC
return version, description


__version__, __description__ = __get_version_and_description()

__all__ = ["Config", "ConfigurationError", "LogFormat", "LogLevel", "setup_logging"]
Expand Down
3 changes: 2 additions & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from codegate.ca.codegate_ca import CertificateAuthority
from codegate.codegate_logging import LogFormat, LogLevel, setup_logging
from codegate.config import Config, ConfigurationError
from codegate.db.connection import init_db_sync
from codegate.db.connection import init_db_sync, init_session_if_not_exists
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.provider import CopilotProvider
Expand Down Expand Up @@ -307,6 +307,7 @@ def serve(
logger = structlog.get_logger("codegate").bind(origin="cli")

init_db_sync(cfg.db_path)
init_session_if_not_exists(cfg.db_path)

# Check certificates and create CA if necessary
logger.info("Checking certificates and creating CA if needed")
Expand Down
145 changes: 136 additions & 9 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import asyncio
import json
import uuid
from pathlib import Path
from typing import List, Optional, Type

import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel
from sqlalchemy import TextClause, text
from pydantic import BaseModel, ValidationError
from sqlalchemy import CursorResult, TextClause, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import create_async_engine

from codegate.db.fim_cache import FimCache
from codegate.db.models import (
ActiveWorkspace,
Alert,
GetAlertsWithPromptAndOutputRow,
GetPromptWithOutputsRow,
Output,
Prompt,
Session,
Workspace,
WorkspaceActive,
)
from codegate.pipeline.base import PipelineContext

Expand Down Expand Up @@ -75,10 +80,14 @@ async def _execute_update_pydantic_model(
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
if prompt_params is None:
return None
# Get the active workspace to store the request
active_workspace = await DbReader().get_active_workspace()
workspace_id = active_workspace.id if active_workspace else "1"
prompt_params.workspace_id = workspace_id
sql = text(
"""
INSERT INTO prompts (id, timestamp, provider, request, type)
VALUES (:id, :timestamp, :provider, :request, :type)
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id)
RETURNING *
"""
)
Expand Down Expand Up @@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))

async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
try:
workspace = Workspace(id=str(uuid.uuid4()), name=workspace_name)
except ValidationError as e:
logger.error(f"Failed to create workspace with name: {workspace_name}: {str(e)}")
return None

sql = text(
"""
INSERT INTO workspaces (id, name)
VALUES (:id, :name)
RETURNING *
"""
)
added_workspace = await self._execute_update_pydantic_model(workspace, sql)
return added_workspace

async def update_session(self, session: Session) -> Optional[Session]:
sql = text(
"""
INSERT INTO sessions (id, active_workspace_id, last_update)
VALUES (:id, :active_workspace_id, :last_update)
ON CONFLICT (id) DO UPDATE SET
active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
WHERE id = excluded.id
RETURNING *
"""
)
# We only pass an object to respect the signature of the function
active_session = await self._execute_update_pydantic_model(session, sql)
return active_session


class DbReader(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

async def _dump_result_to_pydantic_model(
self, model_type: Type[BaseModel], result: CursorResult
) -> Optional[List[BaseModel]]:
try:
if not result:
return None
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
return rows
except Exception as e:
logger.error(f"Failed to dump to pydantic model: {model_type}.", error=str(e))
return None

async def _execute_select_pydantic_model(
self, model_type: Type[BaseModel], sql_command: TextClause
) -> Optional[BaseModel]:
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command)
if not result:
return None
rows = [model_type(**row._asdict()) for row in result.fetchall() if row]
return rows
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model: {model_type}.", error=str(e))
return None

async def _exec_select_conditions_to_pydantic(
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
) -> Optional[List[BaseModel]]:
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(sql_command, conditions)
return await self._dump_result_to_pydantic_model(model_type, result)
except Exception as e:
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
return None

async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
sql = text(
"""
Expand Down Expand Up @@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
prompts = await self._execute_select_pydantic_model(GetAlertsWithPromptAndOutputRow, sql)
return prompts

async def get_workspaces(self) -> List[WorkspaceActive]:
sql = text(
"""
SELECT
w.id, w.name, s.active_workspace_id
FROM workspaces w
LEFT JOIN sessions s ON w.id = s.active_workspace_id
"""
)
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
return workspaces

async def get_workspace_by_name(self, name: str) -> List[Workspace]:
sql = text(
"""
SELECT
id, name
FROM workspaces
WHERE name = :name
"""
)
conditions = {"name": name}
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
return workspaces[0] if workspaces else None

async def get_sessions(self) -> List[Session]:
sql = text(
"""
SELECT
id, active_workspace_id, last_update
FROM sessions
"""
)
sessions = await self._execute_select_pydantic_model(Session, sql)
return sessions

async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
sql = text(
"""
SELECT
w.id, w.name, s.id as session_id, s.last_update
FROM sessions s
INNER JOIN workspaces w ON w.id = s.active_workspace_id
"""
)
active_workspace = await self._execute_select_pydantic_model(ActiveWorkspace, sql)
return active_workspace[0] if active_workspace else None


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand All @@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None):
logger.info("DB initialized successfully.")


def init_session_if_not_exists(db_path: Optional[str] = None):
import datetime

db_reader = DbReader(db_path)
sessions = asyncio.run(db_reader.get_sessions())
# If there are no sessions, create a new one
# TODO: For the moment there's a single session. If it already exists, we don't create a new one
if not sessions:
session = Session(
id=str(uuid.uuid4()),
active_workspace_id="1",
last_update=datetime.datetime.now(datetime.timezone.utc),
)
db_recorder = DbRecorder(db_path)
asyncio.run(db_recorder.update_session(session))
logger.info("Session in DB initialized successfully.")


if __name__ == "__main__":
init_db_sync()
Loading
Loading