Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backend/app/alembic/versions/9baa692f9a5d_add_threads_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""add threads table

Revision ID: 9baa692f9a5d
Revises: 543f97951bd0
Create Date: 2025-05-05 23:25:37.195415

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = "9baa692f9a5d"
down_revision = "543f97951bd0"
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
"threadresponse",
sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("message", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("question", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("updated_at", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("thread_id"),
)


def downgrade():
op.drop_table("threadresponse")
90 changes: 90 additions & 0 deletions backend/app/api/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.api.deps import get_current_user_org, get_db
from app.core import logging, settings
from app.models import UserOrganization
from app.crud import upsert_thread_result, get_thread_result
from app.utils import APIResponse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,6 +160,32 @@ def process_run(request: dict, client: OpenAI):
send_callback(request["callback_url"], callback_response.model_dump())


def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session):
"process a run and send result to DB"
thread_id = request["thread_id"]
question = request["question"]

try:
run = client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=request["assistant_id"],
)

if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=thread_id)
latest_message = messages.data[0]
message_content = latest_message.content[0].text.value
processed_message = process_message_content(
message_content, request.get("remove_citation", False)
)
upsert_thread_result(db, thread_id, question, processed_message)
else:
upsert_thread_result(db, thread_id, question, None)

except openai.OpenAIError:
upsert_thread_result(db, thread_id, question, None)


