Skip to content

Commit a836157

Browse files
Finished functionality to add wrkspace system prompt
1 parent dd9a65e commit a836157

File tree

8 files changed

+119
-61
lines changed

8 files changed

+119
-61
lines changed

migrations/versions/a692c8b52308_add_workspace_system_prompt.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Create Date: 2025-01-17 16:33:58.464223
66
77
"""
8+
89
from typing import Sequence, Union
910

1011
from alembic import op
11-
import sqlalchemy as sa
12-
1312

1413
# revision identifiers, used by Alembic.
15-
revision: str = 'a692c8b52308'
16-
down_revision: Union[str, None] = '5c2f3eee5f90'
14+
revision: str = "a692c8b52308"
15+
down_revision: Union[str, None] = "5c2f3eee5f90"
1716
branch_labels: Union[str, Sequence[str], None] = None
1817
depends_on: Union[str, Sequence[str], None] = None
1918

src/codegate/db/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
436436
sql = text(
437437
"""
438438
SELECT
439-
w.id, w.name, s.id as session_id, s.last_update
439+
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
440440
FROM sessions s
441441
INNER JOIN workspaces w ON w.id = s.active_workspace_id
442442
"""

src/codegate/db/models.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ class Workspace(BaseModel):
5050
def validate_name(cls, value):
5151
if not re.match(r"^[a-zA-Z0-9_-]+$", value):
5252
raise ValueError("name must be alphanumeric and can only contain _ and -")
53-
# Avoid workspace names that are the same as commands that way we can do stuff like
54-
# `codegate workspace list` and
55-
# `codegate workspace my-ws system-prompt` without any conflicts
56-
elif value in ["list", "add", "activate", "system-prompt"]:
57-
raise ValueError("name cannot be the same as a command")
5853
return value
5954

6055

@@ -104,5 +99,6 @@ class WorkspaceActive(BaseModel):
10499
class ActiveWorkspace(BaseModel):
105100
id: str
106101
name: str
102+
system_prompt: Optional[str]
107103
session_id: str
108104
last_update: datetime.datetime

src/codegate/pipeline/cli/commands.py

+43-25
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ async def run(self, args: List[str]) -> str:
3131
@property
3232
def help(self) -> str:
3333
return (
34-
"### CodeGate Version\n\n"
34+
"### CodeGate Version\n"
3535
"Prints the version of CodeGate.\n\n"
36+
"*args*: None\n\n"
3637
"**Usage**: `codegate version`\n\n"
37-
"*args*: None"
3838
)
3939

4040

@@ -46,6 +46,7 @@ def __init__(self):
4646
"list": self._list_workspaces,
4747
"add": self._add_workspace,
4848
"activate": self._activate_workspace,
49+
"system-prompt": self._add_system_prompt,
4950
}
5051

5152
async def _list_workspaces(self, *args: List[str]) -> str:
@@ -66,33 +67,33 @@ async def _add_workspace(self, args: List[str]) -> str:
6667
Add a workspace
6768
"""
6869
if args is None or len(args) == 0:
69-
return "Please provide a name. Use `codegate workspace add your_workspace_name`"
70+
return "Please provide a name. Use `codegate workspace add <workspace_name>`"
7071

7172
new_workspace_name = args[0]
7273
if not new_workspace_name:
73-
return "Please provide a name. Use `codegate workspace add your_workspace_name`"
74+
return "Please provide a name. Use `codegate workspace add <workspace_name>`"
7475

7576
try:
7677
_ = await self.workspace_crud.add_workspace(new_workspace_name)
7778
except ValidationError:
7879
return "Invalid workspace name: It should be alphanumeric and dashes"
7980
except AlreadyExistsError:
80-
return f"Workspace **{new_workspace_name}** already exists"
81+
return f"Workspace `{new_workspace_name}` already exists"
8182
except Exception:
8283
return "An error occurred while adding the workspace"
8384

84-
return f"Workspace **{new_workspace_name}** has been added"
85+
return f"Workspace `{new_workspace_name}` has been added"
8586

