Skip to content

Commit 2d4f144

Browse files
committed
up
1 parent 41c5921 commit 2d4f144

File tree

2 files changed

+4
-257
lines changed

2 files changed

+4
-257
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ def upcast_vae(components):
8181
components.vae.to(dtype=torch.float32)
8282
use_torch_2_0_or_xformers = isinstance(
8383
components.vae.decoder.mid_block.attentions[0].processor,
84-
(
85-
AttnProcessor2_0,
86-
XFormersAttnProcessor,
87-
),
84+
(AttnProcessor2_0, XFormersAttnProcessor),
8885
)
8986
# if xformers or torch_2_0 is used attention block does not need
9087
# to be in float32 which can save lots of memory

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py

Lines changed: 3 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,15 @@
2323
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
2424
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
2525
from ...models.embeddings import get_timestep_embedding
26-
from ...models.lora import adjust_lora_scale_text_encoder
2726
from ...schedulers import KarrasDiffusionSchedulers
2827
from ...utils import (
29-
USE_PEFT_BACKEND,
30-
deprecate,
3128
is_torch_xla_available,
3229
logging,
3330
replace_example_docstring,
34-
scale_lora_layers,
35-
unscale_lora_layers,
3631
)
3732
from ...utils.torch_utils import randn_tensor
38-
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
33+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
34+
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
3935
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
4036

4137

