Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
8fc1f72
chore: ignore .worktrees/ directory
jedzill4 Apr 21, 2026
bc57c25
feat(asr): add TRANSCRIBE_SSE_KEEPALIVE_SECONDS setting
jedzill4 Apr 21, 2026
7e556d0
refactor(asr): extract lines_to_paragraphs helper
jedzill4 Apr 21, 2026
fcdfbc6
feat(asr): add transcribe_audio_bytes_stream async generator
jedzill4 Apr 21, 2026
0f1b1fa
feat(asr): add SSE event formatting helpers
jedzill4 Apr 21, 2026
f0c6f32
feat(asr): add POST /asr/transcribe/stream SSE endpoint
jedzill4 Apr 21, 2026
235f3c8
test(asr): verify SSE stream emits only done event on cache hit
jedzill4 Apr 21, 2026
04f1e01
test(asr): verify deprecated flag on POST /asr/transcribe
jedzill4 Apr 21, 2026
4aec6af
test(asr): verify SSE stream emits error event on upstream failure
jedzill4 Apr 21, 2026
dd78a17
feat(asr): add lead-budget backpressure pacer to prevent 1011 keepali…
jedzill4 Apr 22, 2026
78e6174
refactor(asr): clean up dead code, unify decode, improve consistency
jedzill4 Apr 22, 2026
a2a342b
fix(asr): preserve RuntimeError messages and normalize docstrings
jedzill4 Apr 22, 2026
6b868db
refactor(asr): extract shared helpers, add 5 missing test cases
jedzill4 Apr 22, 2026
c549cee
fix(asr): simplify backpressure handling in _stream_audio_bytes function
jedzill4 Apr 22, 2026
433b0a6
refactor(asr): remove dead backpressure implementation
jedzill4 Apr 22, 2026
9896310
chore(asr): enhance audio processing settings and improve SSE event e…
jedzill4 Apr 22, 2026
af63e29
fix(asr): exclude empty text paragraphs from transcribe response
jedzill4 Apr 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,9 @@ resources
.venv

aymurai/version.py

# Git worktrees
.worktrees/

# Local-only superpowers specs and plans (never committed)
docs/superpowers/
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"subProcess": false,
"envFile": "${workspaceFolder}/.env",
"python": "${workspaceFolder}/.venv/bin/python",
"preLaunchTask": "Start Ollama service"
// "preLaunchTask": "Start Ollama service"
},
]
}
265 changes: 224 additions & 41 deletions aymurai/api/endpoints/routers/asr/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import asyncio
import contextlib
import json
from typing import AsyncGenerator
from uuid import UUID

