Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG BASE_IMAGE=livepeer/comfyui-base:latest
ARG BASE_IMAGE=livepeer/comfyui-base:pr-499

FROM ${BASE_IMAGE}

Expand Down
43 changes: 43 additions & 0 deletions server/byoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,49 @@ async def register_orchestrator_startup(app):
# Add registration to startup hooks
processor.server.app.on_startup.append(register_orchestrator_startup)

# Add warmup endpoint: accepts same body as prompts update
async def warmup_handler(request):
try:
body = await request.json()
except Exception as e:
logger.error(f"Invalid JSON in warmup request: {e}")
return web.json_response({"error": "Invalid JSON"}, status=400)
try:
# Inject sentinel to trigger warmup inside update_params on the model thread
if isinstance(body, dict):
body["warmup"] = True
else:
body = {"warmup": True}
# Fire-and-forget: do not await warmup; update_params will schedule it
asyncio.get_running_loop().create_task(frame_processor.update_params(body))
return web.json_response({"status": "accepted"})
except Exception as e:
logger.error(f"Warmup failed: {e}")
return web.json_response({"error": str(e)}, status=500)

# Add pause endpoint
async def pause_handler(request):
try:
# Fire-and-forget: do not await pause
asyncio.get_running_loop().create_task(frame_processor.pause_prompts())
return web.json_response({"status": "paused"})
except Exception as e:
logger.error(f"Pause failed: {e}")
return web.json_response({"error": str(e)}, status=500)

# Add resume endpoint
async def resume_handler(request):
try:
# Fire-and-forget: do not await resume
asyncio.get_running_loop().create_task(frame_processor.resume_prompts())
return web.json_response({"status": "resumed"})
except Exception as e:
logger.error(f"Resume failed: {e}")
return web.json_response({"error": str(e)}, status=500)

# Mount at same API namespace as StreamProcessor defaults
processor.server.add_route("POST", "/api/stream/warmup", warmup_handler)

# Run the processor
processor.run()

Expand Down
184 changes: 175 additions & 9 deletions server/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import json
import logging
import os
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from pytrickle.frame_processor import FrameProcessor
from pytrickle.frames import AudioFrame, VideoFrame
from pytrickle.stream_processor import VideoProcessingResult
Expand Down Expand Up @@ -40,6 +43,10 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params):
self._text_forward_task = None
self._background_tasks = []
self._stop_event = asyncio.Event()
self._generative_video_task = None
self._generative_audio_task = None
self._generative_video_pts = 0
self._generative_audio_pts = 0

async def _apply_stream_start_prompt(self, prompt_value: Any) -> bool:
if not self.pipeline:
Expand Down Expand Up @@ -164,6 +171,165 @@ async def _stop_text_forwarder(self) -> None:
logger.debug("Error while awaiting text forwarder cancellation", exc_info=True)
self._text_forward_task = None

async def _stop_generative_video_forwarder(self) -> None:
"""Stop the background generative video task if running."""
task = self._generative_video_task
if task and not task.done():
try:
task.cancel()
await task
except asyncio.CancelledError:
pass
except Exception:
logger.debug("Error while awaiting generative video cancellation", exc_info=True)
self._generative_video_task = None
try:
self._background_tasks.remove(task)
except (ValueError, TypeError):
pass

async def _stop_generative_audio_forwarder(self) -> None:
"""Stop the background generative audio task if running."""
task = self._generative_audio_task
if task and not task.done():
try:
task.cancel()
await task
except asyncio.CancelledError:
pass
except Exception:
logger.debug("Error while awaiting generative audio cancellation", exc_info=True)
self._generative_audio_task = None
try:
self._background_tasks.remove(task)
except (ValueError, TypeError):
pass

async def _stop_generative_forwarders(self) -> None:
"""Stop all generative forwarder tasks."""
await self._stop_generative_video_forwarder()
await self._stop_generative_audio_forwarder()

def _start_generative_video_forwarder(self) -> None:
"""Start the generative video forwarder if needed."""
if self._generative_video_task and not self._generative_video_task.done():
return
if not self.pipeline or not self._stream_processor:
return

async def _generative_video_loop():
logger.info("Starting generative video forwarder task")
fps = getattr(self.pipeline, "frame_rate", None) or 30
time_base = Fraction(1, int(fps))
pts = self._generative_video_pts
try:
while not self._stop_event.is_set():
try:
out_tensor = await self.pipeline.client.get_video_output()
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error(f"Failed to retrieve generative video output: {exc}")
await asyncio.sleep(0.1)
continue

