Skip to content

Commit 67b89c8

Browse files
authored
OpenAI: Threads (#40)
* getting threads up and running * added testcases and citation * removing ssl verify * using standardized APIResponse * getting rid of redundant files * refactor code after testing * refactor testcases * setting up init.py * fixing review comments * cleanup * cleanup * removed validate thread as it can be handled by default * fixing few code review suggestions * removed validation testcases for assistant ID
1 parent 261efb7 commit 67b89c8

File tree

8 files changed

+392
-2
lines changed

8 files changed

+392
-2
lines changed

backend/app/api/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from fastapi import APIRouter
2-
from app.api.routes import items, login, private, users, utils,project,organization, project_user, api_keys
2+
3+
from app.api.routes import items, login, private, users, utils, project, organization, project_user, api_keys, threads
34
from app.core.config import settings
45

56
api_router = APIRouter()
67
api_router.include_router(login.router)
78
api_router.include_router(users.router)
89
api_router.include_router(utils.router)
910
api_router.include_router(items.router)
11+
api_router.include_router(threads.router)
1012
api_router.include_router(organization.router)
1113
api_router.include_router(project.router)
1214
api_router.include_router(project_user.router)

backend/app/api/routes/threads.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import re
2+
import requests
3+
4+
import openai
5+
from openai import OpenAI
6+
from fastapi import APIRouter, BackgroundTasks
7+
8+
from app.utils import APIResponse
9+
from app.core import settings, logging
10+
11+
logger = logging.getLogger(__name__)
12+
router = APIRouter(tags=["threads"])
13+
14+
15+
def send_callback(callback_url: str, data: dict):
16+
"""Send results to the callback URL (synchronously)."""
17+
try:
18+
session = requests.Session()
19+
# uncomment this to run locally without SSL
20+
# session.verify = False
21+
response = session.post(callback_url, json=data)
22+
response.raise_for_status()
23+
return True
24+
except requests.RequestException as e:
25+
logger.error(f"Callback failed: {str(e)}")
26+
return False
27+
28+
29+
def process_run(request: dict, client: OpenAI):
30+
"""
31+
Background task to run create_and_poll, then send the callback with the result.
32+
This function is run in the background after we have already returned an initial response.
33+
"""
34+
try:
35+
# Start the run
36+
run = client.beta.threads.runs.create_and_poll(
37+
thread_id=request["thread_id"],
38+
assistant_id=request["assistant_id"],
39+
)
40+
41+
if run.status == "completed":
42+
messages = client.beta.threads.messages.list(
43+
thread_id=request["thread_id"])
44+
latest_message = messages.data[0]
45+
message_content = latest_message.content[0].text.value
46+
47+
remove_citation = request.get("remove_citation", False)
48+
49+
if remove_citation:
50+
message = re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content)
51+
else:
52+
message = message_content
53+
54+
# Update the data dictionary with additional fields from the request, excluding specific keys
55+
additional_data = {k: v for k, v in request.items(
56+
) if k not in {"question", "assistant_id", "callback_url", "thread_id"}}
57+
callback_response = APIResponse.success_response(data={
58+
"status": "success",
59+
"message": message,
60+
"thread_id": request["thread_id"],
61+
"endpoint": getattr(request, "endpoint", "some-default-endpoint"),
62+
**additional_data
63+
})
64+
else:
65+
callback_response = APIResponse.failure_response(
66+
error=f"Run failed with status: {run.status}")
67+
68+
# Send callback with results
69+
send_callback(request["callback_url"], callback_response.model_dump())
70+
71+
except openai.OpenAIError as e:
72+
# Handle any other OpenAI API errors
73+
if isinstance(e.body, dict) and "message" in e.body:
74+
error_message = e.body["message"]
75+
else:
76+
error_message = str(e)
77+
78+
callback_response = APIResponse.failure_response(error=error_message)
79+
80+
send_callback(request["callback_url"], callback_response.model_dump())
81+
82+
83+
@router.post("/threads")
84+
async def threads(request: dict, background_tasks: BackgroundTasks):
85+
"""
86+
Accepts a question, assistant_id, callback_url, and optional thread_id from the request body.
87+
Returns an immediate "processing" response, then continues to run create_and_poll in background.
88+
Once completed, calls send_callback with the final result.
89+
"""
90+
client = OpenAI(api_key=settings.OPENAI_API_KEY)
91+
92+
# Use get method to safely access thread_id
93+
thread_id = request.get("thread_id")
94+
95+
# 1. Validate or check if there's an existing thread with an in-progress run
96+
if thread_id:
97+
try:
98+
runs = client.beta.threads.runs.list(thread_id=thread_id)
99+
# Get the most recent run (first in the list) if any
100+
if runs.data and len(runs.data) > 0:
101+
latest_run = runs.data[0]
102+
if latest_run.status in ["queued", "in_progress", "requires_action"]:
103+
return APIResponse.failure_response(error=f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete.")
104+
except openai.OpenAIError:
105+
# Handle invalid thread ID
106+
return APIResponse.failure_response(error=f"Invalid thread ID provided {thread_id}")
107+
108+
# Use existing thread
109+
client.beta.threads.messages.create(
110+
thread_id=thread_id, role="user", content=request["question"]
111+
)
112+
else:
113+
try:
114+
# Create new thread
115+
thread = client.beta.threads.create()
116+
client.beta.threads.messages.create(
117+
thread_id=thread.id, role="user", content=request["question"]
118+
)
119+
request["thread_id"] = thread.id
120+
except openai.OpenAIError as e:
121+
# Handle any other OpenAI API errors
122+
if isinstance(e.body, dict) and "message" in e.body:
123+
error_message = e.body["message"]
124+
else:
125+
error_message = str(e)
126+
return APIResponse.failure_response(error=error_message)
127+
128+
# 2. Send immediate response to complete the API call
129+
initial_response = APIResponse.success_response(data={
130+
"status": "processing",
131+
"message": "Run started",
132+
"thread_id": request.get("thread_id"),
133+
"success": True,
134+
})
135+
136+
# 3. Schedule the background task to run create_and_poll and send callback
137+
background_tasks.add_task(process_run, request, client)
138+
139+
# 4. Return immediately so the client knows we've accepted the request
140+
return initial_response

backend/app/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .config import settings
2+
from .logger import logging
3+
4+
__all__ = ['settings', 'logging']

backend/app/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import secrets
22
import warnings
3+
import os
34
from typing import Annotated, Any, Literal
45

56
from pydantic import (
@@ -31,6 +32,7 @@ class Settings(BaseSettings):
3132
env_ignore_empty=True,
3233
extra="ignore",
3334
)
35+
OPENAI_API_KEY: str
3436
API_V1_STR: str = "/api/v1"
3537
SECRET_KEY: str = secrets.token_urlsafe(32)
3638
# 60 minutes * 24 hours * 1 days = 1 days
@@ -95,6 +97,9 @@ def emails_enabled(self) -> bool:
9597
FIRST_SUPERUSER: EmailStr
9698
FIRST_SUPERUSER_PASSWORD: str
9799

100+
LOG_DIR: str = os.path.join(os.path.dirname(
101+
os.path.dirname(__file__)), "logs")
102+
98103
def _check_default_secret(self, var_name: str, value: str | None) -> None:
99104
if value == "changethis":
100105
message = (

backend/app/core/logger.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import logging
2+
import os
3+
from logging.handlers import RotatingFileHandler
4+
from app.core.config import settings
5+
6+
LOG_DIR = settings.LOG_DIR
7+
if not os.path.exists(LOG_DIR):
8+
os.makedirs(LOG_DIR)
9+
10+
LOG_FILE_PATH = os.path.join(LOG_DIR, "app.log")
11+
12+
LOGGING_LEVEL = logging.INFO
13+
LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
14+
15+
logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT)
16+
17+
file_handler = RotatingFileHandler(
18+
LOG_FILE_PATH, maxBytes=10485760, backupCount=5)
19+
file_handler.setLevel(LOGGING_LEVEL)
20+
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
21+
22+
logging.getLogger("").addHandler(file_handler)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
import openai
3+
4+
from unittest.mock import MagicMock, patch
5+
from fastapi import FastAPI
6+
from fastapi.testclient import TestClient
7+
8+
from app.api.routes.threads import router, process_run
9+
from app.utils import APIResponse
10+
11+
# Wrap the router in a FastAPI app instance.
12+
app = FastAPI()
13+
app.include_router(router)
14+
client = TestClient(app)
15+
16+
17+
@patch("src.app.api.v1.threads.OpenAI")
18+
def test_threads_endpoint(mock_openai):
19+
"""
20+
Test the /threads endpoint when creating a new thread.
21+
The patched OpenAI client simulates:
22+
- A successful assistant ID validation.
23+
- New thread creation with a dummy thread id.
24+
- No existing runs.
25+
The expected response should have status "processing" and include a thread_id.
26+
"""
27+
# Create a dummy client to simulate OpenAI API behavior.
28+
dummy_client = MagicMock()
29+
# Simulate a valid assistant ID by ensuring retrieve doesn't raise an error.
30+
dummy_client.beta.assistants.retrieve.return_value = None
31+
# Simulate thread creation.
32+
dummy_thread = MagicMock()
33+
dummy_thread.id = "dummy_thread_id"
34+
dummy_client.beta.threads.create.return_value = dummy_thread
35+
# Simulate message creation.
36+
dummy_client.beta.threads.messages.create.return_value = None
37+
# Simulate that no active run exists.
38+
dummy_client.beta.threads.runs.list.return_value = MagicMock(data=[])
39+
40+
mock_openai.return_value = dummy_client
41+
42+
request_data = {
43+
"question": "What is Glific?",
44+
"assistant_id": "assistant_123",
45+
"callback_url": "http://example.com/callback",
46+
}
47+
response = client.post("/threads", json=request_data)
48+
assert response.status_code == 200
49+
response_json = response.json()
50+
assert response_json["success"] is True
51+
assert response_json["data"]["status"] == "processing"
52+
assert response_json["data"]["message"] == "Run started"
53+
assert response_json["data"]["thread_id"] == "dummy_thread_id"
54+
55+
56+
@patch("src.app.api.v1.threads.OpenAI")
57+
@pytest.mark.parametrize(
58+
"remove_citation, expected_message",
59+
[
60+
(
61+
True,
62+
"Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp",
63+
),
64+
(
65+
False,
66+
"Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp【1:2†citation】",
67+
),
68+
],
69+
)
70+
def test_process_run_variants(mock_openai, remove_citation, expected_message):
71+
"""
72+
Test process_run for both remove_citation variants:
73+
- Mocks the OpenAI client to simulate a completed run.
74+
- Verifies that send_callback is called with the expected message based on the remove_citation flag.
75+
"""
76+
# Setup the mock client.
77+
mock_client = MagicMock()
78+
mock_openai.return_value = mock_client
79+
80+
# Create the request with the variable remove_citation flag.
81+
request = {
82+
"question": "What is Glific?",
83+
"assistant_id": "assistant_123",
84+
"callback_url": "http://example.com/callback",
85+
"thread_id": "thread_123",
86+
"remove_citation": remove_citation,
87+
}
88+
89+
# Simulate a completed run.
90+
mock_run = MagicMock()
91+
mock_run.status = "completed"
92+
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run
93+
94+
# Set up the dummy message based on the remove_citation flag.
95+
base_message = "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp"
96+
citation_message = base_message if remove_citation else f"{base_message}【1:2†citation】"
97+
dummy_message = MagicMock()
98+
dummy_message.content = [MagicMock(text=MagicMock(value=citation_message))]
99+
mock_client.beta.threads.messages.list.return_value.data = [dummy_message]
100+
101+
# Patch send_callback and invoke process_run.
102+
with patch("src.app.api.v1.threads.send_callback") as mock_send_callback:
103+
process_run(request, mock_client)
104+
mock_send_callback.assert_called_once()
105+
callback_url, payload = mock_send_callback.call_args[0]
106+
print(payload)
107+
assert callback_url == request["callback_url"]
108+
assert payload["data"]["message"] == expected_message
109+
assert payload["data"]["status"] == "success"
110+
assert payload["data"]["thread_id"] == "thread_123"
111+
assert payload["success"] is True

backend/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dependencies = [
2121
"pydantic-settings<3.0.0,>=2.2.1",
2222
"sentry-sdk[fastapi]<2.0.0,>=1.40.6",
2323
"pyjwt<3.0.0,>=2.8.0",
24+
"openai>=1.67.0",
25+
"pytest>=7.4.4",
2426
]
2527

2628
[tool.uv]

0 commit comments

Comments
 (0)