|
| 1 | +import pytest |
| 2 | +from uuid import uuid4 |
1 | 3 | from unittest.mock import patch |
2 | 4 |
|
| 5 | +from sqlmodel import Session |
3 | 6 | from fastapi.testclient import TestClient |
4 | 7 |
|
5 | | -from app.models import LLMCallRequest |
| 8 | +from app.crud import JobCrud |
| 9 | +from app.crud.llm import create_llm_call, update_llm_call_response |
| 10 | +from app.models import JobType, LLMCallRequest, Job, JobStatus, JobUpdate |
| 11 | +from app.models.llm.response import LLMCallResponse |
6 | 12 | from app.models.llm.request import ( |
7 | 13 | QueryParams, |
8 | 14 | LLMCallConfig, |
|
12 | 18 | ) |
13 | 19 |
|
14 | 20 |
|
| 21 | +@pytest.fixture |
| 22 | +def llm_job(db: Session) -> Job: |
| 23 | + crud = JobCrud(db) |
| 24 | + return crud.create(job_type=JobType.LLM_API) |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def llm_response_in_db(db: Session, llm_job, user_api_key) -> LLMCallResponse: |
| 29 | + config_blob = ConfigBlob( |
| 30 | + completion=KaapiCompletionConfig( |
| 31 | + provider="openai", |
| 32 | + params={ |
| 33 | + "model": "gpt-4o", |
| 34 | + "instructions": "You are helpful.", |
| 35 | + "temperature": 0.7, |
| 36 | + }, |
| 37 | + type="text", |
| 38 | + ) |
| 39 | + ) |
| 40 | + llm_call = create_llm_call( |
| 41 | + db, |
| 42 | + request=LLMCallRequest( |
| 43 | + query=QueryParams(input="What is the capital of France?"), |
| 44 | + config=LLMCallConfig(blob=config_blob), |
| 45 | + ), |
| 46 | + job_id=llm_job.id, |
| 47 | + project_id=user_api_key.project_id, |
| 48 | + organization_id=user_api_key.organization_id, |
| 49 | + resolved_config=config_blob, |
| 50 | + original_provider="openai", |
| 51 | + ) |
| 52 | + update_llm_call_response( |
| 53 | + db, |
| 54 | + llm_call_id=llm_call.id, |
| 55 | + provider_response_id="resp_abc123", |
| 56 | + content={"type": "text", "content": {"format": "text", "value": "Paris"}}, |
| 57 | + usage={ |
| 58 | + "input_tokens": 10, |
| 59 | + "output_tokens": 5, |
| 60 | + "total_tokens": 15, |
| 61 | + "reasoning_tokens": None, |
| 62 | + }, |
| 63 | + ) |
| 64 | + return llm_call |
| 65 | + |
| 66 | + |
15 | 67 | def test_llm_call_success( |
16 | 68 | client: TestClient, user_api_key_header: dict[str, str] |
17 | 69 | ) -> None: |
@@ -247,3 +299,89 @@ def test_llm_call_guardrails_bypassed_still_succeeds( |
247 | 299 | assert "response is being generated" in body["data"]["message"] |
248 | 300 |
|
249 | 301 | mock_start_job.assert_called_once() |
| 302 | + |
| 303 | + |
| 304 | +def test_get_llm_call_pending( |
| 305 | + client: TestClient, |
| 306 | + user_api_key_header: dict[str, str], |
| 307 | + llm_job, |
| 308 | +) -> None: |
| 309 | + """Job in PENDING state returns status with no llm_response.""" |
| 310 | + response = client.get( |
| 311 | + f"/api/v1/llm/call/{llm_job.id}", |
| 312 | + headers=user_api_key_header, |
| 313 | + ) |
| 314 | + |
| 315 | + assert response.status_code == 200 |
| 316 | + body = response.json() |
| 317 | + assert body["success"] is True |
| 318 | + assert body["data"]["job_id"] == str(llm_job.id) |
| 319 | + assert body["data"]["status"] == "PENDING" |
| 320 | + assert body["data"]["llm_response"] is None |
| 321 | + |
| 322 | + |
| 323 | +def test_get_llm_call_success( |
| 324 | + client: TestClient, |
| 325 | + db: Session, |
| 326 | + user_api_key_header: dict[str, str], |
| 327 | + llm_job, |
| 328 | + llm_response_in_db, |
| 329 | +) -> None: |
| 330 | + """Job in SUCCESS state returns full llm_response with usage.""" |
| 331 | + |
| 332 | + JobCrud(db).update(llm_job.id, JobUpdate(status=JobStatus.SUCCESS)) |
| 333 | + |
| 334 | + response = client.get( |
| 335 | + f"/api/v1/llm/call/{llm_job.id}", |
| 336 | + headers=user_api_key_header, |
| 337 | + ) |
| 338 | + |
| 339 | + assert response.status_code == 200 |
| 340 | + body = response.json() |
| 341 | + assert body["success"] is True |
| 342 | + data = body["data"] |
| 343 | + assert data["status"] == "SUCCESS" |
| 344 | + assert data["llm_response"] is not None |
| 345 | + assert data["llm_response"]["response"]["provider_response_id"] == "resp_abc123" |
| 346 | + assert data["llm_response"]["response"]["provider"] == "openai" |
| 347 | + assert data["llm_response"]["usage"]["input_tokens"] == 10 |
| 348 | + assert data["llm_response"]["usage"]["output_tokens"] == 5 |
| 349 | + assert data["llm_response"]["usage"]["total_tokens"] == 15 |
| 350 | + |
| 351 | + |
| 352 | +def test_get_llm_call_failed( |
| 353 | + client: TestClient, |
| 354 | + db: Session, |
| 355 | + user_api_key_header: dict[str, str], |
| 356 | + llm_job, |
| 357 | +) -> None: |
| 358 | + JobCrud(db).update( |
| 359 | + llm_job.id, |
| 360 | + JobUpdate(status=JobStatus.FAILED, error_message="Provider timeout"), |
| 361 | + ) |
| 362 | + |
| 363 | + response = client.get( |
| 364 | + f"/api/v1/llm/call/{llm_job.id}", |
| 365 | + headers=user_api_key_header, |
| 366 | + ) |
| 367 | + |
| 368 | + assert response.status_code == 200 |
| 369 | + body = response.json() |
| 370 | + assert body["success"] is True |
| 371 | + assert body["data"]["status"] == "FAILED" |
| 372 | + assert body["data"]["error_message"] == "Provider timeout" |
| 373 | + assert body["data"]["llm_response"] is None |
| 374 | + |
| 375 | + |
| 376 | +def test_get_llm_call_not_found( |
| 377 | + client: TestClient, |
| 378 | + user_api_key_header: dict[str, str], |
| 379 | +) -> None: |
| 380 | + """Non-existent job_id returns 404.""" |
| 381 | + |
| 382 | + response = client.get( |
| 383 | + f"/api/v1/llm/call/{uuid4()}", |
| 384 | + headers=user_api_key_header, |
| 385 | + ) |
| 386 | + |
| 387 | + assert response.status_code == 404 |
0 commit comments