diff --git a/server/http_streaming/routes.py b/server/http_streaming/routes.py index ac309bae..d5715c1e 100644 --- a/server/http_streaming/routes.py +++ b/server/http_streaming/routes.py @@ -5,12 +5,326 @@ """ import asyncio import logging +import io +import json +import av from aiohttp import web from frame_buffer import FrameBuffer from .tokens import cleanup_expired_sessions, validate_token, create_stream_token logger = logging.getLogger(__name__) +async def process_segment(request): + """Process a video segment using PyAV and the ComfyUI pipeline + + Extracts frames from the uploaded video segment, processes them through + the pipeline, and returns a new video segment with the processed frames. + """ + import time + start_time = time.time() + + pipeline = request.app['pipeline'] + + try: + # Get the multipart data + reader = await request.multipart() + + segment_data = None + segment_index = None + timestamp = None + prompts = None + resolution = None + + # Process multipart fields + async for field in reader: + if field.name == 'segment': + segment_data = await field.read() + elif field.name == 'segmentIndex': + segment_index = int(await field.text()) + elif field.name == 'timestamp': + timestamp = int(await field.text()) + elif field.name == 'prompts': + prompts_text = await field.text() + try: + prompts = json.loads(prompts_text) + except json.JSONDecodeError: + logger.warning("Failed to parse prompts JSON") + elif field.name == 'resolution': + resolution_text = await field.text() + try: + resolution = json.loads(resolution_text) + except json.JSONDecodeError: + logger.warning("Failed to parse resolution JSON") + + if not segment_data: + return web.Response(status=400, text="No segment data provided") + + logger.info(f"Processing segment {segment_index} ({len(segment_data)} bytes)") + + # Update pipeline with prompts if provided + if prompts: + await pipeline.update_prompts(prompts) + + # Create input container from segment data + input_container = av.open(io.BytesIO(segment_data)) + + # Find the first video and audio streams + video_stream = None + audio_stream = None + + for stream in input_container.streams: + if stream.type == 'video' and video_stream is None: + video_stream = stream + elif stream.type == 'audio' and audio_stream is None: + audio_stream = stream + + if not video_stream: + return web.Response(status=400, text="No video stream found in segment") + + logger.info(f"Input video: {video_stream.width}x{video_stream.height}, " + f"codec: {video_stream.codec_context.name}, " + f"fps: {video_stream.average_rate}") + + # Determine output format based on input codec + input_video_codec = video_stream.codec_context.name + input_format = input_container.format.name + + # Map input codec to appropriate output codec + video_codec_map = { + 'h264': 'libx264', + 'h265': 'libx265', + 'hevc': 'libx265', + 'vp8': 'libvpx', + 'vp9': 'libvpx-vp9', + 'av1': 'libaom-av1' + } + + # Use input codec or fallback to vp9 for webm + output_video_codec = video_codec_map.get(input_video_codec, input_video_codec) + + # Verify codec is available, fallback to safe defaults + try: + # Test if codec is available by creating a test stream + test_container = av.open(io.BytesIO(), mode='w', format='null') + test_stream = test_container.add_stream(output_video_codec) + test_container.close() + except Exception as e: + logger.warning(f"Codec {output_video_codec} not available, falling back to libx264: {e}") + output_video_codec = 'libx264' + + # Determine output format - prefer input format if supported, otherwise webm + supported_formats = ['webm', 'mp4', 'mkv', 'avi'] + output_format = input_format if input_format in supported_formats else 'webm' + + # If using mp4, ensure codec compatibility + if output_format == 'mp4' and output_video_codec in ['libvpx', 'libvpx-vp9']: + output_video_codec = 'libx264' # VP8/VP9 not widely supported in MP4 + + logger.info(f"Using output format: {output_format}, video codec: {output_video_codec}") + + # Create output container in memory + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format=output_format) + + # Create output video stream matching input properties + output_video_stream = output_container.add_stream(output_video_codec, rate=video_stream.average_rate) + output_video_stream.width = video_stream.width + output_video_stream.height = video_stream.height + + # Use input pixel format if available, otherwise default + if hasattr(video_stream.codec_context, 'pix_fmt') and video_stream.codec_context.pix_fmt: + output_video_stream.pix_fmt = video_stream.codec_context.pix_fmt + else: + output_video_stream.pix_fmt = 'yuv420p' + + # Copy bitrate and other encoding parameters + if hasattr(video_stream.codec_context, 'bit_rate') and video_stream.codec_context.bit_rate: + output_video_stream.bit_rate = video_stream.codec_context.bit_rate + else: + output_video_stream.bit_rate = 2500000 # fallback bitrate + + # Create output audio stream if input has audio + output_audio_stream = None + if audio_stream: + input_audio_codec = audio_stream.codec_context.name + + # Map input audio codec to appropriate output codec + audio_codec_map = { + 'aac': 'aac', + 'mp3': 'libmp3lame', + 'opus': 'libopus', + 'vorbis': 'libvorbis', + 'flac': 'flac', + 'pcm_s16le': 'pcm_s16le' + } + + # Use input codec or fallback based on container format + output_audio_codec = audio_codec_map.get(input_audio_codec, input_audio_codec) + + # Container-specific codec adjustments + if output_format == 'webm' and output_audio_codec not in ['libopus', 'libvorbis']: + output_audio_codec = 'libopus' # WebM prefers Opus + elif output_format == 'mp4' and output_audio_codec not in ['aac', 'libmp3lame']: + output_audio_codec = 'aac' # MP4 prefers AAC + + # Verify audio codec is available + try: + test_container = av.open(io.BytesIO(), mode='w', format='null') + test_stream = test_container.add_stream(output_audio_codec) + test_container.close() + except Exception as e: + logger.warning(f"Audio codec {output_audio_codec} not available, falling back to aac: {e}") + output_audio_codec = 'aac' + + logger.info(f"Input audio: {audio_stream.rate}Hz, " + f"channels: {audio_stream.channels}, " + f"codec: {input_audio_codec} -> {output_audio_codec}") + + output_audio_stream = output_container.add_stream(output_audio_codec, rate=audio_stream.rate) + output_audio_stream.channels = audio_stream.channels + output_audio_stream.layout = audio_stream.layout + + # Copy audio encoding parameters + if hasattr(audio_stream.codec_context, 'bit_rate') and audio_stream.codec_context.bit_rate: + output_audio_stream.bit_rate = audio_stream.codec_context.bit_rate + else: + output_audio_stream.bit_rate = 128000 # fallback audio bitrate + + processed_frames = [] + audio_frames = [] + frame_count = 0 + + # Process video frames through pipeline and count them + logger.info("Extracting and processing video frames...") + for packet in input_container.demux(video_stream): + for frame in packet.decode(): + frame_count += 1 + # Put frame in pipeline for processing + await pipeline.put_video_frame(frame) + + # Process audio frames if present + if audio_stream: + logger.info("Extracting audio frames...") + for packet in input_container.demux(audio_stream): + for frame in packet.decode(): + audio_frames.append(frame) + # Optionally process audio through pipeline + # await pipeline.put_audio_frame(frame) + + input_container.close() + + # Collect processed video frames from pipeline + logger.info(f"Collecting {frame_count} processed frames from pipeline...") + + # Add timeout protection for pipeline processing + timeout_seconds = 30 # 30 second timeout + + for i in range(frame_count): + try: + # Use asyncio.wait_for to add timeout protection + processed_frame = await asyncio.wait_for( + pipeline.get_processed_video_frame(), + timeout=timeout_seconds + ) + processed_frames.append(processed_frame) + + if (i + 1) % 10 == 0: # Log every 10 frames + logger.info(f"Processed {i + 1}/{frame_count} frames") + + except asyncio.TimeoutError: + logger.error(f"Timeout waiting for processed frame {i}") + break + except Exception as e: + logger.error(f"Error getting processed frame {i}: {e}") + break + + logger.info(f"Collected {len(processed_frames)} processed frames") + + # If we didn't get any processed frames, return an error + if not processed_frames: + return web.Response( + status=500, + text="No frames were successfully processed" + ) + + # Encode processed frames to output container + logger.info("Encoding processed frames to output...") + for i, frame in enumerate(processed_frames): + try: + for packet in output_video_stream.encode(frame): + output_container.mux(packet) + except Exception as e: + logger.error(f"Error encoding frame {i}: {e}") + # Continue with remaining frames + + # Flush video encoder + for packet in output_video_stream.encode(): + output_container.mux(packet) + + # Encode audio frames if present + if output_audio_stream and audio_frames: + for frame in audio_frames: + for packet in output_audio_stream.encode(frame): + output_container.mux(packet) + + # Flush audio encoder + for packet in output_audio_stream.encode(): + output_container.mux(packet) + + # Finalize output + output_container.close() + + # Get the processed segment data + output_data = output_buffer.getvalue() + output_buffer.close() + + processing_time = time.time() - start_time + fps = len(processed_frames) / processing_time if processing_time > 0 else 0 + + logger.info(f"Processed segment {segment_index}: " + f"input {len(segment_data)} bytes -> output {len(output_data)} bytes, " + f"processed {len(processed_frames)}/{frame_count} frames, " + f"processing time: {processing_time:.2f}s ({fps:.1f} fps)") + + # Determine Content-Type based on output format + content_type_map = { + 'webm': 'video/webm', + 'mp4': 'video/mp4', + 'mkv': 'video/x-matroska', + 'avi': 'video/x-msvideo' + } + content_type = content_type_map.get(output_format, 'video/webm') + + # Get codec names for headers (handle potential undefined variables) + video_codec_name = output_video_codec if 'output_video_codec' in locals() else 'unknown' + audio_codec_name = output_audio_codec if 'output_audio_codec' in locals() and output_audio_stream else 'none' + + # Return the processed segment + return web.Response( + body=output_data, + headers={ + 'Content-Type': content_type, + 'Content-Length': str(len(output_data)), + 'X-Segment-Index': str(segment_index), + 'X-Timestamp': str(timestamp), + 'X-Processed-Frames': str(len(processed_frames)), + 'X-Total-Frames': str(frame_count), + 'X-Processing-Time': f"{processing_time:.2f}", + 'X-Processing-FPS': f"{fps:.1f}", + 'X-Output-Format': output_format, + 'X-Video-Codec': video_codec_name, + 'X-Audio-Codec': audio_codec_name + } + ) + + except Exception as e: + logger.error(f"Error processing segment: {e}", exc_info=True) + return web.Response(status=500, text=f"Error processing segment: {str(e)}") + +async def segment(request): + """Serve a single video segment (legacy endpoint)""" + return web.Response(status=501, text="Use /api/segments for segment processing") + async def stream_mjpeg(request): """Serve an MJPEG stream with token validation""" # Clean up expired sessions @@ -67,3 +381,6 @@ def setup_routes(app, cors): # Stream endpoint with token validation cors.add(app.router.add_get("/api/stream", stream_mjpeg)) + + # Segment processing endpoint + cors.add(app.router.add_post("/api/segments", process_segment)) diff --git a/ui/package-lock.json b/ui/package-lock.json index 7420ec99..12bd897d 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "ui", - "version": "0.1.1", + "version": "0.1.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "ui", - "version": "0.1.1", + "version": "0.1.2", "dependencies": { "@hookform/resolvers": "^3.9.1", "@radix-ui/react-dialog": "^1.1.6", diff --git a/ui/src/components/control-panel.tsx b/ui/src/components/control-panel.tsx index 3d5f458c..187ebe00 100644 --- a/ui/src/components/control-panel.tsx +++ b/ui/src/components/control-panel.tsx @@ -409,6 +409,7 @@ export const ControlPanel = ({ return (
+ {/* ...existing controls... */} + + seconds +
+ + + + )} + /> + + +
{/* TODO: Temporary fix to Warn if no camera or mic; improve later */} diff --git a/ui/src/context/segments-context.ts b/ui/src/context/segments-context.ts new file mode 100644 index 00000000..0a247fef --- /dev/null +++ b/ui/src/context/segments-context.ts @@ -0,0 +1,12 @@ +import { Segments } from "@/lib/segments"; +import * as React from "react"; + +export const SegmentsContext = React.createContext(undefined); + +export function useSegmentsContext() { + const ctx = React.useContext(SegmentsContext); + if (!ctx) { + throw new Error("tried to access segments context outside of Segments component"); + } + return ctx; +} diff --git a/ui/src/hooks/use-segments.ts b/ui/src/hooks/use-segments.ts new file mode 100644 index 00000000..04c1bd74 --- /dev/null +++ b/ui/src/hooks/use-segments.ts @@ -0,0 +1,254 @@ +import { useState, useEffect, useRef, useCallback } from "react"; +import { Prompt } from "@/types"; +import { Segments } from "@/lib/segments"; + +interface SegmentsProps { + url: string; + prompts: Prompt[] | null; + connect: boolean; + onConnected: () => void; + onDisconnected: () => void; + localStream: MediaStream | null; + segmentTime: number; + resolution?: { + width: number; + height: number; + }; +} + +const MAX_SEND_RETRIES = 3; +const SEND_RETRY_INTERVAL = 1000; + +export function useSegments(props: SegmentsProps): Segments { + const { + url, + prompts, + connect, + onConnected, + onDisconnected, + localStream, + segmentTime, + resolution, + } = props; + + const [isRecording, setIsRecording] = useState(false); + const [isConnected, setIsConnected] = useState(false); + const [lastSegmentTime, setLastSegmentTime] = useState(null); + const [error, setError] = useState(null); + + const mediaRecorderRef = useRef(null); + const segmentIntervalRef = useRef(null); + const chunksRef = useRef([]); + const segmentCountRef = useRef(0); + + const sendSegment = useCallback( + async (blob: Blob, segmentIndex: number, retry: number = 0): Promise => { + try { + const formData = new FormData(); + formData.append('segment', blob, `segment_${segmentIndex}.webm`); + formData.append('segmentIndex', segmentIndex.toString()); + formData.append('timestamp', Date.now().toString()); + + if (prompts) { + formData.append('prompts', JSON.stringify(prompts)); + } + + if (resolution) { + formData.append('resolution', JSON.stringify(resolution)); + } + + const response = await fetch(`${url}/api/segment`, { + method: 'POST', + body: formData, + }); + + if (!response.ok) { + throw new Error(`Segment send HTTP error: ${response.status}`); + } + + const result = await response.json(); + console.log(`[useSegments] Segment ${segmentIndex} sent successfully:`, result); + setLastSegmentTime(Date.now()); + setError(null); + } catch (error) { + console.error(`[useSegments] Error sending segment ${segmentIndex}:`, error); + + if (retry < MAX_SEND_RETRIES) { + console.log(`[useSegments] Retrying segment ${segmentIndex}, attempt ${retry + 1}`); + await new Promise((resolve) => setTimeout(resolve, SEND_RETRY_INTERVAL)); + return sendSegment(blob, segmentIndex, retry + 1); + } + + setError(`Failed to send segment after ${MAX_SEND_RETRIES} retries`); + throw error; + } + }, + [url, prompts, resolution], + ); + + const startRecording = useCallback(() => { + if (!localStream || isRecording) return; + + try { + // Get supported MIME type + const mimeTypes = [ + 'video/webm;codecs=vp9,opus', + 'video/webm;codecs=vp8,opus', + 'video/webm;codecs=h264,opus', + 'video/webm', + 'video/mp4', + ]; + + let mimeType = ''; + for (const type of mimeTypes) { + if (MediaRecorder.isTypeSupported(type)) { + mimeType = type; + break; + } + } + + const mediaRecorder = new MediaRecorder(localStream, { + mimeType: mimeType || undefined, + videoBitsPerSecond: 2500000, // 2.5 Mbps + audioBitsPerSecond: 128000, // 128 kbps + }); + + mediaRecorderRef.current = mediaRecorder; + chunksRef.current = []; + segmentCountRef.current = 0; + + mediaRecorder.ondataavailable = (event) => { + if (event.data.size > 0) { + chunksRef.current.push(event.data); + } + }; + + mediaRecorder.onstop = () => { + if (chunksRef.current.length > 0) { + const blob = new Blob(chunksRef.current, { type: mimeType }); + const segmentIndex = segmentCountRef.current++; + + // Send the segment + sendSegment(blob, segmentIndex).catch((error) => { + console.error(`[useSegments] Failed to send segment ${segmentIndex}:`, error); + }); + + chunksRef.current = []; + } + + // Auto-restart recording if we're still supposed to be recording + setTimeout(() => { + if (mediaRecorderRef.current && + mediaRecorderRef.current.state === 'inactive' && + isRecording && + localStream) { + try { + mediaRecorderRef.current.start(); + } catch (error) { + console.error('[useSegments] Error restarting recording:', error); + } + } + }, 100); + }; + + mediaRecorder.onerror = (event) => { + console.error('[useSegments] MediaRecorder error:', event); + setError('Recording error occurred'); + }; + + // Start recording + mediaRecorder.start(); + setIsRecording(true); + setIsConnected(true); + onConnected(); + + console.log(`[useSegments] Started recording with ${segmentTime}s segments`); + + // Set up interval to capture segments + segmentIntervalRef.current = setInterval(() => { + if (mediaRecorderRef.current && + mediaRecorderRef.current.state === 'recording' && + isRecording) { + // Stop current recording to get the segment + try { + mediaRecorderRef.current.stop(); + } catch (error) { + console.error('[useSegments] Error stopping recording for segment:', error); + } + } + }, segmentTime * 1000); + + } catch (error) { + console.error('[useSegments] Error starting recording:', error); + setError('Failed to start recording'); + } + }, [localStream, isRecording, segmentTime, onConnected, sendSegment]); + + const stopRecording = useCallback(() => { + if (segmentIntervalRef.current) { + clearInterval(segmentIntervalRef.current); + segmentIntervalRef.current = null; + } + + if (mediaRecorderRef.current && mediaRecorderRef.current.state !== 'inactive') { + mediaRecorderRef.current.stop(); + } + + setIsRecording(false); + setIsConnected(false); + onDisconnected(); + + console.log('[useSegments] Stopped recording'); + }, [onDisconnected]); + + // Main effect to handle connection state + useEffect(() => { + if (connect && localStream) { + startRecording(); + } else { + stopRecording(); + } + + // Cleanup on unmount + return () => { + stopRecording(); + }; + }, [connect, localStream, startRecording, stopRecording]); + + // Effect to handle segment time changes during recording + useEffect(() => { + if (isRecording && segmentIntervalRef.current) { + // Restart with new segment time + clearInterval(segmentIntervalRef.current); + + segmentIntervalRef.current = setInterval(() => { + if (mediaRecorderRef.current && + mediaRecorderRef.current.state === 'recording' && + isRecording) { + try { + mediaRecorderRef.current.stop(); + } catch (error) { + console.error('[useSegments] Error stopping recording for segment time change:', error); + } + } + }, segmentTime * 1000); + + console.log(`[useSegments] Updated segment time to ${segmentTime}s`); + } + }, [segmentTime, isRecording]); + + // Effect to handle resolution changes + useEffect(() => { + // Resolution changes will be included in the next segment automatically + if (resolution && isConnected) { + console.log('[useSegments] Resolution updated:', resolution); + } + }, [resolution, isConnected]); + + return { + isRecording, + isConnected, + lastSegmentTime, + error, + }; +} diff --git a/ui/src/lib/segments.ts b/ui/src/lib/segments.ts new file mode 100644 index 00000000..ae0bada6 --- /dev/null +++ b/ui/src/lib/segments.ts @@ -0,0 +1,6 @@ +export interface Segments { + isRecording: boolean; + isConnected: boolean; + lastSegmentTime: number | null; + error: string | null; +}