diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index e6d44463..f850e460 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -99,7 +99,10 @@ if [ "$1" = "--build-engines" ]; then # Build Static Engine for Dreamshaper - Landscape (704x384) python src/comfystream/scripts/build_trt.py --model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors --out-engine /workspace/ComfyUI/output/tensorrt/static-dreamshaper8_SD15_\$stat-b-1-h-384-w-704_00001_.engine --width 704 --height 384 - + + # Build Static Engine for Dreamshaper - Square (512x512) - Batch Size 2 + python src/comfystream/scripts/build_trt.py --model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors --out-engine /workspace/ComfyUI/output/tensorrt/static-dreamshaper8_SD15_\$stat-b-2-h-512-w-512_00001_.engine --width 512 --height 512 --batch-size 2 + # Build Dynamic Engine for Dreamshaper python src/comfystream/scripts/build_trt.py \ --model /workspace/ComfyUI/models/unet/dreamshaper-8-dmd-1kstep.safetensors \ @@ -110,6 +113,7 @@ if [ "$1" = "--build-engines" ]; then --min-height 512 \ --max-width 448 \ --max-height 704 + # Build Engine for Depth Anything V2 if [ ! -f "$DEPTH_ANYTHING_DIR/$DEPTH_ANYTHING_ENGINE" ]; then diff --git a/nodes/tensor_utils/__init__.py b/nodes/tensor_utils/__init__.py index eadb5b84..41c05d55 100644 --- a/nodes/tensor_utils/__init__.py +++ b/nodes/tensor_utils/__init__.py @@ -3,11 +3,14 @@ from .load_tensor import LoadTensor from .save_tensor import SaveTensor from .save_text_tensor import SaveTextTensor +from .performance_nodes import PerformanceTimerNode, StartPerformanceTimerNode NODE_CLASS_MAPPINGS = { "LoadTensor": LoadTensor, "SaveTensor": SaveTensor, "SaveTextTensor": SaveTextTensor, + "PerformanceTimerNode": PerformanceTimerNode, + "StartPerformanceTimerNode": StartPerformanceTimerNode, } NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index a2fb5940..d0919a60 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -13,6 +13,9 @@ class LoadTensor: @classmethod def INPUT_TYPES(cls): return { + "required": { + "batch_size": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}), + }, "optional": { "timeout_seconds": ( "FLOAT", @@ -31,10 +34,51 @@ def INPUT_TYPES(cls): def IS_CHANGED(cls, **kwargs): return float("nan") - def execute(self, timeout_seconds: float = 1.0): - try: - frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds) - frame.side_data.skipped = False - return (frame.side_data.input,) - except queue.Empty: - raise ComfyStreamInputTimeoutError("video", timeout_seconds) + def execute(self, batch_size: int = 1, timeout_seconds: float = 1.0): + """ + Load tensor(s) from the tensor cache. + If batch_size > 1, loads multiple tensors and stacks them into a batch. + """ + if batch_size == 1: + # Single tensor loading with timeout + try: + frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds) + frame.side_data.skipped = False + return (frame.side_data.input,) + except queue.Empty: + raise ComfyStreamInputTimeoutError("video", timeout_seconds) + else: + # Batch tensor loading - only process if we have enough real frames + batch_images = [] + + # Collect images up to batch_size, but only use real frames + for i in range(batch_size): + if not tensor_cache.image_inputs.empty(): + try: + frame = tensor_cache.image_inputs.get(block=True, timeout=timeout_seconds) + frame.side_data.skipped = False + batch_images.append(frame.side_data.input) + except queue.Empty: + # If timeout occurs, stop collecting and use what we have + break + else: + # If queue is empty, stop collecting + break + + # Only proceed if we have at least one real frame + if not batch_images: + # No frames available - raise timeout error instead of creating dummy + raise ComfyStreamInputTimeoutError("video", timeout_seconds) + + # If we have fewer frames than requested, pad with the last available frame + # This is better than dummy tensors as it maintains visual continuity + while len(batch_images) < batch_size: + batch_images.append(batch_images[-1]) + + # Stack images into a batch + if len(batch_images) > 1: + batch_tensor = torch.cat(batch_images, dim=0) + else: + batch_tensor = batch_images[0] + + return (batch_tensor,) diff --git a/nodes/tensor_utils/performance_nodes.py b/nodes/tensor_utils/performance_nodes.py new file mode 100644 index 00000000..74842f90 --- /dev/null +++ b/nodes/tensor_utils/performance_nodes.py @@ -0,0 +1,70 @@ +""" +Performance measurement nodes for ComfyStream batch processing. +These nodes integrate with the existing tensor_utils structure. +""" + +from comfystream.utils import performance_timer + + +class PerformanceTimerNode: + CATEGORY = "tensor_utils" + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("performance_summary",) + FUNCTION = "execute" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "operation": ("STRING", {"default": "workflow_execution"}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}), + "num_images": ("INT", {"default": 1, "min": 1, "max": 100, "step": 1}), + } + } + + @classmethod + def IS_CHANGED(s): + return float("nan") + + def execute(self, operation: str, batch_size: int, num_images: int): + """Record performance metrics and return summary.""" + performance_timer.record_batch_processing(batch_size, num_images) + performance_timer.end_timing(operation) + + summary = performance_timer.get_performance_summary() + + # Format summary as readable string + summary_str = f"Performance Summary:\n" + summary_str += f"Total Images Processed: {summary['total_images_processed']}\n" + summary_str += f"Total FPS: {summary['total_fps']:.2f}\n" + summary_str += f"Average Batch Size: {summary['average_batch_size']:.2f}\n" + + for key, value in summary.items(): + if key not in ["total_images_processed", "total_fps", "average_batch_size"]: + summary_str += f"{key}: {value:.4f}\n" + + return (summary_str,) + + +class StartPerformanceTimerNode: + CATEGORY = "tensor_utils" + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("timer_started",) + FUNCTION = "execute" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "operation": ("STRING", {"default": "workflow_execution"}), + } + } + + @classmethod + def IS_CHANGED(s): + return float("nan") + + def execute(self, operation: str): + """Start timing an operation.""" + performance_timer.start_timing(operation) + return (f"Started timing: {operation}",) diff --git a/nodes/tensor_utils/save_tensor.py b/nodes/tensor_utils/save_tensor.py index 3a021aa5..39a49b1f 100644 --- a/nodes/tensor_utils/save_tensor.py +++ b/nodes/tensor_utils/save_tensor.py @@ -14,6 +14,9 @@ def INPUT_TYPES(s): return { "required": { "images": ("IMAGE",), + }, + "optional": { + "split_batch": ("BOOLEAN", {"default": False}), } } @@ -21,6 +24,18 @@ def INPUT_TYPES(s): def IS_CHANGED(s): return float("nan") - def execute(self, images: torch.Tensor): - tensor_cache.image_outputs.put_nowait(images) + def execute(self, images: torch.Tensor, split_batch: bool = False): + """ + Save tensor(s) to the tensor cache. + If split_batch is True and images is a batch, splits it into individual images. + """ + if split_batch and images.dim() == 4 and images.shape[0] > 1: + # Split batch into individual images + for i in range(images.shape[0]): + single_image = images[i:i+1] # Keep batch dimension + tensor_cache.image_outputs.put_nowait(single_image) + else: + # Save as single tensor (original behavior) + tensor_cache.image_outputs.put_nowait(images) + return images diff --git a/src/comfystream/modalities.py b/src/comfystream/modalities.py index ded16829..c5037f01 100644 --- a/src/comfystream/modalities.py +++ b/src/comfystream/modalities.py @@ -14,6 +14,8 @@ class WorkflowModality(TypedDict): video: ModalityIO audio: ModalityIO text: ModalityIO + # Batch processing information + max_batch_size: int # Centralized node type definitions @@ -101,8 +103,12 @@ def create_empty_workflow_modality() -> WorkflowModality: def _merge_workflow_modalities(base: WorkflowModality, other: WorkflowModality) -> WorkflowModality: """Merge two WorkflowModality objects using logical OR for all capabilities.""" for modality in base: - for direction in base[modality]: - base[modality][direction] = base[modality][direction] or other[modality][direction] + if modality == "max_batch_size": + # For batch size, take the maximum + base[modality] = max(base[modality], other[modality]) + elif isinstance(base[modality], dict): + for direction in base[modality]: + base[modality][direction] = base[modality][direction] or other[modality][direction] return base diff --git a/src/comfystream/pipeline.py b/src/comfystream/pipeline.py index b3f9e65c..9c86b9da 100644 --- a/src/comfystream/pipeline.py +++ b/src/comfystream/pipeline.py @@ -49,8 +49,10 @@ def __init__( self.width = width self.height = height - self.video_incoming_frames = asyncio.Queue() - self.audio_incoming_frames = asyncio.Queue() + # Initialize queues with default size (will be updated based on workflow analysis) + self._default_queue_size = 10 # Default queue size + self.video_incoming_frames = asyncio.Queue(maxsize=self._default_queue_size) + self.audio_incoming_frames = asyncio.Queue(maxsize=self._default_queue_size) self.processed_audio_buffer = np.array([], dtype=np.int16) @@ -58,6 +60,45 @@ def __init__( self._cached_modalities: Optional[Set[str]] = None self._cached_io_capabilities: Optional[WorkflowModality] = None + def _recreate_queues(self, new_queue_size: int): + """Recreate queues with new size limits to optimize for batch processing. + + Args: + new_queue_size: Maximum number of frames to store in each queue + """ + # Calculate optimal queue size: batch_size * 2 + some buffer for frame skipping + optimal_size = max(new_queue_size * 2, self._default_queue_size) + + logger.info(f"Recreating queues with size {optimal_size} (batch_size: {new_queue_size})") + + # Create new queues with the calculated size + self.video_incoming_frames = asyncio.Queue(maxsize=optimal_size) + self.audio_incoming_frames = asyncio.Queue(maxsize=optimal_size) + + def _update_queue_sizes_for_batch_processing(self): + """Update queue sizes based on detected batch requirements from workflow analysis.""" + if not hasattr(self.client, 'current_prompts') or not self.client.current_prompts: + logger.debug("No prompts available for batch size analysis") + return + + try: + # Get workflow I/O capabilities (which now includes batch_size) + io_capabilities = self.get_workflow_io_capabilities() + detected_batch_size = io_capabilities.get("max_batch_size", 1) + + # Only recreate queues if batch size has changed significantly + current_queue_size = self.video_incoming_frames.maxsize + optimal_size = max(detected_batch_size * 2, self._default_queue_size) + + if optimal_size != current_queue_size: + logger.info(f"Detected batch_size {detected_batch_size}, updating queue size from {current_queue_size} to {optimal_size}") + self._recreate_queues(detected_batch_size) + else: + logger.debug(f"Queue size already optimal for batch_size {detected_batch_size}") + + except Exception as e: + logger.warning(f"Failed to update queue sizes based on batch analysis: {e}") + async def warm_video(self): """Warm up the video processing pipeline with dummy frames.""" # Only warm if workflow accepts video input @@ -120,6 +161,9 @@ async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]] # Clear cached modalities and I/O capabilities when prompts change self._cached_modalities = None self._cached_io_capabilities = None + + # Update queue sizes based on detected batch requirements + self._update_queue_sizes_for_batch_processing() async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): """Update the existing processing prompts. @@ -135,6 +179,9 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any # Clear cached modalities and I/O capabilities when prompts change self._cached_modalities = None self._cached_io_capabilities = None + + # Update queue sizes based on detected batch requirements + self._update_queue_sizes_for_batch_processing() async def put_video_frame(self, frame: av.VideoFrame): """Queue a video frame for processing. diff --git a/src/comfystream/utils.py b/src/comfystream/utils.py index 1a989c8d..d809bdef 100644 --- a/src/comfystream/utils.py +++ b/src/comfystream/utils.py @@ -1,5 +1,9 @@ import copy import importlib +import time +from typing import Dict, Any, List, Tuple, Optional, Union +from contextlib import contextmanager +from pytrickle.api import StreamParamsUpdateRequest import json from typing import Any, Dict @@ -11,10 +15,14 @@ get_node_counts_by_type, ) - -def create_load_tensor_node(): +def create_load_tensor_node(batch_size: int = 1): + """Create a LoadTensor node with specified batch size. + + Args: + batch_size: Number of frames to process in batch (default: 1) + """ return { - "inputs": {}, + "inputs": {"batch_size": batch_size}, "class_type": "LoadTensor", "_meta": {"title": "LoadTensor"}, } @@ -139,3 +147,84 @@ def get_default_workflow() -> dict: }, "2": {"inputs": {}, "class_type": "LoadTensor", "_meta": {"title": "LoadTensor"}}, } + +class PerformanceTimer: + """Utility class for measuring performance metrics in ComfyStream workflows.""" + + def __init__(self): + self.timings: Dict[str, List[float]] = {} + self.current_timings: Dict[str, float] = {} + self.batch_sizes: List[int] = [] + self.total_images_processed = 0 + + def start_timing(self, operation: str): + """Start timing an operation.""" + self.current_timings[operation] = time.time() + + def end_timing(self, operation: str): + """End timing an operation and record the duration.""" + if operation in self.current_timings: + duration = time.time() - self.current_timings[operation] + if operation not in self.timings: + self.timings[operation] = [] + self.timings[operation].append(duration) + del self.current_timings[operation] + return duration + return 0.0 + + def record_batch_processing(self, batch_size: int, num_images: int): + """Record a batch processing event.""" + self.batch_sizes.append(batch_size) + self.total_images_processed += num_images + + def get_fps(self, operation: str = "total") -> float: + """Calculate FPS for a specific operation.""" + if operation not in self.timings or not self.timings[operation]: + return 0.0 + + total_time = sum(self.timings[operation]) + if total_time == 0: + return 0.0 + + return self.total_images_processed / total_time + + def get_average_time(self, operation: str) -> float: + """Get average time for an operation.""" + if operation not in self.timings or not self.timings[operation]: + return 0.0 + + return sum(self.timings[operation]) / len(self.timings[operation]) + + def get_performance_summary(self) -> Dict[str, float]: + """Get a comprehensive performance summary.""" + summary = { + "total_images_processed": self.total_images_processed, + "total_fps": self.get_fps("total"), + "average_batch_size": sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0, + } + + for operation in self.timings: + summary[f"{operation}_fps"] = self.get_fps(operation) + summary[f"{operation}_avg_time"] = self.get_average_time(operation) + + return summary + + def reset(self): + """Reset all performance data.""" + self.timings.clear() + self.current_timings.clear() + self.batch_sizes.clear() + self.total_images_processed = 0 + + @contextmanager + def time_operation(self, operation: str): + """Context manager for timing operations.""" + self.start_timing(operation) + try: + yield + finally: + self.end_timing(operation) + + +# Global performance timer instance +performance_timer = PerformanceTimer() diff --git a/workflows/comfystream/sd15-tensorrt-batch2-api.json b/workflows/comfystream/sd15-tensorrt-batch2-api.json new file mode 100644 index 00000000..e9e68d8a --- /dev/null +++ b/workflows/comfystream/sd15-tensorrt-batch2-api.json @@ -0,0 +1,269 @@ +{ + "1": { + "inputs": { + "image1": "example1.png", + "image2": "example2.png", + "upload": "image" + }, + "class_type": "ImageBatch", + "_meta": { + "title": "Image Batch Input" + } + }, + "2": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "1", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt (Batch)" + } + }, + "3": { + "inputs": { + "unet_name": "static-dreamshaper8_SD15_$stat-b-2-h-512-w-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader (Batch Size 2)" + } + }, + "5": { + "inputs": { + "text": "the hulk", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "7": { + "inputs": { + "seed": 905056445574169, + "steps": 1, + "cfg": 1, + "sampler_name": "lcm", + "scheduler": "normal", + "denoise": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "9", + 0 + ], + "negative": [ + "9", + 1 + ], + "latent_image": [ + "16", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler (Batch Processing)" + } + }, + "8": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "9": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "10", + 0 + ], + "image": [ + "2", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet (Batch)" + } + }, + "10": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "8", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "14": { + "inputs": { + "samples": [ + "7", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode (Batch)" + } + }, + "15": { + "inputs": { + "images": [ + "14", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image (Batch)" + } + }, + "16": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 2 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image (Batch Size 2)" + } + }, + "17": { + "inputs": { + "images": [ + "14", + 0 + ], + "index": 0 + }, + "class_type": "ImageFromBatch", + "_meta": { + "title": "Image From Batch (First Image)" + } + }, + "18": { + "inputs": { + "images": [ + "14", + 0 + ], + "index": 1 + }, + "class_type": "ImageFromBatch", + "_meta": { + "title": "Image From Batch (Second Image)" + } + }, + "19": { + "inputs": { + "images": [ + "17", + 0 + ] + }, + "class_type": "SaveTensor", + "_meta": { + "title": "Save First Image Tensor" + } + }, + "20": { + "inputs": { + "images": [ + "18", + 0 + ] + }, + "class_type": "SaveTensor", + "_meta": { + "title": "Save Second Image Tensor" + } + }, + "23": { + "inputs": { + "clip_name": "CLIPText/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + } +} diff --git a/workflows/comfystream/sd15-tensorrt-batch2-performance-api.json b/workflows/comfystream/sd15-tensorrt-batch2-performance-api.json new file mode 100644 index 00000000..d2b9c5fa --- /dev/null +++ b/workflows/comfystream/sd15-tensorrt-batch2-performance-api.json @@ -0,0 +1,250 @@ +{ + "1": { + "inputs": { + "batch_size": 2 + }, + "class_type": "LoadTensor", + "_meta": { + "title": "Load Tensor (Batch Size 2)" + } + }, + "2": { + "inputs": { + "operation": "workflow_execution" + }, + "class_type": "StartPerformanceTimerNode", + "_meta": { + "title": "Start Performance Timer" + } + }, + "3": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "1", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt (Batch)" + } + }, + "4": { + "inputs": { + "unet_name": "static-dreamshaper8_SD15_$stat-b-2-h-512-w-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader (Batch Size 2)" + } + }, + "5": { + "inputs": { + "text": "the hulk", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "7": { + "inputs": { + "seed": 905056445574169, + "steps": 1, + "cfg": 1, + "sampler_name": "lcm", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "9", + 0 + ], + "negative": [ + "9", + 1 + ], + "latent_image": [ + "16", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler (Batch Processing)" + } + }, + "8": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "9": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "10", + 0 + ], + "image": [ + "3", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet (Batch)" + } + }, + "10": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "8", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "14": { + "inputs": { + "samples": [ + "7", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode (Batch)" + } + }, + "15": { + "inputs": { + "images": [ + "14", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image (Batch)" + } + }, + "16": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 2 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image (Batch Size 2)" + } + }, + "17": { + "inputs": { + "images": [ + "14", + 0 + ], + "split_batch": true + }, + "class_type": "SaveTensor", + "_meta": { + "title": "Save Batch Tensor Output (Split)" + } + }, + "18": { + "inputs": { + "operation": "workflow_execution", + "batch_size": 2, + "num_images": 2 + }, + "class_type": "PerformanceTimerNode", + "_meta": { + "title": "End Performance Timer" + } + }, + "23": { + "inputs": { + "clip_name": "CLIPText/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + } +} diff --git a/workflows/comfystream/sd15-tensorrt-batch2-tensor-api.json b/workflows/comfystream/sd15-tensorrt-batch2-tensor-api.json new file mode 100644 index 00000000..955bcb18 --- /dev/null +++ b/workflows/comfystream/sd15-tensorrt-batch2-tensor-api.json @@ -0,0 +1,230 @@ +{ + "1": { + "inputs": { + "batch_size": 2 + }, + "class_type": "LoadTensor", + "_meta": { + "title": "Load Tensor (Batch Size 2)" + } + }, + "2": { + "inputs": { + "engine": "depth_anything_vitl14-fp16.engine", + "images": [ + "1", + 0 + ] + }, + "class_type": "DepthAnythingTensorrt", + "_meta": { + "title": "Depth Anything Tensorrt (Batch)" + } + }, + "3": { + "inputs": { + "unet_name": "static-dreamshaper8_SD15_$stat-b-2-h-512-w-512_00001_.engine", + "model_type": "SD15" + }, + "class_type": "TensorRTLoader", + "_meta": { + "title": "TensorRT Loader (Batch Size 2)" + } + }, + "5": { + "inputs": { + "text": "the hulk", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "text": "", + "clip": [ + "23", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Negative Prompt)" + } + }, + "7": { + "inputs": { + "seed": 905056445574169, + "steps": 1, + "cfg": 1, + "sampler_name": "lcm", + "scheduler": "normal", + "denoise": 1, + "model": [ + "3", + 0 + ], + "positive": [ + "9", + 0 + ], + "negative": [ + "9", + 1 + ], + "latent_image": [ + "16", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler (Batch Processing)" + } + }, + "8": { + "inputs": { + "control_net_name": "control_v11f1p_sd15_depth_fp16.safetensors" + }, + "class_type": "ControlNetLoader", + "_meta": { + "title": "Load ControlNet Model" + } + }, + "9": { + "inputs": { + "strength": 1, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "5", + 0 + ], + "negative": [ + "6", + 0 + ], + "control_net": [ + "10", + 0 + ], + "image": [ + "2", + 0 + ] + }, + "class_type": "ControlNetApplyAdvanced", + "_meta": { + "title": "Apply ControlNet (Batch)" + } + }, + "10": { + "inputs": { + "backend": "inductor", + "fullgraph": false, + "mode": "reduce-overhead", + "controlnet": [ + "8", + 0 + ] + }, + "class_type": "TorchCompileLoadControlNet", + "_meta": { + "title": "TorchCompileLoadControlNet" + } + }, + "11": { + "inputs": { + "vae_name": "taesd" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "13": { + "inputs": { + "backend": "inductor", + "fullgraph": true, + "mode": "reduce-overhead", + "compile_encoder": true, + "compile_decoder": true, + "vae": [ + "11", + 0 + ] + }, + "class_type": "TorchCompileLoadVAE", + "_meta": { + "title": "TorchCompileLoadVAE" + } + }, + "14": { + "inputs": { + "samples": [ + "7", + 0 + ], + "vae": [ + "13", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode (Batch)" + } + }, + "15": { + "inputs": { + "images": [ + "14", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image (Batch)" + } + }, + "16": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 2 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image (Batch Size 2)" + } + }, + "17": { + "inputs": { + "images": [ + "14", + 0 + ], + "split_batch": true + }, + "class_type": "SaveTensor", + "_meta": { + "title": "Save Batch Tensor Output (Split)" + } + }, + "23": { + "inputs": { + "clip_name": "CLIPText/model.fp16.safetensors", + "type": "stable_diffusion", + "device": "default" + }, + "class_type": "CLIPLoader", + "_meta": { + "title": "Load CLIP" + } + } +}