Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from fastapi import APIRouter

from app.api.routes import items, login, private, users, utils
from app.api.routes import items, login, private, users, utils, threads
from app.core.config import settings

api_router = APIRouter()
api_router.include_router(login.router)
api_router.include_router(users.router)
api_router.include_router(utils.router)
api_router.include_router(items.router)
api_router.include_router(threads.router)


if settings.ENVIRONMENT == "local":
Expand Down
172 changes: 172 additions & 0 deletions backend/app/api/routes/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import openai
import re
import requests
from openai import OpenAI
from fastapi import APIRouter, BackgroundTasks
from app.models import ( MessageRequest, AckPayload, CallbackPayload)
from ...core.config import settings
from ...core.logger import logging

logger = logging.getLogger(__name__)
router = APIRouter(tags=["threads"])


def send_callback(callback_url: str, data: dict):
"""Send results to the callback URL (synchronously)."""
try:
session = requests.Session()
response = session.post(callback_url, json=data)
response.raise_for_status()
return True
except requests.RequestException as e:
logger.error(f"Callback failed: {str(e)}")
return False


def build_callback_payload(request: MessageRequest, status: str, message: str) -> CallbackPayload:
"""
Helper function to build the CallbackPayload from a MessageRequest.
"""
data = {
"status": status,
"message": message,
"thread_id": request.thread_id,
"endpoint": getattr(request, "endpoint", "some-default-endpoint"),
}

# Update with any additional fields from request that we haven't excluded
data.update(
request.model_dump(exclude={"question", "assistant_id", "callback_url", "thread_id"})
)

return CallbackPayload(**data)


def process_run(request: MessageRequest, client: OpenAI):
"""
Background task to run create_and_poll, then send the callback with the result.
This function is run in the background after we have already returned an initial response.
"""
try:
# Start the run
run = client.beta.threads.runs.create_and_poll(
thread_id=request.thread_id,
assistant_id=request.assistant_id,
)

if run.status == "completed":
messages = client.beta.threads.messages.list(thread_id=request.thread_id)
latest_message = messages.data[0]
message_content = latest_message.content[0].text.value

if request.remove_citation:
message = re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content)
else:
message = message_content
callback_response = build_callback_payload(
request=request, status="success", message=message
)
else:
callback_response = build_callback_payload(
request=request, status="error", message=f"Run failed with status: {run.status}"
)

# Send callback with results
send_callback(request.callback_url, callback_response.model_dump())

except openai.OpenAIError as e:
# Handle any other OpenAI API errors
if isinstance(e.body, dict) and "message" in e.body:
error_message = e.body["message"]
else:
error_message = str(e)

callback_response = build_callback_payload(
request=request, status="error", message=error_message
)

send_callback(request.callback_url, callback_response.model_dump())


def validate_assistant_id(assistant_id: str, client: OpenAI):
try:
client.beta.assistants.retrieve(assistant_id=assistant_id)
except openai.NotFoundError:
return AckPayload(
status="error",
message=f"Invalid assistant ID provided {assistant_id}",
success=False,
)
return None


@router.post("/threads")
async def threads(request: MessageRequest, background_tasks: BackgroundTasks):
"""
Accepts a question, assistant_id, callback_url, and optional thread_id from the request body.
Returns an immediate "processing" response, then continues to run create_and_poll in background.
Once completed, calls send_callback with the final result.
"""
client = OpenAI(api_key=settings.OPENAI_API_KEY)

assistant_error = validate_assistant_id(request.assistant_id, client)
if assistant_error:
return assistant_error

# 1. Validate or check if there's an existing thread with an in-progress run
if request.thread_id:
try:
runs = client.beta.threads.runs.list(thread_id=request.thread_id)
# Get the most recent run (first in the list) if any
if runs.data and len(runs.data) > 0:
latest_run = runs.data[0]
if latest_run.status in ["queued", "in_progress", "requires_action"]:
return {
"status": "error",
"message": f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete.",
}
except openai.NotFoundError:
# Handle invalid thread ID
return AckPayload(
status="error",
message=f"Invalid thread ID provided {request.thread_id}",
success=False,
)

# Use existing thread
client.beta.threads.messages.create(
thread_id=request.thread_id, role="user", content=request.question
)
else:
try:
# Create new thread
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=request.question
)
request.thread_id = thread.id
except openai.OpenAIError as e:
# Handle any other OpenAI API errors
if isinstance(e.body, dict) and "message" in e.body:
error_message = e.body["message"]
else:
error_message = str(e)
return AckPayload(
status="error",
message=error_message,
success=False,
)

# 2. Send immediate response to complete the API call
initial_response = AckPayload(
status="processing",
message="Run started",
thread_id=request.thread_id,
success=True,
)

# 3. Schedule the background task to run create_and_poll and send callback
background_tasks.add_task(process_run, request, client)

# 4. Return immediately so the client knows we've accepted the request
return initial_response
1 change: 1 addition & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Settings(BaseSettings):
env_ignore_empty=True,
extra="ignore",
)
OPENAI_API_KEY: str
API_V1_STR: str = "/api/v1"
SECRET_KEY: str = secrets.token_urlsafe(32)
# 60 minutes * 24 hours * 8 days = 8 days
Expand Down
20 changes: 20 additions & 0 deletions backend/app/core/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
import os
from logging.handlers import RotatingFileHandler

LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
Comment on lines +7 to +8
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kody code-review Error Handling high

try:
    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR)
except (OSError, IOError) as e:
    import sys
    sys.stderr.write(f'Failed to create log directory: {e}\n')
    raise

Add error handling around file operations to gracefully handle permission issues or disk space problems when creating log directory and files

