Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ if [ "$1" = "--build-engines" ]; then

# Build Static Engine for Dreamshaper - Landscape (704x384)
python src/comfystream/scripts/build_trt.py --model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors --out-engine /workspace/ComfyUI/output/tensorrt/static-dreamshaper8_SD15_\$stat-b-1-h-384-w-704_00001_.engine --width 704 --height 384


# Build Static Engine for Dreamshaper - Square (512x512) - Batch Size 2
python src/comfystream/scripts/build_trt.py --model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors --out-engine /workspace/ComfyUI/output/tensorrt/static-dreamshaper8_SD15_\$stat-b-2-h-512-w-512_00001_.engine --width 512 --height 512 --batch-size 2

# Build Dynamic Engine for Dreamshaper
python src/comfystream/scripts/build_trt.py \
--model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors \
Expand All @@ -110,6 +113,7 @@ if [ "$1" = "--build-engines" ]; then
--min-height 512 \
--max-width 448 \
--max-height 704


# Build Engine for Depth Anything V2
if [ ! -f "$DEPTH_ANYTHING_DIR/$DEPTH_ANYTHING_ENGINE" ]; then
Expand Down
3 changes: 3 additions & 0 deletions nodes/tensor_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from .load_tensor import LoadTensor
from .save_tensor import SaveTensor
from .save_text_tensor import SaveTextTensor
from .performance_nodes import PerformanceTimerNode, StartPerformanceTimerNode

NODE_CLASS_MAPPINGS = {
"LoadTensor": LoadTensor,
"SaveTensor": SaveTensor,
"SaveTextTensor": SaveTextTensor,
"PerformanceTimerNode": PerformanceTimerNode,
"StartPerformanceTimerNode": StartPerformanceTimerNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {}

Expand Down
58 changes: 51 additions & 7 deletions nodes/tensor_utils/load_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class LoadTensor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"batch_size": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}),
},
"optional": {
"timeout_seconds": (
"FLOAT",
Expand All @@ -31,10 +34,51 @@ def INPUT_TYPES(cls):
def IS_CHANGED(cls, **kwargs):
return float("nan")

def execute(self, timeout_seconds: float = 1.0):
try:
frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds)
frame.side_data.skipped = False
return (frame.side_data.input,)
except queue.Empty:
raise ComfyStreamInputTimeoutError("video", timeout_seconds)
def execute(self, batch_size: int = 1, timeout_seconds: float = 1.0):
"""
Load tensor(s) from the tensor cache.
If batch_size > 1, loads multiple tensors and stacks them into a batch.
"""
if batch_size == 1:
# Single tensor loading with timeout
try:
frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds)
frame.side_data.skipped = False
return (frame.side_data.input,)
except queue.Empty:
raise ComfyStreamInputTimeoutError("video", timeout_seconds)
else:
# Batch tensor loading - only process if we have enough real frames
batch_images = []

# Collect images up to batch_size, but only use real frames
for i in range(batch_size):
if not tensor_cache.image_inputs.empty():
try:
frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds)
frame.side_data.skipped = False
batch_images.append(frame.side_data.input)
except queue.Empty:
# If timeout occurs, stop collecting and use what we have
break
else:
# If queue is empty, stop collecting
break

# Only proceed if we have at least one real frame
if not batch_images:
# No frames available - raise timeout error instead of creating dummy
raise ComfyStreamInputTimeoutError("video", timeout_seconds)

# If we have fewer frames than requested, pad with the last available frame
# This is better than dummy tensors as it maintains visual continuity
while len(batch_images) < batch_size:
batch_images.append(batch_images[-1])

# Stack images into a batch
if len(batch_images) > 1:
batch_tensor = torch.cat(batch_images, dim=0)
else:
batch_tensor = batch_images[0]

