-
Notifications
You must be signed in to change notification settings - Fork 10
Dalgo Migration: Add threads endpoints (thread creation+polling) #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
df5fefd
endpoints and test cases
nishika26 7040fc5
Merge branch 'main' into feature/dalgo_migration
nishika26 02b3277
Merge branch 'main' into feature/dalgo_migration
nishika26 b5a3a6c
test cases
nishika26 1c929df
test cases
nishika26 58b0f47
initial test cases
nishika26 3d2f669
Merge branch 'main' into feature/dalgo_migration
nishika26 8954022
added test cases
nishika26 521dfe0
added test cases
nishika26 1d9013b
added test cases
nishika26 c4703e3
added test cases
nishika26 6bfe8bc
added test cases
nishika26 63234bc
added test cases
nishika26 6b61868
added test cases
nishika26 b72abf3
added test cases
nishika26 d2ae0af
added test cases
nishika26 6c93999
Merge branch 'main' into feature/dalgo_migration
nishika26 e2f6733
alembic fix
nishika26 9db16b6
changes
nishika26 79c6191
test cases failure
nishika26 3b52721
clean db after test
nishika26 512c2ea
clean db after test
nishika26 a74aada
test cases
nishika26 7bb4498
removing sqlite session
nishika26 d3e283a
Merge branch 'main' into feature/dalgo_migration
nishika26 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
backend/app/alembic/versions/9baa692f9a5d_add_threads_table.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| """add threads table | ||
|
|
||
| Revision ID: 9baa692f9a5d | ||
| Revises: 543f97951bd0 | ||
| Create Date: 2025-05-05 23:25:37.195415 | ||
|
|
||
| """ | ||
| from alembic import op | ||
| import sqlalchemy as sa | ||
| import sqlmodel.sql.sqltypes | ||
|
|
||
|
|
||
| # revision identifiers, used by Alembic. | ||
| revision = "9baa692f9a5d" | ||
| down_revision = "543f97951bd0" | ||
| branch_labels = None | ||
| depends_on = None | ||
|
|
||
|
|
||
| def upgrade(): | ||
| op.create_table( | ||
| "threadresponse", | ||
| sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
nishika26 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sa.Column("message", sqlmodel.sql.sqltypes.AutoString(), nullable=True), | ||
nishika26 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sa.Column("question", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
| sa.Column("created_at", sa.DateTime(), nullable=False), | ||
nishika26 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sa.Column("updated_at", sa.DateTime(), nullable=False), | ||
| sa.PrimaryKeyConstraint("thread_id"), | ||
nishika26 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
|
|
||
| def downgrade(): | ||
| op.drop_table("threadresponse") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from sqlmodel import Session | ||
| from datetime import datetime | ||
| from app.models import ThreadResponse | ||
|
|
||
|
|
||
| def upsert_thread_result( | ||
| session: Session, thread_id: str, question: str, message: str | None | ||
| ): | ||
| session.merge( | ||
| ThreadResponse( | ||
| thread_id=thread_id, | ||
| question=question, | ||
| message=message, | ||
| updated_at=datetime.utcnow(), | ||
| ) | ||
| ) | ||
| session.commit() | ||
|
|
||
|
|
||
| def get_thread_result(session: Session, thread_id: str) -> ThreadResponse | None: | ||
| return session.get(ThreadResponse, thread_id) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,3 +50,5 @@ | |
| CredsPublic, | ||
| CredsUpdate, | ||
| ) | ||
|
|
||
| from .threads import ThreadResponse | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from sqlmodel import SQLModel, Field | ||
| from typing import Optional | ||
| from datetime import datetime | ||
|
|
||
|
|
||
| class ThreadResponse(SQLModel, table=True): | ||
| thread_id: str = Field(primary_key=True) | ||
| message: Optional[str] | ||
| question: str | ||
nishika26 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| created_at: datetime = Field(default_factory=datetime.utcnow) | ||
| updated_at: datetime = Field(default_factory=datetime.utcnow) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
| import pytest, uuid | ||
| from fastapi import FastAPI | ||
| from fastapi.testclient import TestClient | ||
| from sqlmodel import select | ||
|
|
@@ -12,9 +12,11 @@ | |
| setup_thread, | ||
| process_message_content, | ||
| handle_openai_error, | ||
| poll_run_and_prepare_response, | ||
| ) | ||
| from app.models import APIKey | ||
| from app.models import APIKey, ThreadResponse | ||
| import openai | ||
| from openai import OpenAIError | ||
|
|
||
| # Wrap the router in a FastAPI app instance. | ||
| app = FastAPI() | ||
|
|
@@ -386,3 +388,171 @@ | |
| error.__str__.return_value = "None body error" | ||
| result = handle_openai_error(error) | ||
| assert result == "None body error" | ||
|
|
||
|
|
||
| @patch("app.api.routes.threads.OpenAI") | ||
| def test_poll_run_and_prepare_response_completed(mock_openai, db): | ||
| mock_client = MagicMock() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now this is okay, but consider using the OpenAI mock library: https://mharrisb1.github.io/openai-responses-python/ In addition to their examples, the see the collections CRUD tests for more |
||
| mock_run = MagicMock() | ||
| mock_run.status = "completed" | ||
| mock_client.beta.threads.runs.create_and_poll.return_value = mock_run | ||
|
|
||
| mock_message = MagicMock() | ||
| mock_message.content = [MagicMock(text=MagicMock(value="Answer "))] | ||
| mock_client.beta.threads.messages.list.return_value.data = [mock_message] | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| request = { | ||
| "question": "What is Glific?", | ||
| "assistant_id": "assist_123", | ||
| "thread_id": "test_thread_001", | ||
| "remove_citation": True, | ||
| } | ||
|
|
||
| poll_run_and_prepare_response(request, mock_client, db) | ||
|
|
||
| result = db.get(ThreadResponse, "test_thread_001") | ||
| assert result.message.strip() == "Answer" | ||
|
|
||
|
|
||
| @patch("app.api.routes.threads.OpenAI") | ||
| def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db): | ||
| mock_client = MagicMock() | ||
| mock_error = OpenAIError("Simulated OpenAI error") | ||
| mock_client.beta.threads.runs.create_and_poll.side_effect = mock_error | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| request = { | ||
| "question": "Failing run", | ||
| "assistant_id": "assist_123", | ||
| "thread_id": "test_openai_error", | ||
| } | ||
|
|
||
| poll_run_and_prepare_response(request, mock_client, db) | ||
| result = db.get(ThreadResponse, "test_openai_error") | ||
| assert result.message is None | ||
|
|
||
|
|
||
| @patch("app.api.routes.threads.OpenAI") | ||
| def test_poll_run_and_prepare_response_non_completed(mock_openai, db): | ||
| mock_client = MagicMock() | ||
| mock_run = MagicMock(status="failed") | ||
| mock_client.beta.threads.runs.create_and_poll.return_value = mock_run | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| request = { | ||
| "question": "Incomplete run", | ||
| "assistant_id": "assist_123", | ||
| "thread_id": "test_non_complete", | ||
| } | ||
|
|
||
| poll_run_and_prepare_response(request, mock_client, db) | ||
| result = db.get(ThreadResponse, "test_non_complete") | ||
| assert result.message is None | ||
|
|
||
|
|
||
| @patch("app.api.routes.threads.OpenAI") | ||
| def test_threads_start_endpoint_creates_thread(mock_openai, db): | ||
| """Test /threads/start creates thread and schedules background task.""" | ||
| mock_client = MagicMock() | ||
| mock_thread = MagicMock() | ||
| mock_thread.id = "mock_thread_001" | ||
| mock_client.beta.threads.create.return_value = mock_thread | ||
| mock_client.beta.threads.messages.create.return_value = None | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() | ||
| if not api_key_record: | ||
| pytest.skip("No API key found in the database for testing") | ||
|
|
||
| headers = {"X-API-KEY": api_key_record.key} | ||
| data = {"question": "What's 2+2?", "assistant_id": "assist_123"} | ||
|
|
||
| response = client.post("/threads/start", json=data, headers=headers) | ||
| assert response.status_code == 200 | ||
| res_json = response.json() | ||
| assert res_json["success"] | ||
| assert res_json["data"]["thread_id"] == "mock_thread_001" | ||
| assert res_json["data"]["status"] == "processing" | ||
| assert res_json["data"]["question"] == "What's 2+2?" | ||
|
|
||
|
|
||
| def test_threads_result_endpoint_success(db): | ||
| """Test /threads/result/{thread_id} returns completed thread.""" | ||
| thread_id = f"test_processing_{uuid.uuid4()}" | ||
| question = "Capital of France?" | ||
| message = "Paris." | ||
|
|
||
| db.add(ThreadResponse(thread_id=thread_id, question=question, message=message)) | ||
| db.commit() | ||
|
|
||
| api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() | ||
| if not api_key_record: | ||
| pytest.skip("No API key found in the database for testing") | ||
|
|
||
| headers = {"X-API-KEY": api_key_record.key} | ||
| response = client.get(f"/threads/result/{thread_id}", headers=headers) | ||
|
|
||
| assert response.status_code == 200 | ||
| data = response.json()["data"] | ||
| assert data["status"] == "success" | ||
| assert data["message"] == "Paris." | ||
| assert data["thread_id"] == thread_id | ||
| assert data["question"] == question | ||
|
|
||
|
|
||
| def test_threads_result_endpoint_processing(db): | ||
| """Test /threads/result/{thread_id} returns processing status if no message yet.""" | ||
| thread_id = f"test_processing_{uuid.uuid4()}" | ||
| question = "What is Glific?" | ||
|
|
||
| db.add(ThreadResponse(thread_id=thread_id, question=question, message=None)) | ||
| db.commit() | ||
|
|
||
| api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() | ||
| if not api_key_record: | ||
| pytest.skip("No API key found in the database for testing") | ||
|
|
||
| headers = {"X-API-KEY": api_key_record.key} | ||
| response = client.get(f"/threads/result/{thread_id}", headers=headers) | ||
|
|
||
| assert response.status_code == 200 | ||
| data = response.json()["data"] | ||
| assert data["status"] == "processing" | ||
| assert data["message"] is None | ||
| assert data["thread_id"] == thread_id | ||
| assert data["question"] == question | ||
|
|
||
|
|
||
| def test_threads_result_not_found(db): | ||
| """Test /threads/result/{thread_id} returns error for nonexistent thread.""" | ||
| api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() | ||
| if not api_key_record: | ||
| pytest.skip("No API key found in the database for testing") | ||
|
|
||
| headers = {"X-API-KEY": api_key_record.key} | ||
| response = client.get("/threads/result/nonexistent_thread", headers=headers) | ||
|
|
||
| assert response.status_code == 200 | ||
| assert response.json()["success"] is False | ||
| assert "not found" in response.json()["error"].lower() | ||
|
|
||
|
|
||
| @patch("app.api.routes.threads.OpenAI") | ||
| def test_threads_start_missing_question(mock_openai, db): | ||
| """Test /threads/start with missing 'question' key in request.""" | ||
| mock_openai.return_value = MagicMock() | ||
|
|
||
| api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() | ||
| if not api_key_record: | ||
| pytest.skip("No API key found in the database for testing") | ||
|
|
||
| headers = {"X-API-KEY": api_key_record.key} | ||
|
|
||
| bad_data = {"assistant_id": "assist_123"} # no "question" key | ||
|
|
||
| response = client.post("/threads/start", json=bad_data, headers=headers) | ||
|
|
||
| assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422) | ||
| error_response = response.json() | ||
| assert "detail" in error_response | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.