8687
async def _activate_workspace(self, args: List[str]) -> str:
8788
"""
8889
Activate a workspace
8990
"""
9091
if args is None or len(args) == 0:
91-
return "Please provide a name. Use `codegate workspace activate workspace_name`"
92+
return "Please provide a name. Use `codegate workspace activate <workspace_name>`"
9293

9394
workspace_name = args[0]
9495
if not workspace_name:
95-
return "Please provide a name. Use `codegate workspace activate workspace_name`"
96+
return "Please provide a name. Use `codegate workspace activate <workspace_name>`"
9697

9798
try:
9899
await self.workspace_crud.activate_workspace(workspace_name)
@@ -104,16 +105,27 @@ async def _activate_workspace(self, args: List[str]) -> str:
104105
return "An error occurred while activating the workspace"
105106
return f"Workspace **{workspace_name}** has been activated"
106107

107-
async def _add_system_prompt(self, workspace_name: str, sys_prompt_lst: List[str]):
108-
updated_worksapce = await self.workspace_crud.update_workspace_system_prompt(workspace_name, sys_prompt_lst)
108+
async def _add_system_prompt(self, args: List[str]):
109+
if len(args) < 2:
110+
return (
111+
"Please provide a workspace name and a system prompt. "
112+
"Use `codegate workspace system-prompt <workspace_name> <system_prompt>`"
113+
)
114+
115+
workspace_name = args[0]
116+
sys_prompt_lst = args[1:]
117+
118+
updated_worksapce = await self.workspace_crud.update_workspace_system_prompt(
119+
workspace_name, sys_prompt_lst
120+
)
109121
if not updated_worksapce:
110122
return (
111123
f"Workspace system prompt not updated. "
112-
f"Check if the workspace **{workspace_name}** exists"
124+
f"Check if the workspace `{workspace_name}` exists"
113125
)
114126
return (
115-
f"Workspace **{updated_worksapce.name}** system prompt "
116-
f"updated to:\n\n```{updated_worksapce.system_prompt}```"
127+
f"Workspace `{updated_worksapce.name}` system prompt "
128+
f"updated to:\n```\n{updated_worksapce.system_prompt}\n```"
117129
)
118130

119131
async def run(self, args: List[str]) -> str:
@@ -124,23 +136,29 @@ async def run(self, args: List[str]) -> str:
124136
if command_to_execute is not None:
125137
return await command_to_execute(args[1:])
126138
else:
127-
if len(args) >= 2 and args[1] == "system-prompt":
128-
return await self._add_system_prompt(args[0], args[2:])
129139
return "Command not found. Use `codegate workspace -h` to see available commands"
130140

131141
@property
132142
def help(self) -> str:
133143
return (
134-
"### CodeGate Workspace\n\n"
144+
"### CodeGate Workspace\n"
135145
"Manage workspaces.\n\n"
136146
"**Usage**: `codegate workspace <command> [args]`\n\n"
137-
"Available commands:\n\n"
138-
"- `list`: List all workspaces\n\n"
139-
" - *args*: None\n\n"
140-
"- `add`: Add a workspace\n\n"
141-
" - *args*:\n\n"
142-
" - `workspace_name`\n\n"
143-
"- `activate`: Activate a workspace\n\n"
144-
" - *args*:\n\n"
145-
" - `workspace_name`"
147+
"Available commands:\n"
148+
"- `list`: List all workspaces\n"
149+
" - *args*: None\n"
150+
" - **Usage**: `codegate workspace list`\n"
151+
"- `add`: Add a workspace\n"
152+
" - *args*:\n"
153+
" - `workspace_name`\n"
154+
" - **Usage**: `codegate workspace add <workspace_name>`\n"
155+
"- `activate`: Activate a workspace\n"
156+
" - *args*:\n"
157+
" - `workspace_name`\n"
158+
" - **Usage**: `codegate workspace activate <workspace_name>`\n"
159+
"- `system-prompt`: Modify the system-prompt of a workspace\n"
160+
" - *args*:\n"
161+
" - `workspace_name`\n"
162+
" - `system_prompt`\n"
163+
" - **Usage**: `codegate workspace system-prompt <workspace_name> <system_prompt>`\n"
146164
)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import json
1+
from typing import Optional
22

33
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
44

