Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) #10603

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 115 additions & 35 deletions src/diffusers/pipelines/flux/pipeline_flux_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,9 @@ def __init__(
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels,
vae_latent_channels=self.vae.config.latent_channels,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use latent_channels here not the config directly. This allows pipelines to be used without the component e.g. FluxFillPipeline(vae=None, ...).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that, Thanks, I changed. :)

do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
Expand Down Expand Up @@ -493,10 +492,40 @@ def encode_prompt(

return prompt_embeds, pooled_prompt_embeds, text_ids

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

image_latents = (
image_latents - self.vae.config.shift_factor
) * self.vae.config.scaling_factor

return image_latents

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)

t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

def check_inputs(
self,
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=None,
Expand All @@ -507,6 +536,9 @@ def check_inputs(
mask_image=None,
masked_image_latents=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
Expand Down Expand Up @@ -627,6 +659,8 @@ def disable_vae_tiling(self):
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we copy from FluxImg2ImgPipeline.prepare_latents or FluxInpaintPipeline.prepare_latents?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure , That would be more clean. Thanks for the review :)

def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
Expand All @@ -643,22 +677,37 @@ def prepare_latents(

shape = (batch_size, num_channels_latents, height, width)

if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
# if latents is not None:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)

if isinstance(generator, list) and len(generator) != batch_size:
latent_image_ids = self._prepare_latent_image_ids(
batch_size, height // 2, width // 2, device, dtype
)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)

latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise

return latents, latent_image_ids
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, height, width
)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids

@property
def guidance_scale(self):
Expand Down Expand Up @@ -687,6 +736,7 @@ def __call__(
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
Expand Down Expand Up @@ -731,6 +781,12 @@ def __call__(
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
Expand Down Expand Up @@ -794,6 +850,7 @@ def __call__(
self.check_inputs(
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=prompt_embeds,
Expand All @@ -809,6 +866,10 @@ def __call__(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False

original_image = image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep :) Maby that came from sdxl inpaint pipeline, but it is not used in this pipeline

init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -821,7 +882,9 @@ def __call__(

# 3. Prepare prompt embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
self.joint_attention_kwargs.get("scale", None)
if self.joint_attention_kwargs is not None
else None
)
(
prompt_embeds,
Expand All @@ -838,9 +901,43 @@ def __call__(
lora_scale=lora_scale,
)

# 6. Prepare timesteps
sigmas = (
np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
if sigmas is None
else sigmas
)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (
int(width) // self.vae_scale_factor // 2
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)

if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

# 4. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
Expand All @@ -855,13 +952,13 @@ def __call__(
if masked_image_latents is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above # 6. Prepare mask and masked image latents.

masked_image_latents = masked_image_latents.to(latents.device)
else:
image = self.image_processor.preprocess(image, height=height, width=width)
# image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)

masked_image = image * (1 - mask_image)
masked_image = init_image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)

height, width = image.shape[-2:]
height, width = init_image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
Expand All @@ -876,23 +973,6 @@ def __call__(
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)

# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

Expand Down