|
37 | 37 | create_langfuse_dataset_run, |
38 | 38 | update_traces_with_cosine_scores, |
39 | 39 | ) |
40 | | -from app.crud.job import get_batch_job |
| 40 | +from app.crud.job import get_batch_job, update_batch_job |
41 | 41 | from app.models import EvaluationRun |
| 42 | +from app.models.batch_job import BatchJob, BatchJobUpdate |
42 | 43 | from app.utils import get_langfuse_client, get_openai_client |
43 | 44 |
|
44 | 45 | logger = logging.getLogger(__name__) |
45 | 46 |
|
46 | 47 |
|
| 48 | +def _extract_batch_error_message( |
| 49 | + provider: OpenAIBatchProvider, |
| 50 | + error_file_id: str, |
| 51 | + batch_job: BatchJob, |
| 52 | + session: Session, |
| 53 | +) -> str: |
| 54 | + """ |
| 55 | + Download the error file from OpenAI, parse JSONL entries, and extract |
| 56 | + the most common error message. Updates batch_job.error_message. |
| 57 | +
|
| 58 | + Args: |
| 59 | + provider: OpenAI batch provider instance |
| 60 | + error_file_id: OpenAI error file ID |
| 61 | + batch_job: BatchJob to update with error message |
| 62 | + session: Database session |
| 63 | +
|
| 64 | + Returns: |
| 65 | + Human-readable error message with the top error and counts |
| 66 | + """ |
| 67 | + try: |
| 68 | + error_content = provider.download_file(error_file_id) |
| 69 | + lines = error_content.strip().split("\n") |
| 70 | + |
| 71 | + error_counts: dict[str, int] = {} |
| 72 | + for line in lines: |
| 73 | + try: |
| 74 | + entry = json.loads(line) |
| 75 | + message = ( |
| 76 | + entry.get("response", {}) |
| 77 | + .get("body", {}) |
| 78 | + .get("error", {}) |
| 79 | + .get("message", "Unknown error") |
| 80 | + ) |
| 81 | + error_counts[message] = error_counts.get(message, 0) + 1 |
| 82 | + except json.JSONDecodeError: |
| 83 | + continue |
| 84 | + |
| 85 | + if error_counts: |
| 86 | + top_error = max(error_counts, key=error_counts.get) |
| 87 | + top_count = error_counts[top_error] |
| 88 | + total = sum(error_counts.values()) |
| 89 | + error_msg = f"{top_error} ({top_count}/{total} requests)" |
| 90 | + else: |
| 91 | + error_msg = "Batch completed with errors but could not parse error file" |
| 92 | + |
| 93 | + except Exception as e: |
| 94 | + logger.error( |
| 95 | + f"[_extract_batch_error_message] Failed to extract errors | batch_job_id={batch_job.id} | {e}", |
| 96 | + exc_info=True, |
| 97 | + ) |
| 98 | + error_msg = ( |
| 99 | + f"Batch completed with all requests failed (error_file_id: {error_file_id})" |
| 100 | + ) |
| 101 | + |
| 102 | + # Update batch_job with extracted error message (outside try/except |
| 103 | + # so persistence failures propagate to the caller) |
| 104 | + batch_job_update = BatchJobUpdate(error_message=error_msg) |
| 105 | + update_batch_job( |
| 106 | + session=session, batch_job=batch_job, batch_job_update=batch_job_update |
| 107 | + ) |
| 108 | + |
| 109 | + logger.info( |
| 110 | + f"[_extract_batch_error_message] Extracted error | batch_job_id={batch_job.id} | {error_msg}" |
| 111 | + ) |
| 112 | + |
| 113 | + return error_msg |
| 114 | + |
| 115 | + |
47 | 116 | def parse_evaluation_output( |
48 | 117 | raw_results: list[dict[str, Any]], dataset_items: list[dict[str, Any]] |
49 | 118 | ) -> list[dict[str, Any]]: |
@@ -560,14 +629,49 @@ async def check_and_process_evaluation( |
560 | 629 |
|
561 | 630 | # IMPORTANT: Poll OpenAI to get the latest status before checking |
562 | 631 | provider = OpenAIBatchProvider(client=openai_client) |
563 | | - poll_batch_status(session=session, provider=provider, batch_job=batch_job) |
| 632 | + status_result = poll_batch_status( |
| 633 | + session=session, provider=provider, batch_job=batch_job |
| 634 | + ) |
564 | 635 |
|
565 | 636 | # Refresh batch_job to get the updated provider_status |
566 | 637 | session.refresh(batch_job) |
567 | 638 | provider_status = batch_job.provider_status |
568 | 639 |
|
569 | 640 | # Handle different provider statuses |
570 | 641 | if provider_status == "completed": |
| 642 | + # Check if batch completed but all requests failed |
| 643 | + # (output_file_id is absent, error_file_id is present) |
| 644 | + if not status_result.get( |
| 645 | + "provider_output_file_id", batch_job.provider_output_file_id |
| 646 | + ) and status_result.get("error_file_id"): |
| 647 | + error_msg = _extract_batch_error_message( |
| 648 | + provider=provider, |
| 649 | + error_file_id=status_result["error_file_id"], |
| 650 | + batch_job=batch_job, |
| 651 | + session=session, |
| 652 | + ) |
| 653 | + |
| 654 | + eval_run = update_evaluation_run( |
| 655 | + session=session, |
| 656 | + eval_run=eval_run, |
| 657 | + status="failed", |
| 658 | + error_message=error_msg, |
| 659 | + ) |
| 660 | + |
| 661 | + logger.error( |
| 662 | + f"[check_and_process_evaluation] {log_prefix} Batch completed with all requests failed | {error_msg}" |
| 663 | + ) |
| 664 | + |
| 665 | + return { |
| 666 | + "run_id": eval_run.id, |
| 667 | + "run_name": eval_run.run_name, |
| 668 | + "previous_status": previous_status, |
| 669 | + "current_status": "failed", |
| 670 | + "provider_status": provider_status, |
| 671 | + "action": "failed", |
| 672 | + "error": error_msg, |
| 673 | + } |
| 674 | + |
571 | 675 | # Process the completed evaluation |
572 | 676 | await process_completed_evaluation( |
573 | 677 | eval_run=eval_run, |
|
0 commit comments