@@ -260,235 +256,7 @@ def _encode_prior_prompt(
260256

261257
return prompt_embeds, text_enc_hid_states, text_mask
262258

263-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
264-
def _encode_prompt(
265-
self,
266-
prompt,
267-
device,
268-
num_images_per_prompt,
269-
do_classifier_free_guidance,
270-
negative_prompt=None,
271-
prompt_embeds: Optional[torch.Tensor] = None,
272-
negative_prompt_embeds: Optional[torch.Tensor] = None,
273-
lora_scale: Optional[float] = None,
274-
**kwargs,
275-
):
276-
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
277-
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
278-
279-
prompt_embeds_tuple = self.encode_prompt(
280-
prompt=prompt,
281-
device=device,
282-
num_images_per_prompt=num_images_per_prompt,
283-
do_classifier_free_guidance=do_classifier_free_guidance,
284-
negative_prompt=negative_prompt,
285-
prompt_embeds=prompt_embeds,
286-
negative_prompt_embeds=negative_prompt_embeds,
287-
lora_scale=lora_scale,
288-
**kwargs,
289-
)
290-
291-
# concatenate for backwards comp
292-
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
293-
294-
return prompt_embeds
295-
296-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
297-
def encode_prompt(
298-
self,
299-
prompt,
300-
device,
301-
num_images_per_prompt,
302-
do_classifier_free_guidance,
303-
negative_prompt=None,
304-
prompt_embeds: Optional[torch.Tensor] = None,
305-
negative_prompt_embeds: Optional[torch.Tensor] = None,
306-
lora_scale: Optional[float] = None,
307-
clip_skip: Optional[int] = None,
308-
):
309-
r"""
310-
Encodes the prompt into text encoder hidden states.
311-
312-
Args:
313-
prompt (`str` or `List[str]`, *optional*):
314-
prompt to be encoded
315-
device: (`torch.device`):
316-
torch device
317-
num_images_per_prompt (`int`):
318-
number of images that should be generated per prompt
319-
do_classifier_free_guidance (`bool`):
320-
whether to use classifier free guidance or not
321-
negative_prompt (`str` or `List[str]`, *optional*):
322-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
323-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
324-
less than `1`).
325-
prompt_embeds (`torch.Tensor`, *optional*):
326-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
327-
provided, text embeddings will be generated from `prompt` input argument.
328-
negative_prompt_embeds (`torch.Tensor`, *optional*):
329-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
330-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
331-
argument.
332-
lora_scale (`float`, *optional*):
333-
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
334-
clip_skip (`int`, *optional*):
335-
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
336-
the output of the pre-final layer will be used for computing the prompt embeddings.
337-
"""
338-
# set lora scale so that monkey patched LoRA
339-
# function of text encoder can correctly access it
340-
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
341-
self._lora_scale = lora_scale
342-
343-
# dynamically adjust the LoRA scale
344-
if not USE_PEFT_BACKEND:
345-
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
346-
else:
347-
scale_lora_layers(self.text_encoder, lora_scale)
348-
349-
if prompt is not None and isinstance(prompt, str):
350-
batch_size = 1
351-
elif prompt is not None and isinstance(prompt, list):
352-
batch_size = len(prompt)
353-
else:
354-
batch_size = prompt_embeds.shape[0]
355-
356-
if prompt_embeds is None:
357-
# textual inversion: process multi-vector tokens if necessary
358-
if isinstance(self, TextualInversionLoaderMixin):
359-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
360-
361-
text_inputs = self.tokenizer(
362-
prompt,
363-
padding="max_length",
364-
max_length=self.tokenizer.model_max_length,
365-
truncation=True,
366-
return_tensors="pt",
367-
)
368-
text_input_ids = text_inputs.input_ids
369-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
370-
371-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
372-
text_input_ids, untruncated_ids
373-
):
374-
removed_text = self.tokenizer.batch_decode(
375-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
376-
)
377-
logger.warning(
378-
"The following part of your input was truncated because CLIP can only handle sequences up to"
379-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
380-
)
381-
382-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
383-
attention_mask = text_inputs.attention_mask.to(device)
384-
else:
385-
attention_mask = None
386-
387-
if clip_skip is None:
388-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
389-
prompt_embeds = prompt_embeds[0]
390-
else:
391-
prompt_embeds = self.text_encoder(
392-
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
393-
)
394-
# Access the `hidden_states` first, that contains a tuple of
395-
# all the hidden states from the encoder layers. Then index into
396-
# the tuple to access the hidden states from the desired layer.
397-
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
398-
# We also need to apply the final LayerNorm here to not mess with the
399-
# representations. The `last_hidden_states` that we typically use for
400-
# obtaining the final prompt representations passes through the LayerNorm
401-
# layer.
402-
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
403-
404-
if self.text_encoder is not None:
405-
prompt_embeds_dtype = self.text_encoder.dtype
406-
elif self.unet is not None:
407-
prompt_embeds_dtype = self.unet.dtype
408-
else:
409-
prompt_embeds_dtype = prompt_embeds.dtype
410-
411-
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
412-
413-
bs_embed, seq_len, _ = prompt_embeds.shape
414-
# duplicate text embeddings for each generation per prompt, using mps friendly method
415-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
416-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
417-
418-
# get unconditional embeddings for classifier free guidance
419-
if do_classifier_free_guidance and negative_prompt_embeds is None:
420-
uncond_tokens: List[str]
421-
if negative_prompt is None:
422-
uncond_tokens = [""] * batch_size
423-
elif prompt is not None and type(prompt) is not type(negative_prompt):
424-
raise TypeError(
425-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
426-
f" {type(prompt)}."
427-
)
428-
elif isinstance(negative_prompt, str):
429-
uncond_tokens = [negative_prompt]
430-
elif batch_size != len(negative_prompt):
431-
raise ValueError(
432-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
433-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
434-
" the batch size of `prompt`."
435-
)
436-
else:
437-
uncond_tokens = negative_prompt
438-
439-
# textual inversion: process multi-vector tokens if necessary
440-
if isinstance(self, TextualInversionLoaderMixin):
441-
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
442-
443-
max_length = prompt_embeds.shape[1]
444-
uncond_input = self.tokenizer(
445-
uncond_tokens,
446-
padding="max_length",
447-
max_length=max_length,
448-
truncation=True,
449-
return_tensors="pt",
450-
)
451-
452-
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
453-
attention_mask = uncond_input.attention_mask.to(device)
454-
else:
455-
attention_mask = None
456-
457-
negative_prompt_embeds = self.text_encoder(
458-
uncond_input.input_ids.to(device),
459-
attention_mask=attention_mask,
460-
)
461-
negative_prompt_embeds = negative_prompt_embeds[0]
462-
463-
if do_classifier_free_guidance:
464-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
465-
seq_len = negative_prompt_embeds.shape[1]
466-
467-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
468-
469-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
470-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
471-
472-
if self.text_encoder is not None:
473-
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
474-
# Retrieve the original scale by scaling back the LoRA layers
475-
unscale_lora_layers(self.text_encoder, lora_scale)
476-
477-
return prompt_embeds, negative_prompt_embeds
478-
479-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
480-
def decode_latents(self, latents):
481-
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
482-
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
483-
484-
latents = 1 / self.vae.config.scaling_factor * latents
485-
image = self.vae.decode(latents, return_dict=False)[0]
486-
image = (image / 2 + 0.5).clamp(0, 1)
487-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
488-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
489-
return image
490-
491-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
259+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.StableDiffusionMixin.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
492260
def prepare_prior_extra_step_kwargs(self, generator, eta):
493261
# prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature
494262
# eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers.
@@ -506,24 +274,6 @@ def prepare_prior_extra_step_kwargs(self, generator, eta):
506274
extra_step_kwargs["generator"] = generator
507275
return extra_step_kwargs
508276

509-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
510-
def prepare_extra_step_kwargs(self, generator, eta):
511-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
512-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
513-
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
514-
# and should be between [0, 1]
515-
516-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
517-
extra_step_kwargs = {}
518-
if accepts_eta:
519-
extra_step_kwargs["eta"] = eta
520-
521-
# check if the scheduler accepts generator
522-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
523-
if accepts_generator:
524-
extra_step_kwargs["generator"] = generator
525-
return extra_step_kwargs
526-
527277
def check_inputs(
528278
self,
529279
prompt,

0 commit comments

Comments
 (0)