Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
8fc1f72
chore: ignore .worktrees/ directory
jedzill4 Apr 21, 2026
bc57c25
feat(asr): add TRANSCRIBE_SSE_KEEPALIVE_SECONDS setting
jedzill4 Apr 21, 2026
7e556d0
refactor(asr): extract lines_to_paragraphs helper
jedzill4 Apr 21, 2026
fcdfbc6
feat(asr): add transcribe_audio_bytes_stream async generator
jedzill4 Apr 21, 2026
0f1b1fa
feat(asr): add SSE event formatting helpers
jedzill4 Apr 21, 2026
f0c6f32
feat(asr): add POST /asr/transcribe/stream SSE endpoint
jedzill4 Apr 21, 2026
235f3c8
test(asr): verify SSE stream emits only done event on cache hit
jedzill4 Apr 21, 2026
04f1e01
test(asr): verify deprecated flag on POST /asr/transcribe
jedzill4 Apr 21, 2026
4aec6af
test(asr): verify SSE stream emits error event on upstream failure
jedzill4 Apr 21, 2026
dd78a17
feat(asr): add lead-budget backpressure pacer to prevent 1011 keepali…
jedzill4 Apr 22, 2026
78e6174
refactor(asr): clean up dead code, unify decode, improve consistency
jedzill4 Apr 22, 2026
a2a342b
fix(asr): preserve RuntimeError messages and normalize docstrings
jedzill4 Apr 22, 2026
6b868db
refactor(asr): extract shared helpers, add 5 missing test cases
jedzill4 Apr 22, 2026
c549cee
fix(asr): simplify backpressure handling in _stream_audio_bytes function
jedzill4 Apr 22, 2026
433b0a6
refactor(asr): remove dead backpressure implementation
jedzill4 Apr 22, 2026
9896310
chore(asr): enhance audio processing settings and improve SSE event e…
jedzill4 Apr 22, 2026
af63e29
fix(asr): exclude empty text paragraphs from transcribe response
jedzill4 Apr 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,9 @@ resources
.venv

aymurai/version.py

# Git worktrees
.worktrees/

# Local-only superpowers specs and plans (never committed)
docs/superpowers/
179 changes: 167 additions & 12 deletions aymurai/api/endpoints/routers/asr/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import asyncio
import contextlib
import json
from typing import AsyncGenerator
from uuid import UUID

