Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 103 additions & 36 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import sys
import time

import torch

Expand All @@ -24,9 +25,8 @@
from aiortc.rtcrtpsender import RTCRtpSender
from pipeline import Pipeline
from twilio.rest import Client
from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter
from metrics import MetricsManager, StreamStatsManager
import time
from utils import patch_loop_datagram, add_prefix_to_app_routes
from metrics import MetricsManager, StreamStatsManager, TrackStats

logger = logging.getLogger(__name__)
logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING)
Expand All @@ -38,7 +38,7 @@


class VideoStreamTrack(MediaStreamTrack):
"""video stream track that processes video frames using a pipeline.
"""Video stream track that processes video frames using a pipeline.

Attributes:
kind (str): The kind of media, which is "video" for this class.
Expand All @@ -54,17 +54,22 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
Args:
track: The underlying media stream track.
pipeline: The processing pipeline to apply to each video frame.
stats: The stream statistics.
"""
self._start_time = time.monotonic()
super().__init__()
self.track = track
self.pipeline = pipeline
self.fps_meter = FPSMeter(
metrics_manager=app["metrics_manager"], track_id=track.id
self.stats = TrackStats(
track_id=track.id,
track_kind="video",
metrics_manager=app.get("metrics_manager", None),
)
self.running = True
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
self._running = True

asyncio.create_task(self.collect_frames())

# Add cleanup when track ends.
@track.on("ended")
async def on_ended():
logger.info("Source video track ended, stopping collection")
Expand All @@ -75,7 +80,7 @@ async def collect_frames(self):
the processing pipeline. Stops when track ends or connection closes.
"""
try:
while self.running:
while self._running:
try:
frame = await self.track.recv()
await self.pipeline.put_video_frame(frame)
Expand All @@ -87,9 +92,9 @@ async def collect_frames(self):
logger.info("Media stream ended")
else:
logger.error(f"Error collecting video frames: {str(e)}")
self.running = False
self._running = False
break

# Perform cleanup outside the exception handler
logger.info("Video frame collection stopped")
except asyncio.CancelledError:
Expand All @@ -100,28 +105,57 @@ async def collect_frames(self):
await self.pipeline.cleanup()

async def recv(self):
"""Receive a processed video frame from the pipeline, increment the frame
count for FPS calculation and return the processed frame to the client.
"""Receive a processed video frame from the pipeline and return it to the
client, while collecting statistics about the stream.
"""
if self.stats.startup_time is None:
self.stats.start_timestamp = time.monotonic()
self.stats.startup_time = self.stats.start_timestamp - self._start_time
self.stats.pipeline.video_warmup_time = (
self.pipeline.stats.video_warmup_time
)

processed_frame = await self.pipeline.get_processed_video_frame()

# Increment the frame count to calculate FPS.
await self.fps_meter.increment_frame_count()
await self.stats.fps_meter.increment_frame_count()

return processed_frame


class AudioStreamTrack(MediaStreamTrack):
"""Audio stream track that processes audio frames using a pipeline.

Attributes:
kind (str): The kind of media, which is "audio" for this class.
track (MediaStreamTrack): The underlying media stream track.
pipeline (Pipeline): The processing pipeline to apply to each audio frame.
"""

kind = "audio"

def __init__(self, track: MediaStreamTrack, pipeline):
"""Initialize the AudioStreamTrack.

Args:
track: The underlying media stream track.
pipeline: The processing pipeline to apply to each audio frame.
stats: The stream statistics.
"""
self._start_time = time.monotonic()
super().__init__()
self.track = track
self.pipeline = pipeline
self.running = True
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
self.stats = TrackStats(
track_id=track.id,
track_kind="audio",
metrics_manager=app.get("metrics_manager", None),
)
self._running = True

asyncio.create_task(self.collect_frames())

# Add cleanup when track ends.
@track.on("ended")
async def on_ended():
logger.info("Source audio track ended, stopping collection")
Expand All @@ -132,7 +166,7 @@ async def collect_frames(self):
the processing pipeline. Stops when track ends or connection closes.
"""
try:
while self.running:
while self._running:
try:
frame = await self.track.recv()
await self.pipeline.put_audio_frame(frame)
Expand All @@ -144,9 +178,9 @@ async def collect_frames(self):
logger.info("Media stream ended")
else:
logger.error(f"Error collecting audio frames: {str(e)}")
self.running = False
self._running = False
break

# Perform cleanup outside the exception handler
logger.info("Audio frame collection stopped")
except asyncio.CancelledError:
Expand All @@ -157,7 +191,22 @@ async def collect_frames(self):
await self.pipeline.cleanup()

async def recv(self):
return await self.pipeline.get_processed_audio_frame()
"""Receive a processed audio frame from the pipeline and return it to the
client, while collecting statistics about the stream.
"""
if self.stats.startup_time is None:
self.stats.start_timestamp = time.monotonic()
self.stats.startup_time = self.stats.start_timestamp - self._start_time
self.stats.pipeline.audio_warmup_time = (
self.pipeline.stats.audio_warmup_time
)

processed_frame = await self.pipeline.get_processed_audio_frame()

# Increment the frame count to calculate FPS.
await self.stats.fps_meter.increment_frame_count()

return processed_frame