@router.post("/threads")
async def threads(
request: dict,
Expand Down Expand Up @@ -240,3 +267,66 @@ async def threads_sync(

except openai.OpenAIError as e:
return APIResponse.failure_response(error=handle_openai_error(e))


@router.post("/threads/start")
async def start_thread(
request: dict,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Create a new OpenAI thread for the given question and start polling in the background.

- If successful, returns the thread ID and a 'processing' status.
- Stores the thread and question in the database.
"""

question = request["question"]

client = OpenAI(api_key=settings.OPENAI_API_KEY)

is_success, error = setup_thread(client, request)
if not is_success:
return APIResponse.failure_response(error=error)

thread_id = request["thread_id"]
upsert_thread_result(db, thread_id, question, None)

background_tasks.add_task(poll_run_and_prepare_response, request, client, db)

return APIResponse.success_response(
data={
"thread_id": thread_id,
"question": question,
"status": "processing",
"message": "Thread created and polling started in background.",
}
)


@router.get("/threads/result/{thread_id}")
async def get_thread_result_by_id(
thread_id: str,
db: Session = Depends(get_db),
_current_user: UserOrganization = Depends(get_current_user_org),
):
"""
Retrieve the result of a previously started OpenAI thread using its thread ID.
"""
result = get_thread_result(db, thread_id)

if not result:
return APIResponse.failure_response(error="Thread not found.")

status = "success" if result.message else "processing"

return APIResponse.success_response(
data={
"thread_id": result.thread_id,
"question": result.question,
"status": status,
"message": result.message,
}
)
2 changes: 2 additions & 0 deletions backend/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
get_api_keys_by_organization,
delete_api_key,
)

from .thread_results import upsert_thread_result, get_thread_result
21 changes: 21 additions & 0 deletions backend/app/crud/thread_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sqlmodel import Session
from datetime import datetime
from app.models import ThreadResponse


def upsert_thread_result(
session: Session, thread_id: str, question: str, message: str | None
):
session.merge(
ThreadResponse(
thread_id=thread_id,
question=question,
message=message,
updated_at=datetime.utcnow(),
)
)
session.commit()


def get_thread_result(session: Session, thread_id: str) -> ThreadResponse | None:
return session.get(ThreadResponse, thread_id)
2 changes: 2 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@
CredsPublic,
CredsUpdate,
)

from .threads import ThreadResponse
11 changes: 11 additions & 0 deletions backend/app/models/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlmodel import SQLModel, Field
from typing import Optional
from datetime import datetime


class ThreadResponse(SQLModel, table=True):
thread_id: str = Field(primary_key=True)
message: Optional[str]
question: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
174 changes: 172 additions & 2 deletions backend/app/tests/api/routes/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import MagicMock, patch

import pytest
import pytest, uuid
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import select
Expand All @@ -12,9 +12,11 @@
setup_thread,
process_message_content,
handle_openai_error,
poll_run_and_prepare_response,
)
from app.models import APIKey
from app.models import APIKey, ThreadResponse
import openai
from openai import OpenAIError

# Wrap the router in a FastAPI app instance.
app = FastAPI()
Expand Down Expand Up @@ -386,3 +388,171 @@ def test_handle_openai_error_with_none_body():
error.__str__.return_value = "None body error"
result = handle_openai_error(error)
assert result == "None body error"


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_completed(mock_openai, db):
mock_client = MagicMock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is okay, but consider using the OpenAI mock library: https://mharrisb1.github.io/openai-responses-python/ In addition to their examples, the see the collections CRUD tests for more

mock_run = MagicMock()
mock_run.status = "completed"
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run

mock_message = MagicMock()
mock_message.content = [MagicMock(text=MagicMock(value="Answer "))]
mock_client.beta.threads.messages.list.return_value.data = [mock_message]
mock_openai.return_value = mock_client

request = {
"question": "What is Glific?",
"assistant_id": "assist_123",
"thread_id": "test_thread_001",
"remove_citation": True,
}

poll_run_and_prepare_response(request, mock_client, db)

result = db.get(ThreadResponse, "test_thread_001")
assert result.message.strip() == "Answer"


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db):
mock_client = MagicMock()
mock_error = OpenAIError("Simulated OpenAI error")
mock_client.beta.threads.runs.create_and_poll.side_effect = mock_error
mock_openai.return_value = mock_client

request = {
"question": "Failing run",
"assistant_id": "assist_123",
"thread_id": "test_openai_error",
}

poll_run_and_prepare_response(request, mock_client, db)
result = db.get(ThreadResponse, "test_openai_error")
assert result.message is None


@patch("app.api.routes.threads.OpenAI")
def test_poll_run_and_prepare_response_non_completed(mock_openai, db):
mock_client = MagicMock()
mock_run = MagicMock(status="failed")
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run
mock_openai.return_value = mock_client

request = {
"question": "Incomplete run",
"assistant_id": "assist_123",
"thread_id": "test_non_complete",
}

poll_run_and_prepare_response(request, mock_client, db)
result = db.get(ThreadResponse, "test_non_complete")
assert result.message is None


@patch("app.api.routes.threads.OpenAI")
def test_threads_start_endpoint_creates_thread(mock_openai, db):
"""Test /threads/start creates thread and schedules background task."""
mock_client = MagicMock()
mock_thread = MagicMock()
mock_thread.id = "mock_thread_001"
mock_client.beta.threads.create.return_value = mock_thread
mock_client.beta.threads.messages.create.return_value = None
mock_openai.return_value = mock_client

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
data = {"question": "What's 2+2?", "assistant_id": "assist_123"}

response = client.post("/threads/start", json=data, headers=headers)
assert response.status_code == 200
res_json = response.json()
assert res_json["success"]
assert res_json["data"]["thread_id"] == "mock_thread_001"
assert res_json["data"]["status"] == "processing"
assert res_json["data"]["question"] == "What's 2+2?"


def test_threads_result_endpoint_success(db):
"""Test /threads/result/{thread_id} returns completed thread."""
thread_id = f"test_processing_{uuid.uuid4()}"
question = "Capital of France?"
message = "Paris."

db.add(ThreadResponse(thread_id=thread_id, question=question, message=message))
db.commit()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get(f"/threads/result/{thread_id}", headers=headers)

assert response.status_code == 200
data = response.json()["data"]
assert data["status"] == "success"
assert data["message"] == "Paris."
assert data["thread_id"] == thread_id
assert data["question"] == question


def test_threads_result_endpoint_processing(db):
"""Test /threads/result/{thread_id} returns processing status if no message yet."""
thread_id = f"test_processing_{uuid.uuid4()}"
question = "What is Glific?"

db.add(ThreadResponse(thread_id=thread_id, question=question, message=None))
db.commit()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get(f"/threads/result/{thread_id}", headers=headers)

assert response.status_code == 200
data = response.json()["data"]
assert data["status"] == "processing"
assert data["message"] is None
assert data["thread_id"] == thread_id
assert data["question"] == question


def test_threads_result_not_found(db):
"""Test /threads/result/{thread_id} returns error for nonexistent thread."""
api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}
response = client.get("/threads/result/nonexistent_thread", headers=headers)

assert response.status_code == 200
assert response.json()["success"] is False
assert "not found" in response.json()["error"].lower()


@patch("app.api.routes.threads.OpenAI")
def test_threads_start_missing_question(mock_openai, db):
"""Test /threads/start with missing 'question' key in request."""
mock_openai.return_value = MagicMock()

api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first()
if not api_key_record:
pytest.skip("No API key found in the database for testing")

headers = {"X-API-KEY": api_key_record.key}

bad_data = {"assistant_id": "assist_123"} # no "question" key

response = client.post("/threads/start", json=bad_data, headers=headers)

assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422)
error_response = response.json()
assert "detail" in error_response
Loading
Loading