from fastapi import Body, Depends, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRouter
from pydantic import UUID5
from sqlmodel import Session
Expand All @@ -9,7 +16,11 @@
NotFoundError,
UpstreamServiceError,
)
from aymurai.audio.asr_client import transcribe_audio_bytes
from aymurai.audio.asr_client import (
lines_to_paragraphs,
transcribe_audio_bytes,
transcribe_audio_bytes_stream,
)
from aymurai.database.crud.audio_transcription import (
audio_transcription_create_or_update,
audio_transcription_get,
Expand All @@ -20,6 +31,35 @@
from aymurai.meta.api_interfaces import ASRDocument, ASRParagraph, ASRParagraphRequest
from aymurai.settings import settings


def _format_transcription_event(
document_id: UUID,
paragraphs: list[ASRParagraph],
) -> str:
"""Format an active_transcription SSE event."""
payload = ASRDocument(
document_id=document_id, document=paragraphs
).model_dump_json()
return f"event: transcription\ndata: {payload}\n\n"


def _format_done_event(
document_id: UUID,
paragraphs: list[ASRParagraph],
) -> str:
"""Format the final 'done' SSE event."""
payload = ASRDocument(
document_id=document_id, document=paragraphs
).model_dump_json()
return f"event: done\ndata: {payload}\n\n"


def _format_error_event(detail: str, code: str) -> str:
"""Format an error SSE event."""
payload = json.dumps({"detail": detail, "code": code})
return f"event: error\ndata: {payload}\n\n"


router = APIRouter()
logger = get_logger(__name__)

Expand Down Expand Up @@ -71,19 +111,14 @@ async def _transcribe_audio_bytes_with_error_handling(
if not status:
raise AymuraiAPIException(detail="No transcription result received")

return [
ASRParagraph(
speaker_no=line.speaker,
speaker_id=f"speaker-{line.speaker}",
start=line.start,
end=line.end,
text=line.text,
)
for line in status.lines
]
return lines_to_paragraphs(status.lines)


@router.post("/transcribe", response_model=ASRDocument)
@router.post(
"/transcribe",
response_model=ASRDocument,
deprecated=True,
)
async def transcribe(
file: UploadFile,
use_cache: bool = True,
Expand Down Expand Up @@ -130,6 +165,126 @@ async def transcribe(
return document


@router.post("/transcribe/stream")
async def transcribe_stream(
file: UploadFile,
use_cache: bool = True,
ws_uri: str = Depends(get_transcribe_ws_uri),
session: Session = Depends(get_session),
) -> StreamingResponse:
"""
Transcribe an uploaded audio file and stream intermediate results as SSE.

Emits `event: transcription` frames per upstream update (cumulative snapshots),
a final `event: done` frame with the complete ASRDocument, and `event: error`
on upstream failure. Checks the cache first; on cache miss, persists the final
result.
"""
data = await file.read()
filename = file.filename
document_id = data_to_uuid(data)

async def _event_stream() -> AsyncGenerator[str, None]:
# Cache check
if use_cache:
cached = audio_transcription_get(
transcription_id=document_id, session=session
)
if cached is not None:
logger.debug(f"Audio transcription DB hit for {filename}")
cached_paragraphs = [
ASRParagraph.model_validate(p)
for p in (cached.validation or cached.transcription)
Comment on lines +191 to +193
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Potential crash if both cached.validation and cached.transcription are None

In the cache-hit branch, (cached.validation or cached.transcription) can evaluate to None if both are unset/false-y, causing a TypeError when iterated. If this state is possible (e.g., partially written rows), add a guard (e.g., default to [] or validate that at least one is populated) and handle the invalid cache state (skip cache or emit an error SSE).

]
yield _format_done_event(document_id, cached_paragraphs)
return

# Live streaming path
keepalive_task: asyncio.Task | None = None
keepalive_queue: asyncio.Queue[str] = asyncio.Queue()
interval = settings.TRANSCRIBE_SSE_KEEPALIVE_SECONDS

async def _keepalive_pump() -> None:
while True:
await asyncio.sleep(interval)
await keepalive_queue.put(": keepalive\n\n")

if interval > 0:
keepalive_task = asyncio.create_task(_keepalive_pump())

last_snapshot: list[ASRParagraph] = []
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()

try:
Comment on lines +211 to +216
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Iterator creation errors from transcribe_audio_bytes_stream bypass SSE-style error reporting

If transcribe_audio_bytes_stream fails during iterator creation (e.g., bad TRANSCRIBE_WS_URI), the exception escapes before the try/finally, so the client gets a plain 500 instead of an SSE error event. To keep SSE behavior consistent, wrap stream_iter construction in a try/except that yields an error via _format_error_event and then returns, mirroring the error handling used inside the loop.

Suggested change
last_snapshot: list[ASRParagraph] = []
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()
try:
last_snapshot: list[ASRParagraph] = []
try:
stream_iter = transcribe_audio_bytes_stream(data).__aiter__()
except Exception as exc:
formatted_error = _format_error_event(exc)
if formatted_error is not None:
yield formatted_error
return
try:

while True:
next_task = asyncio.create_task(stream_iter.__anext__())
keepalive_get = asyncio.create_task(keepalive_queue.get())

done, pending = await asyncio.wait(
{next_task, keepalive_get},
return_when=asyncio.FIRST_COMPLETED,
)

for task in pending:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

if next_task in done:
try:
snapshot = next_task.result()
except StopAsyncIteration:
break
except RuntimeError as exc:
logger.error("upstream error during streaming: %s", exc)
yield _format_error_event(
detail=str(exc), code="UPSTREAM_SERVICE_ERROR"
)
return
except Exception:
logger.exception("unexpected error during streaming")
yield _format_error_event(
detail="Unexpected error during transcription",
code="INTERNAL_ERROR",
)
return
last_snapshot = snapshot
yield _format_transcription_event(document_id, snapshot)
else:
# keepalive fired
yield keepalive_get.result()
finally:
if keepalive_task is not None and not keepalive_task.done():
keepalive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await keepalive_task
with contextlib.suppress(Exception):
await stream_iter.aclose() # type: ignore[attr-defined]

# Persist + done event
try:
audio_transcription_create_or_update(
transcription_id=document_id,
name=filename or str(document_id),
transcription=last_snapshot,
session=session,
)
logger.debug(f"Audio transcription stored in DB for {filename}")
except Exception:
logger.exception("failed to persist transcription; continuing")

yield _format_done_event(document_id, last_snapshot)

return StreamingResponse(
_event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)


@router.get("/validation/document/{document_id}")
async def asr_read_document_validation(
document_id: UUID5,
Expand Down
102 changes: 102 additions & 0 deletions aymurai/audio/asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import json
from pathlib import Path
from typing import AsyncGenerator

import librosa
import numpy as np
Expand All @@ -13,8 +14,10 @@
WLKMessageRawResponse,
WLKMessageReadyToStopMessage,
WLKMessageStatus,
WLKMessageTranscriptionLine,
)
from aymurai.logger import get_logger
from aymurai.meta.api_interfaces import ASRParagraph
from aymurai.settings import settings

logger = get_logger(__name__)
Expand All @@ -26,6 +29,29 @@
ASR_RAW_RESPONSE_ADAPTER = TypeAdapter(WLKMessageRawResponse)


def lines_to_paragraphs(
lines: list[WLKMessageTranscriptionLine],
) -> list[ASRParagraph]:
"""
Map WebSocket transcription lines to ASRParagraph objects.

Args:
lines (list[WLKMessageTranscriptionLine]): Transcription lines from the ASR service.

Returns:
list[ASRParagraph]: ASRParagraph objects ready for serialization or storage.
"""
return [
ASRParagraph(
speaker_no=line.speaker,
start=line.start,
end=line.end,
text=line.text,
)
for line in lines
Comment on lines +107 to +114
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): lines_to_paragraphs drops the speaker_id field that the previous mapping included

The old implementation populated both speaker_no and a derived speaker_id (e.g., f"speaker-{line.speaker}"). The new helper only sets speaker_no, start, end, and text, which changes the response shape. If any callers rely on speaker_id (e.g., for labeling or grouping), this could break them. If that field is still required, please add it back in lines_to_paragraphs or generate it consistently elsewhere.

]


async def _stream_audio_bytes(
payload: bytes,
websocket: websockets.ClientConnection,
Expand Down Expand Up @@ -171,6 +197,82 @@ async def transcribe_audio_bytes(payload: bytes) -> WLKMessageStatus | None:
return last_active_transcription


async def transcribe_audio_bytes_stream(
payload: bytes,
) -> AsyncGenerator[list[ASRParagraph], None]:
"""
Stream transcription updates from the ASR WebSocket service.

Yields a list[ASRParagraph] snapshot for each intermediate active_transcription
update received from the upstream service. The generator terminates when the
service sends a ready_to_stop message or the connection closes normally.

Args:
payload (bytes): The audio data to be transcribed.

Raises:
RuntimeError: If TRANSCRIBE_WS_URI is not configured or the upstream
websocket service errors out mid-stream.

Yields:
list[ASRParagraph]: Cumulative snapshot of paragraphs seen so far.
"""
ws_uri = settings.TRANSCRIBE_WS_URI

if not ws_uri:
raise RuntimeError("TRANSCRIBE_WS_URI is not configured")

logger.info("streaming audio for transcription (sse)")

streaming_task: asyncio.Task | None = None
try:
async with websockets.connect(ws_uri) as websocket:
streaming_task = asyncio.create_task(
_stream_and_signal_end(payload, websocket)
)
while True:
try:
msg = await websocket.recv()
except websockets.exceptions.ConnectionClosedOK:
logger.info("connection closed normally")
break
except websockets.exceptions.WebSocketException as exc:
logger.error("websocket error while receiving: %s", exc)
raise RuntimeError("Transcription service websocket error") from exc

parsed = _parse_ws_message(msg)
match parsed:
case None:
continue
case WLKMessageStatus(status="active_transcription") as message:
yield lines_to_paragraphs(message.lines)
case WLKMessageReadyToStopMessage():
break
except websockets.exceptions.WebSocketException as exc:
logger.error("websocket error during transcription: %s", exc)
raise RuntimeError("Transcription service websocket error") from exc
finally:
if streaming_task is not None:
if not streaming_task.done():
streaming_task.cancel()
try:
await streaming_task
except asyncio.CancelledError:
pass # expected when we cancelled it
except Exception as exc:
logger.error("audio streaming task failed: %s", exc)


async def _stream_and_signal_end(
payload: bytes,
websocket: "websockets.ClientConnection",
) -> None:
"""Stream audio bytes then send the empty end-of-stream marker."""
total_bytes = await _stream_audio_bytes(payload, websocket)
await websocket.send(b"")
logger.info("sent %s bytes to transcription service", total_bytes)


def transcribe_audio_path(path: Path) -> WLKMessageStatus | None:
"""
Transcribes an audio file at the given path by reading its bytes and sending them to the transcription service.
Expand Down
1 change: 1 addition & 0 deletions aymurai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def assemble_cors_origins(cls, v) -> list[str]:
# ASR Config
##########################################################################
TRANSCRIBE_WS_URI: str | None = None
TRANSCRIBE_SSE_KEEPALIVE_SECONDS: int = 15

##########################################################################
# Disambiguation Config
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ dev = [
"dspy>=3.0.4",
"playwright==1.56.0",
]
tests = ["pytest>=9.0.2"]
tests = [
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
]

[tool.setuptools.packages.find]
include = ["aymurai"]
Expand All @@ -134,6 +137,7 @@ disable = "C0330, C0326"

[tool.pytest.ini_options]
# addopts = "-m 'not integration and not slow'"
asyncio_mode = "auto"
markers = [
"integration: marks tests that exercise real external integrations",
"slow: marks tests that are expected to take longer to run",
Expand Down
Loading