diff --git a/server/app.py b/server/app.py index e1e9150f..0036eae0 100644 --- a/server/app.py +++ b/server/app.py @@ -4,6 +4,7 @@ import logging import os import sys +import time import torch @@ -24,9 +25,8 @@ from aiortc.rtcrtpsender import RTCRtpSender from pipeline import Pipeline from twilio.rest import Client -from utils import patch_loop_datagram, add_prefix_to_app_routes, FPSMeter -from metrics import MetricsManager, StreamStatsManager -import time +from utils import patch_loop_datagram, add_prefix_to_app_routes +from metrics import MetricsManager, StreamStatsManager, TrackStats logger = logging.getLogger(__name__) logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) @@ -38,7 +38,7 @@ class VideoStreamTrack(MediaStreamTrack): - """video stream track that processes video frames using a pipeline. + """Video stream track that processes video frames using a pipeline. Attributes: kind (str): The kind of media, which is "video" for this class. @@ -54,17 +54,22 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): Args: track: The underlying media stream track. pipeline: The processing pipeline to apply to each video frame. + stats: The stream statistics. """ + self._start_time = time.monotonic() super().__init__() self.track = track self.pipeline = pipeline - self.fps_meter = FPSMeter( - metrics_manager=app["metrics_manager"], track_id=track.id + self.stats = TrackStats( + track_id=track.id, + track_kind="video", + metrics_manager=app.get("metrics_manager", None), ) - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends + self._running = True + + asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends. @track.on("ended") async def on_ended(): logger.info("Source video track ended, stopping collection") @@ -75,7 +80,7 @@ async def collect_frames(self): the processing pipeline. Stops when track ends or connection closes. """ try: - while self.running: + while self._running: try: frame = await self.track.recv() await self.pipeline.put_video_frame(frame) @@ -87,9 +92,9 @@ async def collect_frames(self): logger.info("Media stream ended") else: logger.error(f"Error collecting video frames: {str(e)}") - self.running = False + self._running = False break - + # Perform cleanup outside the exception handler logger.info("Video frame collection stopped") except asyncio.CancelledError: @@ -100,28 +105,57 @@ async def collect_frames(self): await self.pipeline.cleanup() async def recv(self): - """Receive a processed video frame from the pipeline, increment the frame - count for FPS calculation and return the processed frame to the client. + """Receive a processed video frame from the pipeline and return it to the + client, while collecting statistics about the stream. """ + if self.stats.startup_time is None: + self.stats.start_timestamp = time.monotonic() + self.stats.startup_time = self.stats.start_timestamp - self._start_time + self.stats.pipeline.video_warmup_time = ( + self.pipeline.stats.video_warmup_time + ) + processed_frame = await self.pipeline.get_processed_video_frame() # Increment the frame count to calculate FPS. - await self.fps_meter.increment_frame_count() + await self.stats.fps_meter.increment_frame_count() return processed_frame class AudioStreamTrack(MediaStreamTrack): + """Audio stream track that processes audio frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "audio" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each audio frame. + """ + kind = "audio" def __init__(self, track: MediaStreamTrack, pipeline): + """Initialize the AudioStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each audio frame. + stats: The stream statistics. + """ + self._start_time = time.monotonic() super().__init__() self.track = track self.pipeline = pipeline - self.running = True - self.collect_task = asyncio.create_task(self.collect_frames()) - - # Add cleanup when track ends + self.stats = TrackStats( + track_id=track.id, + track_kind="audio", + metrics_manager=app.get("metrics_manager", None), + ) + self._running = True + + asyncio.create_task(self.collect_frames()) + + # Add cleanup when track ends. @track.on("ended") async def on_ended(): logger.info("Source audio track ended, stopping collection") @@ -132,7 +166,7 @@ async def collect_frames(self): the processing pipeline. Stops when track ends or connection closes. """ try: - while self.running: + while self._running: try: frame = await self.track.recv() await self.pipeline.put_audio_frame(frame) @@ -144,9 +178,9 @@ async def collect_frames(self): logger.info("Media stream ended") else: logger.error(f"Error collecting audio frames: {str(e)}") - self.running = False + self._running = False break - + # Perform cleanup outside the exception handler logger.info("Audio frame collection stopped") except asyncio.CancelledError: @@ -157,7 +191,22 @@ async def collect_frames(self): await self.pipeline.cleanup() async def recv(self): - return await self.pipeline.get_processed_audio_frame() + """Receive a processed audio frame from the pipeline and return it to the + client, while collecting statistics about the stream. + """ + if self.stats.startup_time is None: + self.stats.start_timestamp = time.monotonic() + self.stats.startup_time = self.stats.start_timestamp - self._start_time + self.stats.pipeline.audio_warmup_time = ( + self.pipeline.stats.audio_warmup_time + ) + + processed_frame = await self.pipeline.get_processed_audio_frame() + + # Increment the frame count to calculate FPS. + await self.stats.fps_meter.increment_frame_count() + + return processed_frame def force_codec(pc, sender, forced_codec): @@ -276,8 +325,8 @@ def on_track(track): sender = pc.addTrack(videoTrack) # Store video track in app for stats. - stream_id = track.id - request.app["video_tracks"][stream_id] = videoTrack + track_id = track.id + request.app["video_tracks"][track_id] = videoTrack codec = "video/H264" force_codec(pc, sender, codec) @@ -286,10 +335,15 @@ def on_track(track): tracks["audio"] = audioTrack pc.addTrack(audioTrack) + # Store audio track in app for stats. + track_id = track.id + request.app["audio_tracks"][track_id] = audioTrack + @track.on("ended") async def on_ended(): logger.info(f"{track.kind} track ended") request.app["video_tracks"].pop(track.id, None) + request.app["audio_tracks"].pop(track.id, None) @pc.on("connectionstatechange") async def on_connectionstatechange(): @@ -318,15 +372,17 @@ async def on_connectionstatechange(): ), ) + async def cancel_collect_frames(track): track.running = False - if hasattr(track, 'collect_task') is not None and not track.collect_task.done(): + if hasattr(track, "collect_task") is not None and not track.collect_task.done(): try: track.collect_task.cancel() await track.collect_task - except (asyncio.CancelledError): + except asyncio.CancelledError: pass + async def set_prompt(request): pipeline = request.app["pipeline"] @@ -345,10 +401,14 @@ async def on_startup(app: web.Application): patch_loop_datagram(app["media_ports"]) app["pipeline"] = Pipeline( - cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True, preview_method='none' + cwd=app["workspace"], + disable_cuda_malloc=True, + gpu_only=True, + preview_method="none", ) app["pcs"] = set() app["video_tracks"] = {} + app["audio_tracks"] = {} async def on_shutdown(app: web.Application): @@ -381,10 +441,16 @@ async def on_shutdown(app: web.Application): help="Start a Prometheus metrics endpoint for monitoring.", ) parser.add_argument( - "--stream-id-label", + "--track-id-label", default=False, action="store_true", - help="Include stream ID as a label in Prometheus metrics.", + help="Include track ID in Prometheus metrics.", + ) + parser.add_argument( + "--track-kind-label", + default=False, + action="store_true", + help="Include track kind in Prometheus metrics.", ) args = parser.parse_args() @@ -409,16 +475,17 @@ async def on_shutdown(app: web.Application): app.router.add_post("/prompt", set_prompt) # Add routes for getting stream statistics. + # TODO: Tracks are currently treated as streams (track_id = stream_id). stream_stats_manager = StreamStatsManager(app) + app.router.add_get("/streams/stats", stream_stats_manager.collect_all_stream_stats) app.router.add_get( - "/streams/stats", stream_stats_manager.collect_all_stream_metrics - ) - app.router.add_get( - "/stream/{stream_id}/stats", stream_stats_manager.collect_stream_metrics_by_id + "/stream/{track_id}/stats", stream_stats_manager.collect_stream_stats_by_id ) # Add Prometheus metrics endpoint. - app["metrics_manager"] = MetricsManager(include_stream_id=args.stream_id_label) + app["metrics_manager"] = MetricsManager( + include_track_id=args.track_id_label, include_track_kind=args.track_kind_label + ) if args.monitor: app["metrics_manager"].enable() logger.info( diff --git a/server/metrics/__init__.py b/server/metrics/__init__.py index 5fb1a2ba..2b432d06 100644 --- a/server/metrics/__init__.py +++ b/server/metrics/__init__.py @@ -1,2 +1,4 @@ +from .pipeline_stats import PipelineStats from .prometheus_metrics import MetricsManager -from .stream_stats import StreamStatsManager +from .stream_stats_manager import StreamStatsManager +from .track_stats import TrackStats diff --git a/server/metrics/pipeline_stats.py b/server/metrics/pipeline_stats.py new file mode 100644 index 00000000..bc44a70e --- /dev/null +++ b/server/metrics/pipeline_stats.py @@ -0,0 +1,61 @@ +"""Module for handling real-time media pipeline statistics.""" + +from typing import Optional, Dict, Any +from .prometheus_metrics import MetricsManager + + +class PipelineStats: + """Tracks real-time statistics of the media pipeline. + + Attributes: + metrics_manager: The Prometheus metrics manager instance. + track_id: The ID of the associated media track. + """ + + def __init__( + self, + metrics_manager: Optional[MetricsManager] = None, + track_id: Optional[str] = None, + ): + """Initializes the PipelineStats class. + + Args: + metrics_manager: The Prometheus metrics manager instance. + track_id: The ID of the associated media track. + """ + self.metrics_manager = metrics_manager + self.track_id = track_id + + self._video_warmup_time = None + self._audio_warmup_time = None + + @property + def video_warmup_time(self) -> float: + """Time taken to warm up the video pipeline.""" + return self._video_warmup_time + + @video_warmup_time.setter + def video_warmup_time(self, value: float): + """Sets the time taken to warm up the video pipeline.""" + self._video_warmup_time = value + if self.metrics_manager: + self.metrics_manager.update_warmup_time(value, self.track_id) + + @property + def audio_warmup_time(self) -> float: + """Time taken to warm up the audio pipeline.""" + return self._audio_warmup_time + + @audio_warmup_time.setter + def audio_warmup_time(self, value: float): + """Sets the time taken to warm up the audio pipeline.""" + self._audio_warmup_time = value + if self.metrics_manager: + self.metrics_manager.update_warmup_time(value, self.track_id) + + def to_dict(self) -> Dict[str, Any]: + """Convert stats to a dictionary for easy JSON serialization.""" + return { + "video_warmup_time": self._video_warmup_time, + "audio_warmup_time": self._audio_warmup_time, + } diff --git a/server/metrics/prometheus_metrics.py b/server/metrics/prometheus_metrics.py index 080bc294..6cb2e627 100644 --- a/server/metrics/prometheus_metrics.py +++ b/server/metrics/prometheus_metrics.py @@ -1,4 +1,4 @@ -"""Prometheus metrics utilities.""" +"""Module for handling Prometheus metrics for media tracks.""" from prometheus_client import Gauge, generate_latest from aiohttp import web @@ -6,38 +6,136 @@ class MetricsManager: - """Manages Prometheus metrics collection.""" + """Manages Prometheus metrics collection for media tracks.""" - def __init__(self, include_stream_id: bool = False): + def __init__( + self, include_track_id: bool = False, include_track_kind: bool = False + ): """Initializes the MetricsManager class. Args: - include_stream_id: Whether to include the stream ID as a label in the metrics. + include_track_id: Whether to include the track ID as a label. + include_track_kind: Whether to include the track kind as a label. """ self._enabled = False - self._include_stream_id = include_stream_id + self._include_track_id = include_track_id + self._include_track_kind = include_track_kind - base_labels = ["stream_id"] if include_stream_id else [] - self._fps_gauge = Gauge( - "stream_fps", "Frames per second of the stream", base_labels - ) + base_labels = [] + if include_track_id: + base_labels.append("track_id") + if include_track_kind: + base_labels.append("track_kind") + + self._gauges = { + "fps": Gauge( + "stream_fps", + "Frames per second for the stream. Defaults to all tracks; specific " + "tracks when labels are applied.", + base_labels, + ), + "startup_time": Gauge( + "stream_startup_time", + "Startup time for the stream. Defaults to all tracks; specific tracks " + "when labels are applied.", + base_labels, + ), + "warmup_time": Gauge( + "stream_warmup_time", + "Warmup time for the stream pipeline. Defaults to all tracks; specific " + "tracks when labels are applied.", + base_labels, + ), + } + + def _set_gauge( + self, + gauge: Gauge, + value: float, + track_id: Optional[str] = None, + track_kind: Optional[str] = None, + ): + """Set the value of a gauge metric with dynamic labels. + + Args: + gauge: The Prometheus gauge to update. + value: The value to set for the gauge. + track_id: The ID of the track. + """ + if not self._enabled: + return + + labels = {} + if self._include_track_id and track_id: + labels["track_id"] = track_id + if self._include_track_kind and track_kind: + labels["track_kind"] = track_kind + + if labels: + gauge.labels(**labels).set(value) + else: + gauge.set(value) def enable(self): """Enable Prometheus metrics collection.""" self._enabled = True - def update_fps_metrics(self, fps: float, stream_id: Optional[str] = None): - """Update Prometheus metrics for a given stream. + def update_fps( + self, + fps: float, + track_id: Optional[str] = None, + track_kind: Optional[str] = None, + ): + """Update FPS metrics for a given track. Args: fps: The current frames per second. - stream_id: The ID of the stream. + track_id: The ID of the track. + track_kind: The kind of the track (e.g., "video", "audio"). """ - if self._enabled: - if self._include_stream_id: - self._fps_gauge.labels(stream_id=stream_id or "").set(fps) - else: - self._fps_gauge.set(fps) + self._set_gauge( + self._gauges["fps"], fps, track_id=track_id, track_kind=track_kind + ) + + def update_startup_time( + self, + startup_time: float, + track_id: Optional[str] = None, + track_kind: Optional[str] = None, + ): + """Update startup time metrics for a given track. + + Args: + startup_time: The time taken to start the track. + track_id: The ID of the track. + track_kind: The kind of the track (e.g., "video", "audio"). + """ + self._set_gauge( + self._gauges["startup_time"], + startup_time, + track_id=track_id, + track_kind=track_kind, + ) + + def update_warmup_time( + self, + warmup_time: float, + track_id: Optional[str] = None, + track_kind: Optional[str] = None, + ): + """Update warmup time metrics for a given track. + + Args: + warmup_time: The time taken to warm up the track. + track_id: The ID of the track. + track_kind: The kind of the track (e.g., "video", "audio"). + """ + self._set_gauge( + self._gauges["warmup_time"], + warmup_time, + track_id=track_id, + track_kind=track_kind, + ) async def metrics_handler(self, _): """Handle Prometheus metrics endpoint.""" diff --git a/server/metrics/stream_stats.py b/server/metrics/stream_stats.py deleted file mode 100644 index 8dc2ab19..00000000 --- a/server/metrics/stream_stats.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Handles real-time video stream statistics (non-Prometheus, JSON API).""" - -from typing import Any, Dict -import json -from aiohttp import web -from aiortc import MediaStreamTrack - - -class StreamStatsManager: - """Handles real-time video stream statistics collection.""" - - def __init__(self, app: web.Application): - """Initializes the StreamMetrics class. - - Args: - app: The web application instance storing stream tracks. - """ - self._app = app - - async def collect_video_metrics( - self, video_track: MediaStreamTrack - ) -> Dict[str, Any]: - """Collects real-time statistics for a video track. - - Args: - video_track: The video stream track instance. - - Returns: - A dictionary containing FPS-related statistics. - """ - return { - "timestamp": await video_track.fps_meter.last_fps_calculation_time, - "fps": await video_track.fps_meter.fps, - "minute_avg_fps": await video_track.fps_meter.average_fps, - "minute_fps_array": await video_track.fps_meter.fps_measurements, - } - - async def collect_all_stream_metrics(self, _) -> web.Response: - """Retrieves real-time metrics for all active video streams. - - Returns: - A JSON response containing FPS statistics for all streams. - """ - video_tracks = self._app.get("video_tracks", {}) - all_stats = { - stream_id: await self.collect_video_metrics(track) - for stream_id, track in video_tracks.items() - } - - return web.Response( - content_type="application/json", - text=json.dumps(all_stats), - ) - - async def collect_stream_metrics_by_id(self, request: web.Request) -> web.Response: - """Retrieves real-time metrics for a specific video stream by ID. - - Args: - request: The HTTP request containing the stream ID. - - Returns: - A JSON response with stream metrics or an error message. - """ - stream_id = request.match_info.get("stream_id") - video_tracks = self._app.get("video_tracks", {}) - video_track = video_tracks.get(stream_id) - - if video_track: - stats = await self.collect_video_metrics(video_track) - else: - stats = {"error": "Stream not found"} - - return web.Response( - content_type="application/json", - text=json.dumps(stats), - ) diff --git a/server/metrics/stream_stats_manager.py b/server/metrics/stream_stats_manager.py new file mode 100644 index 00000000..c10ae4e7 --- /dev/null +++ b/server/metrics/stream_stats_manager.py @@ -0,0 +1,109 @@ +"""Module for handling real-time media track statistics for JSON API publishing.""" + +from typing import Any, Dict +import json +from aiohttp import web +from aiortc import MediaStreamTrack + + +class StreamStatsManager: + """Handles real-time statistics collection for media tracks. + + Note: + This class currently uses `track_id` to identify individual media tracks + (e.g., video or audio) instead of `stream_id`. In the future, this may + be extended to support stream-level statistics where a `stream_id` can + represent multiple tracks (e.g., video and audio tracks for the same stream). + """ + + def __init__(self, app: web.Application): + """Initializes the StreamStatsManager class. + + Args: + app: The web application instance storing media tracks. + """ + self._app = app + + async def collect_video_stats( + self, video_track: MediaStreamTrack + ) -> Dict[str, Any]: + """Collects real-time statistics for a video track. + + Args: + video_track: The video track instance. + + Returns: + A dictionary containing FPS-related statistics for the video track. + """ + return await video_track.stats.to_dict() + + async def collect_audio_stats( + self, audio_track: MediaStreamTrack + ) -> Dict[str, Any]: + """Collects real-time statistics for an audio track. + + Args: + audio_track: The audio track instance. + + Returns: + A dictionary containing FPS-related statistics for the audio track. + """ + return await audio_track.stats.to_dict() + + async def collect_all_stream_stats(self, _) -> web.Response: + """Retrieves real-time statistics for all active video and audio tracks. + + Returns: + A JSON response containing statistics for all tracks. + """ + tracks = { + **self._app.get("video_tracks", {}), + **self._app.get("audio_tracks", {}), + } + all_stats = { + track_id: await ( + self.collect_video_stats(track) + if track.kind == "video" + else self.collect_audio_stats(track) + ) + for track_id, track in tracks.items() + } + + return web.Response( + content_type="application/json", + text=json.dumps(all_stats), + ) + + async def collect_stream_stats_by_id(self, request: web.Request) -> web.Response: + """Retrieves real-time statistics for a specific video or audio track by ID. + + Args: + request: The HTTP request containing the track ID. + + Returns: + A JSON response with track statistics or an error message. + """ + track_id = request.match_info.get("track_id") + tracks = { + **self._app.get("video_tracks", {}), + **self._app.get("audio_tracks", {}), + } + track = tracks.get(track_id) + + if not track: + error_response = {"error": "Track not found"} + return web.Response( + status=404, + content_type="application/json", + text=json.dumps(error_response), + ) + + stats = await ( + self.collect_video_stats(track) + if track.kind == "video" + else self.collect_audio_stats(track) + ) + return web.Response( + content_type="application/json", + text=json.dumps(stats), + ) diff --git a/server/metrics/track_stats.py b/server/metrics/track_stats.py new file mode 100644 index 00000000..e8c3535b --- /dev/null +++ b/server/metrics/track_stats.py @@ -0,0 +1,123 @@ +"""Module for tracking real-time statistics of individual media tracks.""" + +from typing import Any, Dict, Optional +from utils.fps_meter import FPSMeter +from .prometheus_metrics import MetricsManager +from .pipeline_stats import PipelineStats +import time + + +class TrackStats: + """Tracks real-time statistics for an individual media track. + + Attributes: + fps_meter: The FPSMeter instance for tracking frame rate. + start_timestamp: The timestamp when the track started. + pipeline: The PipelineStats instance for tracking pipeline-related metrics. + """ + + def __init__( + self, + track_id: str, + track_kind: str, + metrics_manager: Optional[MetricsManager] = None, + ): + """Initializes the TrackStats class. + + Args: + track_id: The unique identifier for the media track. + track_kind: The kind of the track (e.g., "video" or "audio"). + metrics_manager: An optional Prometheus metrics manager instance for + updating metrics related to the track. + """ + update_metrics_callback = ( + metrics_manager.update_fps if metrics_manager else None + ) + self.fps_meter = FPSMeter( + track_id=track_id, + track_kind=track_kind, + update_metrics_callback=update_metrics_callback, + ) + self.pipeline = PipelineStats( + metrics_manager=metrics_manager, track_id=track_id + ) + + self.start_timestamp = None + + self._track_id = track_id + self._track_kind = track_kind + self._metrics_manager = metrics_manager + self._startup_time = None + + @property + def startup_time(self) -> float: + """Time taken to start the track.""" + return self._startup_time + + @startup_time.setter + def startup_time(self, value: float): + """Sets the time taken to start the track. + + Updates the Prometheus metrics if a metrics manager is available. + """ + if self._metrics_manager: + self._metrics_manager.update_startup_time( + value, self._track_id, self._track_kind + ) + self._startup_time = value + + async def get_fps(self) -> float: + """Current frames per second (FPS) of the track. + + Alias for the FPSMeter's `get_fps` method. + """ + return await self.fps_meter.get_fps() + + async def get_fps_measurements(self) -> list: + """List of FPS measurements over time for the track. + + Alias for the FPSMeter's `get_fps_measurements` method. + """ + return await self.fps_meter.get_fps_measurements() + + async def get_average_fps(self) -> float: + """Average FPS over the last minute for the track. + + Alias for the FPSMeter's `get_average_fps` method. + """ + return await self.fps_meter.get_average_fps() + + async def get_last_fps_calculation_time(self) -> float: + """Timestamp of the last FPS calculation for the track. + + Alias for the FPSMeter's `get_last_fps_calculation_time` method. + """ + return await self.fps_meter.get_last_fps_calculation_time() + + @property + def time(self) -> float: + """Elapsed time since the track started.""" + return ( + 0.0 + if self.start_timestamp is None + else time.monotonic() - self.start_timestamp + ) + + async def to_dict(self) -> Dict[str, Any]: + """Converts track statistics to a dictionary for JSON serialization. + + Returns: + A dictionary containing satistics about the media track. + """ + pipeline_stats = { + "warmup": getattr(self.pipeline, f"{self._track_kind}_warmup_time", None) + } + return { + "type": self._track_kind, + "timestamp": self.time, + "startup_time": self.startup_time, + "pipeline": pipeline_stats, + "fps": await self.get_fps(), + "minute_avg_fps": await self.get_average_fps(), + "minute_fps_array": await self.get_fps_measurements(), + } diff --git a/server/pipeline.py b/server/pipeline.py index 26270923..55295f5b 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -2,7 +2,9 @@ import torch import numpy as np import asyncio +import time +from metrics import PipelineStats from typing import Any, Dict, Union, List from comfystream.client import ComfyStreamClient @@ -11,13 +13,27 @@ class Pipeline: def __init__(self, **kwargs): + """Initialize the pipeline with the given configuration. + + Attributes: + client: The client to communicate with the ComfyStream server. + video_incoming_frames: Queue to store incoming video frames. + audio_incoming_frames: Queue to store incoming audio frames. + processed_audio_buffer: Buffer to store processed audio frames. + stats: Statistics about the audio and video streams processed by the + pipeline. + """ self.client = ComfyStreamClient(**kwargs) self.video_incoming_frames = asyncio.Queue() self.audio_incoming_frames = asyncio.Queue() self.processed_audio_buffer = np.array([], dtype=np.int16) + self.stats = PipelineStats() + async def warm_video(self): + start_time = time.monotonic() + dummy_frame = av.VideoFrame() dummy_frame.side_data.input = torch.randn(1, 512, 512, 3) @@ -25,7 +41,11 @@ async def warm_video(self): self.client.put_video_input(dummy_frame) await self.client.get_video_output() + self.stats.video_warmup_time = time.monotonic() - start_time + async def warm_audio(self): + start_time = time.monotonic() + dummy_frame = av.AudioFrame() dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? dummy_frame.sample_rate = 48000 @@ -34,6 +54,8 @@ async def warm_audio(self): self.client.put_audio_input(dummy_frame) await self.client.get_audio_output() + self.stats.audio_warmup_time = time.monotonic() - start_time + async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): if isinstance(prompts, list): await self.client.set_prompts(prompts) @@ -61,10 +83,10 @@ async def put_audio_frame(self, frame: av.AudioFrame): def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 return torch.from_numpy(frame_np).unsqueeze(0) - + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) - + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: return av.VideoFrame.from_ndarray( (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() @@ -72,7 +94,7 @@ def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Video def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) - + async def get_processed_video_frame(self): # TODO: make it generic to support purely generative video cases out_tensor = await self.client.get_video_output() @@ -83,7 +105,7 @@ async def get_processed_video_frame(self): processed_frame = self.video_postprocess(out_tensor) processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base - + return processed_frame async def get_processed_audio_frame(self): @@ -99,13 +121,13 @@ async def get_processed_audio_frame(self): processed_frame.pts = frame.pts processed_frame.time_base = frame.time_base processed_frame.sample_rate = frame.sample_rate - + return processed_frame - + async def get_nodes_info(self) -> Dict[str, Any]: """Get information about all nodes in the current prompt including metadata.""" nodes_info = await self.client.get_available_nodes() return nodes_info - + async def cleanup(self): - await self.client.cleanup() \ No newline at end of file + await self.client.cleanup() diff --git a/server/utils/fps_meter.py b/server/utils/fps_meter.py index ce94317b..0e82d51a 100644 --- a/server/utils/fps_meter.py +++ b/server/utils/fps_meter.py @@ -4,16 +4,34 @@ import logging import time from collections import deque -from metrics import MetricsManager +from typing import Callable, Optional logger = logging.getLogger(__name__) class FPSMeter: - """Class to calculate and store the framerate of a stream by counting frames.""" - - def __init__(self, metrics_manager: MetricsManager, track_id: str): - """Initializes the FPSMeter class.""" + """Class to calculate and store the framerate of a stream by counting frames. + + Attributes: + track_id: The ID of the track. + """ + + def __init__( + self, + track_id: str, + track_kind: str, + update_metrics_callback: Optional[Callable[[float, str], None]] = None, + ): + """Initializes the FPSMeter class. + + Args: + track_id: The ID of the track. + track_kind: The kind of the track (e.g., "video" or "audio"). + update_metrics_callback: An optional callback function to update Prometheus + metrics with FPS data. + """ + self.track_id = track_id + self.track_kind = track_kind self._lock = asyncio.Lock() self._fps_interval_frame_count = 0 self._last_fps_calculation_time = None @@ -21,8 +39,7 @@ def __init__(self, metrics_manager: MetricsManager, track_id: str): self._fps = 0.0 self._fps_measurements = deque(maxlen=60) self._running_event = asyncio.Event() - self._metrics_manager = metrics_manager - self.track_id = track_id + self._update_metrics_callback = update_metrics_callback asyncio.create_task(self._calculate_fps_loop()) @@ -51,8 +68,9 @@ async def _calculate_fps_loop(self): self._last_fps_calculation_time = current_time self._fps_interval_frame_count = 0 - # Update Prometheus metrics if enabled. - self._metrics_manager.update_fps_metrics(self._fps, self.track_id) + # Update Prometheus metrics using the callback if provided. + if self._update_metrics_callback: + self._update_metrics_callback(self._fps, self.track_id, self.track_kind) await asyncio.sleep(1) # Calculate FPS every second. @@ -63,8 +81,7 @@ async def increment_frame_count(self): if not self._running_event.is_set(): self._running_event.set() - @property - async def fps(self) -> float: + async def get_fps(self) -> float: """Get the current output frames per second (FPS). Returns: @@ -73,8 +90,7 @@ async def fps(self) -> float: async with self._lock: return self._fps - @property - async def fps_measurements(self) -> list: + async def get_fps_measurements(self) -> list: """Get the array of FPS measurements for the last minute. Returns: @@ -83,8 +99,7 @@ async def fps_measurements(self) -> list: async with self._lock: return list(self._fps_measurements) - @property - async def average_fps(self) -> float: + async def get_average_fps(self) -> float: """Calculate the average FPS from the measurements taken in the last minute. Returns: @@ -98,8 +113,7 @@ async def average_fps(self) -> float: else self._fps ) - @property - async def last_fps_calculation_time(self) -> float: + async def get_last_fps_calculation_time(self) -> float: """Get the elapsed time since the last FPS calculation. Returns: