diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 1816b3ca6d9b..546a225aa999 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -224,11 +224,13 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=latent_channels, + vae_latent_channels=self.latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -493,10 +495,38 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + def check_inputs( self, prompt, prompt_2, + strength, height, width, prompt_embeds=None, @@ -507,6 +537,9 @@ def check_inputs( mask_image=None, masked_image_latents=None, ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" @@ -624,9 +657,11 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents def prepare_latents( self, + image, + timestep, batch_size, num_channels_latents, height, @@ -636,28 +671,41 @@ def prepare_latents( generator, latents=None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids - if isinstance(generator, list) and len(generator) != batch_size: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) + else: + image_latents = torch.cat([image_latents], dim=0) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents, latent_image_ids @property @@ -687,6 +735,7 @@ def __call__( masked_image_latents: Optional[torch.FloatTensor] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1.0, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: float = 30.0, @@ -731,6 +780,12 @@ def __call__( The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -794,6 +849,7 @@ def __call__( self.check_inputs( prompt, prompt_2, + strength, height, width, prompt_embeds=prompt_embeds, @@ -809,6 +865,9 @@ def __call__( self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -838,9 +897,37 @@ def __call__( lora_scale=lora_scale, ) - # 4. Prepare latent variables + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -851,17 +938,16 @@ def __call__( latents, ) - # 5. Prepare mask and masked image latents + # 6. Prepare mask and masked image latents if masked_image_latents is not None: masked_image_latents = masked_image_latents.to(latents.device) else: - image = self.image_processor.preprocess(image, height=height, width=width) mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - masked_image = image * (1 - mask_image) + masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) - height, width = image.shape[-2:] + height, width = init_image.shape[-2:] mask, masked_image_latents = self.prepare_mask_latents( mask_image, masked_image, @@ -876,23 +962,6 @@ def __call__( ) masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) - # 6. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps)