diff --git a/backend/app/alembic/versions/8eefcfedc409_create_assistant_table.py b/backend/app/alembic/versions/8eefcfedc409_create_assistant_table.py new file mode 100644 index 000000000..6dad89272 --- /dev/null +++ b/backend/app/alembic/versions/8eefcfedc409_create_assistant_table.py @@ -0,0 +1,44 @@ +"""create assistant table + +Revision ID: 8757b005d681 +Revises: 8e7dc5eab0b0 +Create Date: 2025-06-16 13:40:10.447538 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "8757b005d681" +down_revision = "8e7dc5eab0b0" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "openai_assistant", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("assistant_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("name", sa.VARCHAR(length=255), nullable=False), + sa.Column("max_num_results", sa.Integer, nullable=False), + sa.Column("model", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("instructions", sa.Text(), nullable=False), + sa.Column("temperature", sa.Float(), nullable=False), + sa.Column("vector_store_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("organization_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade(): + op.drop_table("openai_assistant") diff --git a/backend/app/api/routes/responses.py b/backend/app/api/routes/responses.py index b6151c23a..2b66fc888 100644 --- a/backend/app/api/routes/responses.py +++ b/backend/app/api/routes/responses.py @@ -1,16 +1,21 @@ -from typing import Optional +from typing import Optional, Dict, Any +import logging import openai -from pydantic import BaseModel -from fastapi import APIRouter, Depends +from pydantic import BaseModel, Extra +from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException from openai import OpenAI from sqlmodel import Session from app.api.deps import get_current_user_org, get_db +from app.api.routes.threads import send_callback from app.crud.credentials import get_provider_credential +from app.crud.assistants import get_assistant_by_id from app.models import UserOrganization from app.utils import APIResponse +logger = logging.getLogger(__name__) + router = APIRouter(tags=["responses"]) @@ -23,14 +28,25 @@ def handle_openai_error(e: openai.OpenAIError) -> str: class ResponsesAPIRequest(BaseModel): project_id: int + assistant_id: str + question: str + callback_url: Optional[str] = None + response_id: Optional[str] = None + + class Config: + extra = ( + Extra.allow + ) # This allows additional fields to be included in the request + +class ResponsesSyncAPIRequest(BaseModel): + project_id: int model: str instructions: str vector_store_ids: list[str] max_num_results: Optional[int] = 20 temperature: Optional[float] = 0.1 response_id: Optional[str] = None - question: str @@ -38,7 +54,6 @@ class Diagnostics(BaseModel): input_tokens: int output_tokens: int total_tokens: int - model: str @@ -49,11 +64,9 @@ class FileResultChunk(BaseModel): class _APIResponse(BaseModel): status: str - response_id: str message: str chunks: list[FileResultChunk] - diagnostics: Optional[Diagnostics] = None @@ -73,14 +86,165 @@ def get_file_search_results(response): return results +def get_additional_data(request: dict) -> dict: + """Extract additional data from request, excluding specific keys.""" + return { + k: v + for k, v in request.items() + if k + not in {"project_id", "assistant_id", "callback_url", "response_id", "question"} + } + + +def process_response( + request: ResponsesAPIRequest, client: OpenAI, assistant, organization_id: int +): + """Process a response and send callback with results.""" + logger.info( + f"[responses.process_response] Starting generating response for assistant_id={request.assistant_id}, project_id={request.project_id}, organization_id={organization_id}" + ) + try: + response = client.responses.create( + model=assistant.model, + previous_response_id=request.response_id, + instructions=assistant.instructions, + tools=[ + { + "type": "file_search", + "vector_store_ids": [assistant.vector_store_id], + "max_num_results": assistant.max_num_results, + } + ], + temperature=assistant.temperature, + input=[{"role": "user", "content": request.question}], + include=["file_search_call.results"], + ) + response_chunks = get_file_search_results(response) + logger.info( + f"[responses.process_response] Successfully generated response: response_id={response.id}, assistant={request.assistant_id}, project_id={request.project_id}, organization_id={organization_id}" + ) + + # Convert request to dict and include all fields + request_dict = request.model_dump() + callback_response = ResponsesAPIResponse.success_response( + data=_APIResponse( + status="success", + response_id=response.id, + message=response.output_text, + chunks=response_chunks, + diagnostics=Diagnostics( + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + model=response.model, + ), + **{ + k: v + for k, v in request_dict.items() + if k + not in { + "project_id", + "assistant_id", + "callback_url", + "response_id", + "question", + } + }, + ), + ) + except openai.OpenAIError as e: + error_message = handle_openai_error(e) + logger.error( + f"[responses.process_response] OpenAI API error during response processing: {error_message}, project_id={request.project_id}, organization_id={organization_id}" + ) + callback_response = ResponsesAPIResponse.failure_response(error=error_message) + + if request.callback_url: + logger.info( + f"[responses.process_response] Sending callback to URL: {request.callback_url}, assistant={request.assistant_id}, project_id={request.project_id}, organization_id={organization_id}" + ) + + send_callback(request.callback_url, callback_response.model_dump()) + logger.info( + f"[responses.process_response] Callback sent successfully, assistant={request.assistant_id}, project_id={request.project_id}, organization_id={organization_id}" + ) + + +@router.post("/responses", response_model=dict) +async def responses( + request: ResponsesAPIRequest, + background_tasks: BackgroundTasks, + _session: Session = Depends(get_db), + _current_user: UserOrganization = Depends(get_current_user_org), +): + """Asynchronous endpoint that processes requests in background.""" + logger.info( + f"[responses.responses] Processing response request for assistant_id={request.assistant_id}, project_id={request.project_id}, organization_id={_current_user.organization_id}" + ) + + # Get assistant details + assistant = get_assistant_by_id( + _session, request.assistant_id, _current_user.organization_id + ) + if not assistant: + logger.error( + f"[responses.responses] Assistant not found: assistant_id={request.assistant_id}, project_id={request.project_id}, organization_id={_current_user.organization_id}" + ) + raise HTTPException( + status_code=404, + detail="Assistant not found or not active", + ) + + credentials = get_provider_credential( + session=_session, + org_id=_current_user.organization_id, + provider="openai", + project_id=request.project_id, + ) + if not credentials or "api_key" not in credentials: + logger.error( + f"[responses.responses] OpenAI API key not configured for org_id={_current_user.organization_id}, project_id={request.project_id}, organization_id={_current_user.organization_id}" + ) + return { + "success": False, + "error": "OpenAI API key not configured for this organization.", + "data": None, + "metadata": None, + } + + client = OpenAI(api_key=credentials["api_key"]) + + # Send immediate response + initial_response = { + "success": True, + "data": { + "status": "processing", + "message": "Response creation started", + "success": True, + }, + "error": None, + "metadata": None, + } + + # Schedule background task + background_tasks.add_task( + process_response, request, client, assistant, _current_user.organization_id + ) + logger.info( + f"[responses.responses] Background task scheduled for response processing: assistant_id={request.assistant_id}, project_id={request.project_id}, organization_id={_current_user.organization_id}" + ) + + return initial_response + + @router.post("/responses/sync", response_model=ResponsesAPIResponse) async def responses_sync( - request: ResponsesAPIRequest, + request: ResponsesSyncAPIRequest, _session: Session = Depends(get_db), _current_user: UserOrganization = Depends(get_current_user_org), ): """ - Temp synchronous endpoint for benchmarking OpenAI responses API + Synchronous endpoint for benchmarking OpenAI responses API """ credentials = get_provider_credential( session=_session, diff --git a/backend/app/api/routes/threads.py b/backend/app/api/routes/threads.py index 4b621a458..66bc35ee2 100644 --- a/backend/app/api/routes/threads.py +++ b/backend/app/api/routes/threads.py @@ -13,7 +13,7 @@ from app.crud import upsert_thread_result, get_thread_result from app.utils import APIResponse from app.crud.credentials import get_provider_credential -from app.core.security import decrypt_credentials +from app.core.util import configure_langfuse, configure_openai logger = logging.getLogger(__name__) router = APIRouter(tags=["threads"]) @@ -59,7 +59,6 @@ def validate_thread(client: OpenAI, thread_id: str) -> tuple[bool, str]: return False, f"Invalid thread ID provided {thread_id}" -@observe(capture_input=False) def setup_thread(client: OpenAI, request: dict) -> tuple[bool, str]: """Set up thread and add message, either creating new or using existing.""" thread_id = request.get("thread_id") @@ -78,9 +77,6 @@ def setup_thread(client: OpenAI, request: dict) -> tuple[bool, str]: thread_id=thread.id, role="user", content=request["question"] ) request["thread_id"] = thread.id - langfuse_context.update_current_trace( - session_id=thread.id, name="New Thread ID created", output=thread.id - ) return True, None except openai.OpenAIError as e: return False, handle_openai_error(e) @@ -135,8 +131,8 @@ def extract_response_from_thread( @observe(as_type="generation") -def process_run(request: dict, client: OpenAI): - """Process a run and send callback with results.""" +def process_run_core(request: dict, client: OpenAI) -> tuple[dict, str]: + """Core function to process a run and return the response and message.""" try: run = client.beta.threads.runs.create_and_poll( thread_id=request["thread_id"], @@ -166,18 +162,29 @@ def process_run(request: dict, client: OpenAI): langfuse_context.update_current_trace( output=message, name="Thread Run Completed" ) + diagnostics = { + "input_tokens": run.usage.prompt_tokens, + "output_tokens": run.usage.completion_tokens, + "total_tokens": run.usage.total_tokens, + "model": run.model, + } + request = {**request, **{"diagnostics": diagnostics}} - callback_response = create_success_response(request, message) + return create_success_response(request, message).model_dump(), None else: - callback_response = APIResponse.failure_response( - error=f"Run failed with status: {run.status}" - ) - - send_callback(request["callback_url"], callback_response.model_dump()) + error_msg = f"Run failed with status: {run.status}" + return APIResponse.failure_response(error=error_msg).model_dump(), error_msg except openai.OpenAIError as e: - callback_response = APIResponse.failure_response(error=handle_openai_error(e)) - send_callback(request["callback_url"], callback_response.model_dump()) + error_msg = handle_openai_error(e) + return APIResponse.failure_response(error=error_msg).model_dump(), error_msg + + +@observe(as_type="generation") +def process_run(request: dict, client: OpenAI): + """Process a run and send callback with results.""" + response, _ = process_run_core(request, client) + send_callback(request["callback_url"], response) def poll_run_and_prepare_response(request: dict, client: OpenAI, db: Session): @@ -228,10 +235,11 @@ async def threads( provider="openai", project_id=request.get("project_id"), ) - if not credentials or "api_key" not in credentials: - raise HTTPException(404, "OpenAI API key not configured for this organization.") - - client = OpenAI(api_key=credentials["api_key"]) + client, success = configure_openai(credentials) + if not success: + return APIResponse.failure_response( + error="OpenAI API key not configured for this organization." + ) langfuse_credentials = get_provider_credential( session=_session, @@ -242,11 +250,13 @@ async def threads( if not langfuse_credentials: raise HTTPException(404, "LANGFUSE keys not configured for this organization.") - langfuse_context.configure( - secret_key=langfuse_credentials["secret_key"], - public_key=langfuse_credentials["public_key"], - host=langfuse_credentials["host"], - ) + # Configure Langfuse + _, success = configure_langfuse(langfuse_credentials) + if not success: + return APIResponse.failure_response( + error="Failed to configure Langfuse client." + ) + # Validate thread is_valid, error_message = validate_thread(client, request.get("thread_id")) if not is_valid: @@ -279,19 +289,38 @@ async def threads_sync( _current_user: UserOrganization = Depends(get_current_user_org), ): """Synchronous endpoint that processes requests immediately.""" - credentials = get_provider_credential( session=_session, org_id=_current_user.organization_id, provider="openai", project_id=request.get("project_id"), ) - if not credentials or "api_key" not in credentials: - raise HTTPException( - 404, error="OpenAI API key not configured for this organization." + + # Configure OpenAI client + client, success = configure_openai(credentials) + if not success: + return APIResponse.failure_response( + error="OpenAI API key not configured for this organization." ) - client = OpenAI(api_key=credentials["api_key"]) + # Get Langfuse credentials + langfuse_credentials = get_provider_credential( + session=_session, + org_id=_current_user.organization_id, + provider="langfuse", + project_id=request.get("project_id"), + ) + if not langfuse_credentials: + return APIResponse.failure_response( + error="LANGFUSE keys not configured for this organization." + ) + + # Configure Langfuse + _, success = configure_langfuse(langfuse_credentials) + if not success: + return APIResponse.failure_response( + error="Failed to configure Langfuse client." + ) # Validate thread is_valid, error_message = validate_thread(client, request.get("thread_id")) @@ -303,36 +332,10 @@ async def threads_sync( raise Exception(error_message) try: - # Process run - run = client.beta.threads.runs.create_and_poll( - thread_id=request["thread_id"], - assistant_id=request["assistant_id"], - ) - - if run.status == "completed": - messages = client.beta.threads.messages.list(thread_id=request["thread_id"]) - latest_message = messages.data[0] - message_content = latest_message.content[0].text.value - message = process_message_content( - message_content, request.get("remove_citation", False) - ) - - diagnostics = { - "input_tokens": run.usage.prompt_tokens, - "output_tokens": run.usage.completion_tokens, - "total_tokens": run.usage.total_tokens, - "model": run.model, - } - request = {**request, **{"diagnostics": diagnostics}} - - return create_success_response(request, message) - else: - return APIResponse.failure_response( - error=f"Run failed with status: {run.status}" - ) - - except openai.OpenAIError as e: - raise Exception(error=handle_openai_error(e)) + response, error_message = process_run_core(request, client) + return response + finally: + langfuse_context.flush() @router.post("/threads/start") @@ -346,7 +349,19 @@ async def start_thread( Create a new OpenAI thread for the given question and start polling in the background. """ prompt = request["question"] - client = OpenAI(api_key=settings.OPENAI_API_KEY) + credentials = get_provider_credential( + session=db, + org_id=_current_user.organization_id, + provider="openai", + project_id=request.get("project_id"), + ) + + # Configure OpenAI client + client, success = configure_openai(credentials) + if not success: + return APIResponse.failure_response( + error="OpenAI API key not configured for this organization." + ) is_success, error = setup_thread(client, request) if not is_success: diff --git a/backend/app/core/util.py b/backend/app/core/util.py index 6f945b9db..c3cd4c934 100644 --- a/backend/app/core/util.py +++ b/backend/app/core/util.py @@ -5,6 +5,9 @@ from fastapi import HTTPException from requests import Session, RequestException from pydantic import BaseModel, HttpUrl +from langfuse import Langfuse +from langfuse.decorators import langfuse_context +from openai import OpenAI def now(): @@ -32,3 +35,59 @@ def post_callback(url: HttpUrl, payload: BaseModel): errno += 1 return not errno + + +def configure_langfuse(credentials: dict) -> tuple[Langfuse, bool]: + """ + Configure Langfuse client and context with the provided credentials. + + Args: + credentials: Dictionary containing Langfuse credentials (public_key, secret_key, host) + + Returns: + Tuple of (Langfuse client instance, success boolean) + """ + if not credentials: + return None, False + + try: + # Configure Langfuse client + langfuse = Langfuse( + public_key=credentials["public_key"], + secret_key=credentials["secret_key"], + host=credentials.get("host", "https://cloud.langfuse.com"), + ) + + # Configure Langfuse context + langfuse_context.configure( + secret_key=credentials["secret_key"], + public_key=credentials["public_key"], + host=credentials.get("host", "https://cloud.langfuse.com"), + ) + + return langfuse, True + except Exception as e: + warnings.warn(f"Failed to configure Langfuse: {str(e)}") + return None, False + + +def configure_openai(credentials: dict) -> tuple[OpenAI, bool]: + """ + Configure OpenAI client with the provided credentials. + + Args: + credentials: Dictionary containing OpenAI credentials (api_key) + + Returns: + Tuple of (OpenAI client instance, success boolean) + """ + if not credentials or "api_key" not in credentials: + return None, False + + try: + # Configure OpenAI client + client = OpenAI(api_key=credentials["api_key"]) + return client, True + except Exception as e: + warnings.warn(f"Failed to configure OpenAI client: {str(e)}") + return None, False diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 671cf267b..6bf4d5da5 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -40,3 +40,5 @@ ) from .thread_results import upsert_thread_result, get_thread_result + +from .assistants import get_assistant_by_id diff --git a/backend/app/crud/assistants.py b/backend/app/crud/assistants.py new file mode 100644 index 000000000..5025c617b --- /dev/null +++ b/backend/app/crud/assistants.py @@ -0,0 +1,18 @@ +from typing import Optional, List, Tuple +from sqlmodel import Session, select, and_ + +from app.core.util import now +from app.models import Assistant + + +def get_assistant_by_id( + session: Session, assistant_id: str, organization_id: int +) -> Optional[Assistant]: + """Get an assistant by its OpenAI assistant ID and organization ID.""" + statement = select(Assistant).where( + and_( + Assistant.assistant_id == assistant_id, + Assistant.organization_id == organization_id, + ) + ) + return session.exec(statement).first() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index f88e019dc..046936371 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -53,3 +53,5 @@ ) from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate + +from .assistants import Assistant, AssistantBase diff --git a/backend/app/models/assistants.py b/backend/app/models/assistants.py new file mode 100644 index 000000000..4163297df --- /dev/null +++ b/backend/app/models/assistants.py @@ -0,0 +1,29 @@ +from datetime import datetime +from typing import Optional, List +from sqlmodel import Field, Relationship, SQLModel + +from app.core.util import now + + +class AssistantBase(SQLModel): + assistant_id: str = Field(index=True, unique=True) + name: str + instructions: str + model: str + vector_store_id: str + temperature: float = 0.1 + max_num_results: int = 20 + project_id: int = Field(foreign_key="project.id") + organization_id: int = Field(foreign_key="organization.id") + + +class Assistant(AssistantBase, table=True): + __tablename__ = "openai_assistant" + + id: int = Field(default=None, primary_key=True) + inserted_at: datetime = Field(default_factory=now, nullable=False) + updated_at: datetime = Field(default_factory=now, nullable=False) + + # Relationships + project: "Project" = Relationship(back_populates="assistants") + organization: "Organization" = Relationship(back_populates="assistants") diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 41862f0ab..7e3e2cd8b 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -9,6 +9,7 @@ from .credentials import Credential from .project import Project from .api_key import APIKey + from .assistants import Assistant # Shared properties for an Organization @@ -44,6 +45,9 @@ class Organization(OrganizationBase, table=True): project: list["Project"] = Relationship( back_populates="organization", sa_relationship_kwargs={"cascade": "all, delete"} ) + assistants: list["Assistant"] = Relationship( + back_populates="organization", sa_relationship_kwargs={"cascade": "all, delete"} + ) # Properties to return via API diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 0d95ab3c4..8a56ec81c 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -37,6 +37,9 @@ class Project(ProjectBase, table=True): creds: list["Credential"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "all, delete"} ) + assistants: list["Assistant"] = Relationship( + back_populates="project", sa_relationship_kwargs={"cascade": "all, delete"} + ) api_keys: list["APIKey"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "all, delete"} ) diff --git a/backend/app/seed_data/seed_data.json b/backend/app/seed_data/seed_data.json index 2be2ca23a..8fb40af52 100644 --- a/backend/app/seed_data/seed_data.json +++ b/backend/app/seed_data/seed_data.json @@ -52,5 +52,28 @@ "is_deleted": false, "deleted_at": null } + ], + "credentials": [ + { + "is_active": true, + "provider": "openai", + "credential": "{\"openai\": {\"api_key\": \"sk-proj-YxK21qI3i5SCxN\"}}", + "project_name": "Glific", + "organization_name": "Project Tech4dev", + "deleted_at": null + } + ], + "assistants": [ + { + "assistant_id": "assistant_123", + "name": "Test Assistant", + "instructions": "Test instructions", + "model": "gpt-4o", + "vector_store_id": "vs_123", + "temperature": 0.1, + "max_num_results": 20, + "project_name": "Glific", + "organization_name": "Project Tech4dev" + } ] } diff --git a/backend/app/seed_data/seed_data.py b/backend/app/seed_data/seed_data.py index e42aadbda..059c07943 100644 --- a/backend/app/seed_data/seed_data.py +++ b/backend/app/seed_data/seed_data.py @@ -9,7 +9,7 @@ from app.core.db import engine from app.core.security import encrypt_api_key, get_password_hash -from app.models import APIKey, Organization, Project, User +from app.models import APIKey, Organization, Project, User, Credential, Assistant # Pydantic models for data validation @@ -43,6 +43,27 @@ class APIKeyData(BaseModel): created_at: Optional[str] = None +class CredentialData(BaseModel): + is_active: bool + provider: str + credential: str + organization_name: str + project_name: str + deleted_at: Optional[str] = None + + +class AssistantData(BaseModel): + assistant_id: str + name: str + instructions: str + model: str + vector_store_id: str + temperature: float + max_num_results: int + project_name: str + organization_name: str + + def load_seed_data() -> dict: """Load seed data from JSON file.""" json_path = Path(__file__).parent / "seed_data.json" @@ -168,13 +189,101 @@ def create_api_key(session: Session, api_key_data_raw: dict) -> APIKey: raise +def create_credential(session: Session, credential_data_raw: dict) -> Credential: + """Create a credential from data.""" + try: + credential_data = CredentialData.model_validate(credential_data_raw) + logging.info(f"Creating credential for provider: {credential_data.provider}") + + # Query organization ID by name + organization = session.exec( + select(Organization).where( + Organization.name == credential_data.organization_name + ) + ).first() + if not organization: + raise ValueError( + f"Organization '{credential_data.organization_name}' not found" + ) + + # Query organization ID by name + project = session.exec( + select(Project).where(Project.name == credential_data.project_name) + ).first() + if not project: + raise ValueError(f"Project '{credential_data.project_name}' not found") + + # Encrypt the credential data + encrypted_credential = encrypt_api_key(credential_data.credential) + + credential = Credential( + is_active=credential_data.is_active, + provider=credential_data.provider, + credential=encrypted_credential, + organization_id=organization.id, + project_id=project.id, + deleted_at=credential_data.deleted_at, + ) + session.add(credential) + session.flush() # Ensure ID is assigned + return credential + except Exception as e: + logging.error(f"Error creating credential: {e}") + raise + + +def create_assistant(session: Session, assistant_data_raw: dict) -> Assistant: + """Create an assistant from data.""" + try: + assistant_data = AssistantData.model_validate(assistant_data_raw) + logging.info(f"Creating assistant: {assistant_data.name}") + + # Query organization ID by name + organization = session.exec( + select(Organization).where( + Organization.name == assistant_data.organization_name + ) + ).first() + if not organization: + raise ValueError( + f"Organization '{assistant_data.organization_name}' not found" + ) + + # Query project ID by name + project = session.exec( + select(Project).where(Project.name == assistant_data.project_name) + ).first() + if not project: + raise ValueError(f"Project '{assistant_data.project_name}' not found") + + assistant = Assistant( + assistant_id=assistant_data.assistant_id, + name=assistant_data.name, + instructions=assistant_data.instructions, + model=assistant_data.model, + vector_store_id=assistant_data.vector_store_id, + temperature=assistant_data.temperature, + max_num_results=assistant_data.max_num_results, + organization_id=organization.id, + project_id=project.id, + ) + session.add(assistant) + session.flush() # Ensure ID is assigned + return assistant + except Exception as e: + logging.error(f"Error creating assistant: {e}") + raise + + def clear_database(session: Session) -> None: """Clear all seeded data from the database.""" logging.info("Clearing existing data...") + session.exec(delete(Assistant)) session.exec(delete(APIKey)) session.exec(delete(Project)) session.exec(delete(Organization)) session.exec(delete(User)) + session.exec(delete(Credential)) session.commit() logging.info("Existing data cleared.") @@ -220,6 +329,22 @@ def seed_database(session: Session) -> None: api_keys.append(api_key) logging.info(f"Created API key (ID: {api_key.id})") + # Create credentials + credentials = [] + for credential_data in seed_data["credentials"]: + credential = create_credential(session, credential_data) + credentials.append(credential) + logging.info( + f"Created credential for provider: {credential.provider} (ID: {credential.id})" + ) + + # Create assistants + assistants = [] + for assistant_data in seed_data.get("assistants", []): + assistant = create_assistant(session, assistant_data) + assistants.append(assistant) + logging.info(f"Created assistant: {assistant.name} (ID: {assistant.id})") + logging.info("Database seeding completed successfully!") session.commit() except Exception as e: diff --git a/backend/app/tests/api/routes/test_responses.py b/backend/app/tests/api/routes/test_responses.py new file mode 100644 index 000000000..40de2d81b --- /dev/null +++ b/backend/app/tests/api/routes/test_responses.py @@ -0,0 +1,72 @@ +from unittest.mock import MagicMock, patch +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlmodel import select + +from app.api.routes.responses import router +from app.models import Project +from app.seed_data.seed_data import seed_database + +# Wrap the router in a FastAPI app instance +app = FastAPI() +app.include_router(router) +client = TestClient(app) + + +@pytest.fixture(scope="function", autouse=True) +def load_seed_data(db): + """Load seed data before each test.""" + seed_database(db) + yield + # Cleanup is handled by the db fixture in conftest.py + + +@patch("app.api.routes.responses.OpenAI") +@patch("app.api.routes.responses.get_provider_credential") +def test_responses_endpoint_success( + mock_get_credential, + mock_openai, + db, +): + """Test the /responses endpoint for successful response creation.""" + # Setup mock credentials + mock_get_credential.return_value = {"api_key": "test_api_key"} + + # Setup mock OpenAI client + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # Setup the mock response object with real values for all used fields + mock_response = MagicMock() + mock_response.id = "mock_response_id" + mock_response.output_text = "Test output" + mock_response.model = "gpt-4o" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.output = [] + mock_client.responses.create.return_value = mock_response + + # Get the Glific project ID (the assistant is created for this project) + glific_project = db.exec(select(Project).where(Project.name == "Glific")).first() + if not glific_project: + pytest.skip("Glific project not found in the database") + + # Use the original API key from seed data (not the encrypted one) + original_api_key = "ApiKey No3x47A5qoIGhm0kVKjQ77dhCqEdWRIQZlEPzzzh7i8" + + headers = {"X-API-KEY": original_api_key} + request_data = { + "project_id": glific_project.id, + "assistant_id": "assistant_123", + "question": "What is Glific?", + "callback_url": "http://example.com/callback", + } + + response = client.post("/responses", json=request_data, headers=headers) + assert response.status_code == 200 + response_json = response.json() + assert response_json["success"] is True + assert response_json["data"]["status"] == "processing" + assert response_json["data"]["message"] == "Response creation started" diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 8ea01daca..a68c3eca0 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -9,6 +9,7 @@ from app.main import app from app.models import ( APIKey, + Assistant, Organization, Project, ProjectUser, @@ -27,6 +28,7 @@ def db() -> Generator[Session, None, None]: yield session # Delete data in reverse dependency order session.execute(delete(ProjectUser)) # Many-to-many relationship + session.execute(delete(Assistant)) session.execute(delete(Credential)) session.execute(delete(Project)) session.execute(delete(Organization))