Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions backend/app/api/routes/collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import logging
import asyncio
import warnings
from uuid import UUID, uuid4
from typing import Any, List, Optional
Expand Down Expand Up @@ -167,7 +168,8 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str):
)


def do_create_collection(
# Async function to create the collection and perform operations
async def do_create_collection(
session: SessionDep,
current_user: CurrentUser,
request: CreationRequest,
Expand Down Expand Up @@ -234,7 +236,7 @@ def do_create_collection(
"/create",
description=load_description("collections/create.md"),
)
def create_collection(
async def create_collection(
session: SessionDep,
current_user: CurrentUser,
request: CreationRequest,
Expand All @@ -244,15 +246,24 @@ def create_collection(
route = router.url_path_for(this.f_code.co_name)
payload = ResponsePayload("processing", route)

background_tasks.add_task(
do_create_collection,
session,
current_user,
request,
payload,
)
# Start the background task asynchronously
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason it is here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it for any future reference, although it is self explanatory but still just in case

# background_tasks.add_task(
# do_create_collection,
# session,
# current_user,
# request,
# payload,
# )

return APIResponse.success_response(data=None, metadata=asdict(payload))
timeout_duration = 15
try:
await asyncio.wait_for(
do_create_collection(session, current_user, request, payload),
timeout=timeout_duration,
)
return APIResponse.success_response(data=None, metadata=asdict(payload))
except asyncio.TimeoutError:
raise HTTPException(status_code=408, detail="The task timed out.")


def do_delete_collection(
Expand Down
107 changes: 107 additions & 0 deletions backend/app/tests/api/routes/collections/test_create_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
import asyncio
import io
from openai import OpenAIError
import openai_responses
from uuid import UUID
from httpx import AsyncClient
from sqlmodel import Session
from app.core.config import settings
from app.tests.utils.document import DocumentStore
from app.tests.utils.utils import openai_credentials


# Automatically mock AmazonCloudStorage for all tests
@pytest.fixture(autouse=True)
def mock_s3(monkeypatch):
class FakeStorage:
def __init__(self, *args, **kwargs):
pass

def upload(self, file_obj, path: str, **kwargs):
# Return a dummy path (this is fine)
return f"s3://fake-bucket/{path or 'mock-file'}"

Check warning on line 23 in backend/app/tests/api/routes/collections/test_create_collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/collections/test_create_collections.py#L23

Added line #L23 was not covered by tests

def stream(self, file_obj):
# Wrap in a file-like object that has a `.name` attribute
fake_file = io.BytesIO(b"dummy content")
fake_file.name = "fake.txt"
return fake_file

monkeypatch.setattr("app.api.routes.collections.AmazonCloudStorage", FakeStorage)


@pytest.mark.usefixtures("openai_credentials")
class TestCollectionRouteCreate:
_n_documents = 5

@pytest.mark.asyncio
@openai_responses.mock()
async def test_create_collection_success(
self,
async_client: AsyncClient,
db: Session,
superuser_token_headers: dict[str, str],
):
store = DocumentStore(db)
documents = store.fill(self._n_documents)
doc_ids = [str(doc.id) for doc in documents]

body = {
"documents": doc_ids,
"batch_size": 2,
"model": "gpt-4o",
"instructions": "Test collection assistant.",
"temperature": 0.1,
}

response = await async_client.post(
f"{settings.API_V1_STR}/collections/create",
json=body,
headers=superuser_token_headers,
)

assert response.status_code == 200
json = response.json()
assert json["success"] is True
metadata = json.get("metadata", {})
assert metadata["status"] == "processing"
assert UUID(metadata["key"])

@pytest.mark.asyncio
async def test_create_collection_timeout(
self,
async_client: AsyncClient,
db: Session,
superuser_token_headers: dict[str, str],
monkeypatch,
):
async def long_task(*args, **kwargs):
await asyncio.sleep(30) # exceed timeout
return None

Check warning on line 81 in backend/app/tests/api/routes/collections/test_create_collections.py

View check run for this annotation

Codecov / codecov/patch

backend/app/tests/api/routes/collections/test_create_collections.py#L81

Added line #L81 was not covered by tests

monkeypatch.setattr(
"app.api.routes.collections.do_create_collection", # adjust if necessary
long_task,
)

body = {
"documents": [],
"batch_size": 1,
"model": "gpt-4o",
"instructions": "Slow task",
"temperature": 0.2,
}

response = await async_client.post(
f"{settings.API_V1_STR}/collections/create",
json=body,
headers=superuser_token_headers,
)

assert response.status_code == 408
json = response.json()
assert json["success"] is False
assert json["data"] is None
assert json["error"] == "The task timed out."
assert json["metadata"] is None
8 changes: 8 additions & 0 deletions backend/app/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Generator

import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from sqlmodel import Session, delete

Expand All @@ -18,6 +19,7 @@
)
from app.tests.utils.user import authentication_token_from_email
from app.tests.utils.utils import get_superuser_token_headers
from httpx import AsyncClient


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -42,6 +44,12 @@ def client() -> Generator[TestClient, None, None]:
yield c


@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as client:
yield client


@pytest.fixture(scope="module")
def superuser_token_headers(client: TestClient) -> dict[str, str]:
return get_superuser_token_headers(client)
Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
[tool.uv]
dev-dependencies = [
"pytest<8.0.0,>=7.4.3",
"pytest-asyncio>=0.21.1,<0.23.0",
"mypy<2.0.0,>=1.8.0",
"ruff<1.0.0,>=0.2.2",
"pre-commit<4.0.0,>=3.6.2",
Expand Down
14 changes: 14 additions & 0 deletions backend/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.