diff --git a/docker/Dockerfile b/docker/Dockerfile index 8eb54e61..3a5915a6 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=livepeer/comfyui-base:latest +ARG BASE_IMAGE=livepeer/comfyui-base:pr-499 FROM ${BASE_IMAGE} diff --git a/server/byoc.py b/server/byoc.py index 3f8f3470..ef8f7eb9 100644 --- a/server/byoc.py +++ b/server/byoc.py @@ -187,6 +187,49 @@ async def register_orchestrator_startup(app): # Add registration to startup hooks processor.server.app.on_startup.append(register_orchestrator_startup) + # Add warmup endpoint: accepts same body as prompts update + async def warmup_handler(request): + try: + body = await request.json() + except Exception as e: + logger.error(f"Invalid JSON in warmup request: {e}") + return web.json_response({"error": "Invalid JSON"}, status=400) + try: + # Inject sentinel to trigger warmup inside update_params on the model thread + if isinstance(body, dict): + body["warmup"] = True + else: + body = {"warmup": True} + # Fire-and-forget: do not await warmup; update_params will schedule it + asyncio.get_running_loop().create_task(frame_processor.update_params(body)) + return web.json_response({"status": "accepted"}) + except Exception as e: + logger.error(f"Warmup failed: {e}") + return web.json_response({"error": str(e)}, status=500) + + # Add pause endpoint + async def pause_handler(request): + try: + # Fire-and-forget: do not await pause + asyncio.get_running_loop().create_task(frame_processor.pause_prompts()) + return web.json_response({"status": "paused"}) + except Exception as e: + logger.error(f"Pause failed: {e}") + return web.json_response({"error": str(e)}, status=500) + + # Add resume endpoint + async def resume_handler(request): + try: + # Fire-and-forget: do not await resume + asyncio.get_running_loop().create_task(frame_processor.resume_prompts()) + return web.json_response({"status": "resumed"}) + except Exception as e: + logger.error(f"Resume failed: {e}") + return web.json_response({"error": str(e)}, status=500) + + # Mount at same API namespace as StreamProcessor defaults + processor.server.add_route("POST", "/api/stream/warmup", warmup_handler) + # Run the processor processor.run() diff --git a/server/frame_processor.py b/server/frame_processor.py index af927935..1f03c0d4 100644 --- a/server/frame_processor.py +++ b/server/frame_processor.py @@ -2,8 +2,11 @@ import json import logging import os +from fractions import Fraction from typing import Any, Dict, List, Optional, Union +import numpy as np +import torch from pytrickle.frame_processor import FrameProcessor from pytrickle.frames import AudioFrame, VideoFrame from pytrickle.stream_processor import VideoProcessingResult @@ -40,6 +43,10 @@ def __init__(self, text_poll_interval: float = 0.25, **load_params): self._text_forward_task = None self._background_tasks = [] self._stop_event = asyncio.Event() + self._generative_video_task = None + self._generative_audio_task = None + self._generative_video_pts = 0 + self._generative_audio_pts = 0 async def _apply_stream_start_prompt(self, prompt_value: Any) -> bool: if not self.pipeline: @@ -164,6 +171,165 @@ async def _stop_text_forwarder(self) -> None: logger.debug("Error while awaiting text forwarder cancellation", exc_info=True) self._text_forward_task = None + async def _stop_generative_video_forwarder(self) -> None: + """Stop the background generative video task if running.""" + task = self._generative_video_task + if task and not task.done(): + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + except Exception: + logger.debug("Error while awaiting generative video cancellation", exc_info=True) + self._generative_video_task = None + try: + self._background_tasks.remove(task) + except (ValueError, TypeError): + pass + + async def _stop_generative_audio_forwarder(self) -> None: + """Stop the background generative audio task if running.""" + task = self._generative_audio_task + if task and not task.done(): + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + except Exception: + logger.debug("Error while awaiting generative audio cancellation", exc_info=True) + self._generative_audio_task = None + try: + self._background_tasks.remove(task) + except (ValueError, TypeError): + pass + + async def _stop_generative_forwarders(self) -> None: + """Stop all generative forwarder tasks.""" + await self._stop_generative_video_forwarder() + await self._stop_generative_audio_forwarder() + + def _start_generative_video_forwarder(self) -> None: + """Start the generative video forwarder if needed.""" + if self._generative_video_task and not self._generative_video_task.done(): + return + if not self.pipeline or not self._stream_processor: + return + + async def _generative_video_loop(): + logger.info("Starting generative video forwarder task") + fps = getattr(self.pipeline, "frame_rate", None) or 30 + time_base = Fraction(1, int(fps)) + pts = self._generative_video_pts + try: + while not self._stop_event.is_set(): + try: + out_tensor = await self.pipeline.client.get_video_output() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error(f"Failed to retrieve generative video output: {exc}") + await asyncio.sleep(0.1) + continue + + if out_tensor is None: + await asyncio.sleep(0.01) + continue + + processed_frame = self.pipeline.video_postprocess(out_tensor) + processed_frame.pts = pts + processed_frame.time_base = time_base + + frame_np = processed_frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + tensor = torch.from_numpy(frame_np) + video_frame = VideoFrame.from_tensor(tensor, timestamp=pts) + video_frame.time_base = time_base + + success = await self._stream_processor.send_input_frame(video_frame) + if not success: + await asyncio.sleep(0.05) + pts += 1 + self._generative_video_pts = pts + except asyncio.CancelledError: + logger.debug("Generative video forwarder cancelled") + raise + except Exception as exc: + logger.error(f"Generative video forwarder encountered an error: {exc}") + finally: + logger.info("Generative video forwarder task exiting") + + self._generative_video_task = asyncio.create_task(_generative_video_loop()) + self._background_tasks.append(self._generative_video_task) + + def _start_generative_audio_forwarder(self) -> None: + """Start the generative audio forwarder if needed.""" + if self._generative_audio_task and not self._generative_audio_task.done(): + return + if not self.pipeline or not self._stream_processor: + return + + async def _generative_audio_loop(): + logger.info("Starting generative audio forwarder task") + sample_rate = 48000 + time_base = Fraction(1, sample_rate) + pts = self._generative_audio_pts + try: + while not self._stop_event.is_set(): + try: + out_audio = await self.pipeline.client.get_audio_output() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error(f"Failed to retrieve generative audio output: {exc}") + await asyncio.sleep(0.1) + continue + + if out_audio is None: + await asyncio.sleep(0.01) + continue + + processed_frame = self.audio_postprocess(out_audio) + processed_frame.pts = pts + processed_frame.time_base = time_base + processed_frame.sample_rate = sample_rate + + audio_frame = AudioFrame.from_av_audio(processed_frame) + success = await self._stream_processor.send_input_frame(audio_frame) + if not success: + await asyncio.sleep(0.05) + pts += audio_frame.nb_samples + self._generative_audio_pts = pts + except asyncio.CancelledError: + logger.debug("Generative audio forwarder cancelled") + raise + except Exception as exc: + logger.error(f"Generative audio forwarder encountered an error: {exc}") + finally: + logger.info("Generative audio forwarder task exiting") + + self._generative_audio_task = asyncio.create_task(_generative_audio_loop()) + self._background_tasks.append(self._generative_audio_task) + + async def _update_generative_forwarders(self) -> None: + """Start or stop generative forwarders based on workflow capabilities.""" + if not self.pipeline or not self._stream_processor: + return + + capabilities = self.pipeline.get_workflow_io_capabilities() + video_only_output = capabilities["video"]["output"] and not capabilities["video"]["input"] + audio_only_output = capabilities["audio"]["output"] and not capabilities["audio"]["input"] + + if video_only_output: + self._start_generative_video_forwarder() + else: + await self._stop_generative_video_forwarder() + + if audio_only_output: + self._start_generative_audio_forwarder() + else: + await self._stop_generative_audio_forwarder() + async def on_stream_stop(self): """Called when stream stops - cleanup background tasks.""" logger.info("Stream stopped, cleaning up background tasks") @@ -179,6 +345,9 @@ async def on_stream_stop(self): except Exception as e: logger.error(f"Error stopping ComfyStream client: {e}") + # Stop generative forwarders + await self._stop_generative_forwarders() + # Stop text forwarder await self._stop_text_forwarder() @@ -406,21 +575,18 @@ async def _process_prompts(self, prompts, *, skip_warmup: bool = False): try: converted = convert_prompt(prompts, return_dict=True) - await self.pipeline.apply_prompts( - [converted], - skip_warmup=skip_warmup, - ) - - if self.pipeline.state_manager.can_stream(): - await self.pipeline.start_streaming() - - logger.info(f"Prompts applied successfully: {list(prompts.keys())}") + # Set prompts in pipeline + await self.pipeline.set_prompts([converted]) + await self.pipeline.resume_prompts() + logger.info(f"Prompts set successfully: {list(prompts.keys())}") if self.pipeline.produces_text_output(): self._setup_text_monitoring() else: await self._stop_text_forwarder() + await self._update_generative_forwarders() + except Exception as e: logger.error(f"Failed to process prompts: {e}") raise diff --git a/src/comfystream/client.py b/src/comfystream/client.py index de543244..d18ddbb1 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -21,11 +21,9 @@ def __init__(self, max_workers: int = 1, **kwargs): self.current_prompts = [] self._cleanup_lock = asyncio.Lock() self._prompt_update_lock = asyncio.Lock() - - # PromptRunner state - self._shutdown_event = asyncio.Event() - self._run_enabled_event = asyncio.Event() - self._runner_task = None + self._stop_event = asyncio.Event() + self._pause_event = asyncio.Event() # Separate event for pause/resume + self._pause_event.set() # Start in "not paused" state (event is set) async def set_prompts(self, prompts: List[PromptDictInput]): """Set new prompts, replacing any existing ones. @@ -40,9 +38,12 @@ async def set_prompts(self, prompts: List[PromptDictInput]): if not prompts: raise ValueError("Cannot set empty prompts list") - # Pause runner while swapping prompts to avoid interleaving - was_running = self._run_enabled_event.is_set() - self._run_enabled_event.clear() + # Cancel existing prompts first to avoid conflicts + await self.cancel_running_prompts() + # Reset stop event for new prompts + self._stop_event.clear() + # Ensure prompts start unpaused + self._pause_event.set() self.current_prompts = [convert_prompt(prompt) for prompt in prompts] logger.info(f"Configured {len(self.current_prompts)} prompt(s)") # Ensure runner exists (IDLE until resumed) @@ -67,42 +68,28 @@ async def update_prompts(self, prompts: List[PromptDictInput]): except Exception as e: raise Exception(f"Prompt update failed: {str(e)}") from e - async def ensure_prompt_tasks_running(self): - # Ensure the single runner task exists (does not force running) - if self._runner_task and not self._runner_task.done(): - return - if not self.current_prompts: - return - self._shutdown_event.clear() - self._runner_task = asyncio.create_task(self._runner_loop()) - - async def _runner_loop(self): - try: - while not self._shutdown_event.is_set(): - # IDLE until running is enabled - await self._run_enabled_event.wait() - # Snapshot prompts without holding the lock during network I/O - async with self._prompt_update_lock: - prompts_snapshot = list(self.current_prompts) - for prompt_index, prompt in enumerate(prompts_snapshot): - if self._shutdown_event.is_set() or not self._run_enabled_event.is_set(): - break - try: - await self.comfy_client.queue_prompt(prompt) - except asyncio.CancelledError: - raise - except ComfyStreamInputTimeoutError: - logger.info(f"Input for prompt {prompt_index} timed out, continuing") - continue - except Exception as e: - logger.error(f"Error running prompt: {str(e)}") - logger.info("Stopping prompt execution and returning to passthrough mode") - - # Stop running and switch to default passthrough workflow - await self._fallback_to_passthrough() - break - except asyncio.CancelledError: - pass + async def run_prompt(self, prompt_index: int): + while not self._stop_event.is_set(): + # Wait for unpause before continuing + await self._pause_event.wait() + + # Check stop event again after waking from pause + if self._stop_event.is_set(): + break + + async with self._prompt_update_lock: + try: + await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) + except asyncio.CancelledError: + raise + except ComfyStreamInputTimeoutError: + # Timeout errors are expected during stream switching - just continue + logger.info(f"Input for prompt {prompt_index} timed out, continuing") + continue + except Exception as e: + await self.cleanup() + logger.error(f"Error running prompt: {str(e)}") + raise async def cleanup(self): # Signal runner to shutdown @@ -126,32 +113,64 @@ async def cleanup(self): await self.cleanup_queues() logger.info("Client cleanup complete") - def pause_prompts(self): - """Pause prompt execution loops without canceling underlying tasks.""" - self._run_enabled_event.clear() - logger.debug("Prompt execution paused") + async def pause_prompts(self): + """Pause prompt execution loops without canceling tasks. + + Clears the pause event, causing prompt loops to wait. Prompts remain + in memory and tasks keep running, they just wait before queueing new prompts. + Can be resumed with resume_prompts(). + """ + self._pause_event.clear() + logger.info("Prompts paused") async def resume_prompts(self): - """Resume prompt execution loops.""" - await self.ensure_prompt_tasks_running() - self._run_enabled_event.set() - logger.debug("Prompt execution resumed") + """Resume paused prompt execution loops. + + Sets the pause event, allowing prompt loops to continue running. + If prompts are not currently running, this will have no effect until + prompts are set via set_prompts(). + """ + self._pause_event.set() + logger.info("Prompts resumed") async def stop_prompts(self, cleanup: bool = False): """Stop running prompts by canceling their tasks. - + Args: - cleanup: If True, perform full cleanup including queue clearing and - client shutdown. If False, only cancel prompt tasks. + cleanup: If True, perform full cleanup including queue clearing + and client shutdown. If False, only cancel prompt tasks. """ - await self.stop_prompts_immediately() - + # Set stop event to signal prompt loops to exit + self._stop_event.set() + + # Cancel running prompt tasks + await self.cancel_running_prompts() + if cleanup: - await self.cleanup() - logger.info("Prompts stopped with full cleanup") + # Perform full cleanup + async with self._cleanup_lock: + if self.comfy_client.is_running: + try: + await self.comfy_client.__aexit__() + except Exception as e: + logger.error(f"Error during ComfyClient cleanup: {e}") + + await self.cleanup_queues() + logger.info("Prompts stopped with full cleanup") else: logger.debug("Prompts stopped (tasks cancelled)") + async def cancel_running_prompts(self): + async with self._cleanup_lock: + tasks_to_cancel = list(self.running_prompts.values()) + for task in tasks_to_cancel: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.running_prompts.clear() + async def cleanup_queues(self): while not tensor_cache.image_inputs.empty(): tensor_cache.image_inputs.get() diff --git a/src/comfystream/modalities.py b/src/comfystream/modalities.py index ded16829..9bd9095b 100644 --- a/src/comfystream/modalities.py +++ b/src/comfystream/modalities.py @@ -25,7 +25,7 @@ class WorkflowModality(TypedDict): "audio_input": {"LoadAudioTensor"}, "audio_output": {"SaveAudioTensor"}, # Text nodes - "text_input": set(), # No text input nodes currently + "text_input": {"PrimitiveString"}, "text_output": {"SaveTextTensor"}, } diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index cf2302de..81749989 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -489,6 +489,37 @@ async def stop_prompts_immediately(self): except Exception: logger.exception("Failed to ensure READY state during immediate stop") + async def pause_prompts(self): + """Pause prompt execution loops without canceling tasks. + + Prompts remain in memory and can be resumed with resume_prompts(). + """ + await self.client.pause_prompts() + + async def resume_prompts(self): + """Resume paused prompt execution loops. + + If prompts are not currently running, this will have no effect until + prompts are set via set_prompts(). + """ + await self.client.resume_prompts() + + async def stop_prompts(self, cleanup: bool = False): + """Stop running prompts by canceling their tasks. + + Args: + cleanup: If True, perform full cleanup including queue clearing + and client shutdown. If False, only cancel prompt tasks. + """ + await self.client.stop_prompts(cleanup=cleanup) + + # Clear cached modalities and I/O capabilities when prompts are stopped + if cleanup: + self._cached_modalities = None + self._cached_io_capabilities = None + # Clear pipeline queues for full cleanup + await self._clear_pipeline_queues() + async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. diff --git a/workflows/comfystream/test generative workflow-2-api-t2i.json b/workflows/comfystream/test generative workflow-2-api-t2i.json new file mode 100644 index 00000000..943b8b5f --- /dev/null +++ b/workflows/comfystream/test generative workflow-2-api-t2i.json @@ -0,0 +1,44 @@ +{ + "1": { + "inputs": { + "value": "gen stream success!" + }, + "class_type": "PrimitiveString", + "_meta": { + "title": "String" + } + }, + "7": { + "inputs": { + "width": 512, + "height": 512, + "font_size": 48, + "font_color": "white", + "background_color": "black", + "x_offset": 0, + "y_offset": 0, + "align": "center", + "wrap_width": 0, + "any": [ + "1", + 0 + ] + }, + "class_type": "TextRenderer", + "_meta": { + "title": "Text Renderer 🕒🅡🅣🅝" + } + }, + "10": { + "inputs": { + "images": [ + "7", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/workflows/comfystream/text-tensor-utils-example-api.json b/workflows/comfystream/text-tensor-utils-example-api.json new file mode 100644 index 00000000..22d77ef4 --- /dev/null +++ b/workflows/comfystream/text-tensor-utils-example-api.json @@ -0,0 +1,25 @@ +{ + "1": { + "inputs": { + "value": "Hello from ComfyStream!" + }, + "class_type": "PrimitiveString", + "_meta": { + "title": "String Input" + } + }, + "2": { + "inputs": { + "data": [ + "1", + 0 + ], + "remove_linebreaks": true + }, + "class_type": "SaveTextTensor", + "_meta": { + "title": "Save Text Tensor" + } + } +} +