-
Notifications
You must be signed in to change notification settings - Fork 10
OpenAI: Threads #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpenAI: Threads #40
Changes from 3 commits
7e4617f
743e5d0
00ebe6f
c24b58f
15aa41c
71bc5a8
8f11fa5
ccf7ed3
2e32a38
5081383
d45984a
941a81c
ad85de8
ee09d3e
5546222
13d913a
146341e
7c79dea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| import openai | ||
| import re | ||
| import requests | ||
| from openai import OpenAI | ||
| from fastapi import APIRouter, BackgroundTasks | ||
| from app.models import ( MessageRequest, AckPayload, CallbackPayload) | ||
| from ...core.config import settings | ||
| from ...core.logger import logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| router = APIRouter(tags=["threads"]) | ||
|
|
||
|
|
||
| def send_callback(callback_url: str, data: dict): | ||
| """Send results to the callback URL (synchronously).""" | ||
| try: | ||
| session = requests.Session() | ||
| response = session.post(callback_url, json=data) | ||
| response.raise_for_status() | ||
| return True | ||
| except requests.RequestException as e: | ||
| logger.error(f"Callback failed: {str(e)}") | ||
| return False | ||
|
|
||
|
|
||
| def build_callback_payload(request: MessageRequest, status: str, message: str) -> CallbackPayload: | ||
| """ | ||
| Helper function to build the CallbackPayload from a MessageRequest. | ||
| """ | ||
| data = { | ||
| "status": status, | ||
| "message": message, | ||
| "thread_id": request.thread_id, | ||
| "endpoint": getattr(request, "endpoint", "some-default-endpoint"), | ||
| } | ||
|
|
||
| # Update with any additional fields from request that we haven't excluded | ||
| data.update( | ||
| request.model_dump(exclude={"question", "assistant_id", "callback_url", "thread_id"}) | ||
| ) | ||
|
|
||
| return CallbackPayload(**data) | ||
|
|
||
|
|
||
| def process_run(request: MessageRequest, client: OpenAI): | ||
| """ | ||
| Background task to run create_and_poll, then send the callback with the result. | ||
| This function is run in the background after we have already returned an initial response. | ||
| """ | ||
| try: | ||
| # Start the run | ||
| run = client.beta.threads.runs.create_and_poll( | ||
| thread_id=request.thread_id, | ||
| assistant_id=request.assistant_id, | ||
| ) | ||
|
|
||
| if run.status == "completed": | ||
| messages = client.beta.threads.messages.list(thread_id=request.thread_id) | ||
| latest_message = messages.data[0] | ||
| message_content = latest_message.content[0].text.value | ||
|
|
||
| if request.remove_citation: | ||
| message = re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content) | ||
| else: | ||
| message = message_content | ||
| callback_response = build_callback_payload( | ||
| request=request, status="success", message=message | ||
| ) | ||
| else: | ||
| callback_response = build_callback_payload( | ||
| request=request, status="error", message=f"Run failed with status: {run.status}" | ||
| ) | ||
|
|
||
| # Send callback with results | ||
| send_callback(request.callback_url, callback_response.model_dump()) | ||
|
|
||
| except openai.OpenAIError as e: | ||
| # Handle any other OpenAI API errors | ||
| if isinstance(e.body, dict) and "message" in e.body: | ||
| error_message = e.body["message"] | ||
| else: | ||
| error_message = str(e) | ||
|
|
||
| callback_response = build_callback_payload( | ||
| request=request, status="error", message=error_message | ||
| ) | ||
|
|
||
| send_callback(request.callback_url, callback_response.model_dump()) | ||
|
|
||
|
|
||
| def validate_assistant_id(assistant_id: str, client: OpenAI): | ||
| try: | ||
| client.beta.assistants.retrieve(assistant_id=assistant_id) | ||
| except openai.NotFoundError: | ||
| return AckPayload( | ||
| status="error", | ||
| message=f"Invalid assistant ID provided {assistant_id}", | ||
| success=False, | ||
| ) | ||
| return None | ||
jerome-white marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @router.post("/threads") | ||
AkhileshNegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def threads(request: MessageRequest, background_tasks: BackgroundTasks): | ||
| """ | ||
| Accepts a question, assistant_id, callback_url, and optional thread_id from the request body. | ||
| Returns an immediate "processing" response, then continues to run create_and_poll in background. | ||
| Once completed, calls send_callback with the final result. | ||
| """ | ||
| client = OpenAI(api_key=settings.OPENAI_API_KEY) | ||
AkhileshNegi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| assistant_error = validate_assistant_id(request.assistant_id, client) | ||
| if assistant_error: | ||
| return assistant_error | ||
|
|
||
| # 1. Validate or check if there's an existing thread with an in-progress run | ||
| if request.thread_id: | ||
| try: | ||
| runs = client.beta.threads.runs.list(thread_id=request.thread_id) | ||
| # Get the most recent run (first in the list) if any | ||
| if runs.data and len(runs.data) > 0: | ||
| latest_run = runs.data[0] | ||
| if latest_run.status in ["queued", "in_progress", "requires_action"]: | ||
| return { | ||
| "status": "error", | ||
| "message": f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete.", | ||
| } | ||
| except openai.NotFoundError: | ||
| # Handle invalid thread ID | ||
| return AckPayload( | ||
| status="error", | ||
| message=f"Invalid thread ID provided {request.thread_id}", | ||
| success=False, | ||
| ) | ||
|
|
||
| # Use existing thread | ||
| client.beta.threads.messages.create( | ||
| thread_id=request.thread_id, role="user", content=request.question | ||
| ) | ||
| else: | ||
| try: | ||
| # Create new thread | ||
| thread = client.beta.threads.create() | ||
| client.beta.threads.messages.create( | ||
| thread_id=thread.id, role="user", content=request.question | ||
| ) | ||
| request.thread_id = thread.id | ||
| except openai.OpenAIError as e: | ||
| # Handle any other OpenAI API errors | ||
| if isinstance(e.body, dict) and "message" in e.body: | ||
| error_message = e.body["message"] | ||
| else: | ||
| error_message = str(e) | ||
| return AckPayload( | ||
| status="error", | ||
| message=error_message, | ||
| success=False, | ||
| ) | ||
|
|
||
| # 2. Send immediate response to complete the API call | ||
| initial_response = AckPayload( | ||
| status="processing", | ||
| message="Run started", | ||
| thread_id=request.thread_id, | ||
| success=True, | ||
| ) | ||
|
|
||
| # 3. Schedule the background task to run create_and_poll and send callback | ||
| background_tasks.add_task(process_run, request, client) | ||
|
|
||
| # 4. Return immediately so the client knows we've accepted the request | ||
| return initial_response | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| import logging | ||
| import os | ||
| from logging.handlers import RotatingFileHandler | ||
|
|
||
| LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") | ||
jerome-white marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not os.path.exists(LOG_DIR): | ||
| os.makedirs(LOG_DIR) | ||
|
Comment on lines
+7
to
+8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try:
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
except (OSError, IOError) as e:
import sys
sys.stderr.write(f'Failed to create log directory: {e}\n')
raiseAdd error handling around file operations to gracefully handle permission issues or disk space problems when creating log directory and files Talk to Kody by mentioning @kody Was this suggestion helpful? React with 👍 or 👎 to help Kody learn from this interaction. |
||
|
|
||
| LOG_FILE_PATH = os.path.join(LOG_DIR, "app.log") | ||
|
|
||
| LOGGING_LEVEL = logging.INFO | ||
| LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||
|
|
||
| logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT) | ||
|
|
||
| file_handler = RotatingFileHandler(LOG_FILE_PATH, maxBytes=10485760, backupCount=5) | ||
| file_handler.setLevel(LOGGING_LEVEL) | ||
| file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) | ||
|
|
||
| logging.getLogger("").addHandler(file_handler) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,4 +11,5 @@ | |
| UserUpdateMe, | ||
| NewPassword, | ||
| UpdatePassword, | ||
| ) | ||
| ) | ||
| from .thread import MessageRequest, AckPayload, CallbackPayload | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from pydantic import BaseModel, ConfigDict | ||
| from typing import Optional | ||
|
|
||
|
|
||
| class MessageRequest(BaseModel): | ||
| model_config = ConfigDict(extra="allow") | ||
|
|
||
| question: str | ||
| assistant_id: str | ||
| callback_url: str | ||
| thread_id: Optional[str] = None | ||
| remove_citation: Optional[bool] = False | ||
|
|
||
|
|
||
| class AckPayload(BaseModel): | ||
| status: str | ||
| message: str | ||
| success: bool | ||
| thread_id: Optional[str] = None | ||
|
|
||
|
|
||
| class CallbackPayload(BaseModel): | ||
| model_config = ConfigDict(extra="allow") | ||
|
|
||
| status: str | ||
| message: str | ||
| thread_id: str | ||
| endpoint: str |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import pytest | ||
| import openai | ||
| from unittest.mock import MagicMock, patch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Standard library imports should come first. See PEP8 |
||
| from fastapi import FastAPI | ||
| from fastapi.testclient import TestClient | ||
| from app.api.routes.threads import router, process_run, validate_assistant_id | ||
| from models.thread import MessageRequest, AckPayload | ||
|
|
||
| # Wrap the router in a FastAPI app instance. | ||
| app = FastAPI() | ||
| app.include_router(router) | ||
| client = TestClient(app) | ||
|
|
||
|
|
||
| @patch("src.app.api.v1.threads.OpenAI") | ||
| def test_threads_endpoint(mock_openai): | ||
| """ | ||
| Test the /threads endpoint when creating a new thread. | ||
| The patched OpenAI client simulates: | ||
| - A successful assistant ID validation. | ||
| - New thread creation with a dummy thread id. | ||
| - No existing runs. | ||
| The expected response should have status "processing" and include a thread_id. | ||
| """ | ||
| # Create a dummy client to simulate OpenAI API behavior. | ||
| dummy_client = MagicMock() | ||
| # Simulate a valid assistant ID by ensuring retrieve doesn't raise an error. | ||
| dummy_client.beta.assistants.retrieve.return_value = None | ||
| # Simulate thread creation. | ||
| dummy_thread = MagicMock() | ||
| dummy_thread.id = "dummy_thread_id" | ||
| dummy_client.beta.threads.create.return_value = dummy_thread | ||
| # Simulate message creation. | ||
| dummy_client.beta.threads.messages.create.return_value = None | ||
| # Simulate that no active run exists. | ||
| dummy_client.beta.threads.runs.list.return_value = MagicMock(data=[]) | ||
|
|
||
| mock_openai.return_value = dummy_client | ||
|
|
||
| request_data = { | ||
| "question": "What is Glific?", | ||
| "assistant_id": "assistant_123", | ||
| "callback_url": "http://example.com/callback", | ||
| } | ||
| response = client.post("/threads", json=request_data) | ||
| assert response.status_code == 200 | ||
| response_json = response.json() | ||
| assert response_json["status"] == "processing" | ||
| assert response_json["message"] == "Run started" | ||
| assert response_json["success"] == True | ||
| assert response_json["thread_id"] == "dummy_thread_id" | ||
|
|
||
|
|
||
| @patch("src.app.api.v1.threads.OpenAI") | ||
| @pytest.mark.parametrize( | ||
| "remove_citation, expected_message", | ||
| [ | ||
| ( | ||
| True, | ||
| "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp", | ||
| ), | ||
| ( | ||
| False, | ||
| "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp【1:2†citation】", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_process_run_variants(mock_openai, remove_citation, expected_message): | ||
| """ | ||
| Test process_run for both remove_citation variants: | ||
| - Mocks the OpenAI client to simulate a completed run. | ||
| - Verifies that send_callback is called with the expected message based on the remove_citation flag. | ||
| """ | ||
| # Setup the mock client. | ||
| mock_client = MagicMock() | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| # Create the request with the variable remove_citation flag. | ||
| request = MessageRequest( | ||
| question="What is Glific?", | ||
| assistant_id="assistant_123", | ||
| callback_url="http://example.com/callback", | ||
| thread_id="thread_123", | ||
| remove_citation=remove_citation, | ||
| ) | ||
|
|
||
| # Simulate a completed run. | ||
| mock_run = MagicMock() | ||
| mock_run.status = "completed" | ||
| mock_client.beta.threads.runs.create_and_poll.return_value = mock_run | ||
|
|
||
| # Set up the dummy message based on the remove_citation flag. | ||
| base_message = "Glific is an open-source, two-way messaging platform designed for nonprofits to scale their outreach via WhatsApp" | ||
| citation_message = base_message if remove_citation else f"{base_message}【1:2†citation】" | ||
| dummy_message = MagicMock() | ||
| dummy_message.content = [MagicMock(text=MagicMock(value=citation_message))] | ||
| mock_client.beta.threads.messages.list.return_value.data = [dummy_message] | ||
|
|
||
| # Patch send_callback and invoke process_run. | ||
| with patch("src.app.api.v1.threads.send_callback") as mock_send_callback: | ||
| process_run(request, mock_client) | ||
| mock_send_callback.assert_called_once() | ||
| callback_url, payload = mock_send_callback.call_args[0] | ||
| print(payload) | ||
| assert callback_url == request.callback_url | ||
| assert payload.get("message", "") == expected_message | ||
| assert payload.get("status", "") == "success" | ||
| assert payload.get("thread_id", "") == "thread_123" | ||
|
|
||
|
|
||
| @patch("src.app.api.v1.threads.OpenAI") | ||
| def test_validate_assistant_id(mock_openai): | ||
| """ | ||
| Test validate_assistant_id: | ||
| - For a valid assistant ID, it should return None. | ||
| - For an invalid assistant ID, it should return an AckPayload with an error. | ||
| """ | ||
| mock_client = MagicMock() | ||
| mock_openai.return_value = mock_client | ||
|
|
||
| # Simulate a valid assistant ID. | ||
| result_valid = validate_assistant_id("valid_assistant_id", mock_client) | ||
| assert result_valid is None | ||
|
|
||
| # Simulate an invalid assistant ID by raising NotFoundError with required kwargs. | ||
| mock_client.beta.assistants.retrieve.side_effect = openai.NotFoundError( | ||
| "Not found", response=MagicMock(), body={"message": "Not found"} | ||
| ) | ||
| ack_payload = validate_assistant_id("invalid_assistant_id", mock_client) | ||
| assert isinstance(ack_payload, AckPayload) | ||
| assert ack_payload.status == "error" | ||
| assert "invalid_assistant_id" in ack_payload.message | ||
Uh oh!
There was an error while loading. Please reload this page.