Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backend/app/alembic/versions/9baa692f9a5d_add_threads_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""add threads table

Revision ID: 9baa692f9a5d
Revises: 543f97951bd0
Create Date: 2025-05-05 23:25:37.195415

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = "9baa692f9a5d"
down_revision = "543f97951bd0"
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
"threadresponse",
sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("message", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("question", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("thread_id"),
)


def downgrade():
op.drop_table("threadresponse")
90 changes: 90 additions & 0 deletions backend/app/api/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.api.deps import get_current_user_org, get_db
from app.core import logging, settings
from app.models import UserOrganization
from app.crud import upsert_thread_result, get_thread_result
from app.utils import APIResponse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,6 +160,32 @@
send_callback(request["callback_url"], callback_response.model_dump())


def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session):
"process a run and send result to DB"
thread_id = request["thread_id"]
question = request["question"]

try:
run = client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=request["assistant_id"],
)

if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread_id)
latest_message = messages.data[0]
message_content = latest_message.content[0].text.value
processed_message = process_message_content(
message_content, request.get("remove_citation", False)
)
upsert_thread_result(db, thread_id, question, processed_message)
else:
upsert_thread_result(db, thread_id, question, None)

except openai.OpenAIError:
upsert_thread_result(db, thread_id, question, None)


@router.post("/threads")
async def threads(
request: dict,
Expand Down Expand Up @@ -240,3 +267,66 @@

except openai.OpenAIError as e:
return APIResponse.failure_response(error=handle_openai_error(e))


@router.post("/threads/start")
async def start_thread(
request: dict,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Create a new OpenAI thread for the given question and start polling in the background.

- If successful, returns the thread ID and a 'processing' status.
- Stores the thread and question in the database.
"""

question = request["question"]

Check warning on line 286 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L286

Added line #L286 was not covered by tests

client = OpenAI(api_key=settings.OPENAI_API_KEY)

Check warning on line 288 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L288

Added line #L288 was not covered by tests

is_success, error = setup_thread(client, request)
if not is_success:
return APIResponse.failure_response(error=error)

Check warning on line 292 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L290-L292

Added lines #L290 - L292 were not covered by tests

thread_id = request["thread_id"]
upsert_thread_result(db, thread_id, question, None)

Check warning on line 295 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L294-L295

Added lines #L294 - L295 were not covered by tests

background_tasks.add_task(poll_run_and_prepare_response, request, client, db)

Check warning on line 297 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L297

Added line #L297 was not covered by tests

return APIResponse.success_response(

Check warning on line 299 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L299

Added line #L299 was not covered by tests
data={
"thread_id": thread_id,
"question": question,
"status": "processing",
"message": "Thread created and polling started in background.",
}
)


@router.get("/threads/result/{thread_id}")
async def get_thread_result_by_id(
thread_id: str,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Retrieve the result of a previously started OpenAI thread using its thread ID.
"""
result = get_thread_result(db, thread_id)

Check warning on line 318 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L318

Added line #L318 was not covered by tests

if not result:
return APIResponse.failure_response(error="Thread not found.")

Check warning on line 321 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L320-L321

Added lines #L320 - L321 were not covered by tests

status = "success" if result.message else "processing"

Check warning on line 323 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L323

Added line #L323 was not covered by tests

return APIResponse.success_response(

Check warning on line 325 in backend/app/api/routes/threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/threads.py#L325

Added line #L325 was not covered by tests
data={
"thread_id": result.thread_id,
"question": result.question,
"status": status,
"message": result.message,
}
)
2 changes: 2 additions & 0 deletions backend/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
get_api_keys_by_organization,
delete_api_key,
)

from .thread_results import upsert_thread_result, get_thread_result
21 changes: 21 additions & 0 deletions backend/app/crud/thread_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sqlmodel import Session
from datetime import datetime
from app.models import ThreadResponse


def upsert_thread_result(
session: Session, thread_id: str, question: str, message: str | None
):
session.merge(
ThreadResponse(
thread_id=thread_id,
question=question,
message=message,
updated_at=datetime.utcnow(),
)
)
session.commit()


def get_thread_result(session: Session, thread_id: str) -> ThreadResponse | None:
return session.get(ThreadResponse, thread_id)
2 changes: 2 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@
CredsPublic,
CredsUpdate,
)

from .threads import ThreadResponse
11 changes: 11 additions & 0 deletions backend/app/models/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlmodel import SQLModel, Field
from typing import Optional
from datetime import datetime


class ThreadResponse(SQLModel, table=True):
thread_id: str = Field(primary_key=True)
message: Optional[str]
question: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
174 changes: 172 additions & 2 deletions backend/app/tests/api/routes/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import MagicMock, patch

import pytest
import pytest, uuid
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import select
Expand All @@ -12,9 +12,11 @@
setup_thread,
process_message_content,
handle_openai_error,
poll_run_and_prepare_response,
)
from app.models import APIKey
from app.models import APIKey, ThreadResponse
import openai
from openai import OpenAIError

# Wrap the router in a FastAPI app instance.
app = FastAPI()
Expand Down Expand Up @@ -386,3 +388,171 @@
error.__str__.return_value = "None body error"
result = handle_openai_error(error)
assert result == "None body error"


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_completed(mock_openai, db):
mock_client = MagicMock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is okay, but consider using the OpenAI mock library: https://mharrisb1.github.io/openai-responses-python/ In addition to their examples, the see the collections CRUD tests for more

