-
Notifications
You must be signed in to change notification settings - Fork 0
feat(asr): add SSE streaming endpoint with lead-budget backpressure #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/v2.0.0
Are you sure you want to change the base?
Changes from 9 commits
8fc1f72
bc57c25
7e556d0
fcdfbc6
0f1b1fa
f0c6f32
235f3c8
04f1e01
4aec6af
dd78a17
78e6174
a2a342b
6b868db
c549cee
433b0a6
9896310
af63e29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,11 @@ | ||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||
| import contextlib | ||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||
| from typing import AsyncGenerator | ||||||||||||||||||||||||||||||
| from uuid import UUID | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from fastapi import Body, Depends, UploadFile | ||||||||||||||||||||||||||||||
| from fastapi.responses import StreamingResponse | ||||||||||||||||||||||||||||||
| from fastapi.routing import APIRouter | ||||||||||||||||||||||||||||||
| from pydantic import UUID5 | ||||||||||||||||||||||||||||||
| from sqlmodel import Session | ||||||||||||||||||||||||||||||
|
|
@@ -9,7 +16,11 @@ | |||||||||||||||||||||||||||||
| NotFoundError, | ||||||||||||||||||||||||||||||
| UpstreamServiceError, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| from aymurai.audio.asr_client import transcribe_audio_bytes | ||||||||||||||||||||||||||||||
| from aymurai.audio.asr_client import ( | ||||||||||||||||||||||||||||||
| lines_to_paragraphs, | ||||||||||||||||||||||||||||||
| transcribe_audio_bytes, | ||||||||||||||||||||||||||||||
| transcribe_audio_bytes_stream, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| from aymurai.database.crud.audio_transcription import ( | ||||||||||||||||||||||||||||||
| audio_transcription_create_or_update, | ||||||||||||||||||||||||||||||
| audio_transcription_get, | ||||||||||||||||||||||||||||||
|
|
@@ -20,6 +31,35 @@ | |||||||||||||||||||||||||||||
| from aymurai.meta.api_interfaces import ASRDocument, ASRParagraph, ASRParagraphRequest | ||||||||||||||||||||||||||||||
| from aymurai.settings import settings | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _format_transcription_event( | ||||||||||||||||||||||||||||||
| document_id: UUID, | ||||||||||||||||||||||||||||||
| paragraphs: list[ASRParagraph], | ||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||
| """Format an active_transcription SSE event.""" | ||||||||||||||||||||||||||||||
| payload = ASRDocument( | ||||||||||||||||||||||||||||||
| document_id=document_id, document=paragraphs | ||||||||||||||||||||||||||||||
| ).model_dump_json() | ||||||||||||||||||||||||||||||
| return f"event: transcription\ndata: {payload}\n\n" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _format_done_event( | ||||||||||||||||||||||||||||||
| document_id: UUID, | ||||||||||||||||||||||||||||||
| paragraphs: list[ASRParagraph], | ||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||
| """Format the final 'done' SSE event.""" | ||||||||||||||||||||||||||||||
| payload = ASRDocument( | ||||||||||||||||||||||||||||||
| document_id=document_id, document=paragraphs | ||||||||||||||||||||||||||||||
| ).model_dump_json() | ||||||||||||||||||||||||||||||
| return f"event: done\ndata: {payload}\n\n" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _format_error_event(detail: str, code: str) -> str: | ||||||||||||||||||||||||||||||
| """Format an error SSE event.""" | ||||||||||||||||||||||||||||||
| payload = json.dumps({"detail": detail, "code": code}) | ||||||||||||||||||||||||||||||
| return f"event: error\ndata: {payload}\n\n" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| router = APIRouter() | ||||||||||||||||||||||||||||||
| logger = get_logger(__name__) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -71,19 +111,14 @@ async def _transcribe_audio_bytes_with_error_handling( | |||||||||||||||||||||||||||||
| if not status: | ||||||||||||||||||||||||||||||
| raise AymuraiAPIException(detail="No transcription result received") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||
| ASRParagraph( | ||||||||||||||||||||||||||||||
| speaker_no=line.speaker, | ||||||||||||||||||||||||||||||
| speaker_id=f"speaker-{line.speaker}", | ||||||||||||||||||||||||||||||
| start=line.start, | ||||||||||||||||||||||||||||||
| end=line.end, | ||||||||||||||||||||||||||||||
| text=line.text, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| for line in status.lines | ||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||
| return lines_to_paragraphs(status.lines) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @router.post("/transcribe", response_model=ASRDocument) | ||||||||||||||||||||||||||||||
| @router.post( | ||||||||||||||||||||||||||||||
| "/transcribe", | ||||||||||||||||||||||||||||||
| response_model=ASRDocument, | ||||||||||||||||||||||||||||||
| deprecated=True, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| async def transcribe( | ||||||||||||||||||||||||||||||
| file: UploadFile, | ||||||||||||||||||||||||||||||
| use_cache: bool = True, | ||||||||||||||||||||||||||||||
|
|
@@ -130,6 +165,126 @@ async def transcribe( | |||||||||||||||||||||||||||||
| return document | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @router.post("/transcribe/stream") | ||||||||||||||||||||||||||||||
| async def transcribe_stream( | ||||||||||||||||||||||||||||||
| file: UploadFile, | ||||||||||||||||||||||||||||||
| use_cache: bool = True, | ||||||||||||||||||||||||||||||
| ws_uri: str = Depends(get_transcribe_ws_uri), | ||||||||||||||||||||||||||||||
| session: Session = Depends(get_session), | ||||||||||||||||||||||||||||||
| ) -> StreamingResponse: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Transcribe an uploaded audio file and stream intermediate results as SSE. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Emits `event: transcription` frames per upstream update (cumulative snapshots), | ||||||||||||||||||||||||||||||
| a final `event: done` frame with the complete ASRDocument, and `event: error` | ||||||||||||||||||||||||||||||
| on upstream failure. Checks the cache first; on cache miss, persists the final | ||||||||||||||||||||||||||||||
| result. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| data = await file.read() | ||||||||||||||||||||||||||||||
| filename = file.filename | ||||||||||||||||||||||||||||||
| document_id = data_to_uuid(data) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| async def _event_stream() -> AsyncGenerator[str, None]: | ||||||||||||||||||||||||||||||
| # Cache check | ||||||||||||||||||||||||||||||
| if use_cache: | ||||||||||||||||||||||||||||||
| cached = audio_transcription_get( | ||||||||||||||||||||||||||||||
| transcription_id=document_id, session=session | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| if cached is not None: | ||||||||||||||||||||||||||||||
| logger.debug(f"Audio transcription DB hit for {filename}") | ||||||||||||||||||||||||||||||
| cached_paragraphs = [ | ||||||||||||||||||||||||||||||
| ASRParagraph.model_validate(p) | ||||||||||||||||||||||||||||||
| for p in (cached.validation or cached.transcription) | ||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||
| yield _format_done_event(document_id, cached_paragraphs) | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Live streaming path | ||||||||||||||||||||||||||||||
| keepalive_task: asyncio.Task | None = None | ||||||||||||||||||||||||||||||
| keepalive_queue: asyncio.Queue[str] = asyncio.Queue() | ||||||||||||||||||||||||||||||
| interval = settings.TRANSCRIBE_SSE_KEEPALIVE_SECONDS | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| async def _keepalive_pump() -> None: | ||||||||||||||||||||||||||||||
| while True: | ||||||||||||||||||||||||||||||
| await asyncio.sleep(interval) | ||||||||||||||||||||||||||||||
| await keepalive_queue.put(": keepalive\n\n") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if interval > 0: | ||||||||||||||||||||||||||||||
| keepalive_task = asyncio.create_task(_keepalive_pump()) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| last_snapshot: list[ASRParagraph] = [] | ||||||||||||||||||||||||||||||
| stream_iter = transcribe_audio_bytes_stream(data).__aiter__() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
|
Comment on lines
+211
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Iterator creation errors from transcribe_audio_bytes_stream bypass SSE-style error reporting If
Suggested change
|
||||||||||||||||||||||||||||||
| while True: | ||||||||||||||||||||||||||||||
| next_task = asyncio.create_task(stream_iter.__anext__()) | ||||||||||||||||||||||||||||||
| keepalive_get = asyncio.create_task(keepalive_queue.get()) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| done, pending = await asyncio.wait( | ||||||||||||||||||||||||||||||
| {next_task, keepalive_get}, | ||||||||||||||||||||||||||||||
| return_when=asyncio.FIRST_COMPLETED, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| for task in pending: | ||||||||||||||||||||||||||||||
| task.cancel() | ||||||||||||||||||||||||||||||
| with contextlib.suppress(asyncio.CancelledError): | ||||||||||||||||||||||||||||||
| await task | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if next_task in done: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| snapshot = next_task.result() | ||||||||||||||||||||||||||||||
| except StopAsyncIteration: | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
| except RuntimeError as exc: | ||||||||||||||||||||||||||||||
| logger.error("upstream error during streaming: %s", exc) | ||||||||||||||||||||||||||||||
| yield _format_error_event( | ||||||||||||||||||||||||||||||
| detail=str(exc), code="UPSTREAM_SERVICE_ERROR" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||
| logger.exception("unexpected error during streaming") | ||||||||||||||||||||||||||||||
| yield _format_error_event( | ||||||||||||||||||||||||||||||
| detail="Unexpected error during transcription", | ||||||||||||||||||||||||||||||
| code="INTERNAL_ERROR", | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| last_snapshot = snapshot | ||||||||||||||||||||||||||||||
| yield _format_transcription_event(document_id, snapshot) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| # keepalive fired | ||||||||||||||||||||||||||||||
| yield keepalive_get.result() | ||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||
| if keepalive_task is not None and not keepalive_task.done(): | ||||||||||||||||||||||||||||||
| keepalive_task.cancel() | ||||||||||||||||||||||||||||||
| with contextlib.suppress(asyncio.CancelledError): | ||||||||||||||||||||||||||||||
| await keepalive_task | ||||||||||||||||||||||||||||||
| with contextlib.suppress(Exception): | ||||||||||||||||||||||||||||||
| await stream_iter.aclose() # type: ignore[attr-defined] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Persist + done event | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| audio_transcription_create_or_update( | ||||||||||||||||||||||||||||||
| transcription_id=document_id, | ||||||||||||||||||||||||||||||
| name=filename or str(document_id), | ||||||||||||||||||||||||||||||
| transcription=last_snapshot, | ||||||||||||||||||||||||||||||
| session=session, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| logger.debug(f"Audio transcription stored in DB for {filename}") | ||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||
| logger.exception("failed to persist transcription; continuing") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| yield _format_done_event(document_id, last_snapshot) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return StreamingResponse( | ||||||||||||||||||||||||||||||
| _event_stream(), | ||||||||||||||||||||||||||||||
| media_type="text/event-stream", | ||||||||||||||||||||||||||||||
| headers={ | ||||||||||||||||||||||||||||||
| "Cache-Control": "no-cache", | ||||||||||||||||||||||||||||||
| "X-Accel-Buffering": "no", | ||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @router.get("/validation/document/{document_id}") | ||||||||||||||||||||||||||||||
| async def asr_read_document_validation( | ||||||||||||||||||||||||||||||
| document_id: UUID5, | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import io | ||
| import json | ||
| from pathlib import Path | ||
| from typing import AsyncGenerator | ||
|
|
||
| import librosa | ||
| import numpy as np | ||
|
|
@@ -13,8 +14,10 @@ | |
| WLKMessageRawResponse, | ||
| WLKMessageReadyToStopMessage, | ||
| WLKMessageStatus, | ||
| WLKMessageTranscriptionLine, | ||
| ) | ||
| from aymurai.logger import get_logger | ||
| from aymurai.meta.api_interfaces import ASRParagraph | ||
| from aymurai.settings import settings | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
@@ -26,6 +29,29 @@ | |
| ASR_RAW_RESPONSE_ADAPTER = TypeAdapter(WLKMessageRawResponse) | ||
|
|
||
|
|
||
| def lines_to_paragraphs( | ||
| lines: list[WLKMessageTranscriptionLine], | ||
| ) -> list[ASRParagraph]: | ||
| """ | ||
| Map WebSocket transcription lines to ASRParagraph objects. | ||
|
|
||
| Args: | ||
| lines (list[WLKMessageTranscriptionLine]): Transcription lines from the ASR service. | ||
|
|
||
| Returns: | ||
| list[ASRParagraph]: ASRParagraph objects ready for serialization or storage. | ||
| """ | ||
| return [ | ||
| ASRParagraph( | ||
| speaker_no=line.speaker, | ||
| start=line.start, | ||
| end=line.end, | ||
| text=line.text, | ||
| ) | ||
| for line in lines | ||
|
Comment on lines
+107
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): lines_to_paragraphs drops the speaker_id field that the previous mapping included The old implementation populated both speaker_no and a derived speaker_id (e.g., f"speaker-{line.speaker}"). The new helper only sets speaker_no, start, end, and text, which changes the response shape. If any callers rely on speaker_id (e.g., for labeling or grouping), this could break them. If that field is still required, please add it back in lines_to_paragraphs or generate it consistently elsewhere. |
||
| ] | ||
|
|
||
|
|
||
| async def _stream_audio_bytes( | ||
| payload: bytes, | ||
| websocket: websockets.ClientConnection, | ||
|
|
@@ -171,6 +197,82 @@ async def transcribe_audio_bytes(payload: bytes) -> WLKMessageStatus | None: | |
| return last_active_transcription | ||
|
|
||
|
|
||
| async def transcribe_audio_bytes_stream( | ||
| payload: bytes, | ||
| ) -> AsyncGenerator[list[ASRParagraph], None]: | ||
| """ | ||
| Stream transcription updates from the ASR WebSocket service. | ||
|
|
||
| Yields a list[ASRParagraph] snapshot for each intermediate active_transcription | ||
| update received from the upstream service. The generator terminates when the | ||
| service sends a ready_to_stop message or the connection closes normally. | ||
|
|
||
| Args: | ||
| payload (bytes): The audio data to be transcribed. | ||
|
|
||
| Raises: | ||
| RuntimeError: If TRANSCRIBE_WS_URI is not configured or the upstream | ||
| websocket service errors out mid-stream. | ||
|
|
||
| Yields: | ||
| list[ASRParagraph]: Cumulative snapshot of paragraphs seen so far. | ||
| """ | ||
| ws_uri = settings.TRANSCRIBE_WS_URI | ||
|
|
||
| if not ws_uri: | ||
| raise RuntimeError("TRANSCRIBE_WS_URI is not configured") | ||
|
|
||
| logger.info("streaming audio for transcription (sse)") | ||
|
|
||
| streaming_task: asyncio.Task | None = None | ||
| try: | ||
| async with websockets.connect(ws_uri) as websocket: | ||
| streaming_task = asyncio.create_task( | ||
| _stream_and_signal_end(payload, websocket) | ||
| ) | ||
| while True: | ||
| try: | ||
| msg = await websocket.recv() | ||
| except websockets.exceptions.ConnectionClosedOK: | ||
| logger.info("connection closed normally") | ||
| break | ||
| except websockets.exceptions.WebSocketException as exc: | ||
| logger.error("websocket error while receiving: %s", exc) | ||
| raise RuntimeError("Transcription service websocket error") from exc | ||
|
|
||
| parsed = _parse_ws_message(msg) | ||
| match parsed: | ||
| case None: | ||
| continue | ||
| case WLKMessageStatus(status="active_transcription") as message: | ||
| yield lines_to_paragraphs(message.lines) | ||
| case WLKMessageReadyToStopMessage(): | ||
| break | ||
| except websockets.exceptions.WebSocketException as exc: | ||
| logger.error("websocket error during transcription: %s", exc) | ||
| raise RuntimeError("Transcription service websocket error") from exc | ||
| finally: | ||
| if streaming_task is not None: | ||
| if not streaming_task.done(): | ||
| streaming_task.cancel() | ||
| try: | ||
| await streaming_task | ||
| except asyncio.CancelledError: | ||
| pass # expected when we cancelled it | ||
| except Exception as exc: | ||
| logger.error("audio streaming task failed: %s", exc) | ||
|
|
||
|
|
||
| async def _stream_and_signal_end( | ||
| payload: bytes, | ||
| websocket: "websockets.ClientConnection", | ||
| ) -> None: | ||
| """Stream audio bytes then send the empty end-of-stream marker.""" | ||
| total_bytes = await _stream_audio_bytes(payload, websocket) | ||
| await websocket.send(b"") | ||
| logger.info("sent %s bytes to transcription service", total_bytes) | ||
|
|
||
|
|
||
| def transcribe_audio_path(path: Path) -> WLKMessageStatus | None: | ||
| """ | ||
| Transcribes an audio file at the given path by reading its bytes and sending them to the transcription service. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Potential crash if both cached.validation and cached.transcription are None
In the cache-hit branch,
(cached.validation or cached.transcription)can evaluate toNoneif both are unset/false-y, causing aTypeErrorwhen iterated. If this state is possible (e.g., partially written rows), add a guard (e.g., default to[]or validate that at least one is populated) and handle the invalid cache state (skip cache or emit an error SSE).