def force_codec(pc, sender, forced_codec):
Expand Down Expand Up @@ -276,8 +325,8 @@ def on_track(track):
sender = pc.addTrack(videoTrack)

# Store video track in app for stats.
stream_id = track.id
request.app["video_tracks"][stream_id] = videoTrack
track_id = track.id
request.app["video_tracks"][track_id] = videoTrack

codec = "video/H264"
force_codec(pc, sender, codec)
Expand All @@ -286,10 +335,15 @@ def on_track(track):
tracks["audio"] = audioTrack
pc.addTrack(audioTrack)

# Store audio track in app for stats.
track_id = track.id
request.app["audio_tracks"][track_id] = audioTrack

@track.on("ended")
async def on_ended():
logger.info(f"{track.kind} track ended")
request.app["video_tracks"].pop(track.id, None)
request.app["audio_tracks"].pop(track.id, None)

@pc.on("connectionstatechange")
async def on_connectionstatechange():
Expand Down Expand Up @@ -318,15 +372,17 @@ async def on_connectionstatechange():
),
)


async def cancel_collect_frames(track):
track.running = False
if hasattr(track, 'collect_task') is not None and not track.collect_task.done():
if hasattr(track, "collect_task") is not None and not track.collect_task.done():
try:
track.collect_task.cancel()
await track.collect_task
except (asyncio.CancelledError):
except asyncio.CancelledError:
pass


async def set_prompt(request):
pipeline = request.app["pipeline"]

Expand All @@ -345,10 +401,14 @@ async def on_startup(app: web.Application):
patch_loop_datagram(app["media_ports"])

app["pipeline"] = Pipeline(
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none'
cwd=app["workspace"],
disable_cuda_malloc=True,
gpu_only=True,
preview_method="none",
)
app["pcs"] = set()
app["video_tracks"] = {}
app["audio_tracks"] = {}


async def on_shutdown(app: web.Application):
Expand Down Expand Up @@ -381,10 +441,16 @@ async def on_shutdown(app: web.Application):
help="Start a Prometheus metrics endpoint for monitoring.",
)
parser.add_argument(
"--stream-id-label",
"--track-id-label",
default=False,
action="store_true",
help="Include stream ID as a label in Prometheus metrics.",
help="Include track ID in Prometheus metrics.",
)
parser.add_argument(
"--track-kind-label",
default=False,
action="store_true",
help="Include track kind in Prometheus metrics.",
)
args = parser.parse_args()

Expand All @@ -409,16 +475,17 @@ async def on_shutdown(app: web.Application):
app.router.add_post("/prompt", set_prompt)

# Add routes for getting stream statistics.
# TODO: Tracks are currently treated as streams (track_id = stream_id).
stream_stats_manager = StreamStatsManager(app)
app.router.add_get("/streams/stats", stream_stats_manager.collect_all_stream_stats)
app.router.add_get(
"/streams/stats", stream_stats_manager.collect_all_stream_metrics
)
app.router.add_get(
"/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id
"/stream/{track_id}/stats", stream_stats_manager.collect_stream_stats_by_id
)

# Add Prometheus metrics endpoint.
app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label)
app["metrics_manager"] = MetricsManager(
include_track_id=args.track_id_label, include_track_kind=args.track_kind_label
)
if args.monitor:
app["metrics_manager"].enable()
logger.info(
Expand Down
4 changes: 3 additions & 1 deletion server/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .pipeline_stats import PipelineStats
from .prometheus_metrics import MetricsManager
from .stream_stats import StreamStatsManager
from .stream_stats_manager import StreamStatsManager
from .track_stats import TrackStats
61 changes: 61 additions & 0 deletions server/metrics/pipeline_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Module for handling real-time media pipeline statistics."""

from typing import Optional, Dict, Any
from .prometheus_metrics import MetricsManager


class PipelineStats:
"""Tracks real-time statistics of the media pipeline.

Attributes:
metrics_manager: The Prometheus metrics manager instance.
track_id: The ID of the associated media track.
"""

def __init__(
self,
metrics_manager: Optional[MetricsManager] = None,
track_id: Optional[str] = None,
):
"""Initializes the PipelineStats class.

Args:
metrics_manager: The Prometheus metrics manager instance.
track_id: The ID of the associated media track.
"""
self.metrics_manager = metrics_manager
self.track_id = track_id

self._video_warmup_time = None
self._audio_warmup_time = None

@property
def video_warmup_time(self) -> float:
"""Time taken to warm up the video pipeline."""
return self._video_warmup_time

@video_warmup_time.setter
def video_warmup_time(self, value: float):
"""Sets the time taken to warm up the video pipeline."""
self._video_warmup_time = value
if self.metrics_manager:
self.metrics_manager.update_warmup_time(value, self.track_id)

@property
def audio_warmup_time(self) -> float:
"""Time taken to warm up the audio pipeline."""
return self._audio_warmup_time

@audio_warmup_time.setter
def audio_warmup_time(self, value: float):
"""Sets the time taken to warm up the audio pipeline."""
self._audio_warmup_time = value
if self.metrics_manager:
self.metrics_manager.update_warmup_time(value, self.track_id)

def to_dict(self) -> Dict[str, Any]:
"""Convert stats to a dictionary for easy JSON serialization."""
return {
"video_warmup_time": self._video_warmup_time,
"audio_warmup_time": self._audio_warmup_time,
}
Loading