Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,68 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
self.current_patcher: 'ModelPatcher' = None

self.enable_trt = True
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
breakpoint()
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
dtype = self.diffusion_model.dtype

# trt_rtx cannot handle bfloat16, so convert to float16, float32 always convert to float16 for trt_rtx to save memory
if self.enable_trt and dtype in (torch.float32, torch.bfloat16, torch.float16):
self.diffusion_model = self.diffusion_model.half()
unet_config["dtype"] = torch.float16
self.diffusion_model.dtype = torch.float16
logging.debug(f"converted diffusion modelfrom {dtype} to float16 for trt_rtx")
else:
self.enable_trt = False
logging.warning("trt_rtx cannot handle ${dtype}, so disabling trt_rtx")

self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))

if self.enable_trt:
import torch_tensorrt
settings = {
"use_python_runtime": False,
"immutable_weights": False,
"offload_to_cpu": True,
}
self.trt_compiled_diffusion_model = torch_tensorrt.MutableTorchTensorRTModule(self.diffusion_model, **settings)
# TODO: INVESTIGATE WHY DYNAMIC SHAPE IS NOT WORKING
enable_trt_dynamic_shape = False
if enable_trt_dynamic_shape:
# if batch size is 2, then sigmas-batch is 2, dim_batch is 4
sigmas_batch = torch.export.Dim("sigmas_batch", min=1, max=20)
dim_batch = torch.export.Dim("batch", min=1, max=40)

#dim_width = torch.export.Dim("width", min=3, max=64)
#dim_height = torch.export.Dim("height", min=5, max=64)
# args: xc, t
args_dynamic_shapes=({0: dim_batch}, {0: dim_batch},)
#args_dynamic_shapes=({0: dim_batch, 2: dim_width*4, 3: dim_height*4}, {0: dim_batch},)
# kwargs: context, transformer_options, y
kwargs_dynamic_shape = {
'context': {0: dim_batch},
'transformer_options': {
'wrappers': {},
'callbacks': {},
'sample_sigmas': {},
#'cond_or_uncond': {},
'sigmas': {0:sigmas_batch},
},
'y': {0: dim_batch,},
}
self.trt_compiled_diffusion_model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shape)
logging.debug("lan added ********** trt_model: trt_compiled_diffusion_model is created")

self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

Expand Down Expand Up @@ -199,8 +248,23 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
t = self.process_timestep(t, x=x, **extra_conds)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))

model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)

logging.debug(f"lan added *** {xc.shape=} {t.shape=} {context.shape=} {extra_conds['y'].shape=} {transformer_options['sample_sigmas'].shape=} {transformer_options['sigmas'].shape=}")
logging.debug(f"lan added *** {transformer_options=}")
if control is not None:
logging.debug(f"lan added ***{control.shape=}")
else:
logging.debug("lan added ***control is None")
logging.debug(f"lan added ***{extra_conds=}")
if self.enable_trt:
transformer_options.pop("uuids", None)
transformer_options.pop("cond_or_uncond", None)
with torch.no_grad():
model_output = self.trt_compiled_diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
else:
transformer_options.pop("uuids", None)
transformer_options.pop("cond_or_uncond", None)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)

Expand Down