Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/diffusers/pipelines/flux/pipeline_flux_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,6 @@ def prepare_latents(
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have an script on how to run this with masked_image_latents? it seems if image is None it won't work here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why move this code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My code changes are based on Sayak Paul’s gist.
Because of the RTX 3090’s 24GB VRAM limitation, I preprocess the input image and mask into masked_image_latents before feeding them into the transformer.

latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

if latents is not None:
Expand All @@ -703,6 +702,7 @@ def prepare_latents(
else:
image_latents = torch.cat([image_latents], dim=0)

shape = (batch_size, num_channels_latents, height, width)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
Expand Down Expand Up @@ -773,7 +773,7 @@ def __call__(
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
masked_image_latents (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -865,8 +865,10 @@ def __call__(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False

init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
init_image = None
if image is not None:
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -924,7 +926,7 @@ def __call__(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_latents = self.vae.config.latent_channels if init_image is not None else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why setting it to None here?

Copy link
Contributor Author

@Men1scus Men1scus Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the latents parameter in self.prepare_latents is not None:

if latents is not None:
    return latents.to(device=device, dtype=dtype), latent_image_ids

it will return early, and num_channels_latents will not be used,
because during denoising my VAE has already been deleted, so I cannot get it through self.vae.config.latent_channels.

latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
Expand Down