@@ -7,6 +7,7 @@
77
PipelineResult,
88
PipelineStep,
99
)
10+
from codegate.workspaces.crud import WorkspaceCrud
1011

1112

1213
class SystemPrompt(PipelineStep):
@@ -16,7 +17,7 @@ class SystemPrompt(PipelineStep):
1617
"""
1718

1819
def __init__(self, system_prompt: str):
19-
self._system_message = ChatCompletionSystemMessage(content=system_prompt, role="system")
20+
self.codegate_system_prompt = system_prompt
2021

2122
@property
2223
def name(self) -> str:
@@ -25,6 +26,44 @@ def name(self) -> str:
2526
"""
2627
return "system-prompt"
2728

29+
async def _get_workspace_system_prompt(self) -> str:
30+
wksp_crud = WorkspaceCrud()
31+
workspace = await wksp_crud.get_active_workspace()
32+
if not workspace:
33+
return ""
34+
35+
return workspace.system_prompt
36+
37+
async def _construct_system_prompt(
38+
self,
39+
wrksp_sys_prompt: str,
40+
req_sys_prompt: Optional[str],
41+
should_add_codegate_sys_prompt: bool,
42+
) -> ChatCompletionSystemMessage:
43+
44+
def _start_or_append(existing_prompt: str, new_prompt: str) -> str:
45+
if existing_prompt:
46+
return existing_prompt + "\n\nHere are additional instructions:\n\n" + new_prompt
47+
return new_prompt
48+
49+
system_prompt = ""
50+
# Add codegate system prompt if secrets or bad packages are found at the beginning
51+
if should_add_codegate_sys_prompt:
52+
system_prompt = _start_or_append(system_prompt, self.codegate_system_prompt)
53+
54+
# Add workspace system prompt if present
55+
if wrksp_sys_prompt:
56+
system_prompt = _start_or_append(system_prompt, wrksp_sys_prompt)
57+
58+
# Add request system prompt if present
59+
if req_sys_prompt and "codegate" not in req_sys_prompt.lower():
60+
system_prompt = _start_or_append(system_prompt, req_sys_prompt)
61+
62+
return system_prompt
63+
64+
async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool:
65+
return context.secrets_found or context.bad_packages_found
66+
2867
async def process(
2968
self, request: ChatCompletionRequest, context: PipelineContext
3069
) -> PipelineResult:
@@ -33,32 +72,35 @@ async def process(
3372
to the existing system prompt
3473
"""
3574

36-
# Nothing to do if no secrets or bad_packages are found
37-
if not (context.secrets_found or context.bad_packages_found):
75+
wrksp_sys_prompt = await self._get_workspace_system_prompt()
76+
should_add_codegate_sys_prompt = await self._should_add_codegate_system_prompt(context)
77+
78+
# Nothing to do if no secrets or bad_packages are found and we don't have a workspace
79+
# system prompt
80+
if not should_add_codegate_sys_prompt and not wrksp_sys_prompt:
3881
return PipelineResult(request=request, context=context)
3982

4083
new_request = request.copy()
4184

4285
if "messages" not in new_request:
4386
new_request["messages"] = []
4487

45-
request_system_message = None
88+
request_system_message = {}
4689
for message in new_request["messages"]:
4790
if message["role"] == "system":
4891
request_system_message = message
92+
req_sys_prompt = request_system_message.get("content")
4993

50-
if request_system_message is None:
51-
# Add system message
52-
context.add_alert(self.name, trigger_string=json.dumps(self._system_message))
53-
new_request["messages"].insert(0, self._system_message)
54-
elif "codegate" not in request_system_message["content"].lower():
55-
# Prepend to the system message
56-
prepended_message = (
57-
self._system_message["content"]
58-
+ "\n Here are additional instructions. \n "
59-
+ request_system_message["content"]
60-
)
61-
context.add_alert(self.name, trigger_string=prepended_message)
62-
request_system_message["content"] = prepended_message
94+
system_prompt = await self._construct_system_prompt(
95+
wrksp_sys_prompt, req_sys_prompt, should_add_codegate_sys_prompt
96+
)
97+
context.add_alert(self.name, trigger_string=system_prompt)
98+
if not request_system_message:
99+
# Insert the system prompt at the beginning of the messages
100+
sytem_message = ChatCompletionSystemMessage(content=system_prompt, role="system")
101+
new_request["messages"].insert(0, sytem_message)
102+
else:
103+
# Update the existing system prompt
104+
request_system_message["content"] = system_prompt
63105

64106
return PipelineResult(request=new_request, context=context)

src/codegate/workspaces/crud.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ async def activate_workspace(self, workspace_name: str):
8484
return
8585

8686
async def update_workspace_system_prompt(
87-
self, workspace_name: str, sys_prompt_lst: List[str]
88-
) -> Optional[Workspace]:
87+
self, workspace_name: str, sys_prompt_lst: List[str]
88+
) -> Optional[Workspace]:
8989
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
9090
if not selected_workspace:
9191
return None

tests/pipeline/system_prompt/test_system_prompt.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import AsyncMock, Mock
22

33
import pytest
44
from litellm.types.llms.openai import ChatCompletionRequest
@@ -14,7 +14,7 @@ def test_init_with_system_message(self):
1414
"""
1515
test_message = "Test system prompt"
1616
step = SystemPrompt(system_prompt=test_message)
17-
assert step._system_message["content"] == test_message
17+
assert step.codegate_system_prompt == test_message
1818

1919
@pytest.mark.asyncio
2020
async def test_process_system_prompt_insertion(self):
@@ -29,6 +29,7 @@ async def test_process_system_prompt_insertion(self):
2929
# Create system prompt step
3030
system_prompt = "Security analysis system prompt"
3131
step = SystemPrompt(system_prompt=system_prompt)
32+
step._get_workspace_system_prompt = AsyncMock(return_value="")
3233

3334
# Mock the get_last_user_message method
3435
step.get_last_user_message = Mock(return_value=(user_message, 0))
@@ -62,6 +63,7 @@ async def test_process_system_prompt_update(self):
6263
# Create system prompt step
6364
system_prompt = "Security analysis system prompt"
6465
step = SystemPrompt(system_prompt=system_prompt)
66+
step._get_workspace_system_prompt = AsyncMock(return_value="")
6567

6668
# Mock the get_last_user_message method
6769
step.get_last_user_message = Mock(return_value=(user_message, 0))
@@ -74,7 +76,7 @@ async def test_process_system_prompt_update(self):
7476
assert result.request["messages"][0]["role"] == "system"
7577
assert (
7678
result.request["messages"][0]["content"]
77-
== system_prompt + "\n Here are additional instructions. \n " + request_system_message
79+
== system_prompt + "\n\nHere are additional instructions:\n\n" + request_system_message
7880
)
7981
assert result.request["messages"][1]["role"] == "user"
8082
assert result.request["messages"][1]["content"] == user_message
@@ -96,6 +98,7 @@ async def test_edge_cases(self, edge_case):
9698

9799
system_prompt = "Security edge case prompt"
98100
step = SystemPrompt(system_prompt=system_prompt)
101+
step._get_workspace_system_prompt = AsyncMock(return_value="")
99102

100103
# Mock get_last_user_message to return None
101104
step.get_last_user_message = Mock(return_value=None)

tests/pipeline/workspace/test_workspace.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ async def test_list_workspaces(mock_workspaces, expected_output):
5454
"args, existing_workspaces, expected_message",
5555
[
5656
# Case 1: No workspace name provided
57-
([], [], "Please provide a name. Use `codegate workspace add your_workspace_name`"),
57+
([], [], "Please provide a name. Use `codegate workspace add <workspace_name>`"),
5858
# Case 2: Workspace name is empty string
59-
([""], [], "Please provide a name. Use `codegate workspace add your_workspace_name`"),
59+
([""], [], "Please provide a name. Use `codegate workspace add <workspace_name>`"),
6060
# Case 3: Successful add
61-
(["myworkspace"], [], "Workspace **myworkspace** has been added"),
61+
(["myworkspace"], [], "Workspace `myworkspace` has been added"),
6262
],
6363
)
6464
async def test_add_workspaces(args, existing_workspaces, expected_message):

0 commit comments

Comments
 (0)