Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ DOCKER_IMAGE_FRONTEND=frontend
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_DEFAULT_REGION=ap-south-1
AWS_S3_BUCKET_PREFIX = "bucket-prefix-name"
AWS_S3_BUCKET_PREFIX="bucket-prefix-name"

# OpenAI

Expand Down
50 changes: 27 additions & 23 deletions backend/app/api/routes/collections.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import inspect
import logging
import time
import warnings
from uuid import UUID, uuid4
from typing import Any, List, Optional
from dataclasses import dataclass, field, fields, asdict, replace

from openai import OpenAI, OpenAIError
from openai import OpenAIError, OpenAI
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from fastapi import Path as FastPath
from pydantic import BaseModel, Field, HttpUrl
from sqlalchemy.exc import NoResultFound, MultipleResultsFound, SQLAlchemyError
from sqlalchemy.exc import SQLAlchemyError

from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject
from app.core.cloud import AmazonCloudStorage
from app.core.config import settings
from app.core.util import now, raise_from_unknown, post_callback
from app.crud import DocumentCrud, CollectionCrud, DocumentCollectionCrud
from app.core.util import now, post_callback
from app.crud import (
DocumentCrud,
CollectionCrud,
DocumentCollectionCrud,
)
from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud
Comment thread
nishika26 marked this conversation as resolved.
from app.models import Collection, Document
from app.models.collection import CollectionStatus
from app.utils import APIResponse, load_description
from app.utils import APIResponse, load_description, get_openai_client

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/collections", tags=["collections"])
Expand Down Expand Up @@ -180,12 +182,13 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str):

def do_create_collection(
session: SessionDep,
current_user: CurrentUser,
current_user: CurrentUserOrgProject,
request: CreationRequest,
payload: ResponsePayload,
client: OpenAI,
):
start_time = time.time()
client = OpenAI(api_key=settings.OPENAI_API_KEY)

