2323from ...loaders import StableDiffusionLoraLoaderMixin , TextualInversionLoaderMixin
2424from ...models import AutoencoderKL , PriorTransformer , UNet2DConditionModel
2525from ...models .embeddings import get_timestep_embedding
26- from ...models .lora import adjust_lora_scale_text_encoder
2726from ...schedulers import KarrasDiffusionSchedulers
2827from ...utils import (
29- USE_PEFT_BACKEND ,
30- deprecate ,
3128 is_torch_xla_available ,
3229 logging ,
3330 replace_example_docstring ,
34- scale_lora_layers ,
35- unscale_lora_layers ,
3631)
3732from ...utils .torch_utils import randn_tensor
38- from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput , StableDiffusionMixin
33+ from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
34+ from .pipeline_stable_diffusion_utils import StableDiffusionMixin
3935from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
4036
4137
@@ -260,235 +256,7 @@ def _encode_prior_prompt(
260256
261257 return prompt_embeds , text_enc_hid_states , text_mask
262258
263- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
264- def _encode_prompt (
265- self ,
266- prompt ,
267- device ,
268- num_images_per_prompt ,
269- do_classifier_free_guidance ,
270- negative_prompt = None ,
271- prompt_embeds : Optional [torch .Tensor ] = None ,
272- negative_prompt_embeds : Optional [torch .Tensor ] = None ,
273- lora_scale : Optional [float ] = None ,
274- ** kwargs ,
275- ):
276- deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
277- deprecate ("_encode_prompt()" , "1.0.0" , deprecation_message , standard_warn = False )
278-
279- prompt_embeds_tuple = self .encode_prompt (
280- prompt = prompt ,
281- device = device ,
282- num_images_per_prompt = num_images_per_prompt ,
283- do_classifier_free_guidance = do_classifier_free_guidance ,
284- negative_prompt = negative_prompt ,
285- prompt_embeds = prompt_embeds ,
286- negative_prompt_embeds = negative_prompt_embeds ,
287- lora_scale = lora_scale ,
288- ** kwargs ,
289- )
290-
291- # concatenate for backwards comp
292- prompt_embeds = torch .cat ([prompt_embeds_tuple [1 ], prompt_embeds_tuple [0 ]])
293-
294- return prompt_embeds
295-
296- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
297- def encode_prompt (
298- self ,
299- prompt ,
300- device ,
301- num_images_per_prompt ,
302- do_classifier_free_guidance ,
303- negative_prompt = None ,
304- prompt_embeds : Optional [torch .Tensor ] = None ,
305- negative_prompt_embeds : Optional [torch .Tensor ] = None ,
306- lora_scale : Optional [float ] = None ,
307- clip_skip : Optional [int ] = None ,
308- ):
309- r"""
310- Encodes the prompt into text encoder hidden states.
311-
312- Args:
313- prompt (`str` or `List[str]`, *optional*):
314- prompt to be encoded
315- device: (`torch.device`):
316- torch device
317- num_images_per_prompt (`int`):
318- number of images that should be generated per prompt
319- do_classifier_free_guidance (`bool`):
320- whether to use classifier free guidance or not
321- negative_prompt (`str` or `List[str]`, *optional*):
322- The prompt or prompts not to guide the image generation. If not defined, one has to pass
323- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
324- less than `1`).
325- prompt_embeds (`torch.Tensor`, *optional*):
326- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
327- provided, text embeddings will be generated from `prompt` input argument.
328- negative_prompt_embeds (`torch.Tensor`, *optional*):
329- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
330- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
331- argument.
332- lora_scale (`float`, *optional*):
333- A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
334- clip_skip (`int`, *optional*):
335- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
336- the output of the pre-final layer will be used for computing the prompt embeddings.
337- """
338- # set lora scale so that monkey patched LoRA
339- # function of text encoder can correctly access it
340- if lora_scale is not None and isinstance (self , StableDiffusionLoraLoaderMixin ):
341- self ._lora_scale = lora_scale
342-
343- # dynamically adjust the LoRA scale
344- if not USE_PEFT_BACKEND :
345- adjust_lora_scale_text_encoder (self .text_encoder , lora_scale )
346- else :
347- scale_lora_layers (self .text_encoder , lora_scale )
348-
349- if prompt is not None and isinstance (prompt , str ):
350- batch_size = 1
351- elif prompt is not None and isinstance (prompt , list ):
352- batch_size = len (prompt )
353- else :
354- batch_size = prompt_embeds .shape [0 ]
355-
356- if prompt_embeds is None :
357- # textual inversion: process multi-vector tokens if necessary
358- if isinstance (self , TextualInversionLoaderMixin ):
359- prompt = self .maybe_convert_prompt (prompt , self .tokenizer )
360-
361- text_inputs = self .tokenizer (
362- prompt ,
363- padding = "max_length" ,
364- max_length = self .tokenizer .model_max_length ,
365- truncation = True ,
366- return_tensors = "pt" ,
367- )
368- text_input_ids = text_inputs .input_ids
369- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
370-
371- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
372- text_input_ids , untruncated_ids
373- ):
374- removed_text = self .tokenizer .batch_decode (
375- untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
376- )
377- logger .warning (
378- "The following part of your input was truncated because CLIP can only handle sequences up to"
379- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
380- )
381-
382- if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
383- attention_mask = text_inputs .attention_mask .to (device )
384- else :
385- attention_mask = None
386-
387- if clip_skip is None :
388- prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
389- prompt_embeds = prompt_embeds [0 ]
390- else :
391- prompt_embeds = self .text_encoder (
392- text_input_ids .to (device ), attention_mask = attention_mask , output_hidden_states = True
393- )
394- # Access the `hidden_states` first, that contains a tuple of
395- # all the hidden states from the encoder layers. Then index into
396- # the tuple to access the hidden states from the desired layer.
397- prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 1 )]
398- # We also need to apply the final LayerNorm here to not mess with the
399- # representations. The `last_hidden_states` that we typically use for
400- # obtaining the final prompt representations passes through the LayerNorm
401- # layer.
402- prompt_embeds = self .text_encoder .text_model .final_layer_norm (prompt_embeds )
403-
404- if self .text_encoder is not None :
405- prompt_embeds_dtype = self .text_encoder .dtype
406- elif self .unet is not None :
407- prompt_embeds_dtype = self .unet .dtype
408- else :
409- prompt_embeds_dtype = prompt_embeds .dtype
410-
411- prompt_embeds = prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
412-
413- bs_embed , seq_len , _ = prompt_embeds .shape
414- # duplicate text embeddings for each generation per prompt, using mps friendly method
415- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
416- prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
417-
418- # get unconditional embeddings for classifier free guidance
419- if do_classifier_free_guidance and negative_prompt_embeds is None :
420- uncond_tokens : List [str ]
421- if negative_prompt is None :
422- uncond_tokens = ["" ] * batch_size
423- elif prompt is not None and type (prompt ) is not type (negative_prompt ):
424- raise TypeError (
425- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
426- f" { type (prompt )} ."
427- )
428- elif isinstance (negative_prompt , str ):
429- uncond_tokens = [negative_prompt ]
430- elif batch_size != len (negative_prompt ):
431- raise ValueError (
432- f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
433- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
434- " the batch size of `prompt`."
435- )
436- else :
437- uncond_tokens = negative_prompt
438-
439- # textual inversion: process multi-vector tokens if necessary
440- if isinstance (self , TextualInversionLoaderMixin ):
441- uncond_tokens = self .maybe_convert_prompt (uncond_tokens , self .tokenizer )
442-
443- max_length = prompt_embeds .shape [1 ]
444- uncond_input = self .tokenizer (
445- uncond_tokens ,
446- padding = "max_length" ,
447- max_length = max_length ,
448- truncation = True ,
449- return_tensors = "pt" ,
450- )
451-
452- if hasattr (self .text_encoder .config , "use_attention_mask" ) and self .text_encoder .config .use_attention_mask :
453- attention_mask = uncond_input .attention_mask .to (device )
454- else :
455- attention_mask = None
456-
457- negative_prompt_embeds = self .text_encoder (
458- uncond_input .input_ids .to (device ),
459- attention_mask = attention_mask ,
460- )
461- negative_prompt_embeds = negative_prompt_embeds [0 ]
462-
463- if do_classifier_free_guidance :
464- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
465- seq_len = negative_prompt_embeds .shape [1 ]
466-
467- negative_prompt_embeds = negative_prompt_embeds .to (dtype = prompt_embeds_dtype , device = device )
468-
469- negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
470- negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
471-
472- if self .text_encoder is not None :
473- if isinstance (self , StableDiffusionLoraLoaderMixin ) and USE_PEFT_BACKEND :
474- # Retrieve the original scale by scaling back the LoRA layers
475- unscale_lora_layers (self .text_encoder , lora_scale )
476-
477- return prompt_embeds , negative_prompt_embeds
478-
479- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
480- def decode_latents (self , latents ):
481- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
482- deprecate ("decode_latents" , "1.0.0" , deprecation_message , standard_warn = False )
483-
484- latents = 1 / self .vae .config .scaling_factor * latents
485- image = self .vae .decode (latents , return_dict = False )[0 ]
486- image = (image / 2 + 0.5 ).clamp (0 , 1 )
487- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
488- image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
489- return image
490-
491- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
259+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.StableDiffusionMixin.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
492260 def prepare_prior_extra_step_kwargs (self , generator , eta ):
493261 # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature
494262 # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers.
@@ -506,24 +274,6 @@ def prepare_prior_extra_step_kwargs(self, generator, eta):
506274 extra_step_kwargs ["generator" ] = generator
507275 return extra_step_kwargs
508276
509- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
510- def prepare_extra_step_kwargs (self , generator , eta ):
511- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
512- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
513- # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
514- # and should be between [0, 1]
515-
516- accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
517- extra_step_kwargs = {}
518- if accepts_eta :
519- extra_step_kwargs ["eta" ] = eta
520-
521- # check if the scheduler accepts generator
522- accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
523- if accepts_generator :
524- extra_step_kwargs ["generator" ] = generator
525- return extra_step_kwargs
526-
527277 def check_inputs (
528278 self ,
529279 prompt ,
0 commit comments