Talk to Kody by mentioning @kody

Was this suggestion helpful? React with 👍 or 👎 to help Kody learn from this interaction.


LOG_FILE_PATH = os.path.join(LOG_DIR, "app.log")

LOGGING_LEVEL = logging.INFO
LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT)

file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=10485760, backupCount=5)
file_handler.setLevel(LOGGING_LEVEL)
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))

logging.getLogger("").addHandler(file_handler)
3 changes: 2 additions & 1 deletion backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
UserUpdateMe,
NewPassword,
UpdatePassword,
)
)
from .thread import MessageRequest, AckPayload, CallbackPayload
28 changes: 28 additions & 0 deletions backend/app/models/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pydantic import BaseModel, ConfigDict
from typing import Optional


class MessageRequest(BaseModel):
model_config = ConfigDict(extra="allow")

question: str
assistant_id: str
callback_url: str
thread_id: Optional[str] = None
remove_citation: Optional[bool] = False


class AckPayload(BaseModel):
status: str
message: str
success: bool
thread_id: Optional[str] = None


class CallbackPayload(BaseModel):
model_config = ConfigDict(extra="allow")

status: str
message: str
thread_id: str
endpoint: str
132 changes: 132 additions & 0 deletions backend/app/tests/api/routes/test_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import openai
from unittest.mock import MagicMock, patch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standard library imports should come first. See PEP8

from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.api.routes.threads import router, process_run, validate_assistant_id
from models.thread import MessageRequest, AckPayload

# Wrap the router in a FastAPI app instance.
app = FastAPI()
app.include_router(router)
client = TestClient(app)


@patch("src.app.api.v1.threads.OpenAI")
def test_threads_endpoint(mock_openai):
"""
Test the /threads endpoint when creating a new thread.
The patched OpenAI client simulates:
- A successful assistant ID validation.
- New thread creation with a dummy thread id.
- No existing runs.
The expected response should have status "processing" and include a thread_id.
"""
# Create a dummy client to simulate OpenAI API behavior.
dummy_client = MagicMock()
# Simulate a valid assistant ID by ensuring retrieve doesn't raise an error.
dummy_client.beta.assistants.retrieve.return_value = None
# Simulate thread creation.
dummy_thread = MagicMock()
dummy_thread.id = "dummy_thread_id"
dummy_client.beta.threads.create.return_value = dummy_thread
# Simulate message creation.
dummy_client.beta.threads.messages.create.return_value = None
# Simulate that no active run exists.
dummy_client.beta.threads.runs.list.return_value = MagicMock(data=[])

mock_openai.return_value = dummy_client

request_data = {
"question": "What is Glific?",
"assistant_id": "assistant_123",
"callback_url": "http://example.com/callback",
}
response = client.post("/threads", json=request_data)
assert response.status_code == 200
response_json = response.json()
assert response_json["status"] == "processing"
assert response_json["message"] == "Run started"
assert response_json["success"] == True
assert response_json["thread_id"] == "dummy_thread_id"


@patch("src.app.api.v1.threads.OpenAI")
@pytest.mark.parametrize(
"remove_citation, expected_message",
[
(
True,
"Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp",
),
(
False,
"Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp【1:2†citation】",
),
],
)
def test_process_run_variants(mock_openai, remove_citation, expected_message):
"""
Test process_run for both remove_citation variants:
- Mocks the OpenAI client to simulate a completed run.
- Verifies that send_callback is called with the expected message based on the remove_citation flag.
"""
# Setup the mock client.
mock_client = MagicMock()
mock_openai.return_value = mock_client

# Create the request with the variable remove_citation flag.
request = MessageRequest(
question="What is Glific?",
assistant_id="assistant_123",
callback_url="http://example.com/callback",
thread_id="thread_123",
remove_citation=remove_citation,
)

# Simulate a completed run.
mock_run = MagicMock()
mock_run.status = "completed"
mock_client.beta.threads.runs.create_and_poll.return_value = mock_run

# Set up the dummy message based on the remove_citation flag.
base_message = "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp"
citation_message = base_message if remove_citation else f"{base_message}【1:2†citation】"
dummy_message = MagicMock()
dummy_message.content = [MagicMock(text=MagicMock(value=citation_message))]
mock_client.beta.threads.messages.list.return_value.data = [dummy_message]

# Patch send_callback and invoke process_run.
with patch("src.app.api.v1.threads.send_callback") as mock_send_callback:
process_run(request, mock_client)
mock_send_callback.assert_called_once()
callback_url, payload = mock_send_callback.call_args[0]
print(payload)
assert callback_url == request.callback_url
assert payload.get("message", "") == expected_message
assert payload.get("status", "") == "success"
assert payload.get("thread_id", "") == "thread_123"


@patch("src.app.api.v1.threads.OpenAI")
def test_validate_assistant_id(mock_openai):
"""
Test validate_assistant_id:
- For a valid assistant ID, it should return None.
- For an invalid assistant ID, it should return an AckPayload with an error.
"""
mock_client = MagicMock()
mock_openai.return_value = mock_client

# Simulate a valid assistant ID.
result_valid = validate_assistant_id("valid_assistant_id", mock_client)
assert result_valid is None

# Simulate an invalid assistant ID by raising NotFoundError with required kwargs.
mock_client.beta.assistants.retrieve.side_effect = openai.NotFoundError(
"Not found", response=MagicMock(), body={"message": "Not found"}
)
ack_payload = validate_assistant_id("invalid_assistant_id", mock_client)
assert isinstance(ack_payload, AckPayload)
assert ack_payload.status == "error"
assert "invalid_assistant_id" in ack_payload.message
Loading