callback = (
SilentCallback(payload)
if request.callback_url is None
Expand Down Expand Up @@ -226,7 +229,7 @@ def do_create_collection(
collection_crud._update(collection)

elapsed = time.time() - start_time
logging.info(
logger.info(
f"[do_create_collection] Collection created: {collection.id} | Time: {elapsed:.2f}s | "
f"Files: {len(flat_docs)} | Sizes: {file_sizes_kb} KB | Types: {list(file_exts)}"
)
Expand Down Expand Up @@ -261,6 +264,10 @@ def create_collection(
request: CreationRequest,
background_tasks: BackgroundTasks,
):
client = get_openai_client(
session, current_user.organization_id, current_user.project_id
)

this = inspect.currentframe()
route = router.url_path_for(this.f_code.co_name)
payload = ResponsePayload("processing", route)
Expand All @@ -278,11 +285,7 @@ def create_collection(

# 2. Launch background task
background_tasks.add_task(
do_create_collection,
session,
current_user,
request,
payload,
do_create_collection, session, current_user, request, payload, client
)

logger.info(
Expand All @@ -294,9 +297,10 @@ def create_collection(

def do_delete_collection(
session: SessionDep,
current_user: CurrentUser,
current_user: CurrentUserOrgProject,
request: DeletionRequest,
payload: ResponsePayload,
client: OpenAI,
):
if request.callback_url is None:
callback = SilentCallback(payload)
Expand All @@ -306,7 +310,7 @@ def do_delete_collection(
collection_crud = CollectionCrud(session, current_user.id)
try:
collection = collection_crud.read_one(request.collection_id)
assistant = OpenAIAssistantCrud()
assistant = OpenAIAssistantCrud(client)
data = collection_crud.delete(collection, assistant)
logger.info(
f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}"
Expand All @@ -332,20 +336,20 @@ def do_delete_collection(
)
def delete_collection(
session: SessionDep,
current_user: CurrentUser,
current_user: CurrentUserOrgProject,
request: DeletionRequest,
background_tasks: BackgroundTasks,
):
client = get_openai_client(
session, current_user.organization_id, current_user.project_id
)

this = inspect.currentframe()
route = router.url_path_for(this.f_code.co_name)
payload = ResponsePayload("processing", route)

background_tasks.add_task(
do_delete_collection,
session,
current_user,
request,
payload,
do_delete_collection, session, current_user, request, payload, client
)

logger.info(
Expand Down
20 changes: 14 additions & 6 deletions backend/app/api/routes/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from app.crud import DocumentCrud, CollectionCrud
from app.models import Document
from app.utils import APIResponse, load_description
from app.api.deps import CurrentUser, SessionDep
from app.utils import APIResponse, load_description, get_openai_client
from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject
from app.core.cloud import AmazonCloudStorage
from app.crud.rag import OpenAIAssistantCrud

Expand Down Expand Up @@ -65,10 +65,14 @@ def upload_doc(
)
def remove_doc(
session: SessionDep,
current_user: CurrentUser,
current_user: CurrentUserOrgProject,
doc_id: UUID = FastPath(description="Document to delete"),
):
a_crud = OpenAIAssistantCrud()
client = get_openai_client(
session, current_user.organization_id, current_user.project_id
)

a_crud = OpenAIAssistantCrud(client)
d_crud = DocumentCrud(session, current_user.id)
c_crud = CollectionCrud(session, current_user.id)

Expand All @@ -84,10 +88,14 @@ def remove_doc(
)
def permanent_delete_doc(
session: SessionDep,
current_user: CurrentUser,
current_user: CurrentUserOrgProject,
doc_id: UUID = FastPath(description="Document to permanently delete"),
):
a_crud = OpenAIAssistantCrud()
client = get_openai_client(
session, current_user.organization_id, current_user.project_id
)
Comment thread
nishika26 marked this conversation as resolved.

a_crud = OpenAIAssistantCrud(client)
d_crud = DocumentCrud(session, current_user.id)
c_crud = CollectionCrud(session, current_user.id)
storage = AmazonCloudStorage(current_user)
Expand Down
8 changes: 6 additions & 2 deletions backend/app/crud/rag/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ def clean(self, resource):


class OpenAICrud:
def __init__(self, client=None):
self.client = client or OpenAI(api_key=settings.OPENAI_API_KEY)
def __init__(self, client):
Comment thread
nishika26 marked this conversation as resolved.
if client is None:
logger.error("[OpenAICrud] OpenAI client is not configured")
raise ValueError("OpenAI client is not configured")
Comment thread
avirajsingh7 marked this conversation as resolved.

self.client = client


class OpenAIVectorStoreCrud(OpenAICrud):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
from sqlmodel import Session
from app.core.config import settings
from app.models import Collection
from app.main import app
from app.tests.utils.utils import get_user_from_api_key
from app.models.collection import CollectionStatus

client = TestClient(app)


def create_collection(
db,
Expand Down
25 changes: 15 additions & 10 deletions backend/app/tests/api/routes/collections/test_create_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
from uuid import UUID
import io

import openai_responses
from sqlmodel import Session
from fastapi.testclient import TestClient
from unittest.mock import patch

from app.core.config import settings
from app.tests.utils.document import DocumentStore
from app.tests.utils.utils import openai_credentials, get_user_from_api_key
from app.main import app
from app.tests.utils.utils import get_user_from_api_key
from app.crud.collection import CollectionCrud
from app.models.collection import CollectionStatus

client = TestClient(app)
from app.tests.utils.collections_openai_mock import get_mock_openai_client


@pytest.fixture(autouse=True)
Expand All @@ -31,7 +29,7 @@ def stream(self, file_obj):
return fake_file

def get_file_size_kb(self, url: str) -> float:
return 1.0 # Simulate 1KB files
return 1.0

class FakeS3Client:
def head_object(self, Bucket, Key):
Expand All @@ -41,13 +39,16 @@ def head_object(self, Bucket, Key):
monkeypatch.setattr("boto3.client", lambda service: FakeS3Client())


@pytest.mark.usefixtures("openai_credentials")
class TestCollectionRouteCreate:
_n_documents = 5

@openai_responses.mock()
@patch("app.api.routes.collections.get_openai_client")
def test_create_collection_success(
self, client: TestClient, db: Session, user_api_key_header
self,
mock_get_openai_client,
client: TestClient,
Comment thread
nishika26 marked this conversation as resolved.
db: Session,
user_api_key_header,
):
store = DocumentStore(db)
documents = store.fill(self._n_documents)
Expand All @@ -60,8 +61,12 @@ def test_create_collection_success(
"instructions": "Test collection assistant.",
"temperature": 0.1,
}

headers = user_api_key_header

mock_openai_client = get_mock_openai_client()
mock_get_openai_client.return_value = mock_openai_client

response = client.post(
f"{settings.API_V1_STR}/collections/create", json=body, headers=headers
)
Expand All @@ -73,8 +78,8 @@ def test_create_collection_success(
assert metadata["status"] == CollectionStatus.processing.value
assert UUID(metadata["key"])

# Confirm collection metadata in DB
collection_id = UUID(metadata["key"])

user = get_user_from_api_key(db, headers)
collection = CollectionCrud(db, user.user_id).read_one(collection_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def test_info_reflects_database(
assert source == target.data

def test_cannot_info_unknown_document(
self,
db: Session,
route: Route,
crawler: Route,
self, db: Session, route: Route, crawler: Route
):
DocumentStore.clear(db)
maker = DocumentMaker(db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from urllib.parse import urlparse

import pytest
from unittest.mock import patch
from botocore.exceptions import ClientError
from moto import mock_aws
from sqlmodel import Session, select

from openai import OpenAI
import openai_responses
from openai_responses import OpenAIMock

from app.core.cloud import AmazonCloudStorageClient
from app.core.config import settings
Expand All @@ -19,7 +22,6 @@
WebCrawler,
crawler,
)
from app.tests.utils.utils import openai_credentials


@pytest.fixture
Expand All @@ -36,16 +38,23 @@ def aws_credentials():
os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION


@pytest.mark.usefixtures("openai_credentials", "aws_credentials")
@pytest.mark.usefixtures("aws_credentials")
@mock_aws
class TestDocumentRoutePermanentRemove:
@openai_responses.mock()
@patch("app.api.routes.documents.get_openai_client")
def test_permanent_delete_document_from_s3(
self,
mock_get_openai_client,
db: Session,
route: Route,
crawler: WebCrawler,
):
openai_mock = OpenAIMock()
with openai_mock.router:
client = OpenAI(api_key="sk-test-key")
mock_get_openai_client.return_value = client

# Setup AWS
aws = AmazonCloudStorageClient()
aws.create()
Expand Down
Loading