diff --git a/backend/app/alembic/versions/79e47bc3aac6_add_threads_table.py b/backend/app/alembic/versions/79e47bc3aac6_add_threads_table.py new file mode 100644 index 000000000..ea1fd6c19 --- /dev/null +++ b/backend/app/alembic/versions/79e47bc3aac6_add_threads_table.py @@ -0,0 +1,70 @@ +"""add threads table + +Revision ID: 79e47bc3aac6 +Revises: f23675767ed2 +Create Date: 2025-05-12 15:49:39.142806 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "79e47bc3aac6" +down_revision = "f23675767ed2" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "openai_thread", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("thread_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("prompt", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("response", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("status", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("error", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_openai_thread_thread_id"), "openai_thread", ["thread_id"], unique=True + ) + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) + op.create_foreign_key( + None, "credential", "organization", ["organization_id"], ["id"] + ) + op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey") + op.create_foreign_key(None, "project", "organization", ["organization_id"], ["id"]) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "project", type_="foreignkey") + op.create_foreign_key( + "project_organization_id_fkey", + "project", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint(None, "credential", type_="foreignkey") + op.create_foreign_key( + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_index(op.f("ix_openai_thread_thread_id"), table_name="openai_thread") + op.drop_table("openai_thread") + # ### end Alembic commands ### diff --git a/backend/app/api/routes/threads.py b/backend/app/api/routes/threads.py index 8a49ef3ce..d27b578e7 100644 --- a/backend/app/api/routes/threads.py +++ b/backend/app/api/routes/threads.py @@ -9,7 +9,8 @@ from app.api.deps import get_current_user_org, get_db from app.core import logging, settings -from app.models import UserOrganization +from app.models import UserOrganization, OpenAIThreadCreate +from app.crud import upsert_thread_result, get_thread_result from app.utils import APIResponse logger = logging.getLogger(__name__) @@ -113,6 +114,24 @@ def create_success_response(request: dict, message: str) -> APIResponse: ) +def run_and_poll_thread(client: OpenAI, thread_id: str, assistant_id: str): + """Runs and polls a thread with the specified assistant using the OpenAI client.""" + return client.beta.threads.runs.create_and_poll( + thread_id=thread_id, + assistant_id=assistant_id, + ) + + +def extract_response_from_thread( + client: OpenAI, thread_id: str, remove_citation: bool = False +) -> str: + """Fetches and processes the latest message from a thread.""" + messages = client.beta.threads.messages.list(thread_id=thread_id) + latest_message = messages.data[0] + message_content = latest_message.content[0].text.value + return process_message_content(message_content, remove_citation) + + @observe(as_type="generation") def process_run(request: dict, client: OpenAI): """Process a run and send callback with results.""" @@ -159,6 +178,40 @@ def process_run(request: dict, client: OpenAI): send_callback(request["callback_url"], callback_response.model_dump()) +def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session): + """Handles a thread run, processes the response, and upserts the result to the database.""" + thread_id = request["thread_id"] + prompt = request["question"] + + try: + run = run_and_poll_thread(client, thread_id, request["assistant_id"]) + + status = run.status or "unknown" + response = None + error = None + + if status == "completed": + response = extract_response_from_thread( + client, thread_id, request.get("remove_citation", False) + ) + + except openai.OpenAIError as e: + status = "failed" + error = str(e) + response = None + + upsert_thread_result( + db, + OpenAIThreadCreate( + thread_id=thread_id, + prompt=prompt, + response=response, + status=status, + error=error, + ), + ) + + @router.post("/threads") async def threads( request: dict, @@ -240,3 +293,72 @@ async def threads_sync( except openai.OpenAIError as e: return APIResponse.failure_response(error=handle_openai_error(e)) + + +@router.post("/threads/start") +async def start_thread( + request: OpenAIThreadCreate, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db), + _current_user: UserOrganization = Depends(get_current_user_org), +): + """ + Create a new OpenAI thread for the given question and start polling in the background. + """ + prompt = request["question"] + client = OpenAI(api_key=settings.OPENAI_API_KEY) + + is_success, error = setup_thread(client, request) + if not is_success: + return APIResponse.failure_response(error=error) + + thread_id = request["thread_id"] + + upsert_thread_result( + db, + OpenAIThreadCreate( + thread_id=thread_id, + prompt=prompt, + response=None, + status="processing", + error=None, + ), + ) + + background_tasks.add_task(poll_run_and_prepare_response, request, client, db) + + return APIResponse.success_response( + data={ + "thread_id": thread_id, + "prompt": prompt, + "status": "processing", + "message": "Thread created and polling started in background.", + } + ) + + +@router.get("/threads/result/{thread_id}") +async def get_thread( + thread_id: str, + db: Session = Depends(get_db), + _current_user: UserOrganization = Depends(get_current_user_org), +): + """ + Retrieve the result of a previously started OpenAI thread using its thread ID. + """ + result = get_thread_result(db, thread_id) + + if not result: + return APIResponse.failure_response(error="Thread not found.") + + status = result.status or ("success" if result.response else "processing") + + return APIResponse.success_response( + data={ + "thread_id": result.thread_id, + "prompt": result.prompt, + "status": status, + "response": result.response, + "error": result.error, + } + ) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index a58c5f277..963b0bf1f 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -29,3 +29,5 @@ get_api_keys_by_organization, delete_api_key, ) + +from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/crud/thread_results.py b/backend/app/crud/thread_results.py new file mode 100644 index 000000000..cd72ef188 --- /dev/null +++ b/backend/app/crud/thread_results.py @@ -0,0 +1,25 @@ +from sqlmodel import Session, select +from datetime import datetime +from app.models import OpenAIThreadCreate, OpenAI_Thread + + +def upsert_thread_result(session: Session, data: OpenAIThreadCreate): + statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == data.thread_id) + existing = session.exec(statement).first() + + if existing: + existing.prompt = data.prompt + existing.response = data.response + existing.status = data.status + existing.error = data.error + existing.updated_at = datetime.utcnow() + else: + new_thread = OpenAI_Thread(**data.dict()) + session.add(new_thread) + + session.commit() + + +def get_thread_result(session: Session, thread_id: str) -> OpenAI_Thread | None: + statement = select(OpenAI_Thread).where(OpenAI_Thread.thread_id == thread_id) + return session.exec(statement).first() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 693e73bbf..f88e019dc 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -51,3 +51,5 @@ CredsPublic, CredsUpdate, ) + +from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate diff --git a/backend/app/models/threads.py b/backend/app/models/threads.py new file mode 100644 index 000000000..e353c6760 --- /dev/null +++ b/backend/app/models/threads.py @@ -0,0 +1,21 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class OpenAIThreadBase(SQLModel): + thread_id: str = Field(index=True, unique=True) + prompt: str + response: Optional[str] = None + status: Optional[str] = None + error: Optional[str] = None + + +class OpenAIThreadCreate(OpenAIThreadBase): + pass # Used for requests, no `id` or timestamps + + +class OpenAI_Thread(OpenAIThreadBase, table=True): + id: int = Field(default=None, primary_key=True) + inserted_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index bdba0d8e0..2f1a0b78e 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -43,21 +43,16 @@ def create_organization_and_creds(db: Session, superuser_token_headers: dict[str def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - unique_org_id = 2 - existing_org = ( - db.query(Organization).filter(Organization.id == unique_org_id).first() - ) + unique_name = "Test Organization " + generate_random_string(5) - if not existing_org: - new_org = Organization( - id=unique_org_id, name="Test Organization", is_active=True - ) - db.add(new_org) - db.commit() + new_org = Organization(name=unique_name, is_active=True) + db.add(new_org) + db.commit() + db.refresh(new_org) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": unique_org_id, + "organization_id": new_org.id, "is_active": True, "credential": {"openai": {"api_key": api_key}}, } @@ -69,10 +64,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) ) assert response.status_code == 200 - created_creds = response.json() assert "data" in created_creds - assert created_creds["data"]["organization_id"] == unique_org_id + assert created_creds["data"]["organization_id"] == new_org.id assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key diff --git a/backend/app/tests/api/routes/test_threads.py b/backend/app/tests/api/routes/test_threads.py index 8a652a9a6..04c7daddc 100644 --- a/backend/app/tests/api/routes/test_threads.py +++ b/backend/app/tests/api/routes/test_threads.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -import pytest +import pytest, uuid from fastapi import FastAPI from fastapi.testclient import TestClient from sqlmodel import select @@ -12,9 +12,12 @@ setup_thread, process_message_content, handle_openai_error, + poll_run_and_prepare_response, ) -from app.models import APIKey +from app.models import APIKey, OpenAI_Thread +from app.crud import get_thread_result import openai +from openai import OpenAIError # Wrap the router in a FastAPI app instance. app = FastAPI() @@ -386,3 +389,188 @@ def test_handle_openai_error_with_none_body(): error.__str__.return_value = "None body error" result = handle_openai_error(error) assert result == "None body error" + + +@patch("app.api.routes.threads.OpenAI") +def test_poll_run_and_prepare_response_completed(mock_openai, db): + mock_client = MagicMock() + mock_run = MagicMock() + mock_run.status = "completed" + mock_client.beta.threads.runs.create_and_poll.return_value = mock_run + + mock_message = MagicMock() + mock_message.content = [MagicMock(text=MagicMock(value="Answer"))] + mock_client.beta.threads.messages.list.return_value.data = [mock_message] + mock_openai.return_value = mock_client + + request = { + "question": "What is Glific?", + "assistant_id": "assist_123", + "thread_id": "test_thread_001", + "remove_citation": True, + } + + poll_run_and_prepare_response(request, mock_client, db) + + result = get_thread_result(db, "test_thread_001") + assert result.response.strip() == "Answer" + + +@patch("app.api.routes.threads.OpenAI") +def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db): + mock_client = MagicMock() + mock_error = OpenAIError("Simulated OpenAI error") + mock_client.beta.threads.runs.create_and_poll.side_effect = mock_error + mock_openai.return_value = mock_client + + request = { + "question": "Failing run", + "assistant_id": "assist_123", + "thread_id": "test_openai_error", + } + + poll_run_and_prepare_response(request, mock_client, db) + + # Since thread_id is not the primary key, use select query + statement = select(OpenAI_Thread).where( + OpenAI_Thread.thread_id == "test_openai_error" + ) + result = db.exec(statement).first() + + assert result is not None + assert result.response is None + assert result.status == "failed" + assert "Simulated OpenAI error" in (result.error or "") + + +@patch("app.api.routes.threads.OpenAI") +def test_poll_run_and_prepare_response_non_completed(mock_openai, db): + mock_client = MagicMock() + mock_run = MagicMock(status="failed") + mock_client.beta.threads.runs.create_and_poll.return_value = mock_run + mock_openai.return_value = mock_client + + request = { + "question": "Incomplete run", + "assistant_id": "assist_123", + "thread_id": "test_non_complete", + } + + poll_run_and_prepare_response(request, mock_client, db) + + # thread_id is not the primary key, so we query using SELECT + statement = select(OpenAI_Thread).where( + OpenAI_Thread.thread_id == "test_non_complete" + ) + result = db.exec(statement).first() + + assert result is not None + assert result.response is None + assert result.status == "failed" + + +@patch("app.api.routes.threads.OpenAI") +def test_threads_start_endpoint_creates_thread(mock_openai, db): + """Test /threads/start creates thread and schedules background task.""" + mock_client = MagicMock() + mock_thread = MagicMock() + mock_thread.id = "mock_thread_001" + mock_client.beta.threads.create.return_value = mock_thread + mock_client.beta.threads.messages.create.return_value = None + mock_openai.return_value = mock_client + + api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() + if not api_key_record: + pytest.skip("No API key found in the database for testing") + + headers = {"X-API-KEY": api_key_record.key} + data = {"question": "What's 2+2?", "assistant_id": "assist_123"} + + response = client.post("/threads/start", json=data, headers=headers) + assert response.status_code == 200 + res_json = response.json() + assert res_json["success"] + assert res_json["data"]["thread_id"] == "mock_thread_001" + assert res_json["data"]["status"] == "processing" + assert res_json["data"]["prompt"] == "What's 2+2?" + + +def test_threads_result_endpoint_success(db): + """Test /threads/result/{thread_id} returns completed thread.""" + thread_id = f"test_processing_{uuid.uuid4()}" + question = "Capital of France?" + message = "Paris." + + db.add(OpenAI_Thread(thread_id=thread_id, prompt=question, response=message)) + db.commit() + + api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() + if not api_key_record: + pytest.skip("No API key found in the database for testing") + + headers = {"X-API-KEY": api_key_record.key} + response = client.get(f"/threads/result/{thread_id}", headers=headers) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["status"] == "success" + assert data["response"] == "Paris." + assert data["thread_id"] == thread_id + assert data["prompt"] == question + + +def test_threads_result_endpoint_processing(db): + """Test /threads/result/{thread_id} returns processing status if no message yet.""" + thread_id = f"test_processing_{uuid.uuid4()}" + question = "What is Glific?" + + db.add(OpenAI_Thread(thread_id=thread_id, prompt=question, response=None)) + db.commit() + + api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() + if not api_key_record: + pytest.skip("No API key found in the database for testing") + + headers = {"X-API-KEY": api_key_record.key} + response = client.get(f"/threads/result/{thread_id}", headers=headers) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["status"] == "processing" + assert data["message"] is None + assert data["thread_id"] == thread_id + assert data["prompt"] == question + + +def test_threads_result_not_found(db): + """Test /threads/result/{thread_id} returns error for nonexistent thread.""" + api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() + if not api_key_record: + pytest.skip("No API key found in the database for testing") + + headers = {"X-API-KEY": api_key_record.key} + response = client.get("/threads/result/nonexistent_thread", headers=headers) + + assert response.status_code == 200 + assert response.json()["success"] is False + assert "not found" in response.json()["error"].lower() + + +@patch("app.api.routes.threads.OpenAI") +def test_threads_start_missing_question(mock_openai, db): + """Test /threads/start with missing 'question' key in request.""" + mock_openai.return_value = MagicMock() + + api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() + if not api_key_record: + pytest.skip("No API key found in the database for testing") + + headers = {"X-API-KEY": api_key_record.key} + + bad_data = {"assistant_id": "assist_123"} # no "question" key + + response = client.post("/threads/start", json=bad_data, headers=headers) + + assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422) + error_response = response.json() + assert "detail" in error_response diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 9cd6c4971..fa36ddf0d 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -13,6 +13,8 @@ Project, ProjectUser, User, + OpenAI_Thread, + Credential, ) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers @@ -26,9 +28,11 @@ def db() -> Generator[Session, None, None]: # Delete data in reverse dependency order session.execute(delete(ProjectUser)) # Many-to-many relationship session.execute(delete(Project)) + session.execute(delete(Credential)) session.execute(delete(Organization)) session.execute(delete(APIKey)) session.execute(delete(User)) + session.execute(delete(OpenAI_Thread)) session.commit() diff --git a/backend/app/tests/crud/test_thread_result.py b/backend/app/tests/crud/test_thread_result.py new file mode 100644 index 000000000..00c581dea --- /dev/null +++ b/backend/app/tests/crud/test_thread_result.py @@ -0,0 +1,47 @@ +import pytest +from sqlmodel import SQLModel, Session, create_engine + +from app.models import OpenAI_Thread, OpenAIThreadCreate +from app.crud import upsert_thread_result, get_thread_result + + +def test_upsert_and_get_thread_result(db: Session): + thread_id = "thread_test_123" + prompt = "What is the capital of Spain?" + response = "Madrid is the capital of Spain." + + # Insert + upsert_thread_result( + db, + OpenAIThreadCreate( + thread_id=thread_id, + prompt=prompt, + response=response, + status="completed", + error=None, + ), + ) + + # Retrieve + result = get_thread_result(db, thread_id) + + assert result is not None + assert result.thread_id == thread_id + assert result.prompt == prompt + assert result.response == response + + # Update with new response + updated_response = "Madrid." + upsert_thread_result( + db, + OpenAIThreadCreate( + thread_id=thread_id, + prompt=prompt, + response=updated_response, + status="completed", + error=None, + ), + ) + + result_updated = get_thread_result(db, thread_id) + assert result_updated.response == updated_response