diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 532bab721..4d01b5f31 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -1,5 +1,7 @@ import inspect import logging +import time +import asyncio import warnings from uuid import UUID, uuid4 from typing import Any, List, Optional @@ -167,26 +169,26 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str): ) -def do_create_collection( +async def do_create_collection( session: SessionDep, current_user: CurrentUser, request: CreationRequest, payload: ResponsePayload, ): + start_time = time.time() client = OpenAI(api_key=settings.OPENAI_API_KEY) - if request.callback_url is None: - callback = SilentCallback(payload) - else: - callback = WebHookCallback(request.callback_url, payload) - # - # Create the assistant and vector store - # + callback = ( + SilentCallback(payload) + if request.callback_url is None + else WebHookCallback(request.callback_url, payload) + ) vector_store_crud = OpenAIVectorStoreCrud(client) try: vector_store = vector_store_crud.create() except OpenAIError as err: + logging.error(f"OpenAI vector store creation failed: {err}") callback.fail(str(err)) return @@ -194,22 +196,29 @@ def do_create_collection( document_crud = DocumentCrud(session, current_user.id) assistant_crud = OpenAIAssistantCrud(client) - docs = request(document_crud) + docs = list(request(document_crud)) + doc_count = len(docs) + flat_docs = [doc for sublist in docs for doc in sublist] + file_exts = list( + {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} + ) + + file_sizes_kb = [] + for doc in flat_docs: + size_kb = storage.get_file_size_kb(doc.object_store_url) + file_sizes_kb.append(size_kb) + kwargs = dict(request.extract_super_type(AssistantOptions)) try: updates = vector_store_crud.update(vector_store.id, storage, docs) documents = list(updates) assistant = assistant_crud.create(vector_store.id, **kwargs) - except Exception as err: # blanket to handle SQL and OpenAI errors + except Exception as err: logging.error(f"File Search setup error: {err} ({type(err).__name__})") vector_store_crud.delete(vector_store.id) callback.fail(str(err)) return - # - # Store the results - # - collection_crud = CollectionCrud(session, current_user.id) collection = Collection( id=UUID(payload.key), @@ -219,13 +228,16 @@ def do_create_collection( try: collection_crud.create(collection, documents) except SQLAlchemyError as err: + logging.error(f"DB insert failed for collection: {err}") _backout(assistant_crud, assistant.id) callback.fail(str(err)) return - # - # Send back successful response - # + elapsed = time.time() - start_time + logging.info( + f"Collection created: {collection.id} | " + f"Time: {elapsed}s | Files: {doc_count} |Sizes:{file_sizes_kb} KB |Types: {file_exts}" + ) callback.success(collection.model_dump(mode="json")) @@ -234,7 +246,7 @@ def do_create_collection( "/create", description=load_description("collections/create.md"), ) -def create_collection( +async def create_collection( session: SessionDep, current_user: CurrentUser, request: CreationRequest, @@ -244,15 +256,24 @@ def create_collection( route = router.url_path_for(this.f_code.co_name) payload = ResponsePayload("processing", route) - background_tasks.add_task( - do_create_collection, - session, - current_user, - request, - payload, - ) + # Start the background task asynchronously + # background_tasks.add_task( + # do_create_collection, + # session, + # current_user, + # request, + # payload, + # ) - return APIResponse.success_response(data=None, metadata=asdict(payload)) + try: + await asyncio.wait_for( + do_create_collection(session, current_user, request, payload), + timeout=settings.COLLECTION_CREATION_TIMEOUT_SECONDS, + ) + return APIResponse.success_response(data=None, metadata=asdict(payload)) + except asyncio.TimeoutError: + logging.error(f"Timeout while creating collection for org: {current_user}") + raise HTTPException(status_code=408, detail="The task timed out.") def do_delete_collection( diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index 341abad59..78a44fed2 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -124,3 +124,10 @@ def stream(self, url: str) -> StreamingBody: return self.aws.client.get_object(**kwargs).get("Body") except ClientError as err: raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err + + def get_file_size_kb(self, url: str) -> float: + name = SimpleStorageName.from_url(url) + kwargs = asdict(name) + response = self.aws.client.head_object(**kwargs) + size_bytes = response["ContentLength"] + return round(size_bytes / 1024, 2) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 24779bf33..0cd7181d2 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -30,6 +30,7 @@ class Settings(BaseSettings): env_ignore_empty=True, extra="ignore", ) + COLLECTION_CREATION_TIMEOUT_SECONDS: int = 15 LANGFUSE_PUBLIC_KEY: str LANGFUSE_SECRET_KEY: str LANGFUSE_HOST: str # 🇪🇺 EU region diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py new file mode 100644 index 000000000..3d3128a8b --- /dev/null +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -0,0 +1,112 @@ +import pytest +import asyncio +import io +from openai import OpenAIError +import openai_responses +from uuid import UUID +from httpx import AsyncClient +from sqlmodel import Session +from app.core.config import settings +from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import openai_credentials + + +@pytest.fixture(autouse=True) +def mock_s3(monkeypatch): + class FakeStorage: + def __init__(self, *args, **kwargs): + pass + + def upload(self, file_obj, path: str, **kwargs): + return f"s3://fake-bucket/{path or 'mock-file.txt'}" + + def stream(self, file_obj): + fake_file = io.BytesIO(b"dummy content") + fake_file.name = "fake.txt" + return fake_file + + def get_file_size_kb(self, url: str) -> float: + return 1.0 # Simulate 1KB files + + class FakeS3Client: + def head_object(self, Bucket, Key): + return {"ContentLength": 1024} + + monkeypatch.setattr("app.api.routes.collections.AmazonCloudStorage", FakeStorage) + monkeypatch.setattr("boto3.client", lambda service: FakeS3Client()) + + +@pytest.mark.usefixtures("openai_credentials") +class TestCollectionRouteCreate: + _n_documents = 5 + + @pytest.mark.asyncio + @openai_responses.mock() + async def test_create_collection_success( + self, + async_client: AsyncClient, + db: Session, + superuser_token_headers: dict[str, str], + ): + store = DocumentStore(db) + documents = store.fill(self._n_documents) + doc_ids = [str(doc.id) for doc in documents] + + body = { + "documents": doc_ids, + "batch_size": 2, + "model": "gpt-4o", + "instructions": "Test collection assistant.", + "temperature": 0.1, + } + + response = await async_client.post( + f"{settings.API_V1_STR}/collections/create", + json=body, + headers=superuser_token_headers, + ) + + assert response.status_code == 200 + json = response.json() + assert json["success"] is True + metadata = json.get("metadata", {}) + assert metadata["status"] == "processing" + assert UUID(metadata["key"]) + + @pytest.mark.asyncio + async def test_create_collection_timeout( + self, + async_client: AsyncClient, + db: Session, + superuser_token_headers: dict[str, str], + monkeypatch, + ): + async def long_task(*args, **kwargs): + await asyncio.sleep(30) # exceed timeout + return None + + monkeypatch.setattr( + "app.api.routes.collections.do_create_collection", # adjust if necessary + long_task, + ) + + body = { + "documents": [], + "batch_size": 1, + "model": "gpt-4o", + "instructions": "Slow task", + "temperature": 0.2, + } + + response = await async_client.post( + f"{settings.API_V1_STR}/collections/create", + json=body, + headers=superuser_token_headers, + ) + + assert response.status_code == 408 + json = response.json() + assert json["success"] is False + assert json["data"] is None + assert json["error"] == "The task timed out." + assert json["metadata"] is None diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index a68c3eca0..155c44708 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,6 +1,7 @@ from collections.abc import Generator import pytest +import pytest_asyncio from fastapi.testclient import TestClient from sqlmodel import Session, delete @@ -19,6 +20,7 @@ ) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers +from httpx import AsyncClient @pytest.fixture(scope="session", autouse=True) @@ -44,6 +46,12 @@ def client() -> Generator[TestClient, None, None]: yield c +@pytest_asyncio.fixture +async def async_client(): + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + + @pytest.fixture(scope="module") def superuser_token_headers(client: TestClient) -> dict[str, str]: return get_superuser_token_headers(client) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 009073ec3..4428badf3 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ [tool.uv] dev-dependencies = [ "pytest<8.0.0,>=7.4.3", + "pytest-asyncio>=0.21.1,<0.23.0", "mypy<2.0.0,>=1.8.0", "ruff<1.0.0,>=0.2.2", "pre-commit<4.0.0,>=3.6.2", diff --git a/backend/uv.lock b/backend/uv.lock index d7cb213bd..3f78f3e42 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -80,6 +80,7 @@ dev = [ { name = "mypy" }, { name = "pre-commit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, { name = "types-passlib" }, ] @@ -117,6 +118,7 @@ dev = [ { name = "mypy", specifier = ">=1.8.0,<2.0.0" }, { name = "pre-commit", specifier = ">=3.6.2,<4.0.0" }, { name = "pytest", specifier = ">=7.4.3,<8.0.0" }, + { name = "pytest-asyncio", specifier = ">=0.21.1,<0.23.0" }, { name = "ruff", specifier = ">=0.2.2,<1.0.0" }, { name = "types-passlib", specifier = ">=1.7.7.20240106,<2.0.0.0" }, ] @@ -1388,6 +1390,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/ff/f6e8b8f39e08547faece4bd80f89d5a8de68a38b2d179cc1c4490ffa3286/pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8", size = 325287 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.21.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/53/57663d99acaac2fcdafdc697e52a9b1b7d6fcf36616281ff9768a44e7ff3/pytest_asyncio-0.21.2.tar.gz", hash = "sha256:d67738fc232b94b326b9d060750beb16e0074210b98dd8b58a5239fa2a154f45", size = 30656 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/ce/1e4b53c213dce25d6e8b163697fbce2d43799d76fa08eea6ad270451c370/pytest_asyncio-0.21.2-py3-none-any.whl", hash = "sha256:ab664c88bb7998f711d8039cacd4884da6430886ae8bbd4eded552ed2004f16b", size = 13368 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"