Skip to content

Commit a7cd9de

Browse files
feat: Let the user add their own system prompts
Related: #454 This PR is not ready yet. For the moment it adds the system prompts to DB and associates it to a workspace. It's missing to use the system prompt and actually send it to the LLM
1 parent 9f20ec0 commit a7cd9de

File tree

5 files changed

+82
-4
lines changed

5 files changed

+82
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""add_workspace_system_prompt
2+
3+
Revision ID: a692c8b52308
4+
Revises: 5c2f3eee5f90
5+
Create Date: 2025-01-17 16:33:58.464223
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = 'a692c8b52308'
16+
down_revision: Union[str, None] = '5c2f3eee5f90'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
# Add column to workspaces table
23+
op.execute("ALTER TABLE workspaces ADD COLUMN system_prompt TEXT DEFAULT NULL;")
24+
25+
26+
def downgrade() -> None:
27+
op.execute("ALTER TABLE workspaces DROP COLUMN system_prompt;")

src/codegate/db/connection.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,20 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
272272
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
273273
return added_workspace
274274

275+
async def update_workspace(self, workspace: Workspace) -> Optional[Workspace]:
276+
sql = text(
277+
"""
278+
UPDATE workspaces SET
279+
name = :name,
280+
system_prompt = :system_prompt
281+
WHERE id = :id
282+
RETURNING *
283+
"""
284+
)
285+
# We only pass an object to respect the signature of the function
286+
updated_workspace = await self._execute_update_pydantic_model(workspace, sql)
287+
return updated_workspace
288+
275289
async def update_session(self, session: Session) -> Optional[Session]:
276290
sql = text(
277291
"""
@@ -382,11 +396,11 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
382396
workspaces = await self._execute_select_pydantic_model(WorkspaceActive, sql)
383397
return workspaces
384398

385-
async def get_workspace_by_name(self, name: str) -> List[Workspace]:
399+
async def get_workspace_by_name(self, name: str) -> Optional[Workspace]:
386400
sql = text(
387401
"""
388402
SELECT
389-
id, name
403+
id, name, system_prompt
390404
FROM workspaces
391405
WHERE name = :name
392406
"""

src/codegate/db/models.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,18 @@ class Setting(BaseModel):
4343
class Workspace(BaseModel):
4444
id: str
4545
name: str
46+
system_prompt: Optional[str]
4647

4748
@field_validator("name", mode="plain")
4849
@classmethod
49-
def name_must_be_alphanumeric(cls, value):
50+
def validate_name(cls, value):
5051
if not re.match(r"^[a-zA-Z0-9_-]+$", value):
5152
raise ValueError("name must be alphanumeric and can only contain _ and -")
53+
# Avoid workspace names that are the same as commands that way we can do stuff like
54+
# `codegate workspace list` and
55+
# `codegate workspace my-ws system-prompt` without any conflicts
56+
elif value in ["list", "add", "activate", "system-prompt"]:
57+
raise ValueError("name cannot be the same as a command")
5258
return value
5359

5460

src/codegate/pipeline/cli/commands.py

+14
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ async def _activate_workspace(self, args: List[str]) -> str:
102102
)
103103
return f"Workspace **{workspace_name}** has been activated"
104104

105+
async def _add_system_prompt(self, workspace_name: str, sys_prompt_lst: List[str]):
106+
updated_worksapce = await self.workspace_crud.update_workspace_system_prompt(workspace_name, sys_prompt_lst)
107+
if not updated_worksapce:
108+
return (
109+
f"Workspace system prompt not updated. "
110+
f"Check if the workspace **{workspace_name}** exists"
111+
)
112+
return (
113+
f"Workspace **{updated_worksapce.name}** system prompt "
114+
f"updated to:\n\n```{updated_worksapce.system_prompt}```"
115+
)
116+
105117
async def run(self, args: List[str]) -> str:
106118
if not args:
107119
return "Please provide a command. Use `codegate workspace -h` to see available commands"
@@ -110,6 +122,8 @@ async def run(self, args: List[str]) -> str:
110122
if command_to_execute is not None:
111123
return await command_to_execute(args[1:])
112124
else:
125+
if len(args) >= 2 and args[1] == "system-prompt":
126+
return await self._add_system_prompt(args[0], args[2:])
113127
return "Command not found. Use `codegate workspace -h` to see available commands"
114128

115129
@property

src/codegate/workspaces/crud.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def _is_workspace_active_or_not_exist(
5353
sessions = await self._db_reader.get_sessions()
5454
# The current implementation expects only one active session
5555
if len(sessions) != 1:
56-
raise RuntimeError("Something went wrong. No active session found.")
56+
raise WorkspaceCrudError("Something went wrong. More than one session found.")
5757

5858
session = sessions[0]
5959
if session.active_workspace_id == selected_workspace.id:
@@ -77,3 +77,20 @@ async def activate_workspace(self, workspace_name: str) -> bool:
7777
db_recorder = DbRecorder()
7878
await db_recorder.update_session(session)
7979
return True
80+
81+
async def update_workspace_system_prompt(
82+
self, workspace_name: str, sys_prompt_lst: List[str]
83+
) -> Optional[Workspace]:
84+
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
85+
if not selected_workspace:
86+
return None
87+
88+
system_prompt = " ".join(sys_prompt_lst)
89+
workspace_update = Workspace(
90+
id=selected_workspace.id,
91+
name=selected_workspace.name,
92+
system_prompt=system_prompt,
93+
)
94+
db_recorder = DbRecorder()
95+
updated_workspace = await db_recorder.update_workspace(workspace_update)
96+
return updated_workspace

0 commit comments

Comments
 (0)