-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) #10603
base: main
Are you sure you want to change the base?
Changes from 1 commit
14c452a
a7e1501
cf60e52
25fa97c
5d6b78c
3a1ea2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,10 +225,9 @@ def __init__( | |
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible | ||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this | ||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) | ||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 | ||
self.mask_processor = VaeImageProcessor( | ||
vae_scale_factor=self.vae_scale_factor * 2, | ||
vae_latent_channels=latent_channels, | ||
vae_latent_channels=self.vae.config.latent_channels, | ||
do_normalize=False, | ||
do_binarize=True, | ||
do_convert_grayscale=True, | ||
|
@@ -493,10 +492,40 @@ def encode_prompt( | |
|
||
return prompt_embeds, pooled_prompt_embeds, text_ids | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image | ||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): | ||
if isinstance(generator, list): | ||
image_latents = [ | ||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) | ||
for i in range(image.shape[0]) | ||
] | ||
image_latents = torch.cat(image_latents, dim=0) | ||
else: | ||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator) | ||
|
||
image_latents = ( | ||
image_latents - self.vae.config.shift_factor | ||
) * self.vae.config.scaling_factor | ||
|
||
return image_latents | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps | ||
def get_timesteps(self, num_inference_steps, strength, device): | ||
# get the original timestep using init_timestep | ||
init_timestep = min(num_inference_steps * strength, num_inference_steps) | ||
|
||
t_start = int(max(num_inference_steps - init_timestep, 0)) | ||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] | ||
if hasattr(self.scheduler, "set_begin_index"): | ||
self.scheduler.set_begin_index(t_start * self.scheduler.order) | ||
|
||
return timesteps, num_inference_steps - t_start | ||
|
||
def check_inputs( | ||
self, | ||
prompt, | ||
prompt_2, | ||
strength, | ||
height, | ||
width, | ||
prompt_embeds=None, | ||
|
@@ -507,6 +536,9 @@ def check_inputs( | |
mask_image=None, | ||
masked_image_latents=None, | ||
): | ||
if strength < 0 or strength > 1: | ||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") | ||
|
||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: | ||
logger.warning( | ||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" | ||
|
@@ -627,6 +659,8 @@ def disable_vae_tiling(self): | |
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we copy from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure , That would be more clean. Thanks for the review :) |
||
def prepare_latents( | ||
self, | ||
image, | ||
timestep, | ||
batch_size, | ||
num_channels_latents, | ||
height, | ||
|
@@ -643,22 +677,37 @@ def prepare_latents( | |
|
||
shape = (batch_size, num_channels_latents, height, width) | ||
|
||
if latents is not None: | ||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | ||
return latents.to(device=device, dtype=dtype), latent_image_ids | ||
# if latents is not None: | ||
image = image.to(device=device, dtype=dtype) | ||
image_latents = self._encode_vae_image(image=image, generator=generator) | ||
|
||
if isinstance(generator, list) and len(generator) != batch_size: | ||
latent_image_ids = self._prepare_latent_image_ids( | ||
batch_size, height // 2, width // 2, device, dtype | ||
) | ||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | ||
# expand init_latents for batch_size | ||
additional_image_per_prompt = batch_size // image_latents.shape[0] | ||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | ||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | ||
raise ValueError( | ||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | ||
) | ||
else: | ||
image_latents = torch.cat([image_latents], dim=0) | ||
|
||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | ||
|
||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) | ||
if latents is None: | ||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
latents = self.scheduler.scale_noise(image_latents, timestep, noise) | ||
else: | ||
noise = latents.to(device) | ||
latents = noise | ||
|
||
return latents, latent_image_ids | ||
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) | ||
image_latents = self._pack_latents( | ||
image_latents, batch_size, num_channels_latents, height, width | ||
) | ||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | ||
return latents, noise, image_latents, latent_image_ids | ||
|
||
@property | ||
def guidance_scale(self): | ||
|
@@ -687,6 +736,7 @@ def __call__( | |
masked_image_latents: Optional[torch.FloatTensor] = None, | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
strength: float = 1.0, | ||
num_inference_steps: int = 50, | ||
sigmas: Optional[List[float]] = None, | ||
guidance_scale: float = 30.0, | ||
|
@@ -731,6 +781,12 @@ def __call__( | |
The height in pixels of the generated image. This is set to 1024 by default for the best results. | ||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
The width in pixels of the generated image. This is set to 1024 by default for the best results. | ||
strength (`float`, *optional*, defaults to 1.0): | ||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | ||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends | ||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising | ||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 | ||
essentially ignores `image`. | ||
num_inference_steps (`int`, *optional*, defaults to 50): | ||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
expense of slower inference. | ||
|
@@ -794,6 +850,7 @@ def __call__( | |
self.check_inputs( | ||
prompt, | ||
prompt_2, | ||
strength, | ||
height, | ||
width, | ||
prompt_embeds=prompt_embeds, | ||
|
@@ -809,6 +866,10 @@ def __call__( | |
self._joint_attention_kwargs = joint_attention_kwargs | ||
self._interrupt = False | ||
|
||
original_image = image | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep :) Maby that came from sdxl inpaint pipeline, but it is not used in this pipeline |
||
init_image = self.image_processor.preprocess(image, height=height, width=width) | ||
init_image = init_image.to(dtype=torch.float32) | ||
|
||
# 2. Define call parameters | ||
if prompt is not None and isinstance(prompt, str): | ||
batch_size = 1 | ||
|
@@ -821,7 +882,9 @@ def __call__( | |
|
||
# 3. Prepare prompt embeddings | ||
lora_scale = ( | ||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None | ||
self.joint_attention_kwargs.get("scale", None) | ||
if self.joint_attention_kwargs is not None | ||
else None | ||
) | ||
( | ||
prompt_embeds, | ||
|
@@ -838,9 +901,43 @@ def __call__( | |
lora_scale=lora_scale, | ||
) | ||
|
||
# 6. Prepare timesteps | ||
Suprhimp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sigmas = ( | ||
np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | ||
if sigmas is None | ||
else sigmas | ||
) | ||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * ( | ||
int(width) // self.vae_scale_factor // 2 | ||
) | ||
mu = calculate_shift( | ||
image_seq_len, | ||
self.scheduler.config.base_image_seq_len, | ||
self.scheduler.config.max_image_seq_len, | ||
self.scheduler.config.base_shift, | ||
self.scheduler.config.max_shift, | ||
) | ||
timesteps, num_inference_steps = retrieve_timesteps( | ||
self.scheduler, | ||
num_inference_steps, | ||
device, | ||
sigmas=sigmas, | ||
mu=mu, | ||
) | ||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | ||
|
||
if num_inference_steps < 1: | ||
raise ValueError( | ||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" | ||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." | ||
) | ||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
|
||
# 4. Prepare latent variables | ||
Suprhimp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_channels_latents = self.vae.config.latent_channels | ||
latents, latent_image_ids = self.prepare_latents( | ||
latents, noise, image_latents, latent_image_ids = self.prepare_latents( | ||
init_image, | ||
latent_timestep, | ||
batch_size * num_images_per_prompt, | ||
num_channels_latents, | ||
height, | ||
|
@@ -855,13 +952,13 @@ def __call__( | |
if masked_image_latents is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above |
||
masked_image_latents = masked_image_latents.to(latents.device) | ||
else: | ||
image = self.image_processor.preprocess(image, height=height, width=width) | ||
# image = self.image_processor.preprocess(image, height=height, width=width) | ||
Suprhimp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) | ||
|
||
masked_image = image * (1 - mask_image) | ||
masked_image = init_image * (1 - mask_image) | ||
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) | ||
|
||
height, width = image.shape[-2:] | ||
height, width = init_image.shape[-2:] | ||
mask, masked_image_latents = self.prepare_mask_latents( | ||
mask_image, | ||
masked_image, | ||
|
@@ -876,23 +973,6 @@ def __call__( | |
) | ||
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) | ||
|
||
# 6. Prepare timesteps | ||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | ||
image_seq_len = latents.shape[1] | ||
mu = calculate_shift( | ||
image_seq_len, | ||
self.scheduler.config.get("base_image_seq_len", 256), | ||
self.scheduler.config.get("max_image_seq_len", 4096), | ||
self.scheduler.config.get("base_shift", 0.5), | ||
self.scheduler.config.get("max_shift", 1.16), | ||
) | ||
timesteps, num_inference_steps = retrieve_timesteps( | ||
self.scheduler, | ||
num_inference_steps, | ||
device, | ||
sigmas=sigmas, | ||
mu=mu, | ||
) | ||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | ||
self._num_timesteps = len(timesteps) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to use
latent_channels
here not the config directly. This allows pipelines to be used without the component e.g.FluxFillPipeline(vae=None, ...)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that, Thanks, I changed. :)