Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions backend/app/api/routes/collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import logging
import time
import asyncio
import warnings
from uuid import UUID, uuid4
from typing import Any, List, Optional
Expand Down Expand Up @@ -167,49 +169,56 @@
)


def do_create_collection(
async def do_create_collection(
session: SessionDep,
current_user: CurrentUser,
request: CreationRequest,
payload: ResponsePayload,
):
start_time = time.time()
client = OpenAI(api_key=settings.OPENAI_API_KEY)
if request.callback_url is None:
callback = SilentCallback(payload)
else:
callback = WebHookCallback(request.callback_url, payload)

#
# Create the assistant and vector store
#
callback = (
SilentCallback(payload)
if request.callback_url is None
else WebHookCallback(request.callback_url, payload)
)

vector_store_crud = OpenAIVectorStoreCrud(client)
try:
vector_store = vector_store_crud.create()
except OpenAIError as err:
logging.error(f"OpenAI vector store creation failed: {err}")

Check warning on line 191 in backend/app/api/routes/collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/collections.py#L191

Added line #L191 was not covered by tests
callback.fail(str(err))
return

storage = AmazonCloudStorage(current_user)
document_crud = DocumentCrud(session, current_user.id)
assistant_crud = OpenAIAssistantCrud(client)

docs = request(document_crud)
docs = list(request(document_crud))
doc_count = len(docs)
flat_docs = [doc for sublist in docs for doc in sublist]
file_exts = list(
{doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname}
)

file_sizes_kb = []
for doc in flat_docs:
size_kb = storage.get_file_size_kb(doc.object_store_url)
file_sizes_kb.append(size_kb)

kwargs = dict(request.extract_super_type(AssistantOptions))
try:
updates = vector_store_crud.update(vector_store.id, storage, docs)
documents = list(updates)
assistant = assistant_crud.create(vector_store.id, **kwargs)
except Exception as err: # blanket to handle SQL and OpenAI errors
except Exception as err:

Check warning on line 216 in backend/app/api/routes/collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/collections.py#L216

Added line #L216 was not covered by tests
logging.error(f"File Search setup error: {err} ({type(err).__name__})")
vector_store_crud.delete(vector_store.id)
callback.fail(str(err))
return

#
# Store the results
#

collection_crud = CollectionCrud(session, current_user.id)
collection = Collection(
id=UUID(payload.key),
Expand All @@ -219,13 +228,16 @@
try:
collection_crud.create(collection, documents)
except SQLAlchemyError as err:
logging.error(f"DB insert failed for collection: {err}")

Check warning on line 231 in backend/app/api/routes/collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/collections.py#L231

Added line #L231 was not covered by tests
_backout(assistant_crud, assistant.id)
callback.fail(str(err))
return

#
# Send back successful response
#
elapsed = time.time() - start_time
logging.info(
f"Collection created: {collection.id} | "
f"Time: {elapsed}s | Files: {doc_count} |Sizes:{file_sizes_kb} KB |Types: {file_exts}"
)

callback.success(collection.model_dump(mode="json"))

Expand All @@ -234,7 +246,7 @@
"/create",
description=load_description("collections/create.md"),
)
def create_collection(
async def create_collection(
session: SessionDep,
current_user: CurrentUser,
request: CreationRequest,
Expand All @@ -244,15 +256,24 @@
route = router.url_path_for(this.f_code.co_name)
payload = ResponsePayload("processing", route)

background_tasks.add_task(
do_create_collection,
session,
current_user,
request,
payload,
)
# Start the background task asynchronously
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason it is here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it for any future reference, although it is self explanatory but still just in case

# background_tasks.add_task(
# do_create_collection,
# session,
# current_user,
# request,
# payload,
# )

return APIResponse.success_response(data=None, metadata=asdict(payload))
try:
await asyncio.wait_for(
do_create_collection(session, current_user, request, payload),
timeout=settings.COLLECTION_CREATION_TIMEOUT_SECONDS,
)
return APIResponse.success_response(data=None, metadata=asdict(payload))
except asyncio.TimeoutError:
logging.error(f"Timeout while creating collection for org: {current_user}")
raise HTTPException(status_code=408, detail="The task timed out.")


def do_delete_collection(
Expand Down
7 changes: 7 additions & 0 deletions backend/app/core/cloud/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,10 @@
return self.aws.client.get_object(**kwargs).get("Body")
except ClientError as err:
raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err

def get_file_size_kb(self, url: str) -> float:
name = SimpleStorageName.from_url(url)
kwargs = asdict(name)
response = self.aws.client.head_object(**kwargs)
size_bytes = response["ContentLength"]
return round(size_bytes / 1024, 2)

Check warning on line 133 in backend/app/core/cloud/storage.py

View check run for this annotation

Codecov / codecov/patch

backend/app/core/cloud/storage.py#L129-L133

Added lines #L129 - L133 were not covered by tests
1 change: 1 addition & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Settings(BaseSettings):
env_ignore_empty=True,
extra="ignore",
)
COLLECTION_CREATION_TIMEOUT_SECONDS: int = 15
LANGFUSE_PUBLIC_KEY: str
LANGFUSE_SECRET_KEY: str
LANGFUSE_HOST: str # 🇪🇺 EU region
Expand Down
112 changes: 112 additions & 0 deletions backend/app/tests/api/routes/collections/test_create_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pytest
import asyncio
import io
from openai import OpenAIError
import openai_responses
from uuid import UUID
from httpx import AsyncClient
from sqlmodel import Session
from app.core.config import settings
from app.tests.utils.document import DocumentStore
from app.tests.utils.utils import openai_credentials


@pytest.fixture(autouse=True)
def mock_s3(monkeypatch):
class FakeStorage:
def __init__(self, *args, **kwargs):
pass

def upload(self, file_obj, path: str, **kwargs):
return f"s3://fake-bucket/{path or 'mock-file.txt'}"

Check warning on line 21 in backend/app/tests/api/routes/collections/test_create_collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/collections/test_create_collections.py#L21

Added line #L21 was not covered by tests

def stream(self, file_obj):
fake_file = io.BytesIO(b"dummy content")
fake_file.name = "fake.txt"
return fake_file

def get_file_size_kb(self, url: str) -> float:
return 1.0 # Simulate 1KB files

class FakeS3Client:
def head_object(self, Bucket, Key):
return {"ContentLength": 1024}

Check warning on line 33 in backend/app/tests/api/routes/collections/test_create_collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/collections/test_create_collections.py#L33

Added line #L33 was not covered by tests

monkeypatch.setattr("app.api.routes.collections.AmazonCloudStorage", FakeStorage)
monkeypatch.setattr("boto3.client", lambda service: FakeS3Client())


@pytest.mark.usefixtures("openai_credentials")
class TestCollectionRouteCreate:
_n_documents = 5

@pytest.mark.asyncio
@openai_responses.mock()
async def test_create_collection_success(
self,
async_client: AsyncClient,
db: Session,
superuser_token_headers: dict[str, str],
):
store = DocumentStore(db)
documents = store.fill(self._n_documents)
doc_ids = [str(doc.id) for doc in documents]

body = {
"documents": doc_ids,
"batch_size": 2,
"model": "gpt-4o",
"instructions": "Test collection assistant.",
"temperature": 0.1,
}

response = await async_client.post(
f"{settings.API_V1_STR}/collections/create",
json=body,
headers=superuser_token_headers,
)

assert response.status_code == 200
json = response.json()
assert json["success"] is True
metadata = json.get("metadata", {})
assert metadata["status"] == "processing"
assert UUID(metadata["key"])

@pytest.mark.asyncio
async def test_create_collection_timeout(
self,
async_client: AsyncClient,
db: Session,
superuser_token_headers: dict[str, str],
monkeypatch,
):
async def long_task(*args, **kwargs):
await asyncio.sleep(30) # exceed timeout
return None

Check warning on line 86 in backend/app/tests/api/routes/collections/test_create_collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/collections/test_create_collections.py#L86

Added line #L86 was not covered by tests

monkeypatch.setattr(
"app.api.routes.collections.do_create_collection", # adjust if necessary
long_task,
)

body = {
"documents": [],
"batch_size": 1,
"model": "gpt-4o",
"instructions": "Slow task",
"temperature": 0.2,
}

response = await async_client.post(
f"{settings.API_V1_STR}/collections/create",
json=body,
headers=superuser_token_headers,
)

assert response.status_code == 408
json = response.json()
assert json["success"] is False
assert json["data"] is None
assert json["error"] == "The task timed out."
assert json["metadata"] is None
8 changes: 8 additions & 0 deletions backend/app/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Generator

import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from sqlmodel import Session, delete

Expand All @@ -19,6 +20,7 @@
)
from app.tests.utils.user import authentication_token_from_email
from app.tests.utils.utils import get_superuser_token_headers
from httpx import AsyncClient


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -44,6 +46,12 @@ def client() -> Generator[TestClient, None, None]:
yield c


@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as client:
yield client


@pytest.fixture(scope="module")
def superuser_token_headers(client: TestClient) -> dict[str, str]:
return get_superuser_token_headers(client)
Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
[tool.uv]
dev-dependencies = [
"pytest<8.0.0,>=7.4.3",
"pytest-asyncio>=0.21.1,<0.23.0",
"mypy<2.0.0,>=1.8.0",
"ruff<1.0.0,>=0.2.2",
"pre-commit<4.0.0,>=3.6.2",
Expand Down
14 changes: 14 additions & 0 deletions backend/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.