from fastapi import Body, Depends, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRouter
from pydantic import UUID5
from sqlmodel import Session
Expand All @@ -9,7 +16,12 @@
NotFoundError,
UpstreamServiceError,
)
from aymurai.audio.asr_client import transcribe_audio_bytes
from aymurai.audio.asr_client import (
ASRStreamChunk,
lines_to_paragraphs,
transcribe_audio_bytes,
transcribe_audio_bytes_stream,
)
from aymurai.database.crud.audio_transcription import (
audio_transcription_create_or_update,
audio_transcription_get,
Expand All @@ -20,19 +32,42 @@
from aymurai.meta.api_interfaces import ASRDocument, ASRParagraph, ASRParagraphRequest
from aymurai.settings import settings


def _format_sse_event(
event_name: str,
document_id: UUID,
paragraphs: list[ASRParagraph],
current_time: float | None = None,
total_time: float | None = None,
) -> str:
"""Format an SSE event with an ASRDocument payload."""
payload = ASRDocument(
document_id=document_id,
document=paragraphs,
current_time=current_time,
total_time=total_time,
).model_dump_json()
return f"event: {event_name}\ndata: {payload}\n\n"


def _format_error_event(detail: str, code: str) -> str:
"""Format an error SSE event."""
payload = json.dumps({"detail": detail, "code": code})
return f"event: error\ndata: {payload}\n\n"


router = APIRouter()
logger = get_logger(__name__)


def get_transcribe_ws_uri() -> str:
"""
Get the WebSocket URI for the transcription service from settings.
"""Get the WebSocket URI for the transcription service from settings.

Raises:
ConfigurationError: If the WebSocket URI is not configured in settings.
ConfigurationError: If the WebSocket URI is not configured.

Returns:
str: The WebSocket URI for the transcription service.
The WebSocket URI for the transcription service.
"""
ws_uri = settings.TRANSCRIBE_WS_URI
if not ws_uri:
Expand All @@ -43,18 +78,17 @@ def get_transcribe_ws_uri() -> str:
async def _transcribe_audio_bytes_with_error_handling(
data: bytes,
) -> list[ASRParagraph]:
"""
Transcribes audio bytes into a list of ASRParagraph objects.
"""Transcribe audio bytes into a list of ASRParagraph objects.

Args:
data (bytes): The audio data to be transcribed.
data: The audio data to be transcribed.

Raises:
UpstreamServiceError: If there is an error with the upstream transcription service.
UpstreamServiceError: If there is an error with the upstream service.
AymuraiAPIException: If there is an unexpected error during transcription.

Returns:
list[ASRParagraph]: A list of ASRParagraph objects representing the transcribed audio.
A list of ASRParagraph objects representing the transcribed audio.
"""
try:
status = await transcribe_audio_bytes(data)
Expand All @@ -71,36 +105,32 @@ async def _transcribe_audio_bytes_with_error_handling(
if not status:
raise AymuraiAPIException(detail="No transcription result received")

return [
ASRParagraph(
speaker_no=line.speaker,
speaker_id=f"speaker-{line.speaker}",
start=line.start,
end=line.end,
text=line.text,
)
for line in status.lines
]
return lines_to_paragraphs(status.lines)


@router.post("/transcribe", response_model=ASRDocument)
@router.post(
"/transcribe",
response_model=ASRDocument,
deprecated=True,
)
async def transcribe(
file: UploadFile,
use_cache: bool = True,
ws_uri: str = Depends(get_transcribe_ws_uri),
session: Session = Depends(get_session),
) -> ASRDocument:
"""
Transcribes an uploaded audio file and returns the transcribed document.
"""Transcribe an uploaded audio file and return the transcribed document.

Args:
file (UploadFile): The audio file to be transcribed.
use_cache (bool, optional): Whether to use cached transcription results. Defaults to True.
ws_uri (str, optional): The WebSocket URI for the transcription service. Defaults to Depends(get_transcribe_ws_uri).
session (Session, optional): The database session. Defaults to Depends(get_session).
file: The audio file to be transcribed.
use_cache: Whether to use cached transcription results.
ws_uri: The WebSocket URI for the transcription service (injected via
``Depends`` for configuration validation — the value itself is read
from settings by the ASR client).
session: The database session.

Returns:
ASRDocument: The transcribed audio document.
The transcribed audio document.
"""
data = await file.read()
document_id = data_to_uuid(data)
Expand All @@ -110,44 +140,198 @@ async def transcribe(
transcription_id=document_id, session=session
)
if cached_record is not None:
logger.debug(f"Audio transcription DB hit for {file.filename}")
logger.debug("Audio transcription DB hit for %s", file.filename)
cached_document = ASRDocument(
document_id=document_id,
document=cached_record.validation or cached_record.transcription,
)
return cached_document

transcription_items = await _transcribe_audio_bytes_with_error_handling(data)
transcription_items = [p for p in transcription_items if p.text.strip()]
document = ASRDocument(document_id=document_id, document=transcription_items)
audio_transcription_create_or_update(
transcription_id=document_id,
name=file.filename or str(document_id),
transcription=document.document,
session=session,
)
logger.debug(f"Audio transcription stored in DB for {file.filename}")
logger.debug("Audio transcription stored in DB for %s", file.filename)

return document


@router.post("/transcribe/stream")
async def transcribe_stream(
file: UploadFile,
use_cache: bool = True,
ws_uri: str = Depends(get_transcribe_ws_uri),
session: Session = Depends(get_session),
) -> StreamingResponse:
"""
Transcribe an uploaded audio file and stream intermediate results as SSE.

Emits `event: transcription` frames per upstream update (cumulative snapshots),
a final `event: done` frame with the complete ASRDocument, and `event: error`
on upstream failure. Checks the cache first; on cache miss, persists the final
result.
"""
data = await file.read()
filename = file.filename
document_id = data_to_uuid(data)

async def _event_stream() -> AsyncGenerator[str, None]:
# Cache check
if use_cache:
cached = audio_transcription_get(
transcription_id=document_id, session=session
)
if cached is not None:
logger.debug("Audio transcription DB hit for %s", filename)
cached_paragraphs = [
ASRParagraph.model_validate(p)
for p in (cached.validation or cached.transcription)
Comment on lines +191 to +193
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Potential crash if both cached.validation and cached.transcription are None

In the cache-hit branch, (cached.validation or cached.transcription) can evaluate to None if both are unset/false-y, causing a TypeError when iterated. If this state is possible (e.g., partially written rows), add a guard (e.g., default to [] or validate that at least one is populated) and handle the invalid cache state (skip cache or emit an error SSE).

]
yield _format_sse_event("done", document_id, cached_paragraphs)
return

# Live streaming path
keepalive_task: asyncio.Task | None = None
keepalive_queue: asyncio.Queue[str] = asyncio.Queue()
interval = settings.TRANSCRIBE_SSE_KEEPALIVE_SECONDS

async def _keepalive_pump() -> None:
while True:
await asyncio.sleep(interval)
await keepalive_queue.put(": keepalive\n\n")

if interval > 0:
keepalive_task = asyncio.create_task(_keepalive_pump())

last_snapshot: list[ASRParagraph] = []
last_current_time: float | None = None
last_total_time: float | None = None
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()

try:
Comment on lines +211 to +216
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Iterator creation errors from transcribe_audio_bytes_stream bypass SSE-style error reporting

If transcribe_audio_bytes_stream fails during iterator creation (e.g., bad TRANSCRIBE_WS_URI), the exception escapes before the try/finally, so the client gets a plain 500 instead of an SSE error event. To keep SSE behavior consistent, wrap stream_iter construction in a try/except that yields an error via _format_error_event and then returns, mirroring the error handling used inside the loop.

Suggested change
last_snapshot: list[ASRParagraph] = []
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()
try:
last_snapshot: list[ASRParagraph] = []
try:
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()
except Exception as exc:
formatted_error = _format_error_event(exc)
if formatted_error is not None:
yield formatted_error
return
try:

# next_task is created once and reused across keepalive interruptions so
# that cancelling the keepalive_get task never aborts the in-flight
# __anext__() call.
next_task: asyncio.Task[ASRStreamChunk] = asyncio.create_task(
stream_iter.__anext__() # pyrefly: ignore[bad-argument-type]
)
while True:
keepalive_get = asyncio.create_task(keepalive_queue.get())

done, _ = await asyncio.wait(
{next_task, keepalive_get},
return_when=asyncio.FIRST_COMPLETED,
)

if keepalive_get in done:
yield keepalive_get.result()
else:
# keepalive_get lost the race - discard it cleanly
keepalive_get.cancel()
with contextlib.suppress(asyncio.CancelledError):
await keepalive_get

if next_task in done:
try:
chunk: ASRStreamChunk = next_task.result()
except StopAsyncIteration:
break
except RuntimeError as exc:
logger.error("upstream error during streaming: %s", exc)
yield _format_error_event(
detail=str(exc), code="UPSTREAM_SERVICE_ERROR"
)
return
except Exception:
logger.exception("unexpected error during streaming")
yield _format_error_event(
detail="Unexpected error during transcription",
code="INTERNAL_ERROR",
)
return

last_snapshot = chunk.paragraphs
last_current_time = chunk.current_time
last_total_time = chunk.total_time

# emit an SSE event only if transcription content is present in the chunk
paragraphs = [
paragraph
for paragraph in chunk.paragraphs
if paragraph.text.strip()
]
if paragraphs:
yield _format_sse_event(
"transcription",
document_id,
paragraphs,
current_time=chunk.current_time,
total_time=chunk.total_time,
)

# Advance to the next chunk only after the current one is consumed
next_task = asyncio.create_task(
stream_iter.__anext__() # pyrefly: ignore[bad-argument-type]
)
finally:
if keepalive_task is not None and not keepalive_task.done():
keepalive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await keepalive_task
with contextlib.suppress(Exception):
await stream_iter.aclose() # type: ignore[attr-defined]

# Persist + done event
try:
audio_transcription_create_or_update(
transcription_id=document_id,
name=filename or str(document_id),
transcription=last_snapshot,
session=session,
)
logger.debug("Audio transcription stored in DB for %s", filename)
except Exception:
logger.exception("failed to persist transcription; continuing")

yield _format_sse_event(
"done",
document_id,
paragraphs=[p for p in last_snapshot if p.text.strip()],
current_time=last_current_time,
total_time=last_total_time,
)

return StreamingResponse(
_event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)


@router.get("/validation/document/{document_id}")
async def asr_read_document_validation(
document_id: UUID5,
session: Session = Depends(get_session),
) -> ASRDocument | None:
"""
Retrieves the validation document for a given document ID.
"""Retrieve the validation document for a given document ID.

Args:
document_id (UUID5): The ID of the document to retrieve.
session (Session, optional): The database session. Defaults to Depends(get_session).

document_id: The ID of the document to retrieve.
session: The database session.

Raises:
NotFoundError: If the document with the given ID is not found.

Returns:
ASRDocument | None: The validation document if found, otherwise None.
The validation document if found, otherwise None.
"""
record = audio_transcription_get(transcription_id=document_id, session=session)
if not record:
Expand All @@ -165,13 +349,12 @@ async def asr_save_document_validation(
annotations: list[ASRParagraphRequest] = Body(...),
session: Session = Depends(get_session),
) -> None:
"""
Saves the validation annotations for a given document ID.
"""Save validation annotations for a given document ID.

Args:
document_id (UUID5): The ID of the document to validate.
annotations (list[ASRParagraphRequest], optional): The list of annotations for the document. Defaults to Body(...).
session (Session, optional): The database session. Defaults to Depends(get_session).
document_id: The ID of the document to validate.
annotations: The list of annotations for the document.
session: The database session.

Raises:
NotFoundError: If the document with the given ID is not found.
Expand Down
Loading