mock_run = MagicMock()
mock_run.status = "completed"
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run

mock_message = MagicMock()
mock_message.content = [MagicMock(text=MagicMock(value="Answer "))]
mock_client.beta.threads.messages.list.return_value.data = [mock_message]
mock_openai.return_value = mock_client

request = {
"question": "What is Glific?",
"assistant_id": "assist_123",
"thread_id": "test_thread_001",
"remove_citation": True,
}

poll_run_and_prepare_response(request, mock_client, db)

result = db.get(ThreadResponse, "test_thread_001")
assert result.message.strip() == "Answer"


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db):
mock_client = MagicMock()
mock_error = OpenAIError("Simulated OpenAI error")
mock_client.beta.threads.runs.create_and_poll.side_effect = mock_error
mock_openai.return_value = mock_client

request = {
"question": "Failing run",
"assistant_id": "assist_123",
"thread_id": "test_openai_error",
}

poll_run_and_prepare_response(request, mock_client, db)
result = db.get(ThreadResponse, "test_openai_error")
assert result.message is None


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_non_completed(mock_openai, db):
mock_client = MagicMock()
mock_run = MagicMock(status="failed")
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run
mock_openai.return_value = mock_client

request = {
"question": "Incomplete run",
"assistant_id": "assist_123",
"thread_id": "test_non_complete",
}

poll_run_and_prepare_response(request, mock_client, db)
result = db.get(ThreadResponse, "test_non_complete")
assert result.message is None


@patch("app.api.routes.threads.OpenAI")
def test_threads_start_endpoint_creates_thread(mock_openai, db):
"""Test /threads/start creates thread and schedules background task."""
mock_client = MagicMock()
mock_thread = MagicMock()
mock_thread.id = "mock_thread_001"
mock_client.beta.threads.create.return_value = mock_thread
mock_client.beta.threads.messages.create.return_value = None
mock_openai.return_value = mock_client

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
data = {"question": "What's 2+2?", "assistant_id": "assist_123"}

Check warning on line 469 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L468-L469

Added lines #L468 - L469 were not covered by tests

response = client.post("/threads/start", json=data, headers=headers)
assert response.status_code == 200
res_json = response.json()
assert res_json["success"]
assert res_json["data"]["thread_id"] == "mock_thread_001"
assert res_json["data"]["status"] == "processing"
assert res_json["data"]["question"] == "What's 2+2?"

Check warning on line 477 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L471-L477

Added lines #L471 - L477 were not covered by tests


def test_threads_result_endpoint_success(db):
"""Test /threads/result/{thread_id} returns completed thread."""
thread_id = f"test_processing_{uuid.uuid4()}"
question = "Capital of France?"
message = "Paris."

db.add(ThreadResponse(thread_id=thread_id, question=question, message=message))
db.commit()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get(f"/threads/result/{thread_id}", headers=headers)

Check warning on line 494 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L493-L494

Added lines #L493 - L494 were not covered by tests

assert response.status_code == 200
data = response.json()["data"]
assert data["status"] == "success"
assert data["message"] == "Paris."
assert data["thread_id"] == thread_id
assert data["question"] == question

Check warning on line 501 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L496-L501

Added lines #L496 - L501 were not covered by tests


def test_threads_result_endpoint_processing(db):
"""Test /threads/result/{thread_id} returns processing status if no message yet."""
thread_id = f"test_processing_{uuid.uuid4()}"
question = "What is Glific?"

db.add(ThreadResponse(thread_id=thread_id, question=question, message=None))
db.commit()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get(f"/threads/result/{thread_id}", headers=headers)

Check warning on line 517 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L516-L517

Added lines #L516 - L517 were not covered by tests

assert response.status_code == 200
data = response.json()["data"]
assert data["status"] == "processing"
assert data["message"] is None
assert data["thread_id"] == thread_id
assert data["question"] == question

Check warning on line 524 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L519-L524

Added lines #L519 - L524 were not covered by tests


def test_threads_result_not_found(db):
"""Test /threads/result/{thread_id} returns error for nonexistent thread."""
api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get("/threads/result/nonexistent_thread", headers=headers)

Check warning on line 534 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L533-L534

Added lines #L533 - L534 were not covered by tests

assert response.status_code == 200
assert response.json()["success"] is False
assert "not found" in response.json()["error"].lower()

Check warning on line 538 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L536-L538

Added lines #L536 - L538 were not covered by tests


@patch("app.api.routes.threads.OpenAI")
def test_threads_start_missing_question(mock_openai, db):
"""Test /threads/start with missing 'question' key in request."""
mock_openai.return_value = MagicMock()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}

Check warning on line 550 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L550

Added line #L550 was not covered by tests

bad_data = {"assistant_id": "assist_123"} # no "question" key

Check warning on line 552 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L552

Added line #L552 was not covered by tests

response = client.post("/threads/start", json=bad_data, headers=headers)

Check warning on line 554 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L554

Added line #L554 was not covered by tests

assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422)
error_response = response.json()
assert "detail" in error_response

Check warning on line 558 in backend/app/tests/api/routes/test_threads.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/test_threads.py#L556-L558

Added lines #L556 - L558 were not covered by tests
Loading