return (batch_tensor,)
70 changes: 70 additions & 0 deletions nodes/tensor_utils/performance_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Performance measurement nodes for ComfyStream batch processing.
These nodes integrate with the existing tensor_utils structure.
"""

from comfystream.utils import performance_timer


class PerformanceTimerNode:
CATEGORY = "tensor_utils"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("performance_summary",)
FUNCTION = "execute"

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"operation": ("STRING", {"default": "workflow_execution"}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}),
"num_images": ("INT", {"default": 1, "min": 1, "max": 100, "step": 1}),
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, operation: str, batch_size: int, num_images: int):
"""Record performance metrics and return summary."""
performance_timer.record_batch_processing(batch_size, num_images)
performance_timer.end_timing(operation)

summary = performance_timer.get_performance_summary()

# Format summary as readable string
summary_str = f"Performance Summary:\n"
summary_str += f"Total Images Processed: {summary['total_images_processed']}\n"
summary_str += f"Total FPS: {summary['total_fps']:.2f}\n"
summary_str += f"Average Batch Size: {summary['average_batch_size']:.2f}\n"

for key, value in summary.items():
if key not in ["total_images_processed", "total_fps", "average_batch_size"]:
summary_str += f"{key}: {value:.4f}\n"

return (summary_str,)


class StartPerformanceTimerNode:
CATEGORY = "tensor_utils"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("timer_started",)
FUNCTION = "execute"

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"operation": ("STRING", {"default": "workflow_execution"}),
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, operation: str):
"""Start timing an operation."""
performance_timer.start_timing(operation)
return (f"Started timing: {operation}",)
19 changes: 17 additions & 2 deletions nodes/tensor_utils/save_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,28 @@ def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
},
"optional": {
"split_batch": ("BOOLEAN", {"default": False}),
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, images: torch.Tensor):
tensor_cache.image_outputs.put_nowait(images)
def execute(self, images: torch.Tensor, split_batch: bool = False):
"""
Save tensor(s) to the tensor cache.
If split_batch is True and images is a batch, splits it into individual images.
"""
if split_batch and images.dim() == 4 and images.shape[0] > 1:
# Split batch into individual images
for i in range(images.shape[0]):
single_image = images[i:i+1] # Keep batch dimension
tensor_cache.image_outputs.put_nowait(single_image)
else:
# Save as single tensor (original behavior)
tensor_cache.image_outputs.put_nowait(images)

return images
10 changes: 8 additions & 2 deletions src/comfystream/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class WorkflowModality(TypedDict):
video: ModalityIO
audio: ModalityIO
text: ModalityIO
# Batch processing information
max_batch_size: int


# Centralized node type definitions
Expand Down Expand Up @@ -101,8 +103,12 @@ def create_empty_workflow_modality() -> WorkflowModality:
def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) -> WorkflowModality:
"""Merge two WorkflowModality objects using logical OR for all capabilities."""
for modality in base:
for direction in base[modality]:
base[modality][direction] = base[modality][direction] or other[modality][direction]
if modality == "max_batch_size":
# For batch size, take the maximum
base[modality] = max(base[modality], other[modality])
elif isinstance(base[modality], dict):
for direction in base[modality]:
base[modality][direction] = base[modality][direction] or other[modality][direction]
return base


Expand Down
51 changes: 49 additions & 2 deletions src/comfystream/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,56 @@ def __init__(
self.width = width
self.height = height

self.video_incoming_frames = asyncio.Queue()
self.audio_incoming_frames = asyncio.Queue()
# Initialize queues with default size (will be updated based on workflow analysis)
self._default_queue_size = 10 # Default queue size
self.video_incoming_frames = asyncio.Queue(maxsize=self._default_queue_size)
self.audio_incoming_frames = asyncio.Queue(maxsize=self._default_queue_size)

self.processed_audio_buffer = np.array([], dtype=np.int16)

self._comfyui_inference_log_level = comfyui_inference_log_level
self._cached_modalities: Optional[Set[str]] = None
self._cached_io_capabilities: Optional[WorkflowModality] = None

def _recreate_queues(self, new_queue_size: int):
"""Recreate queues with new size limits to optimize for batch processing.

Args:
new_queue_size: Maximum number of frames to store in each queue
"""
# Calculate optimal queue size: batch_size * 2 + some buffer for frame skipping
optimal_size = max(new_queue_size * 2, self._default_queue_size)

logger.info(f"Recreating queues with size {optimal_size} (batch_size: {new_queue_size})")

# Create new queues with the calculated size
self.video_incoming_frames = asyncio.Queue(maxsize=optimal_size)
self.audio_incoming_frames = asyncio.Queue(maxsize=optimal_size)

def _update_queue_sizes_for_batch_processing(self):
"""Update queue sizes based on detected batch requirements from workflow analysis."""
if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts:
logger.debug("No prompts available for batch size analysis")
return

try:
# Get workflow I/O capabilities (which now includes batch_size)
io_capabilities = self.get_workflow_io_capabilities()
detected_batch_size = io_capabilities.get("max_batch_size", 1)

# Only recreate queues if batch size has changed significantly
current_queue_size = self.video_incoming_frames.maxsize
optimal_size = max(detected_batch_size * 2, self._default_queue_size)

if optimal_size != current_queue_size:
logger.info(f"Detected batch_size {detected_batch_size}, updating queue size from {current_queue_size} to {optimal_size}")
self._recreate_queues(detected_batch_size)
else:
logger.debug(f"Queue size already optimal for batch_size {detected_batch_size}")

except Exception as e:
logger.warning(f"Failed to update queue sizes based on batch analysis: {e}")

async def warm_video(self):
"""Warm up the video processing pipeline with dummy frames."""
# Only warm if workflow accepts video input
Expand Down Expand Up @@ -120,6 +161,9 @@ async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]
# Clear cached modalities and I/O capabilities when prompts change
self._cached_modalities = None
self._cached_io_capabilities = None

# Update queue sizes based on detected batch requirements
self._update_queue_sizes_for_batch_processing()

async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
"""Update the existing processing prompts.
Expand All @@ -135,6 +179,9 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any
# Clear cached modalities and I/O capabilities when prompts change
self._cached_modalities = None
self._cached_io_capabilities = None

# Update queue sizes based on detected batch requirements
self._update_queue_sizes_for_batch_processing()

async def put_video_frame(self, frame: av.VideoFrame):
"""Queue a video frame for processing.
Expand Down
Loading