From 251dbeadcf86e48d4621d2b8eeec25774b47afe9 Mon Sep 17 00:00:00 2001 From: Ziheng Zhang Date: Sat, 6 Sep 2025 06:01:01 +0800 Subject: [PATCH] refactor: improve handling when only masked_image_latents are provided --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 956f6fb10652..86d14833f006 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -681,7 +681,6 @@ def prepare_latents( # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: @@ -703,6 +702,7 @@ def prepare_latents( else: image_latents = torch.cat([image_latents], dim=0) + shape = (batch_size, num_channels_latents, height, width) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) @@ -773,7 +773,7 @@ def __call__( color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + masked_image_latents (`torch.Tensor`, `List[torch.Tensor]`): `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask latents tensor will be generated by `mask_image`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -865,8 +865,10 @@ def __call__( self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False - init_image = self.image_processor.preprocess(image, height=height, width=width) - init_image = init_image.to(dtype=torch.float32) + init_image = None + if image is not None: + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -924,7 +926,7 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels + num_channels_latents = self.vae.config.latent_channels if init_image is not None else None latents, latent_image_ids = self.prepare_latents( init_image, latent_timestep,