Skip to content

Commit 7d9ae48

Browse files
authored
feat: add conversation id header to stored interactions (#91)
1 parent 698e37c commit 7d9ae48

File tree

10 files changed

+194
-424
lines changed

10 files changed

+194
-424
lines changed

python/pyproject.toml

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ dependencies = [
4343
"pydantic-settings>=2.0.0",
4444
"pydantic>=2.0.0",
4545
"pyinstrument>=5.1.1",
46-
"pytest-asyncio>=1.0.0",
47-
"pytest>=8.4.1",
46+
4847
"python-dotenv>=1.0.0",
4948
"python-multipart>=0.0.6",
5049
"structlog>=24.0.0",
@@ -55,20 +54,13 @@ dependencies = [
5554
"xai_sdk>=1.3.1",
5655
]
5756

58-
[project.optional-dependencies]
57+
[dependency-groups]
5958
dev = [
60-
"pytest>=8.0.0",
61-
"pytest-asyncio>=0.23.0",
62-
"pytest-cov>=5.0.0",
63-
"pytest-benchmark>=4.0.0",
64-
"pytest-mock>=3.0.0",
65-
"black>=24.0.0",
66-
"ruff>=0.4.0",
67-
"mypy>=1.0.0",
68-
"types-toml>=0.10.0",
69-
"pre-commit>=3.0.0",
70-
"testcontainers[postgres]>=4.0.0",
7159
"nest-asyncio>=1.6.0",
60+
"ty>=0.0.1a15",
61+
"testcontainers[postgres]>=4.13.0",
62+
"pytest-asyncio>=1.0.0",
63+
"pytest>=8.4.1",
7264
]
7365

7466
[project.scripts]
@@ -163,6 +155,3 @@ exclude_lines = [
163155
"raise NotImplementedError",
164156
"if TYPE_CHECKING:",
165157
]
166-
167-
[dependency-groups]
168-
dev = ["nest-asyncio>=1.6.0", "ty>=0.0.1a15"]

python/src/cairo_coder/db/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class UserInteraction(BaseModel):
2020
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
2121
agent_id: str
2222
mcp_mode: bool = False
23+
conversation_id: Optional[str] = None
2324
chat_history: Optional[list[dict[str, Any]]] = None
2425
query: str
2526
generated_answer: Optional[str] = None

python/src/cairo_coder/db/repository.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,19 @@ async def create_user_interaction(interaction: UserInteraction) -> None:
8383
id,
8484
agent_id,
8585
mcp_mode,
86+
conversation_id,
8687
chat_history,
8788
query,
8889
generated_answer,
8990
retrieved_sources,
9091
llm_usage
9192
)
92-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
93+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
9394
""",
9495
interaction.id,
9596
interaction.agent_id,
9697
interaction.mcp_mode,
98+
interaction.conversation_id,
9799
_serialize_json_field(interaction.chat_history),
98100
interaction.query,
99101
interaction.generated_answer,
@@ -112,6 +114,7 @@ async def get_interactions(
112114
limit: int,
113115
offset: int,
114116
query_text: str | None = None,
117+
conversation_id: str | None = None,
115118
) -> tuple[list[dict[str, Any]], int]:
116119
"""Fetch paginated interactions matching the supplied filters.
117120
@@ -139,6 +142,10 @@ async def get_interactions(
139142
params.append(f"%{query_text}%")
140143
filters.append(f"query ILIKE ${len(params)}")
141144

145+
if conversation_id:
146+
params.append(conversation_id)
147+
filters.append(f"conversation_id = ${len(params)}")
148+
142149
where_clause = "WHERE " + " AND ".join(filters) if filters else ""
143150

144151
count_query = f"""
@@ -152,7 +159,7 @@ async def get_interactions(
152159
limit_placeholder = len(params) - 1
153160
offset_placeholder = len(params)
154161
data_query = f"""
155-
SELECT id, created_at, agent_id, query, chat_history, generated_answer
162+
SELECT id, created_at, agent_id, query, chat_history, generated_answer, conversation_id
156163
FROM user_interactions
157164
{where_clause}
158165
ORDER BY created_at DESC
@@ -192,17 +199,19 @@ async def migrate_user_interaction(interaction: UserInteraction) -> tuple[bool,
192199
created_at,
193200
agent_id,
194201
mcp_mode,
202+
conversation_id,
195203
chat_history,
196204
query,
197205
generated_answer,
198206
retrieved_sources,
199207
llm_usage
200208
)
201-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
209+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
202210
ON CONFLICT (id) DO UPDATE SET
203211
created_at = EXCLUDED.created_at,
204212
agent_id = EXCLUDED.agent_id,
205213
mcp_mode = EXCLUDED.mcp_mode,
214+
conversation_id = EXCLUDED.conversation_id,
206215
chat_history = EXCLUDED.chat_history,
207216
query = EXCLUDED.query,
208217
generated_answer = EXCLUDED.generated_answer,
@@ -214,6 +223,7 @@ async def migrate_user_interaction(interaction: UserInteraction) -> tuple[bool,
214223
interaction.created_at,
215224
interaction.agent_id,
216225
interaction.mcp_mode,
226+
interaction.conversation_id,
217227
_serialize_json_field(interaction.chat_history),
218228
interaction.query,
219229
interaction.generated_answer,

python/src/cairo_coder/db/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ async def execute_schema_scripts() -> None:
7171
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
7272
agent_id VARCHAR(50) NOT NULL,
7373
mcp_mode BOOLEAN NOT NULL DEFAULT FALSE,
74+
conversation_id VARCHAR(100),
7475
chat_history JSONB,
7576
query TEXT NOT NULL,
7677
generated_answer TEXT,
@@ -79,12 +80,21 @@ async def execute_schema_scripts() -> None:
7980
);
8081
"""
8182
)
83+
# Migration: add conversation_id column if it doesn't exist (for existing tables)
84+
await connection.execute(
85+
"""
86+
ALTER TABLE user_interactions
87+
ADD COLUMN IF NOT EXISTS conversation_id VARCHAR(100);
88+
"""
89+
)
8290
await connection.execute(
8391
"""
8492
CREATE INDEX IF NOT EXISTS idx_interactions_created_at
8593
ON user_interactions(created_at);
8694
CREATE INDEX IF NOT EXISTS idx_interactions_agent_id
8795
ON user_interactions(agent_id);
96+
CREATE INDEX IF NOT EXISTS idx_interactions_conversation_id
97+
ON user_interactions(conversation_id);
8898
"""
8999
)
90100
logger.info("Database schema initialized.")

python/src/cairo_coder/server/app.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ async def log_interaction_task(
151151
chat_history: list[Message],
152152
response: ChatCompletionResponse,
153153
agent: RagPipeline,
154+
conversation_id: str | None = None,
154155
) -> None:
155156
"""Background task that persists a user interaction."""
156157
sources_data = [
@@ -167,6 +168,7 @@ async def log_interaction_task(
167168
interaction = UserInteraction(
168169
agent_id=agent_id,
169170
mcp_mode=mcp_mode,
171+
conversation_id=conversation_id,
170172
chat_history=chat_history_dicts,
171173
query=query,
172174
generated_answer=response.choices[0].message.content if response.choices else None,
@@ -183,6 +185,7 @@ async def log_interaction_raw(
183185
chat_history: list[Message],
184186
generated_answer: str | None,
185187
agent: RagPipeline,
188+
conversation_id: str | None = None,
186189
) -> None:
187190
"""Persist a user interaction without constructing a full response object."""
188191
sources_data = [
@@ -198,6 +201,7 @@ async def log_interaction_raw(
198201
interaction = UserInteraction(
199202
agent_id=agent_id,
200203
mcp_mode=mcp_mode,
204+
conversation_id=conversation_id,
201205
chat_history=chat_history_dicts,
202206
query=query,
203207
generated_answer=generated_answer,
@@ -423,6 +427,9 @@ async def _handle_chat_completion(
423427
vector_db: SourceFilteredPgVectorRM | None = None,
424428
):
425429
"""Handle chat completion request."""
430+
# Extract conversation ID from header
431+
conversation_id = req.headers.get("x-conversation-id")
432+
426433
# Convert messages to internal format
427434
messages = []
428435
for msg in request.messages:
@@ -443,7 +450,9 @@ async def _handle_chat_completion(
443450
# Handle streaming vs non-streaming
444451
if request.stream:
445452
return StreamingResponse(
446-
self._stream_chat_completion(agent, query, messages[:-1], mcp_mode, effective_agent_id),
453+
self._stream_chat_completion(
454+
agent, query, messages[:-1], mcp_mode, effective_agent_id, conversation_id
455+
),
447456
media_type="text/event-stream",
448457
headers={
449458
"Cache-Control": "no-cache",
@@ -462,12 +471,19 @@ async def _handle_chat_completion(
462471
chat_history=chat_history,
463472
response=response,
464473
agent=agent,
474+
conversation_id=conversation_id,
465475
)
466476

467477
return response
468478

469479
async def _stream_chat_completion(
470-
self, agent: RagPipeline, query: str, history: list[Message], mcp_mode: bool, agent_id: str
480+
self,
481+
agent: RagPipeline,
482+
query: str,
483+
history: list[Message],
484+
mcp_mode: bool,
485+
agent_id: str,
486+
conversation_id: str | None = None,
471487
) -> AsyncGenerator[str, None]:
472488
"""Stream chat completion response - replicates TypeScript streaming."""
473489
response_id = str(uuid.uuid4())
@@ -580,6 +596,7 @@ async def _stream_chat_completion(
580596
chat_history=history,
581597
generated_answer=final_response,
582598
agent=agent,
599+
conversation_id=conversation_id,
583600
)
584601
except Exception as log_error:
585602
logger.error("Failed to log streaming interaction", error=str(log_error), exc_info=True)

python/src/cairo_coder/server/insights_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class QueryResponse(BaseModel):
2828
query: str
2929
chat_history: list[dict[str, Any]]
3030
output: str | None
31+
conversation_id: str | None = None
3132

3233

3334
class PaginatedQueryResponse(BaseModel):
@@ -45,16 +46,19 @@ async def get_raw_queries(
4546
end_date: datetime | None = None,
4647
agent_id: str | None = None,
4748
query_text: str | None = None,
49+
conversation_id: str | None = None,
4850
limit: int = 100,
4951
offset: int = 0,
5052
) -> PaginatedQueryResponse:
5153
"""Return raw user queries.
5254
5355
If start_date and end_date are not provided, returns the last N queries
5456
ordered by creation time (where N is the limit parameter).
57+
58+
Use conversation_id to filter queries belonging to a specific conversation.
5559
"""
5660
items, total = await get_interactions(
57-
start_date, end_date, agent_id, limit, offset, query_text
61+
start_date, end_date, agent_id, limit, offset, query_text, conversation_id
5862
)
5963
# Map generated_answer to output for API response
6064
responses = [
@@ -65,6 +69,7 @@ async def get_raw_queries(
6569
query=item["query"],
6670
chat_history=item["chat_history"] or [],
6771
output=item.get("generated_answer"),
72+
conversation_id=item.get("conversation_id"),
6873
)
6974
for item in items
7075
]

python/tests/integration/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async def test_db_pool(postgres_container):
182182
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
183183
agent_id VARCHAR(50) NOT NULL,
184184
mcp_mode BOOLEAN NOT NULL DEFAULT FALSE,
185+
conversation_id VARCHAR(100),
185186
chat_history JSONB,
186187
query TEXT NOT NULL,
187188
generated_answer TEXT,
@@ -196,6 +197,8 @@ async def test_db_pool(postgres_container):
196197
ON user_interactions(created_at);
197198
CREATE INDEX IF NOT EXISTS idx_interactions_agent_id
198199
ON user_interactions(agent_id);
200+
CREATE INDEX IF NOT EXISTS idx_interactions_conversation_id
201+
ON user_interactions(conversation_id);
199202
"""
200203
)
201204

python/tests/integration/test_insights_api.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,72 @@ def test_get_queries_without_dates_with_filters(self, client, populated_db_conne
155155
assert data["total"] >= 1
156156
assert all(item["agent_id"] == "cairo-coder" for item in data["items"])
157157

158+
def test_get_queries_with_conversation_id_filter(self, client, db_connection):
159+
"""Test that queries can be filtered by conversation_id."""
160+
import asyncio
161+
import json as _json
162+
import uuid
163+
from datetime import datetime, timedelta, timezone
164+
165+
now = datetime.now(timezone.utc)
166+
conv_id = "test-conversation-123"
167+
168+
# Seed records with and without conversation_id
169+
async def seed():
170+
await db_connection.execute(
171+
"""
172+
INSERT INTO user_interactions (id, created_at, agent_id, mcp_mode, conversation_id, chat_history, query, generated_answer)
173+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8),
174+
($9, $10, $11, $12, $13, $14, $15, $16),
175+
($17, $18, $19, $20, $21, $22, $23, $24)
176+
""",
177+
uuid.uuid4(), now - timedelta(hours=2), "cairo-coder", False, conv_id, _json.dumps([]), "First msg", "Response 1",
178+
uuid.uuid4(), now - timedelta(hours=1), "cairo-coder", False, conv_id, _json.dumps([{"role": "user", "content": "First msg"}]), "Second msg", "Response 2",
179+
uuid.uuid4(), now - timedelta(minutes=30), "cairo-coder", False, None, _json.dumps([]), "Other msg", "Other response",
180+
)
181+
182+
asyncio.get_event_loop().run_until_complete(seed())
183+
184+
# Filter by conversation_id
185+
resp = client.get(
186+
"/v1/insights/queries",
187+
params={"conversation_id": conv_id, "limit": 100, "offset": 0},
188+
)
189+
assert resp.status_code == 200
190+
data = resp.json()
191+
assert data["total"] == 2
192+
assert all(item["conversation_id"] == conv_id for item in data["items"])
193+
194+
def test_get_queries_returns_conversation_id(self, client, db_connection):
195+
"""Test that conversation_id is included in the response."""
196+
import asyncio
197+
import json as _json
198+
import uuid
199+
from datetime import datetime, timezone
200+
201+
now = datetime.now(timezone.utc)
202+
conv_id = "response-test-conv-456"
203+
204+
async def seed():
205+
await db_connection.execute(
206+
"""
207+
INSERT INTO user_interactions (id, created_at, agent_id, mcp_mode, conversation_id, chat_history, query, generated_answer)
208+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
209+
""",
210+
uuid.uuid4(), now, "cairo-coder", False, conv_id, _json.dumps([]), "Test query", "Test response",
211+
)
212+
213+
asyncio.get_event_loop().run_until_complete(seed())
214+
215+
resp = client.get("/v1/insights/queries", params={"limit": 100, "offset": 0})
216+
assert resp.status_code == 200
217+
data = resp.json()
218+
219+
# Find our seeded record and verify conversation_id is present
220+
matching = [item for item in data["items"] if item.get("conversation_id") == conv_id]
221+
assert len(matching) == 1
222+
assert matching[0]["conversation_id"] == conv_id
223+
158224

159225
class TestDataIngestion:
160226
async def test_chat_completion_logs_interaction_to_db(self, client, test_db_pool):

0 commit comments

Comments
 (0)