if out_tensor is None:
await asyncio.sleep(0.01)
continue

processed_frame = self.pipeline.video_postprocess(out_tensor)
processed_frame.pts = pts
processed_frame.time_base = time_base

frame_np = processed_frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
tensor = torch.from_numpy(frame_np)
video_frame = VideoFrame.from_tensor(tensor, timestamp=pts)
video_frame.time_base = time_base

success = await self._stream_processor.send_input_frame(video_frame)
if not success:
await asyncio.sleep(0.05)
pts += 1
self._generative_video_pts = pts
except asyncio.CancelledError:
logger.debug("Generative video forwarder cancelled")
raise
except Exception as exc:
logger.error(f"Generative video forwarder encountered an error: {exc}")
finally:
logger.info("Generative video forwarder task exiting")

self._generative_video_task = asyncio.create_task(_generative_video_loop())
self._background_tasks.append(self._generative_video_task)

def _start_generative_audio_forwarder(self) -> None:
"""Start the generative audio forwarder if needed."""
if self._generative_audio_task and not self._generative_audio_task.done():
return
if not self.pipeline or not self._stream_processor:
return

async def _generative_audio_loop():
logger.info("Starting generative audio forwarder task")
sample_rate = 48000
time_base = Fraction(1, sample_rate)
pts = self._generative_audio_pts
try:
while not self._stop_event.is_set():
try:
out_audio = await self.pipeline.client.get_audio_output()
except asyncio.CancelledError:
raise
except Exception as exc:
logger.error(f"Failed to retrieve generative audio output: {exc}")
await asyncio.sleep(0.1)
continue

if out_audio is None:
await asyncio.sleep(0.01)
continue

processed_frame = self.audio_postprocess(out_audio)
processed_frame.pts = pts
processed_frame.time_base = time_base
processed_frame.sample_rate = sample_rate

audio_frame = AudioFrame.from_av_audio(processed_frame)
success = await self._stream_processor.send_input_frame(audio_frame)
if not success:
await asyncio.sleep(0.05)
pts += audio_frame.nb_samples
self._generative_audio_pts = pts
except asyncio.CancelledError:
logger.debug("Generative audio forwarder cancelled")
raise
except Exception as exc:
logger.error(f"Generative audio forwarder encountered an error: {exc}")
finally:
logger.info("Generative audio forwarder task exiting")

self._generative_audio_task = asyncio.create_task(_generative_audio_loop())
self._background_tasks.append(self._generative_audio_task)

async def _update_generative_forwarders(self) -> None:
"""Start or stop generative forwarders based on workflow capabilities."""
if not self.pipeline or not self._stream_processor:
return

capabilities = self.pipeline.get_workflow_io_capabilities()
video_only_output = capabilities["video"]["output"] and not capabilities["video"]["input"]
audio_only_output = capabilities["audio"]["output"] and not capabilities["audio"]["input"]

if video_only_output:
self._start_generative_video_forwarder()
else:
await self._stop_generative_video_forwarder()

if audio_only_output:
self._start_generative_audio_forwarder()
else:
await self._stop_generative_audio_forwarder()

async def on_stream_stop(self):
"""Called when stream stops - cleanup background tasks."""
logger.info("Stream stopped, cleaning up background tasks")
Expand All @@ -179,6 +345,9 @@ async def on_stream_stop(self):
except Exception as e:
logger.error(f"Error stopping ComfyStream client: {e}")

# Stop generative forwarders
await self._stop_generative_forwarders()

# Stop text forwarder
await self._stop_text_forwarder()

Expand Down Expand Up @@ -406,21 +575,18 @@ async def _process_prompts(self, prompts, *, skip_warmup: bool = False):
try:
converted = convert_prompt(prompts, return_dict=True)

await self.pipeline.apply_prompts(
[converted],
skip_warmup=skip_warmup,
)

if self.pipeline.state_manager.can_stream():
await self.pipeline.start_streaming()

logger.info(f"Prompts applied successfully: {list(prompts.keys())}")
# Set prompts in pipeline
await self.pipeline.set_prompts([converted])
await self.pipeline.resume_prompts()
logger.info(f"Prompts set successfully: {list(prompts.keys())}")

if self.pipeline.produces_text_output():
self._setup_text_monitoring()
else:
await self._stop_text_forwarder()

await self._update_generative_forwarders()

except Exception as e:
logger.error(f"Failed to process prompts: {e}")
raise
Loading
Loading