diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index 946a8b1af875..260762f36d50 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] ## IPAdapterMixin -[[autodoc]] loaders.ip_adapter.IPAdapterMixin +[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin ## SD3IPAdapterMixin -[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin +[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin - all - is_ip_adapter_active diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 1c716f6d5e85..6b14c0b25869 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## StableDiffusionLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin ## StableDiffusionXLLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin ## SD3LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin ## FluxLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin ## CogVideoXLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin ## Mochi1LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin ## AuraFlowLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin ## LTXVideoLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin ## SanaLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin ## HunyuanVideoLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin ## Lumina2LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin + +## WanLoraLoaderMixin + +[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin + +## CogView4LoraLoaderMixin + +[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin ## CogView4LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin ## WanLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin ## AmusedLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin ## HiDreamImageLoraLoaderMixin @@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## LoraBaseMixin -[[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file +[[autodoc]] loaders.lora.lora_base.LoraBaseMixin \ No newline at end of file diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md index 4fc9603054b4..cb4f9bdf37b6 100644 --- a/docs/source/en/api/loaders/transformer_sd3.md +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # SD3Transformer2D -This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. @@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## SD3Transformer2DLoadersMixin -[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin +[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin - all - _load_ip_adapter_weights \ No newline at end of file diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 84c6d9f32c66..9b0dd6ea984b 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -54,14 +54,14 @@ def text_encoder_attn_modules(text_encoder): _import_structure = {} if is_torch_available(): - _import_structure["single_file_model"] = ["FromOriginalModelMixin"] - _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] - _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] - _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] + _import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"] + _import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] + _import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): - _import_structure["single_file"] = ["FromSingleFileMixin"] - _import_structure["lora_pipeline"] = [ + _import_structure["single_file.single_file"] = ["FromSingleFileMixin"] + _import_structure["lora.lora_pipeline"] = [ "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", @@ -80,7 +80,7 @@ def text_encoder_attn_modules(text_encoder): "HiDreamImageLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = [ + _import_structure["ip_adapter.ip_adapter"] = [ "IPAdapterMixin", "FluxIPAdapterMixin", "SD3IPAdapterMixin", @@ -91,19 +91,14 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from .single_file_model import FromOriginalModelMixin - from .transformer_flux import FluxTransformer2DLoadersMixin - from .transformer_sd3 import SD3Transformer2DLoadersMixin + from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin + from .single_file import FromOriginalModelMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import ( - FluxIPAdapterMixin, - IPAdapterMixin, - SD3IPAdapterMixin, - ) - from .lora_pipeline import ( + from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin + from .lora import ( AmusedLoraLoaderMixin, AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, @@ -111,6 +106,7 @@ def text_encoder_attn_modules(text_encoder): FluxLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + LoraBaseMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index f4c48f254c44..07043114187e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -12,868 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -from typing import Dict, List, Optional, Union -import torch -import torch.nn.functional as F -from huggingface_hub.utils import validate_hf_hub_args -from safetensors import safe_open +from ..utils import deprecate +from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict -from ..utils import ( - USE_PEFT_BACKEND, - _get_detailed_type, - _get_model_file, - _is_valid_type, - is_accelerate_available, - is_torch_version, - is_transformers_available, - logging, -) -from .unet_loader_utils import _maybe_expand_lora_scales +class IPAdapterMixin(IPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import IPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.IPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) -if is_transformers_available(): - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel -from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - FluxAttnProcessor2_0, - FluxIPAdapterJointAttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, - JointAttnProcessor2_0, - SD3IPAdapterJointAttnProcessor2_0, -) +class FluxIPAdapterMixin(FluxIPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxIPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import FluxIPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.FluxIPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) -logger = logging.get_logger(__name__) - - -class IPAdapterMixin: - """Mixin for handling IP Adapters.""" - - @validate_hf_hub_args - def load_ip_adapter( - self, - pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], - subfolder: Union[str, List[str]], - weight_name: Union[str, List[str]], - image_encoder_folder: Optional[str] = "image_encoder", - **kwargs, - ): - """ - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - subfolder (`str` or `List[str]`): - The subfolder location of a model file within a larger model repository on the Hub or locally. If a - list is passed, it should have the same length as `weight_name`. - weight_name (`str` or `List[str]`): - The name of the weight file to load. If a list is passed, it should have the same length as - `subfolder`. - image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): - The subfolder location of the image encoder within a larger model repository on the Hub or locally. - Pass `None` to not load the image encoder. If the image encoder is located in a folder inside - `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. - `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than - `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, - `image_encoder_folder="different_subfolder/image_encoder"`. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - - # handle the list inputs for multiple IP Adapters - if not isinstance(weight_name, list): - weight_name = [weight_name] - - if not isinstance(pretrained_model_name_or_path_or_dict, list): - pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] - if len(pretrained_model_name_or_path_or_dict) == 1: - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) - - if not isinstance(subfolder, list): - subfolder = [subfolder] - if len(subfolder) == 1: - subfolder = subfolder * len(weight_name) - - if len(weight_name) != len(pretrained_model_name_or_path_or_dict): - raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") - - if len(weight_name) != len(subfolder): - raise ValueError("`weight_name` and `subfolder` must have the same length.") - - # Load the main state dict first. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - state_dicts = [] - for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( - pretrained_model_name_or_path_or_dict, weight_name, subfolder - ): - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - keys = list(state_dict.keys()) - if "image_proj" not in keys and "ip_adapter" not in keys: - raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - - state_dicts.append(state_dict) - - # load CLIP image encoder here if it has not been registered to the pipeline yet - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: - if image_encoder_folder is not None: - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") - if image_encoder_folder.count("/") == 0: - image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() - else: - image_encoder_subfolder = Path(image_encoder_folder).as_posix() - - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - pretrained_model_name_or_path_or_dict, - subfolder=image_encoder_subfolder, - low_cpu_mem_usage=low_cpu_mem_usage, - cache_dir=cache_dir, - local_files_only=local_files_only, - torch_dtype=self.dtype, - ).to(self.device) - self.register_modules(image_encoder=image_encoder) - else: - raise ValueError( - "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." - ) - else: - logger.warning( - "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." - "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." - ) - - # create feature extractor if it has not been registered to the pipeline yet - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: - # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 - default_clip_size = 224 - clip_image_size = ( - self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size - ) - feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) - self.register_modules(feature_extractor=feature_extractor) - - # load ip-adapter into unet - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - - extra_loras = unet._load_ip_adapter_loras(state_dicts) - if extra_loras != {}: - if not USE_PEFT_BACKEND: - logger.warning("PEFT backend is required to load these weights.") - else: - # apply the IP Adapter Face ID LoRA weights - peft_config = getattr(unet, "peft_config", {}) - for k, lora in extra_loras.items(): - if f"faceid_{k}" not in peft_config: - self.load_lora_weights(lora, adapter_name=f"faceid_{k}") - self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) - - def set_ip_adapter_scale(self, scale): - """ - Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for - granular control over each IP-Adapter behavior. A config can be a float or a dictionary. - - Example: - - ```py - # To use original IP-Adapter - scale = 1.0 - pipeline.set_ip_adapter_scale(scale) - - # To use style block only - scale = { - "up": {"block_0": [0.0, 1.0, 0.0]}, - } - pipeline.set_ip_adapter_scale(scale) - - # To use style+layout blocks - scale = { - "down": {"block_2": [0.0, 1.0]}, - "up": {"block_0": [0.0, 1.0, 0.0]}, - } - pipeline.set_ip_adapter_scale(scale) - - # To use style and layout from 2 reference images - scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] - pipeline.set_ip_adapter_scale(scales) - ``` - """ - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - if not isinstance(scale, list): - scale = [scale] - scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) - - for attn_name, attn_processor in unet.attn_processors.items(): - if isinstance( - attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) - ): - if len(scale_configs) != len(attn_processor.scale): - raise ValueError( - f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter." - ) - elif len(scale_configs) == 1: - scale_configs = scale_configs * len(attn_processor.scale) - for i, scale_config in enumerate(scale_configs): - if isinstance(scale_config, dict): - for k, s in scale_config.items(): - if attn_name.startswith(k): - attn_processor.scale[i] = s - else: - attn_processor.scale[i] = scale_config - - def unload_ip_adapter(self): - """ - Unloads the IP Adapter weights - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. - >>> pipeline.unload_ip_adapter() - >>> ... - ``` - """ - # remove CLIP image encoder - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: - self.image_encoder = None - self.register_to_config(image_encoder=[None, None]) - - # remove feature extractor only when safety_checker is None as safety_checker uses - # the feature_extractor later - if not hasattr(self, "safety_checker"): - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: - self.feature_extractor = None - self.register_to_config(feature_extractor=[None, None]) - - # remove hidden encoder - self.unet.encoder_hid_proj = None - self.unet.config.encoder_hid_dim_type = None - - # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` - if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: - self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj - self.unet.text_encoder_hid_proj = None - self.unet.config.encoder_hid_dim_type = "text_proj" - - # restore original Unet attention processors layers - attn_procs = {} - for name, value in self.unet.attn_processors.items(): - attn_processor_class = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() - ) - attn_procs[name] = ( - attn_processor_class - if isinstance( - value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) - ) - else value.__class__() - ) - self.unet.set_attn_processor(attn_procs) - - -class FluxIPAdapterMixin: - """Mixin for handling Flux IP Adapters.""" - - @validate_hf_hub_args - def load_ip_adapter( - self, - pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], - weight_name: Union[str, List[str]], - subfolder: Optional[Union[str, List[str]]] = "", - image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder", - image_encoder_subfolder: Optional[str] = "", - image_encoder_dtype: torch.dtype = torch.float16, - **kwargs, - ): - """ - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - subfolder (`str` or `List[str]`): - The subfolder location of a model file within a larger model repository on the Hub or locally. If a - list is passed, it should have the same length as `weight_name`. - weight_name (`str` or `List[str]`): - The name of the weight file to load. If a list is passed, it should have the same length as - `weight_name`. - image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`): - Can be either: - - - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model - hosted on the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - - # handle the list inputs for multiple IP Adapters - if not isinstance(weight_name, list): - weight_name = [weight_name] - - if not isinstance(pretrained_model_name_or_path_or_dict, list): - pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] - if len(pretrained_model_name_or_path_or_dict) == 1: - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) - - if not isinstance(subfolder, list): - subfolder = [subfolder] - if len(subfolder) == 1: - subfolder = subfolder * len(weight_name) - - if len(weight_name) != len(pretrained_model_name_or_path_or_dict): - raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") - - if len(weight_name) != len(subfolder): - raise ValueError("`weight_name` and `subfolder` must have the same length.") - - # Load the main state dict first. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - state_dicts = [] - for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( - pretrained_model_name_or_path_or_dict, weight_name, subfolder - ): - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - image_proj_keys = ["ip_adapter_proj_model.", "image_proj."] - ip_adapter_keys = ["double_blocks.", "ip_adapter."] - for key in f.keys(): - if any(key.startswith(prefix) for prefix in image_proj_keys): - diffusers_name = ".".join(key.split(".")[1:]) - state_dict["image_proj"][diffusers_name] = f.get_tensor(key) - elif any(key.startswith(prefix) for prefix in ip_adapter_keys): - diffusers_name = ( - ".".join(key.split(".")[1:]) - .replace("ip_adapter_double_stream_k_proj", "to_k_ip") - .replace("ip_adapter_double_stream_v_proj", "to_v_ip") - .replace("processor.", "") - ) - state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key) - else: - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - keys = list(state_dict.keys()) - if keys != ["image_proj", "ip_adapter"]: - raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - - state_dicts.append(state_dict) - - # load CLIP image encoder here if it has not been registered to the pipeline yet - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: - if image_encoder_pretrained_model_name_or_path is not None: - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}") - image_encoder = ( - CLIPVisionModelWithProjection.from_pretrained( - image_encoder_pretrained_model_name_or_path, - subfolder=image_encoder_subfolder, - low_cpu_mem_usage=low_cpu_mem_usage, - cache_dir=cache_dir, - local_files_only=local_files_only, - torch_dtype=image_encoder_dtype, - ) - .to(self.device) - .eval() - ) - self.register_modules(image_encoder=image_encoder) - else: - raise ValueError( - "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." - ) - else: - logger.warning( - "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." - "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." - ) - - # create feature extractor if it has not been registered to the pipeline yet - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: - # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 - default_clip_size = 224 - clip_image_size = ( - self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size - ) - feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) - self.register_modules(feature_extractor=feature_extractor) - - # load ip-adapter into transformer - self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - - def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]): - """ - Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for - granular control over each IP-Adapter behavior. A config can be a float or a list. - - `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]` - length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the - number of IP adapters and each must match the number of blocks. - - Example: - - ```py - # To use original IP-Adapter - scale = 1.0 - pipeline.set_ip_adapter_scale(scale) - - - def LinearStrengthModel(start, finish, size): - return [(start + (finish - start) * (i / (size - 1))) for i in range(size)] - - - ip_strengths = LinearStrengthModel(0.3, 0.92, 19) - pipeline.set_ip_adapter_scale(ip_strengths) - ``` - """ - - scale_type = Union[int, float] - num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters - num_layers = self.transformer.config.num_layers - - # Single value for all layers of all IP-Adapters - if isinstance(scale, scale_type): - scale = [scale for _ in range(num_ip_adapters)] - # List of per-layer scales for a single IP-Adapter - elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1: - scale = [scale] - # Invalid scale type - elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]): - raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.") - - if len(scale) != num_ip_adapters: - raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.") - - if any(len(s) != num_layers for s in scale if isinstance(s, list)): - invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers} - raise ValueError( - f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}." - ) - - # Scalars are transformed to lists with length num_layers - scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale] - - # Set scales. zip over scale_configs prevents going into single transformer layers - for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs): - attn_processor.scale = scale - - def unload_ip_adapter(self): - """ - Unloads the IP Adapter weights - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. - >>> pipeline.unload_ip_adapter() - >>> ... - ``` - """ - # remove CLIP image encoder - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: - self.image_encoder = None - self.register_to_config(image_encoder=[None, None]) - - # remove feature extractor only when safety_checker is None as safety_checker uses - # the feature_extractor later - if not hasattr(self, "safety_checker"): - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: - self.feature_extractor = None - self.register_to_config(feature_extractor=[None, None]) - - # remove hidden encoder - self.transformer.encoder_hid_proj = None - self.transformer.config.encoder_hid_dim_type = None - - # restore original Transformer attention processors layers - attn_procs = {} - for name, value in self.transformer.attn_processors.items(): - attn_processor_class = FluxAttnProcessor2_0() - attn_procs[name] = ( - attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() - ) - self.transformer.set_attn_processor(attn_procs) - - -class SD3IPAdapterMixin: - """Mixin for handling StableDiffusion 3 IP Adapters.""" - - @property - def is_ip_adapter_active(self) -> bool: - """Checks if IP-Adapter is loaded and scale > 0. - - IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, - the image context is irrelevant. - - Returns: - `bool`: True when IP-Adapter is loaded and any layer has scale > 0. - """ - scales = [ - attn_proc.scale - for attn_proc in self.transformer.attn_processors.values() - if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) - ] - - return len(scales) > 0 and any(scale > 0 for scale in scales) - - @validate_hf_hub_args - def load_ip_adapter( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - weight_name: str = "ip-adapter.safetensors", - subfolder: Optional[str] = None, - image_encoder_folder: Optional[str] = "image_encoder", - **kwargs, - ) -> None: - """ - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - weight_name (`str`, defaults to "ip-adapter.safetensors"): - The name of the weight file to load. If a list is passed, it should have the same length as - `subfolder`. - subfolder (`str`, *optional*): - The subfolder location of a model file within a larger model repository on the Hub or locally. If a - list is passed, it should have the same length as `weight_name`. - image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): - The subfolder location of the image encoder within a larger model repository on the Hub or locally. - Pass `None` to not load the image encoder. If the image encoder is located in a folder inside - `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. - `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than - `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, - `image_encoder_folder="different_subfolder/image_encoder"`. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - # Load the main state dict first - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - keys = list(state_dict.keys()) - if "image_proj" not in keys and "ip_adapter" not in keys: - raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - - # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: - if image_encoder_folder is not None: - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") - if image_encoder_folder.count("/") == 0: - image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() - else: - image_encoder_subfolder = Path(image_encoder_folder).as_posix() - - # Commons args for loading image encoder and image processor - kwargs = { - "low_cpu_mem_usage": low_cpu_mem_usage, - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - - self.register_modules( - feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs), - image_encoder=SiglipVisionModel.from_pretrained( - image_encoder_subfolder, torch_dtype=self.dtype, **kwargs - ).to(self.device), - ) - else: - raise ValueError( - "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." - ) - else: - logger.warning( - "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." - "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." - ) - - # Load IP-Adapter into transformer - self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) - - def set_ip_adapter_scale(self, scale: float) -> None: - """ - Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only - conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages - the model to produce more diverse images, but they may not be as aligned with the image prompt. - - Example: - - ```python - >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. - >>> pipeline.set_ip_adapter_scale(0.6) - >>> ... - ``` - - Args: - scale (float): - IP-Adapter scale to be set. - - """ - for attn_processor in self.transformer.attn_processors.values(): - if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): - attn_processor.scale = scale - - def unload_ip_adapter(self) -> None: - """ - Unloads the IP Adapter weights. - - Example: - - ```python - >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. - >>> pipeline.unload_ip_adapter() - >>> ... - ``` - """ - # Remove image encoder - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: - self.image_encoder = None - self.register_to_config(image_encoder=None) - - # Remove feature extractor - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: - self.feature_extractor = None - self.register_to_config(feature_extractor=None) - - # Remove image projection - self.transformer.image_proj = None - - # Restore original attention processors layers - attn_procs = { - name: ( - JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() - ) - for name, value in self.transformer.attn_processors.items() - } - self.transformer.set_attn_processor(attn_procs) +class SD3IPAdapterMixin(SD3IPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import SD3IPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.SD3IPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/ip_adapter/__init__.py b/src/diffusers/loaders/ip_adapter/__init__.py new file mode 100644 index 000000000000..42b1d6e1b509 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter/__init__.py @@ -0,0 +1,9 @@ +from ...utils.import_utils import is_torch_available, is_transformers_available + + +if is_torch_available(): + from .transformer_flux import FluxTransformer2DLoadersMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin + + if is_transformers_available(): + from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin diff --git a/src/diffusers/loaders/ip_adapter/ip_adapter.py b/src/diffusers/loaders/ip_adapter/ip_adapter.py new file mode 100644 index 000000000000..782391d84a88 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter/ip_adapter.py @@ -0,0 +1,879 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from huggingface_hub.utils import validate_hf_hub_args +from safetensors import safe_open + +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict +from ...utils import ( + USE_PEFT_BACKEND, + _get_detailed_type, + _get_model_file, + _is_valid_type, + is_accelerate_available, + is_torch_version, + is_transformers_available, + logging, +) +from ..unet.unet_loader_utils import _maybe_expand_lora_scales + + +if is_transformers_available(): + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel + +from ...models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + FluxAttnProcessor2_0, + FluxIPAdapterJointAttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + JointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, +) + + +logger = logging.get_logger(__name__) + + +class IPAdapterMixin: + """Mixin for handling IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + subfolder: Union[str, List[str]], + weight_name: Union[str, List[str]], + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + torch_dtype=self.dtype, + ).to(self.device) + self.register_modules(image_encoder=image_encoder) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into unet + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + extra_loras = unet._load_ip_adapter_loras(state_dicts) + if extra_loras != {}: + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + + def set_ip_adapter_scale(self, scale): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a dictionary. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + # To use style block only + scale = { + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style+layout blocks + scale = { + "down": {"block_2": [0.0, 1.0]}, + "up": {"block_0": [0.0, 1.0, 0.0]}, + } + pipeline.set_ip_adapter_scale(scale) + + # To use style and layout from 2 reference images + scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] + pipeline.set_ip_adapter_scale(scales) + ``` + """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not isinstance(scale, list): + scale = [scale] + scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) + + for attn_name, attn_processor in unet.attn_processors.items(): + if isinstance( + attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ): + if len(scale_configs) != len(attn_processor.scale): + raise ValueError( + f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter." + ) + elif len(scale_configs) == 1: + scale_configs = scale_configs * len(attn_processor.scale) + for i, scale_config in enumerate(scale_configs): + if isinstance(scale_config, dict): + for k, s in scale_config.items(): + if attn_name.startswith(k): + attn_processor.scale[i] = s + else: + attn_processor.scale[i] = scale_config + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor only when safety_checker is None as safety_checker uses + # the feature_extractor later + if not hasattr(self, "safety_checker"): + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.unet.encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = None + + # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` + if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: + self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj + self.unet.text_encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = "text_proj" + + # restore original Unet attention processors layers + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + ) + attn_procs[name] = ( + attn_processor_class + if isinstance( + value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ) + else value.__class__() + ) + self.unet.set_attn_processor(attn_procs) + + +class FluxIPAdapterMixin: + """Mixin for handling Flux IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + weight_name: Union[str, List[str]], + subfolder: Optional[Union[str, List[str]]] = "", + image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder", + image_encoder_subfolder: Optional[str] = "", + image_encoder_dtype: torch.dtype = torch.float16, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str` or `List[str]`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str` or `List[str]`): + The name of the weight file to load. If a list is passed, it should have the same length as + `weight_name`. + image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`): + Can be either: + + - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + image_proj_keys = ["ip_adapter_proj_model.", "image_proj."] + ip_adapter_keys = ["double_blocks.", "ip_adapter."] + for key in f.keys(): + if any(key.startswith(prefix) for prefix in image_proj_keys): + diffusers_name = ".".join(key.split(".")[1:]) + state_dict["image_proj"][diffusers_name] = f.get_tensor(key) + elif any(key.startswith(prefix) for prefix in ip_adapter_keys): + diffusers_name = ( + ".".join(key.split(".")[1:]) + .replace("ip_adapter_double_stream_k_proj", "to_k_ip") + .replace("ip_adapter_double_stream_v_proj", "to_v_ip") + .replace("processor.", "") + ) + state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_pretrained_model_name_or_path is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}") + image_encoder = ( + CLIPVisionModelWithProjection.from_pretrained( + image_encoder_pretrained_model_name_or_path, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + dtype=image_encoder_dtype, + ) + .to(self.device) + .eval() + ) + self.register_modules(image_encoder=image_encoder) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 + default_clip_size = 224 + clip_image_size = ( + self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size + ) + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into transformer + self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]): + """ + Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for + granular control over each IP-Adapter behavior. A config can be a float or a list. + + `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]` + length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the + number of IP adapters and each must match the number of blocks. + + Example: + + ```py + # To use original IP-Adapter + scale = 1.0 + pipeline.set_ip_adapter_scale(scale) + + + def LinearStrengthModel(start, finish, size): + return [(start + (finish - start) * (i / (size - 1))) for i in range(size)] + + + ip_strengths = LinearStrengthModel(0.3, 0.92, 19) + pipeline.set_ip_adapter_scale(ip_strengths) + ``` + """ + + scale_type = Union[int, float] + num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters + num_layers = self.transformer.config.num_layers + + # Single value for all layers of all IP-Adapters + if isinstance(scale, scale_type): + scale = [scale for _ in range(num_ip_adapters)] + # List of per-layer scales for a single IP-Adapter + elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1: + scale = [scale] + # Invalid scale type + elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]): + raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.") + + if len(scale) != num_ip_adapters: + raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.") + + if any(len(s) != num_layers for s in scale if isinstance(s, list)): + invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers} + raise ValueError( + f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}." + ) + + # Scalars are transformed to lists with length num_layers + scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale] + + # Set scales. zip over scale_configs prevents going into single transformer layers + for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs): + attn_processor.scale = scale + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor only when safety_checker is None as safety_checker uses + # the feature_extractor later + if not hasattr(self, "safety_checker"): + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.transformer.encoder_hid_proj = None + self.transformer.config.encoder_hid_dim_type = None + + # restore original Transformer attention processors layers + attn_procs = {} + for name, value in self.transformer.attn_processors.items(): + attn_processor_class = FluxAttnProcessor2_0() + attn_procs[name] = ( + attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + ) + self.transformer.set_attn_processor(attn_procs) + + +class SD3IPAdapterMixin: + """Mixin for handling StableDiffusion 3 IP Adapters.""" + + @property + def is_ip_adapter_active(self) -> bool: + """Checks if IP-Adapter is loaded and scale > 0. + + IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, + the image context is irrelevant. + + Returns: + `bool`: True when IP-Adapter is loaded and any layer has scale > 0. + """ + scales = [ + attn_proc.scale + for attn_proc in self.transformer.attn_processors.values() + if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) + ] + + return len(scales) > 0 and any(scale > 0 for scale in scales) + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str = "ip-adapter.safetensors", + subfolder: Optional[str] = None, + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ) -> None: + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + weight_name (`str`, defaults to "ip-adapter.safetensors"): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + # Commons args for loading image encoder and image processor + kwargs = { + "low_cpu_mem_usage": low_cpu_mem_usage, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + self.register_modules( + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs), + image_encoder=SiglipVisionModel.from_pretrained( + image_encoder_subfolder, torch_dtype=self.dtype, **kwargs + ).to(self.device), + ) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # Load IP-Adapter into transformer + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: float) -> None: + """ + Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only + conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages + the model to produce more diverse images, but they may not be as aligned with the image prompt. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.set_ip_adapter_scale(0.6) + >>> ... + ``` + + Args: + scale (float): + IP-Adapter scale to be set. + + """ + for attn_processor in self.transformer.attn_processors.values(): + if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + def unload_ip_adapter(self) -> None: + """ + Unloads the IP Adapter weights. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # Remove image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=None) + + # Remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=None) + + # Remove image projection + self.transformer.image_proj = None + + # Restore original attention processors layers + attn_procs = { + name: ( + JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() + ) + for name, value in self.transformer.attn_processors.items() + } + self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/ip_adapter/transformer_flux.py b/src/diffusers/loaders/ip_adapter/transformer_flux.py new file mode 100644 index 000000000000..5d6e7c809b12 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter/transformer_flux.py @@ -0,0 +1,168 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ...utils import is_accelerate_available, is_torch_version, logging + + +logger = logging.get_logger(__name__) + + +class FluxTransformer2DLoadersMixin: + """ + Load layers into a [`FluxTransformer2DModel`]. + """ + + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + updated_state_dict = {} + image_projection = None + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + if state_dict["proj.weight"].shape[0] == 65536: + num_image_text_embeds = 16 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds + + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + if not low_cpu_mem_usage: + image_projection.load_state_dict(updated_state_dict, strict=True) + else: + device_map = {"": self.device} + load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) + + return image_projection + + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 + + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 0 + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + for name in self.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + attn_processor_class = self.attn_processors[name].__class__ + attn_procs[name] = attn_processor_class() + else: + cross_attention_dim = self.config.joint_attention_dim + hidden_size = self.inner_dim + attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + num_image_text_embeds = [] + for state_dict in state_dicts: + if "proj.weight" in state_dict["image_proj"]: + num_image_text_embed = 4 + if state_dict["image_proj"]["proj.weight"].shape[0] == 65536: + num_image_text_embed = 16 + # IP-Adapter + num_image_text_embeds += [num_image_text_embed] + + with init_context(): + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + dtype=self.dtype, + device=self.device, + ) + + value_dict = {} + for i, state_dict in enumerate(state_dicts): + value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]}) + value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]}) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(value_dict) + else: + device_map = {"": self.device} + dtype = self.dtype + load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype) + + key_id += 1 + + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/loaders/ip_adapter/transformer_sd3.py b/src/diffusers/loaders/ip_adapter/transformer_sd3.py new file mode 100644 index 000000000000..5911ec5903d5 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter/transformer_sd3.py @@ -0,0 +1,170 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext +from typing import Dict + +from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ...models.embeddings import IPAdapterTimeImageProjection +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ...utils import is_accelerate_available, is_torch_version, logging + + +logger = logging.get_logger(__name__) + + +class SD3Transformer2DLoadersMixin: + """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" + + def _convert_ip_adapter_attn_to_diffusers( + self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT + ) -> Dict: + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict.items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor & load state_dict + attn_procs = {} + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + for idx, name in enumerate(self.attn_processors.keys()): + with init_context(): + attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + device_map = {"": self.device} + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype + ) + + return attn_procs + + def _convert_ip_adapter_image_proj_to_diffusers( + self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT + ) -> IPAdapterTimeImageProjection: + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + + # Convert to diffusers + updated_state_dict = {} + for key, value in state_dict.items(): + # InstantX/SD3.5-Large-IP-Adapter + if key.startswith("layers."): + idx = key.split(".")[1] + key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") + key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") + key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") + key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv") + key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0") + key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm") + key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj") + key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2") + key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") + updated_state_dict[key] = value + + # Image projetion parameters + embed_dim = updated_state_dict["proj_in.weight"].shape[1] + output_dim = updated_state_dict["proj_out.weight"].shape[0] + hidden_dim = updated_state_dict["proj_in.weight"].shape[0] + heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = updated_state_dict["latents"].shape[1] + timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1] + + # Image projection + with init_context(): + image_proj = IPAdapterTimeImageProjection( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim, + ) + + if not low_cpu_mem_usage: + image_proj.load_state_dict(updated_state_dict, strict=True) + else: + device_map = {"": self.device} + load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) + + return image_proj + + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + State dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage) + self.set_attn_processor(attn_procs) + + self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) diff --git a/src/diffusers/loaders/lora/__init__.py b/src/diffusers/loaders/lora/__init__.py new file mode 100644 index 000000000000..a0f2111cfcab --- /dev/null +++ b/src/diffusers/loaders/lora/__init__.py @@ -0,0 +1,25 @@ +from ...utils import is_peft_available, is_torch_available, is_transformers_available + + +if is_torch_available(): + from .lora_base import LoraBaseMixin + + if is_transformers_available(): + from .lora_pipeline import ( + AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, + CogVideoXLoraLoaderMixin, + CogView4LoraLoaderMixin, + FluxLoraLoaderMixin, + HiDreamImageLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, + LoraLoaderMixin, + LTXVideoLoraLoaderMixin, + Lumina2LoraLoaderMixin, + Mochi1LoraLoaderMixin, + SanaLoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + WanLoraLoaderMixin, + ) diff --git a/src/diffusers/loaders/lora/lora_base.py b/src/diffusers/loaders/lora/lora_base.py new file mode 100644 index 000000000000..80445106015b --- /dev/null +++ b/src/diffusers/loaders/lora/lora_base.py @@ -0,0 +1,935 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import os +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import safetensors +import torch +import torch.nn as nn +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE + +from ...models.modeling_utils import ModelMixin, load_state_dict +from ...utils import ( + USE_PEFT_BACKEND, + _get_model_file, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + delete_adapter_layers, + deprecate, + get_adapter_name, + get_peft_kwargs, + is_accelerate_available, + is_peft_available, + is_peft_version, + is_transformers_available, + is_transformers_version, + logging, + recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, +) + + +if is_transformers_available(): + from transformers import PreTrainedModel + + from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + +if is_peft_available(): + from peft.tuners.tuners_utils import BaseTunerLayer + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + + +def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + """ + Fuses LoRAs for the text encoder. + + Args: + text_encoder (`torch.nn.Module`): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + """ + merge_kwargs = {"safe_merge": safe_fusing} + + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) + + +def unfuse_text_encoder_lora(text_encoder): + """ + Unfuses LoRAs for the text encoder. + + Args: + text_encoder (`torch.nn.Module`): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + +def set_adapters_for_text_encoder( + adapter_names: Union[List[str], str], + text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 + text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, +): + """ + Sets the adapter layers for the text encoder. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + text_encoder_weights (`List[float]`, *optional*): + The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. + """ + if text_encoder is None: + raise ValueError( + "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." + ) + + def process_weights(adapter_names, weights): + # Expand weights into a list, one entry per adapter + # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" + ) + + # Set None values to default of 1.0 + # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] + weights = [w if w is not None else 1.0 for w in weights] + + return weights + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + text_encoder_weights = process_weights(adapter_names, text_encoder_weights) + set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) + + +def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): + """ + Disables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=False) + + +def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): + """ + Enables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=True) + + +def _remove_text_encoder_monkey_patch(text_encoder): + recurse_remove_peft_layers(text_encoder) + if getattr(text_encoder, "peft_config", None) is not None: + del text_encoder.peft_config + text_encoder._hf_peft_config_loaded = None + + +def _fetch_state_dict( + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, +): + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + + +def _load_lora_into_text_encoder( + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + text_encoder_name="text_encoder", + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, +): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as + # their prefixes. + prefix = text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if hotswap and any(text_encoder_name in key for key in state_dict.keys()): + raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") + + # Load the layers corresponding to text encoder and make necessary adjustments. + if prefix is not None: + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + + if len(state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + state_dict = convert_state_dict_to_diffusers(state_dict) + + # convert state dict + state_dict = convert_state_dict_to_peft(state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + if prefix is not None and not state_dict: + logger.warning( + f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " + "This is safe to ignore if LoRA state dict didn't originally have any " + f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " + "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " + "https://github.com/huggingface/diffusers/issues/new" + ) + + +def _func_optionally_disable_offloading(_pipeline): + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + +class LoraBaseMixin: + """Utility class for handling LoRAs.""" + + _lora_loadable_modules = [] + num_fused_loras = 0 + + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") + + @classmethod + def save_lora_weights(cls, **kwargs): + raise NotImplementedError("`save_lora_weights()` not implemented.") + + @classmethod + def lora_state_dict(cls, **kwargs): + raise NotImplementedError("`lora_state_dict()` is not implemented.") + + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + return _func_optionally_disable_offloading(_pipeline=_pipeline) + + @classmethod + def _fetch_state_dict(cls, *args, **kwargs): + deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." + deprecate("_fetch_state_dict", "0.35.0", deprecation_message) + return _fetch_state_dict(*args, **kwargs) + + @classmethod + def _best_guess_weight_name(cls, *args, **kwargs): + deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." + deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) + return _best_guess_weight_name(*args, **kwargs) + + def unload_lora_weights(self): + """ + Unloads the LoRA parameters. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.unload_lora() + elif issubclass(model.__class__, PreTrainedModel): + _remove_text_encoder_monkey_patch(model) + + def fuse_lora( + self, + components: List[str] = [], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + if "fuse_unet" in kwargs: + depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version." + deprecate( + "fuse_unet", + "1.0.0", + depr_message, + ) + if "fuse_transformer" in kwargs: + depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version." + deprecate( + "fuse_transformer", + "1.0.0", + depr_message, + ) + if "fuse_text_encoder" in kwargs: + depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version." + deprecate( + "fuse_text_encoder", + "1.0.0", + depr_message, + ) + + if len(components) == 0: + raise ValueError("`components` cannot be an empty list.") + + for fuse_component in components: + if fuse_component not in self._lora_loadable_modules: + raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") + + model = getattr(self, fuse_component, None) + if model is not None: + # check if diffusers model + if issubclass(model.__class__, ModelMixin): + model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + # handle transformers models. + if issubclass(model.__class__, PreTrainedModel): + fuse_text_encoder_lora( + model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + self.num_fused_loras += 1 + + def unfuse_lora(self, components: List[str] = [], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + if "unfuse_unet" in kwargs: + depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version." + deprecate( + "unfuse_unet", + "1.0.0", + depr_message, + ) + if "unfuse_transformer" in kwargs: + depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version." + deprecate( + "unfuse_transformer", + "1.0.0", + depr_message, + ) + if "unfuse_text_encoder" in kwargs: + depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version." + deprecate( + "unfuse_text_encoder", + "1.0.0", + depr_message, + ) + + if len(components) == 0: + raise ValueError("`components` cannot be an empty list.") + + for fuse_component in components: + if fuse_component not in self._lora_loadable_modules: + raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") + + model = getattr(self, fuse_component, None) + if model is not None: + if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + self.num_fused_loras -= 1 + + def set_adapters( + self, + adapter_names: Union[List[str], str], + adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, + ): + if isinstance(adapter_weights, dict): + components_passed = set(adapter_weights.keys()) + lora_components = set(self._lora_loadable_modules) + + invalid_components = sorted(components_passed - lora_components) + if invalid_components: + logger.warning( + f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. " + f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging " + "to the invalid components will be removed and ignored." + ) + adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components} + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + adapter_weights = copy.deepcopy(adapter_weights) + + # Expand weights into a list, one entry per adapter + if not isinstance(adapter_weights, list): + adapter_weights = [adapter_weights] * len(adapter_names) + + if len(adapter_names) != len(adapter_weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" + ) + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + # eg ["adapter1", "adapter2"] + all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} + missing_adapters = set(adapter_names) - all_adapters + if len(missing_adapters) > 0: + raise ValueError( + f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." + ) + + # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + invert_list_adapters = { + adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] + for adapter in all_adapters + } + + # Decompose weights into weights for denoiser and text encoders. + _component_adapter_weights = {} + for component in self._lora_loadable_modules: + model = getattr(self, component) + + for adapter_name, weights in zip(adapter_names, adapter_weights): + if isinstance(weights, dict): + component_adapter_weights = weights.pop(component, None) + if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]: + logger.warning( + ( + f"Lora weight dict for adapter '{adapter_name}' contains {component}," + f"but this will be ignored because {adapter_name} does not contain weights for {component}." + f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." + ) + ) + + else: + component_adapter_weights = weights + + _component_adapter_weights.setdefault(component, []) + _component_adapter_weights[component].append(component_adapter_weights) + + if issubclass(model.__class__, ModelMixin): + model.set_adapters(adapter_names, _component_adapter_weights[component]) + elif issubclass(model.__class__, PreTrainedModel): + set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) + + def disable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.disable_lora() + elif issubclass(model.__class__, PreTrainedModel): + disable_lora_for_text_encoder(model) + + def enable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.enable_lora() + elif issubclass(model.__class__, PreTrainedModel): + enable_lora_for_text_encoder(model) + + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Args: + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). + adapter_names (`Union[List[str], str]`): + The names of the adapter to delete. Can be a single string or a list of strings + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + if issubclass(model.__class__, ModelMixin): + model.delete_adapters(adapter_names) + elif issubclass(model.__class__, PreTrainedModel): + for adapter_name in adapter_names: + delete_adapter_layers(model, adapter_name) + + def get_active_adapters(self) -> List[str]: + """ + Gets the list of the current active adapters. + + Example: + + ```python + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + ).to("cuda") + pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + pipeline.get_active_adapters() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + active_adapters = [] + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None and issubclass(model.__class__, ModelMixin): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + break + + return active_adapters + + def get_list_adapters(self) -> Dict[str, List[str]]: + """ + Gets the current list of all available adapters in the pipeline. + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + set_adapters = {} + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if ( + model is not None + and issubclass(model.__class__, (ModelMixin, PreTrainedModel)) + and hasattr(model, "peft_config") + ): + set_adapters[component] = list(model.peft_config.keys()) + + return set_adapters + + def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: + """ + Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case + you want to load multiple adapters and free some GPU memory. + + Args: + adapter_names (`List[str]`): + List of adapters to send device to. + device (`Union[torch.device, str, int]`): + Device to send the adapters to. Can be either a torch device, a str or an integer. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + for component in self._lora_loadable_modules: + model = getattr(self, component, None) + if model is not None: + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + for adapter_name in adapter_names: + module.lora_A[adapter_name].to(device) + module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: + if adapter_name in module.lora_magnitude_vector: + module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[ + adapter_name + ].to(device) + + @staticmethod + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + @staticmethod + def write_lora_layers( + state_dict: Dict[str, torch.Tensor], + save_directory: str, + is_main_process: bool, + weight_name: str, + save_function: Callable, + safe_serialization: bool, + ): + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_path = Path(save_directory, weight_name).as_posix() + save_function(state_dict, save_path) + logger.info(f"Model weights saved in {save_path}") + + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + def enable_lora_hotswap(self, **kwargs) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + for key, component in self.components.items(): + if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): + component.enable_lora_hotswap(**kwargs) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora/lora_conversion_utils.py similarity index 99% rename from src/diffusers/loaders/lora_conversion_utils.py rename to src/diffusers/loaders/lora/lora_conversion_utils.py index d0c9611735ce..3c5afdcf6523 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora/lora_conversion_utils.py @@ -17,7 +17,7 @@ import torch -from ..utils import is_peft_version, logging, state_dict_all_zero +from ...utils import is_peft_version, logging, state_dict_all_zero logger = logging.get_logger(__name__) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py new file mode 100644 index 000000000000..27705041408d --- /dev/null +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -0,0 +1,5686 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict, List, Optional, Union + +import torch +from huggingface_hub.utils import validate_hf_hub_args + +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + get_submodule_by_name, + is_bitsandbytes_available, + is_gguf_available, + is_peft_available, + is_peft_version, + is_torch_version, + is_transformers_available, + is_transformers_version, + logging, +) +from .lora_base import ( # noqa + LORA_WEIGHT_NAME, + LORA_WEIGHT_NAME_SAFE, + LoraBaseMixin, + _fetch_state_dict, + _load_lora_into_text_encoder, +) +from .lora_conversion_utils import ( + _convert_bfl_flux_control_lora_to_diffusers, + _convert_hunyuan_video_lora_to_diffusers, + _convert_kohya_flux_lora_to_diffusers, + _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_lumina2_lora_to_diffusers, + _convert_non_diffusers_wan_lora_to_diffusers, + _convert_xlabs_flux_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) + + +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + + +logger = logging.get_logger(__name__) + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + +_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} + + +def _maybe_dequantize_weight_for_expanded_lora(model, module): + if is_bitsandbytes_available(): + from ...quantizers.bitsandbytes import dequantize_bnb_weight + + if is_gguf_available(): + from ...quantizers.gguf.utils import dequantize_gguf_tensor + + is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" + + if is_bnb_4bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." + ) + if is_gguf_quantized and not is_gguf_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." + ) + + weight_on_cpu = False + if not module.weight.is_cuda: + weight_on_cpu = True + + if is_bnb_4bit_quantized: + module_weight = dequantize_bnb_weight( + module.weight.cuda() if weight_on_cpu else module.weight, + state=module.weight.quant_state, + dtype=model.dtype, + ).data + elif is_gguf_quantized: + module_weight = dequantize_gguf_tensor( + module.weight.cuda() if weight_on_cpu else module.weight, + ) + module_weight = module_weight.to(model.dtype) + else: + module_weight = module.weight.data + + if weight_on_cpu: + module_weight = module_weight.cpu() + + return module_weight + + +class StableDiffusionLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + + _lora_loadable_modules = ["unet", "text_encoder"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + def load_lora_into_unet( + cls, + state_dict, + network_alphas, + unet, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as + # their prefixes. + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (unet_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") + + if unet_lora_layers: + state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["unet", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components, **kwargs) + + +class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and + [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). + """ + + _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix=f"{self.text_encoder_name}_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet + def load_lora_into_unet( + cls, + state_dict, + network_alphas, + unet, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as + # their prefixes. + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." + ) + + if unet_lora_layers: + state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["unet", "text_encoder", "text_encoder_2"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components, **kwargs) + + +class SD3LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SD3Transformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and + [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name=None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=None, + text_encoder=self.text_encoder_2, + prefix=f"{self.text_encoder_name}_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." + ) + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + + if text_encoder_2_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components, **kwargs) + + +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`AuraFlowTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class FluxLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`FluxTransformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return (state_dict, None) if return_alphas else state_dict + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return (state_dict, None) if return_alphas else state_dict + + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict + + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + keys = list(state_dict.keys()) + network_alphas = {} + for k in keys: + if "alpha" in k: + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + network_alphas[k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) + + if return_alphas: + return state_dict, network_alphas + else: + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) + + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): + raise ValueError("Invalid LoRA checkpoint.") + + transformer_lora_state_dict = { + k: state_dict.get(k) + for k in list(state_dict.keys()) + if k.startswith(f"{self.transformer_name}.") and "lora" in k + } + transformer_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if k.startswith(f"{self.transformer_name}.") + and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + has_param_with_expanded_shape = False + if len(transformer_lora_state_dict) > 0: + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) + + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + if len(transformer_lora_state_dict) > 0: + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) + for k in transformer_lora_state_dict: + state_dict.update({k: transformer_lora_state_dict[k]}) + + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + if len(transformer_norm_state_dict) > 0: + transformer._transformer_norm_layers = self._load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + transformer (`FluxTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def _load_norm_into_transformer( + cls, + state_dict, + transformer, + prefix=None, + discard_original_layers=False, + ) -> Dict[str, torch.Tensor]: + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Find invalid keys + transformer_state_dict = transformer.state_dict() + transformer_keys = set(transformer_state_dict.keys()) + state_dict_keys = set(state_dict.keys()) + extra_keys = list(state_dict_keys - transformer_keys) + + if extra_keys: + logger.warning( + f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." + ) + + for key in extra_keys: + state_dict.pop(key) + + # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected + overwritten_layers_state_dict = {} + if not discard_original_layers: + for key in state_dict.keys(): + overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() + + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " + 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " + "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." + ) + + # We can't load with strict=True because the current state_dict does not contain all the transformer keys + incompatible_keys = transformer.load_state_dict(state_dict, strict=False) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + + # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. + if unexpected_keys: + if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): + raise ValueError( + f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." + ) + + return overwritten_layers_state_dict + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if ( + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 + ): + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + ) + + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + + super().unfuse_lora(components=components, **kwargs) + + # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. + def unload_lora_weights(self, reset_to_overwritten_params=False): + """ + Unloads the LoRA parameters. + + Args: + reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules + to their original params. Refer to the [Flux + documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + super().unload_lora_weights() + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + transformer._transformer_norm_layers = None + + if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: + overwritten_params = transformer._overwritten_params + module_names = set() + + for param_name in overwritten_params: + if param_name.endswith(".weight"): + module_names.add(param_name.replace(".weight", "")) + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear) and name in module_names: + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + current_param_weight = overwritten_params[f"{name}.weight"] + in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] + with torch.device("meta"): + original_module = torch.nn.Linear( + in_features, + out_features, + bias=bias, + dtype=module_weight.dtype, + ) + + tmp_state_dict = {"weight": current_param_weight} + if module_bias is not None: + tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) + original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) + setattr(parent_module, current_module_name, original_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(current_param_weight.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + + @classmethod + def _maybe_expand_transformer_param_shape_or_error_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + """ + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + overwritten_params = {} + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + is_quantized = hasattr(transformer, "hf_quantizer") + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + # Model maybe loaded with different quantization schemes which may flatten the params. + # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models + # preserve weight shape. + module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) + + # This means there's no need for an expansion in the params, so we simply skip. + if tuple(module_weight_shape) == (out_features, in_features): + continue + + module_out_features, module_in_features = module_weight_shape + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" + ) + if out_features > module_out_features: + debug_message += ( + ", and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + else: + debug_message += "." + if debug_message: + logger.debug(debug_message) + + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + if is_quantized: + module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) + + # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight_shape) + new_weight[slices] = module_weight + tmp_state_dict = {"weight": new_weight} + if module_bias is not None: + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) + + setattr(parent_module, current_module_name, expanded_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + + # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + + return has_param_with_shape_update + + @classmethod + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + for k in lora_module_names: + if k in unexpected_modules: + continue + + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" + if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + else f"{k.replace(prefix, '')}.weight" + ) + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] + + # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. + base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) + + if base_module_shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + elif base_module_shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." + ) + + if expanded_module_names: + logger.info( + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) + + return lora_state_dict + + @staticmethod + def _calculate_module_shape( + model: "torch.nn.Module", + base_module: "torch.nn.Linear" = None, + base_weight_param_name: str = None, + ) -> "torch.Size": + def _get_weight_shape(weight: torch.Tensor): + if weight.__class__.__name__ == "Params4bit": + return weight.quant_state.shape + elif weight.__class__.__name__ == "GGUFParameter": + return weight.quant_shape + else: + return weight.shape + + if base_module is not None: + return _get_weight_shape(base_module.weight) + elif base_weight_param_name is not None: + if not base_weight_param_name.endswith(".weight"): + raise ValueError( + f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." + ) + module_path = base_weight_param_name.rsplit(".weight", 1)[0] + submodule = get_submodule_by_name(model, module_path) + return _get_weight_shape(submodule.weight) + + raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + + +# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially +# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. +class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel + def load_lora_into_transformer( + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + transformer (`UVit2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + +class CogVideoXLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`CogVideoXTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class Mochi1LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`MochiTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class LTXVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`LTXVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class SanaLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SanaTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading original format HunyuanVideo LoRA checkpoints. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) + if is_original_hunyuan_video: + state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HunyuanVideoTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class Lumina2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # conversion. + non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) + if non_diffusers: + state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`Lumina2Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class WanLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + @classmethod + def _maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + if transformer.config.image_dim is None: + return state_dict + + if any(k.startswith("transformer.blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + + if is_i2v_lora: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] + ) + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] + ) + + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`WanTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class CogView4LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`CogView4Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class HiDreamImageLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HiDreamImageTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." + deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 280a9fa6e73f..0970ef86f109 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -12,924 +12,66 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import inspect -import os -from pathlib import Path -from typing import Callable, Dict, List, Optional, Union -import safetensors -import torch -import torch.nn as nn -from huggingface_hub import model_info -from huggingface_hub.constants import HF_HUB_OFFLINE - -from ..models.modeling_utils import ModelMixin, load_state_dict -from ..utils import ( - USE_PEFT_BACKEND, - _get_model_file, - convert_state_dict_to_diffusers, - convert_state_dict_to_peft, - delete_adapter_layers, - deprecate, - get_adapter_name, - get_peft_kwargs, - is_accelerate_available, - is_peft_available, - is_peft_version, - is_transformers_available, - is_transformers_version, - logging, - recurse_remove_peft_layers, - scale_lora_layers, - set_adapter_layers, - set_weights_and_activate_adapters, -) - - -if is_transformers_available(): - from transformers import PreTrainedModel - - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules - -if is_peft_available(): - from peft.tuners.tuners_utils import BaseTunerLayer - -if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - -logger = logging.get_logger(__name__) - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +from ..utils import deprecate +from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401 def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): - """ - Fuses LoRAs for the text encoder. - - Args: - text_encoder (`torch.nn.Module`): - The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` - attribute. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]` or `str`): - The names of the adapters to use. - """ - merge_kwargs = {"safe_merge": safe_fusing} + from .lora.lora_base import fuse_text_encoder_lora - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - if lora_scale != 1.0: - module.scale_layer(lora_scale) + deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead." + deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message) - # For BC with previous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. " - "Please upgrade to the latest version of PEFT. `pip install -U peft`" - ) - - module.merge(**merge_kwargs) + return fuse_text_encoder_lora( + text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) def unfuse_text_encoder_lora(text_encoder): - """ - Unfuses LoRAs for the text encoder. - - Args: - text_encoder (`torch.nn.Module`): - The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` - attribute. - """ - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - -def set_adapters_for_text_encoder( - adapter_names: Union[List[str], str], - text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 - text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, -): - """ - Sets the adapter layers for the text encoder. - - Args: - adapter_names (`List[str]` or `str`): - The names of the adapters to use. - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` - attribute. - text_encoder_weights (`List[float]`, *optional*): - The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. - """ - if text_encoder is None: - raise ValueError( - "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." - ) - - def process_weights(adapter_names, weights): - # Expand weights into a list, one entry per adapter - # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] - if not isinstance(weights, list): - weights = [weights] * len(adapter_names) + from .lora.lora_base import unfuse_text_encoder_lora - if len(adapter_names) != len(weights): - raise ValueError( - f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" - ) + deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead." + deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message) - # Set None values to default of 1.0 - # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] - weights = [w if w is not None else 1.0 for w in weights] + return unfuse_text_encoder_lora(text_encoder) - return weights - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - text_encoder_weights = process_weights(adapter_names, text_encoder_weights) - set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - - -def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): - """ - Disables the LoRA layers for the text encoder. - - Args: - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` - attribute. - """ - if text_encoder is None: - raise ValueError("Text Encoder not found.") - set_adapter_layers(text_encoder, enabled=False) - - -def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): - """ - Enables the LoRA layers for the text encoder. - - Args: - text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` - attribute. - """ - if text_encoder is None: - raise ValueError("Text Encoder not found.") - set_adapter_layers(text_encoder, enabled=True) - - -def _remove_text_encoder_monkey_patch(text_encoder): - recurse_remove_peft_layers(text_encoder) - if getattr(text_encoder, "peft_config", None) is not None: - del text_encoder.peft_config - text_encoder._hf_peft_config_loaded = None - - -def _fetch_state_dict( - pretrained_model_name_or_path_or_dict, - weight_name, - use_safetensors, - local_files_only, - cache_dir, - force_download, - proxies, - token, - revision, - subfolder, - user_agent, - allow_pickle, -): - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - # Here we're relaxing the loading check to enable more Inference API - # friendliness where sometimes, it's not at all possible to automatically - # determine `weight_name`. - if weight_name is None: - weight_name = _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, - file_extension=".safetensors", - local_files_only=local_files_only, - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - pass - - if model_file is None: - if weight_name is None: - weight_name = _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - return state_dict - - -def _best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +def set_adapters_for_text_encoder( + adapter_names, + text_encoder=None, + text_encoder_weights=None, ): - if local_files_only or HF_HUB_OFFLINE: - raise ValueError("When using the offline mode, you must specify a `weight_name`.") - - targeted_files = [] + from .lora.lora_base import set_adapters_for_text_encoder - if os.path.isfile(pretrained_model_name_or_path_or_dict): - return - elif os.path.isdir(pretrained_model_name_or_path_or_dict): - targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] - else: - files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings - targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] - if len(targeted_files) == 0: - return + deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message) - # "scheduler" does not correspond to a LoRA checkpoint. - # "optimizer" does not correspond to a LoRA checkpoint - # only top-level checkpoints are considered and not the other ones, hence "checkpoint". - unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} - targeted_files = list( - filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + return set_adapters_for_text_encoder( + adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights ) - if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) - elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) - - if len(targeted_files) > 1: - raise ValueError( - f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." - ) - weight_name = targeted_files[0] - return weight_name - - -def _load_lora_into_text_encoder( - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - text_encoder_name="text_encoder", - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, -): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as - # their prefixes. - prefix = text_encoder_name if prefix is None else prefix - - # Safe prefix to check with. - if hotswap and any(text_encoder_name in key for key in state_dict.keys()): - raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") - - # Load the layers corresponding to text encoder and make necessary adjustments. - if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - - if len(state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - state_dict = convert_state_dict_to_diffusers(state_dict) - - # convert state dict - state_dict = convert_state_dict_to_peft(state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - if prefix is not None and not state_dict: - logger.warning( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " - "This is safe to ignore if LoRA state dict didn't originally have any " - f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " - "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " - "https://github.com/huggingface/diffusers/issues/new" - ) - - -def _func_optionally_disable_offloading(_pipeline): - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload) - - -class LoraBaseMixin: - """Utility class for handling LoRAs.""" - - _lora_loadable_modules = [] - num_fused_loras = 0 - - def load_lora_weights(self, **kwargs): - raise NotImplementedError("`load_lora_weights()` is not implemented.") - - @classmethod - def save_lora_weights(cls, **kwargs): - raise NotImplementedError("`save_lora_weights()` not implemented.") - - @classmethod - def lora_state_dict(cls, **kwargs): - raise NotImplementedError("`lora_state_dict()` is not implemented.") - - @classmethod - def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ - return _func_optionally_disable_offloading(_pipeline=_pipeline) - - @classmethod - def _fetch_state_dict(cls, *args, **kwargs): - deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." - deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - return _fetch_state_dict(*args, **kwargs) - - @classmethod - def _best_guess_weight_name(cls, *args, **kwargs): - deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." - deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - return _best_guess_weight_name(*args, **kwargs) - - def unload_lora_weights(self): - """ - Unloads the LoRA parameters. - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the LoRA parameters. - >>> pipeline.unload_lora_weights() - >>> ... - ``` - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None: - if issubclass(model.__class__, ModelMixin): - model.unload_lora() - elif issubclass(model.__class__, PreTrainedModel): - _remove_text_encoder_monkey_patch(model) - - def fuse_lora( - self, - components: List[str] = [], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - if "fuse_unet" in kwargs: - depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version." - deprecate( - "fuse_unet", - "1.0.0", - depr_message, - ) - if "fuse_transformer" in kwargs: - depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version." - deprecate( - "fuse_transformer", - "1.0.0", - depr_message, - ) - if "fuse_text_encoder" in kwargs: - depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version." - deprecate( - "fuse_text_encoder", - "1.0.0", - depr_message, - ) - - if len(components) == 0: - raise ValueError("`components` cannot be an empty list.") - - for fuse_component in components: - if fuse_component not in self._lora_loadable_modules: - raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") - - model = getattr(self, fuse_component, None) - if model is not None: - # check if diffusers model - if issubclass(model.__class__, ModelMixin): - model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) - # handle transformers models. - if issubclass(model.__class__, PreTrainedModel): - fuse_text_encoder_lora( - model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names - ) - - self.num_fused_loras += 1 - - def unfuse_lora(self, components: List[str] = [], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - if "unfuse_unet" in kwargs: - depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version." - deprecate( - "unfuse_unet", - "1.0.0", - depr_message, - ) - if "unfuse_transformer" in kwargs: - depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version." - deprecate( - "unfuse_transformer", - "1.0.0", - depr_message, - ) - if "unfuse_text_encoder" in kwargs: - depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version." - deprecate( - "unfuse_text_encoder", - "1.0.0", - depr_message, - ) - - if len(components) == 0: - raise ValueError("`components` cannot be an empty list.") - - for fuse_component in components: - if fuse_component not in self._lora_loadable_modules: - raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") - - model = getattr(self, fuse_component, None) - if model is not None: - if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - self.num_fused_loras -= 1 - - def set_adapters( - self, - adapter_names: Union[List[str], str], - adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, - ): - if isinstance(adapter_weights, dict): - components_passed = set(adapter_weights.keys()) - lora_components = set(self._lora_loadable_modules) - - invalid_components = sorted(components_passed - lora_components) - if invalid_components: - logger.warning( - f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. " - f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging " - "to the invalid components will be removed and ignored." - ) - adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components} - - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - adapter_weights = copy.deepcopy(adapter_weights) - - # Expand weights into a list, one entry per adapter - if not isinstance(adapter_weights, list): - adapter_weights = [adapter_weights] * len(adapter_names) - - if len(adapter_names) != len(adapter_weights): - raise ValueError( - f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" - ) - - list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} - # eg ["adapter1", "adapter2"] - all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters} - missing_adapters = set(adapter_names) - all_adapters - if len(missing_adapters) > 0: - raise ValueError( - f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}." - ) - - # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} - invert_list_adapters = { - adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] - for adapter in all_adapters - } - - # Decompose weights into weights for denoiser and text encoders. - _component_adapter_weights = {} - for component in self._lora_loadable_modules: - model = getattr(self, component) - - for adapter_name, weights in zip(adapter_names, adapter_weights): - if isinstance(weights, dict): - component_adapter_weights = weights.pop(component, None) - if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]: - logger.warning( - ( - f"Lora weight dict for adapter '{adapter_name}' contains {component}," - f"but this will be ignored because {adapter_name} does not contain weights for {component}." - f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." - ) - ) - - else: - component_adapter_weights = weights - - _component_adapter_weights.setdefault(component, []) - _component_adapter_weights[component].append(component_adapter_weights) - - if issubclass(model.__class__, ModelMixin): - model.set_adapters(adapter_names, _component_adapter_weights[component]) - elif issubclass(model.__class__, PreTrainedModel): - set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) - - def disable_lora(self): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None: - if issubclass(model.__class__, ModelMixin): - model.disable_lora() - elif issubclass(model.__class__, PreTrainedModel): - disable_lora_for_text_encoder(model) - - def enable_lora(self): - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None: - if issubclass(model.__class__, ModelMixin): - model.enable_lora() - elif issubclass(model.__class__, PreTrainedModel): - enable_lora_for_text_encoder(model) - - def delete_adapters(self, adapter_names: Union[List[str], str]): - """ - Args: - Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). - adapter_names (`Union[List[str], str]`): - The names of the adapter to delete. Can be a single string or a list of strings - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - if isinstance(adapter_names, str): - adapter_names = [adapter_names] - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None: - if issubclass(model.__class__, ModelMixin): - model.delete_adapters(adapter_names) - elif issubclass(model.__class__, PreTrainedModel): - for adapter_name in adapter_names: - delete_adapter_layers(model, adapter_name) - - def get_active_adapters(self) -> List[str]: - """ - Gets the list of the current active adapters. - - Example: - - ```python - from diffusers import DiffusionPipeline - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - ).to("cuda") - pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") - pipeline.get_active_adapters() - ``` - """ - if not USE_PEFT_BACKEND: - raise ValueError( - "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" - ) - - active_adapters = [] - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None and issubclass(model.__class__, ModelMixin): - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - active_adapters = module.active_adapters - break - - return active_adapters - - def get_list_adapters(self) -> Dict[str, List[str]]: - """ - Gets the current list of all available adapters in the pipeline. - """ - if not USE_PEFT_BACKEND: - raise ValueError( - "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" - ) - - set_adapters = {} - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if ( - model is not None - and issubclass(model.__class__, (ModelMixin, PreTrainedModel)) - and hasattr(model, "peft_config") - ): - set_adapters[component] = list(model.peft_config.keys()) - - return set_adapters - - def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: - """ - Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case - you want to load multiple adapters and free some GPU memory. - - Args: - adapter_names (`List[str]`): - List of adapters to send device to. - device (`Union[torch.device, str, int]`): - Device to send the adapters to. Can be either a torch device, a str or an integer. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - for component in self._lora_loadable_modules: - model = getattr(self, component, None) - if model is not None: - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - for adapter_name in adapter_names: - module.lora_A[adapter_name].to(device) - module.lora_B[adapter_name].to(device) - # this is a param, not a module, so device placement is not in-place -> re-assign - if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: - if adapter_name in module.lora_magnitude_vector: - module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[ - adapter_name - ].to(device) - - @staticmethod - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - @staticmethod - def write_lora_layers( - state_dict: Dict[str, torch.Tensor], - save_directory: str, - is_main_process: bool, - weight_name: str, - save_function: Callable, - safe_serialization: bool, - ): - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - if safe_serialization: - def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) +def disable_lora_for_text_encoder(text_encoder=None): + from .lora.lora_base import disable_lora_for_text_encoder - else: - save_function = torch.save + deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message) - os.makedirs(save_directory, exist_ok=True) + return disable_lora_for_text_encoder(text_encoder=text_encoder) - if weight_name is None: - if safe_serialization: - weight_name = LORA_WEIGHT_NAME_SAFE - else: - weight_name = LORA_WEIGHT_NAME - save_path = Path(save_directory, weight_name).as_posix() - save_function(state_dict, save_path) - logger.info(f"Model weights saved in {save_path}") +def enable_lora_for_text_encoder(text_encoder=None): + from .lora.lora_base import enable_lora_for_text_encoder - @property - def lora_scale(self) -> float: - # property function that returns the lora scale which can be set at run time by the pipeline. - # if _lora_scale has not been set, return 1 - return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message) - def enable_lora_hotswap(self, **kwargs) -> None: - """Enables the possibility to hotswap LoRA adapters. + return enable_lora_for_text_encoder(text_encoder=text_encoder) - Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of - the loaded adapters differ. - Args: - target_rank (`int`): - The highest rank among all the adapters that will be loaded. - check_compiled (`str`, *optional*, defaults to `"error"`): - How to handle the case when the model is already compiled, which should generally be avoided. The - options are: - - "error" (default): raise an error - - "warn": issue a warning - - "ignore": do nothing - """ - for key, component in self.components.items(): - if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): - component.enable_lora_hotswap(**kwargs) +class LoraBaseMixin(LoraBaseMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead." + deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 50a99cee1d23..1494e6e0b82a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -12,5672 +12,133 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Callable, Dict, List, Optional, Union -import torch -from huggingface_hub.utils import validate_hf_hub_args - -from ..utils import ( - USE_PEFT_BACKEND, - deprecate, - get_submodule_by_name, - is_bitsandbytes_available, - is_gguf_available, - is_peft_available, - is_peft_version, - is_torch_version, - is_transformers_available, - is_transformers_version, - logging, -) -from .lora_base import ( # noqa - LORA_WEIGHT_NAME, - LORA_WEIGHT_NAME_SAFE, - LoraBaseMixin, - _fetch_state_dict, - _load_lora_into_text_encoder, +from ..utils import deprecate +from .lora.lora_pipeline import ( + TEXT_ENCODER_NAME, # noqa: F401 + TRANSFORMER_NAME, # noqa: F401 + UNET_NAME, # noqa: F401 + AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, + CogVideoXLoraLoaderMixin, + CogView4LoraLoaderMixin, + FluxLoraLoaderMixin, + HiDreamImageLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, + LTXVideoLoraLoaderMixin, + Lumina2LoraLoaderMixin, + Mochi1LoraLoaderMixin, + SanaLoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + WanLoraLoaderMixin, ) -from .lora_conversion_utils import ( - _convert_bfl_flux_control_lora_to_diffusers, - _convert_hunyuan_video_lora_to_diffusers, - _convert_kohya_flux_lora_to_diffusers, - _convert_musubi_wan_lora_to_diffusers, - _convert_non_diffusers_lora_to_diffusers, - _convert_non_diffusers_lumina2_lora_to_diffusers, - _convert_non_diffusers_wan_lora_to_diffusers, - _convert_xlabs_flux_lora_to_diffusers, - _maybe_map_sgm_blocks_to_diffusers, -) - - -_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False -if is_torch_version(">=", "1.9.0"): - if ( - is_peft_available() - and is_peft_version(">=", "0.13.1") - and is_transformers_available() - and is_transformers_version(">", "4.45.2") - ): - _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True - - -logger = logging.get_logger(__name__) - -TEXT_ENCODER_NAME = "text_encoder" -UNET_NAME = "unet" -TRANSFORMER_NAME = "transformer" - -_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} - - -def _maybe_dequantize_weight_for_expanded_lora(model, module): - if is_bitsandbytes_available(): - from ..quantizers.bitsandbytes import dequantize_bnb_weight - - if is_gguf_available(): - from ..quantizers.gguf.utils import dequantize_gguf_tensor - - is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" - is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" - - if is_bnb_4bit_quantized and not is_bitsandbytes_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." - ) - if is_gguf_quantized and not is_gguf_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." - ) - - weight_on_cpu = False - if module.weight.device.type == "cpu": - weight_on_cpu = True - - device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" - if is_bnb_4bit_quantized: - module_weight = dequantize_bnb_weight( - module.weight.to(device) if weight_on_cpu else module.weight, - state=module.weight.quant_state, - dtype=model.dtype, - ).data - elif is_gguf_quantized: - module_weight = dequantize_gguf_tensor( - module.weight.to(device) if weight_on_cpu else module.weight, - ) - module_weight = module_weight.to(model.dtype) - else: - module_weight = module.weight.data - - if weight_on_cpu: - module_weight = module_weight.cpu() - - return module_weight - - -class StableDiffusionLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - """ - - _lora_loadable_modules = ["unet", "text_encoder"] - unet_name = UNET_NAME - text_encoder_name = TEXT_ENCODER_NAME - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is - loaded into `self.unet`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state - dict is loaded into `self.text_encoder`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_unet( - state_dict, - network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name) - if not hasattr(self, "text_encoder") - else self.text_encoder, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. - """ - # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - return state_dict, network_alphas - - @classmethod - def load_lora_into_unet( - cls, - state_dict, - network_alphas, - unet, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `unet`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as - # their prefixes. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `unet`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") - - if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["unet", "text_encoder"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - super().unfuse_lora(components=components, **kwargs) - - -class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`], - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and - [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). - """ - - _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] - unet_name = UNET_NAME - text_encoder_name = TEXT_ENCODER_NAME - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is - loaded into `self.unet`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state - dict is loaded into `self.text_encoder`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_unet( - state_dict, - network_alphas=network_alphas, - unet=self.unet, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix=f"{self.text_encoder_name}_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. - """ - # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - return state_dict, network_alphas - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet - def load_lora_into_unet( - cls, - state_dict, - network_alphas, - unet, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `unet`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as - # their prefixes. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `unet`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) - - if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["unet", "text_encoder", "text_encoder_2"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - super().unfuse_lora(components=components, **kwargs) - - -class SD3LoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`SD3Transformer2DModel`], - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and - [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). - - Specific to [`StableDiffusion3Pipeline`]. - """ - - _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"] - transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name=None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=None, - text_encoder=self.text_encoder_2, - prefix=f"{self.text_encoder_name}_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`SD3Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) - - if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer - def fuse_lora( - self, - components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - super().unfuse_lora(components=components, **kwargs) - - -class AuraFlowLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`AuraFlowTransformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) +class StableDiffusionLoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `StableDiffusionLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) -class FluxLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`FluxTransformer2DModel`], - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - Specific to [`StableDiffusion3Pipeline`]. - """ +class StableDiffusionXLLoraLoaderMixin(StableDiffusionXLLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `StableDiffusionXLLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionXLLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - _lora_loadable_modules = ["transformer", "text_encoder"] - transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME - _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - return_alphas: bool = False, - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. +class SD3LoraLoaderMixin(SD3LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SD3LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. +class AuraFlowLoraLoaderMixin(AuraFlowLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `AuraFlowLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AuraFlowLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.AuraFlowLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - This function is experimental and might change in the future. - +class FluxLoraLoaderMixin(FluxLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import FluxLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). +class AmusedLoraLoaderMixin(AmusedLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `AmusedLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AmusedLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.AmusedLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. +class CogVideoXLoraLoaderMixin(CogVideoXLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CogVideoXLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogVideoXLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True +class Mochi1LoraLoaderMixin(Mochi1LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Mochi1LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Mochi1LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} +class LTXVideoLoraLoaderMixin(LTXVideoLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `LTXVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import LTXVideoLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.LTXVideoLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) - if is_kohya: - state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) - # Kohya already takes care of scaling the LoRA parameters with alpha. - return (state_dict, None) if return_alphas else state_dict - is_xlabs = any("processor" in k for k in state_dict) - if is_xlabs: - state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) - # xlabs doesn't use `alpha`. - return (state_dict, None) if return_alphas else state_dict +class SanaLoraLoaderMixin(SanaLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SanaLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SanaLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - is_bfl_control = any("query_norm.scale" in k for k in state_dict) - if is_bfl_control: - state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict - # For state dicts like - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA - keys = list(state_dict.keys()) - network_alphas = {} - for k in keys: - if "alpha" in k: - alpha_value = state_dict.get(k) - if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( - alpha_value, float - ): - network_alphas[k] = state_dict.pop(k) - else: - raise ValueError( - f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." - ) +class HunyuanVideoLoraLoaderMixin(HunyuanVideoLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `HunyuanVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import HunyuanVideoLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - if return_alphas: - return state_dict, network_alphas - else: - return state_dict - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. +class Lumina2LoraLoaderMixin(Lumina2LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Lumina2LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Lumina2LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.Lumina2LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - All kwargs are forwarded to `self.lora_state_dict`. - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. +class WanLoraLoaderMixin(WanLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `WanLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import WanLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.WanLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") +class CogView4LoraLoaderMixin(CogView4LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CogView4LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogView4LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.CogView4LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs - ) - - has_lora_keys = any("lora" in key for key in state_dict.keys()) - - # Flux Control LoRAs also have norm keys - has_norm_keys = any( - norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys - ) - - if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") - - transformer_lora_state_dict = { - k: state_dict.get(k) - for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") and "lora" in k - } - transformer_norm_state_dict = { - k: state_dict.pop(k) - for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") - and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) - } - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - has_param_with_expanded_shape = False - if len(transformer_lora_state_dict) > 0: - has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) - - if has_param_with_expanded_shape: - logger.info( - "The LoRA weights contain parameters that have different shapes that expected by the transformer. " - "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " - "To get a comprehensive list of parameter names that were modified, enable debug logging." - ) - if len(transformer_lora_state_dict) > 0: - transformer_lora_state_dict = self._maybe_expand_lora_state_dict( - transformer=transformer, lora_state_dict=transformer_lora_state_dict - ) - for k in transformer_lora_state_dict: - state_dict.update({k: transformer_lora_state_dict[k]}) - - self.load_lora_into_transformer( - state_dict, - network_alphas=network_alphas, - transformer=transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - if len(transformer_norm_state_dict) > 0: - transformer._transformer_norm_layers = self._load_norm_into_transformer( - transformer_norm_state_dict, - transformer=transformer, - discard_original_layers=False, - ) - - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def load_lora_into_transformer( - cls, - state_dict, - network_alphas, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - transformer (`FluxTransformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def _load_norm_into_transformer( - cls, - state_dict, - transformer, - prefix=None, - discard_original_layers=False, - ) -> Dict[str, torch.Tensor]: - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) - - # Find invalid keys - transformer_state_dict = transformer.state_dict() - transformer_keys = set(transformer_state_dict.keys()) - state_dict_keys = set(state_dict.keys()) - extra_keys = list(state_dict_keys - transformer_keys) - - if extra_keys: - logger.warning( - f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." - ) - - for key in extra_keys: - state_dict.pop(key) - - # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected - overwritten_layers_state_dict = {} - if not discard_original_layers: - for key in state_dict.keys(): - overwritten_layers_state_dict[key] = transformer_state_dict[key].clone() - - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer " - 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly ' - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. " - "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues." - ) - - # We can't load with strict=True because the current state_dict does not contain all the transformer keys - incompatible_keys = transformer.load_state_dict(state_dict, strict=False) - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - - # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. - if unexpected_keys: - if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys): - raise ValueError( - f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." - ) - - return overwritten_layers_state_dict - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if ( - hasattr(transformer, "_transformer_norm_layers") - and isinstance(transformer._transformer_norm_layers, dict) - and len(transformer._transformer_norm_layers.keys()) > 0 - ): - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " - "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." - ) - - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - """ - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: - transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - - super().unfuse_lora(components=components, **kwargs) - - # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. - def unload_lora_weights(self, reset_to_overwritten_params=False): - """ - Unloads the LoRA parameters. - - Args: - reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules - to their original params. Refer to the [Flux - documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. - - Examples: - - ```python - >>> # Assuming `pipeline` is already loaded with the LoRA parameters. - >>> pipeline.unload_lora_weights() - >>> ... - ``` - """ - super().unload_lora_weights() - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: - transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - transformer._transformer_norm_layers = None - - if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None: - overwritten_params = transformer._overwritten_params - module_names = set() - - for param_name in overwritten_params: - if param_name.endswith(".weight"): - module_names.add(param_name.replace(".weight", "")) - - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear) and name in module_names: - module_weight = module.weight.data - module_bias = module.bias.data if module.bias is not None else None - bias = module_bias is not None - - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) - - current_param_weight = overwritten_params[f"{name}.weight"] - in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] - with torch.device("meta"): - original_module = torch.nn.Linear( - in_features, - out_features, - bias=bias, - dtype=module_weight.dtype, - ) - - tmp_state_dict = {"weight": current_param_weight} - if module_bias is not None: - tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) - original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) - setattr(parent_module, current_module_name, original_module) - - del tmp_state_dict - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(current_param_weight.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info( - f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." - ) - - @classmethod - def _maybe_expand_transformer_param_shape_or_error_( - cls, - transformer: torch.nn.Module, - lora_state_dict=None, - norm_state_dict=None, - prefix=None, - ) -> bool: - """ - Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and - generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. - """ - state_dict = {} - if lora_state_dict is not None: - state_dict.update(lora_state_dict) - if norm_state_dict is not None: - state_dict.update(norm_state_dict) - - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) - - # Expand transformer parameter shapes if they don't match lora - has_param_with_shape_update = False - overwritten_params = {} - - is_peft_loaded = getattr(transformer, "peft_config", None) is not None - is_quantized = hasattr(transformer, "hf_quantizer") - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear): - module_weight = module.weight.data - module_bias = module.bias.data if module.bias is not None else None - bias = module_bias is not None - - lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name - lora_A_weight_name = f"{lora_base_name}.lora_A.weight" - lora_B_weight_name = f"{lora_base_name}.lora_B.weight" - if lora_A_weight_name not in state_dict: - continue - - in_features = state_dict[lora_A_weight_name].shape[1] - out_features = state_dict[lora_B_weight_name].shape[0] - - # Model maybe loaded with different quantization schemes which may flatten the params. - # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models - # preserve weight shape. - module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) - - # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight_shape) == (out_features, in_features): - continue - - module_out_features, module_in_features = module_weight_shape - debug_message = "" - if in_features > module_in_features: - debug_message += ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if out_features > module_out_features: - debug_message += ( - ", and the number of output features will be " - f"expanded from {module_out_features} to {out_features}." - ) - else: - debug_message += "." - if debug_message: - logger.debug(debug_message) - - if out_features > module_out_features or in_features > module_in_features: - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) - - if is_quantized: - module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) - - # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. - with torch.device("meta"): - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. This is because only the input dimensions - # are changed while the output dimensions remain the same. The shape of the weight tensor - # is (out_features, in_features), while the shape of bias tensor is (out_features,), which - # explains the reason why only weights are expanded. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype - ) - slices = tuple(slice(0, dim) for dim in module_weight_shape) - new_weight[slices] = module_weight - tmp_state_dict = {"weight": new_weight} - if module_bias is not None: - tmp_state_dict["bias"] = module_bias - expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) - - setattr(parent_module, current_module_name, expanded_module) - - del tmp_state_dict - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info( - f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." - ) - - # For `unload_lora_weights()`. - # TODO: this could lead to more memory overhead if the number of overwritten params - # are large. Should be revisited later and tackled through a `discard_original_layers` arg. - overwritten_params[f"{current_module_name}.weight"] = module_weight - if module_bias is not None: - overwritten_params[f"{current_module_name}.bias"] = module_bias - - if len(overwritten_params) > 0: - transformer._overwritten_params = overwritten_params - - return has_param_with_shape_update - - @classmethod - def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): - expanded_module_names = set() - transformer_state_dict = transformer.state_dict() - prefix = f"{cls.transformer_name}." - - lora_module_names = [ - key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") - ] - lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] - lora_module_names = sorted(set(lora_module_names)) - transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) - unexpected_modules = set(lora_module_names) - set(transformer_module_names) - if unexpected_modules: - logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - - is_peft_loaded = getattr(transformer, "peft_config", None) is not None - for k in lora_module_names: - if k in unexpected_modules: - continue - - base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" - if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict - else f"{k.replace(prefix, '')}.weight" - ) - base_weight_param = transformer_state_dict[base_param_name] - lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - - # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. - base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) - - if base_module_shape[1] > lora_A_param.shape[1]: - shape = (lora_A_param.shape[0], base_weight_param.shape[1]) - expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) - expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) - lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight - expanded_module_names.add(k) - elif base_module_shape[1] < lora_A_param.shape[1]: - raise NotImplementedError( - f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." - ) - - if expanded_module_names: - logger.info( - f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." - ) - - return lora_state_dict - - @staticmethod - def _calculate_module_shape( - model: "torch.nn.Module", - base_module: "torch.nn.Linear" = None, - base_weight_param_name: str = None, - ) -> "torch.Size": - def _get_weight_shape(weight: torch.Tensor): - if weight.__class__.__name__ == "Params4bit": - return weight.quant_state.shape - elif weight.__class__.__name__ == "GGUFParameter": - return weight.quant_shape - else: - return weight.shape - - if base_module is not None: - return _get_weight_shape(base_module.weight) - elif base_weight_param_name is not None: - if not base_weight_param_name.endswith(".weight"): - raise ValueError( - f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." - ) - module_path = base_weight_param_name.rsplit(".weight", 1)[0] - submodule = get_submodule_by_name(model, module_path) - return _get_weight_shape(submodule.weight) - - raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") - - -# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially -# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. -class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): - _lora_loadable_modules = ["transformer", "text_encoder"] - transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel - def load_lora_into_transformer( - cls, - state_dict, - network_alphas, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - transformer (`UVit2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - transformer_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `unet`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - -class CogVideoXLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`CogVideoXTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class Mochi1LoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`MochiTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class LTXVideoLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`LTXVideoTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class SanaLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`SanaTransformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading original format HunyuanVideo LoRA checkpoints. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) - if is_original_hunyuan_video: - state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`HunyuanVideoTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class Lumina2LoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - # conversion. - non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) - if non_diffusers: - state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`Lumina2Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class WanLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) - elif any(k.startswith("lora_unet_") for k in state_dict): - state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - @classmethod - def _maybe_expand_t2v_lora_for_i2v( - cls, - transformer: torch.nn.Module, - state_dict, - ): - if transformer.config.image_dim is None: - return state_dict - - if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) - is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - - if is_i2v_lora: - return state_dict - - for i in range(num_blocks): - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] - ) - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] - ) - - return state_dict - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - state_dict = self._maybe_expand_t2v_lora_for_i2v( - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - state_dict=state_dict, - ) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`WanTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class CogView4LoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`CogView4Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class HiDreamImageLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`HiDreamImageTransformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) +class HiDreamImageLoraLoaderMixin(HiDreamImageLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CogView4LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogView4LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.CogView4LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 50450ab7d880..708c118ff521 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -35,8 +35,8 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading -from .unet_loader_utils import _maybe_expand_lora_scales +from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading +from .unet.unet_loader_utils import _maybe_expand_lora_scales logger = logging.get_logger(__name__) @@ -99,7 +99,7 @@ class PeftAdapterMixin: _prepare_lora_hotswap_kwargs: Optional[dict] = None @classmethod - # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c2843fc7406a..e6acefb43976 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -11,42 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import os - -import torch -from huggingface_hub import snapshot_download -from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args -from packaging import version -from typing_extensions import Self - -from ..utils import deprecate, is_transformers_available, logging -from .single_file_utils import ( - SingleFileComponentError, - _is_legacy_scheduler_kwargs, - _is_model_weights_in_cached_folder, - _legacy_load_clip_tokenizer, - _legacy_load_safety_checker, - _legacy_load_scheduler, - create_diffusers_clip_model_from_ldm, - create_diffusers_t5_model_from_checkpoint, - fetch_diffusers_config, - fetch_original_config, - is_clip_model_in_single_file, - is_t5_in_single_file, - load_single_file_checkpoint, -) - - -logger = logging.get_logger(__name__) - -# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided -SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] - -if is_transformers_available(): - import transformers - from transformers import PreTrainedModel, PreTrainedTokenizer +from ..utils import deprecate +from .single_file.single_file import FromSingleFileMixin def load_single_file_sub_model( @@ -64,502 +30,30 @@ def load_single_file_sub_model( disable_mmap=False, **kwargs, ): - if is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - - if is_transformers_available(): - transformers_version = version.parse(version.parse(transformers.__version__).base_version) - else: - transformers_version = "N/A" - - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) - is_tokenizer = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedTokenizer) - and transformers_version >= version.parse("4.20.0") + from .single_file.single_file import load_single_file_sub_model + + deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead." + deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message) + + return load_single_file_sub_model( + library_name, + class_name, + name, + checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config, + local_files_only, + torch_dtype, + is_legacy_loading, + disable_mmap, + **kwargs, ) - diffusers_module = importlib.import_module(__name__.split(".")[0]) - is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin) - is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) - is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin) - - if is_diffusers_single_file_model: - load_method = getattr(class_obj, "from_single_file") - - # We cannot provide two different config options to the `from_single_file` method - # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided - if original_config: - cached_model_config_path = None - - loaded_sub_model = load_method( - pretrained_model_link_or_path_or_dict=checkpoint, - original_config=original_config, - config=cached_model_config_path, - subfolder=name, - torch_dtype=torch_dtype, - local_files_only=local_files_only, - disable_mmap=disable_mmap, - **kwargs, - ) - - elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint): - loaded_sub_model = create_diffusers_clip_model_from_ldm( - class_obj, - checkpoint=checkpoint, - config=cached_model_config_path, - subfolder=name, - torch_dtype=torch_dtype, - local_files_only=local_files_only, - is_legacy_loading=is_legacy_loading, - ) - - elif is_transformers_model and is_t5_in_single_file(checkpoint): - loaded_sub_model = create_diffusers_t5_model_from_checkpoint( - class_obj, - checkpoint=checkpoint, - config=cached_model_config_path, - subfolder=name, - torch_dtype=torch_dtype, - local_files_only=local_files_only, - ) - - elif is_tokenizer and is_legacy_loading: - loaded_sub_model = _legacy_load_clip_tokenizer( - class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only - ) - - elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)): - loaded_sub_model = _legacy_load_scheduler( - class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs - ) - - else: - if not hasattr(class_obj, "from_pretrained"): - raise ValueError( - ( - f"The component {class_obj.__name__} cannot be loaded as it does not seem to have" - " a supported loading method." - ) - ) - - loading_kwargs = {} - loading_kwargs.update( - { - "pretrained_model_name_or_path": cached_model_config_path, - "subfolder": name, - "local_files_only": local_files_only, - } - ) - - # Schedulers and Tokenizers don't make use of torch_dtype - # Skip passing it to those objects - if issubclass(class_obj, torch.nn.Module): - loading_kwargs.update({"torch_dtype": torch_dtype}) - - if is_diffusers_model or is_transformers_model: - if not _is_model_weights_in_cached_folder(cached_model_config_path, name): - raise SingleFileComponentError( - f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." - ) - - load_method = getattr(class_obj, "from_pretrained") - loaded_sub_model = load_method(**loading_kwargs) - - return loaded_sub_model - - -def _map_component_types_to_config_dict(component_types): - diffusers_module = importlib.import_module(__name__.split(".")[0]) - config_dict = {} - component_types.pop("self", None) - - if is_transformers_available(): - transformers_version = version.parse(version.parse(transformers.__version__).base_version) - else: - transformers_version = "N/A" - - for component_name, component_value in component_types.items(): - is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin) - is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers" - is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin) - - is_transformers_model = ( - is_transformers_available() - and issubclass(component_value[0], PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) - is_transformers_tokenizer = ( - is_transformers_available() - and issubclass(component_value[0], PreTrainedTokenizer) - and transformers_version >= version.parse("4.20.0") - ) - - if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: - config_dict[component_name] = ["diffusers", component_value[0].__name__] - - elif is_scheduler_enum or is_scheduler: - if is_scheduler_enum: - # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler - # if the type hint is a KarrassDiffusionSchedulers enum - config_dict[component_name] = ["diffusers", "DDIMScheduler"] - - elif is_scheduler: - config_dict[component_name] = ["diffusers", component_value[0].__name__] - - elif ( - is_transformers_model or is_transformers_tokenizer - ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: - config_dict[component_name] = ["transformers", component_value[0].__name__] - - else: - config_dict[component_name] = [None, None] - - return config_dict - - -def _infer_pipeline_config_dict(pipeline_class): - parameters = inspect.signature(pipeline_class.__init__).parameters - required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} - component_types = pipeline_class._get_signature_types() - - # Ignore parameters that are not required for the pipeline - component_types = {k: v for k, v in component_types.items() if k in required_parameters} - config_dict = _map_component_types_to_config_dict(component_types) - - return config_dict - - -def _download_diffusers_model_config_from_hub( - pretrained_model_name_or_path, - cache_dir, - revision, - proxies, - force_download=None, - local_files_only=None, - token=None, -): - allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"] - cached_model_path = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - revision=revision, - proxies=proxies, - force_download=force_download, - local_files_only=local_files_only, - token=token, - allow_patterns=allow_patterns, - ) - - return cached_model_path - - -class FromSingleFileMixin: - """ - Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: - r""" - Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` - format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - original_config_file (`str`, *optional*): - The path to the original config file that was used to train the model. If not provided, the config file - will be inferred from the checkpoint file. - config (`str`, *optional*): - Can be either: - - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline - hosted on the Hub. - - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline - component configs in Diffusers format. - disable_mmap ('bool', *optional*, defaults to 'False'): - Whether to disable mmap when loading a Safetensors model. This option can perform better when the model - is on a network mount or hard drive. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline - class). The overwritten components are passed directly to the pipelines `__init__` method. See example - below for more information. - - Examples: - - ```py - >>> from diffusers import StableDiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = StableDiffusionPipeline.from_single_file( - ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" - ... ) - - >>> # Download pipeline from local file - >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt - >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt") - - >>> # Enable float16 and move to GPU - >>> pipeline = StableDiffusionPipeline.from_single_file( - ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", - ... torch_dtype=torch.float16, - ... ) - >>> pipeline.to("cuda") - ``` - - """ - original_config_file = kwargs.pop("original_config_file", None) - config = kwargs.pop("config", None) - original_config = kwargs.pop("original_config", None) - - if original_config_file is not None: - deprecation_message = ( - "`original_config_file` argument is deprecated and will be removed in future versions." - "please use the `original_config` argument instead." - ) - deprecate("original_config_file", "1.0.0", deprecation_message) - original_config = original_config_file - - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - disable_mmap = kwargs.pop("disable_mmap", False) - - is_legacy_loading = False - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - torch_dtype = torch.float32 - logger.warning( - f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." - ) - - # We shouldn't allow configuring individual models components through a Pipeline creation method - # These model kwargs should be deprecated - scaling_factor = kwargs.get("scaling_factor", None) - if scaling_factor is not None: - deprecation_message = ( - "Passing the `scaling_factor` argument to `from_single_file is deprecated " - "and will be ignored in future versions." - ) - deprecate("scaling_factor", "1.0.0", deprecation_message) - - if original_config is not None: - original_config = fetch_original_config(original_config, local_files_only=local_files_only) - - from ..pipelines.pipeline_utils import _get_pipeline_class - - pipeline_class = _get_pipeline_class(cls, config=None) - - checkpoint = load_single_file_checkpoint( - pretrained_model_link_or_path, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - disable_mmap=disable_mmap, - ) - - if config is None: - config = fetch_diffusers_config(checkpoint) - default_pretrained_model_config_name = config["pretrained_model_name_or_path"] - else: - default_pretrained_model_config_name = config - - if not os.path.isdir(default_pretrained_model_config_name): - # Provided config is a repo_id - if default_pretrained_model_config_name.count("/") > 1: - raise ValueError( - f'The provided config "{config}"' - " is neither a valid local path nor a valid repo id. Please check the parameter." - ) - try: - # Attempt to download the config files for the pipeline - cached_model_config_path = _download_diffusers_model_config_from_hub( - default_pretrained_model_config_name, - cache_dir=cache_dir, - revision=revision, - proxies=proxies, - force_download=force_download, - local_files_only=local_files_only, - token=token, - ) - config_dict = pipeline_class.load_config(cached_model_config_path) - - except LocalEntryNotFoundError: - # `local_files_only=True` but a local diffusers format model config is not available in the cache - # If `original_config` is not provided, we need override `local_files_only` to False - # to fetch the config files from the hub so that we have a way - # to configure the pipeline components. - - if original_config is None: - logger.warning( - "`local_files_only` is True but no local configs were found for this checkpoint.\n" - "Attempting to download the necessary config files for this pipeline.\n" - ) - cached_model_config_path = _download_diffusers_model_config_from_hub( - default_pretrained_model_config_name, - cache_dir=cache_dir, - revision=revision, - proxies=proxies, - force_download=force_download, - local_files_only=False, - token=token, - ) - config_dict = pipeline_class.load_config(cached_model_config_path) - - else: - # For backwards compatibility - # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components - logger.warning( - "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" - "This may lead to errors if the model components are not correctly inferred. \n" - "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" - "e.g. `from_single_file(, config=) \n" - "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " - "the necessary config files.\n" - ) - is_legacy_loading = True - cached_model_config_path = None - - config_dict = _infer_pipeline_config_dict(pipeline_class) - config_dict["_class_name"] = pipeline_class.__name__ - - else: - # Provided config is a path to a local directory attempt to load directly. - cached_model_config_path = default_pretrained_model_config_name - config_dict = pipeline_class.load_config(cached_model_config_path) - - # pop out "_ignore_files" as it is only needed for download - config_dict.pop("_ignore_files", None) - - expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - - init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} - init_kwargs = {**init_kwargs, **passed_pipe_kwargs} - - from diffusers import pipelines - - # remove `null` components - def load_module(name, value): - if value[0] is None: - return False - if name in passed_class_obj and passed_class_obj[name] is None: - return False - if name in SINGLE_FILE_OPTIONAL_COMPONENTS: - return False - - return True - - init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - - for name, (library_name, class_name) in logging.tqdm( - sorted(init_dict.items()), desc="Loading pipeline components..." - ): - loaded_sub_model = None - is_pipeline_module = hasattr(pipelines, library_name) - - if name in passed_class_obj: - loaded_sub_model = passed_class_obj[name] - - else: - try: - loaded_sub_model = load_single_file_sub_model( - library_name=library_name, - class_name=class_name, - name=name, - checkpoint=checkpoint, - is_pipeline_module=is_pipeline_module, - cached_model_config_path=cached_model_config_path, - pipelines=pipelines, - torch_dtype=torch_dtype, - original_config=original_config, - local_files_only=local_files_only, - is_legacy_loading=is_legacy_loading, - disable_mmap=disable_mmap, - **kwargs, - ) - except SingleFileComponentError as e: - raise SingleFileComponentError( - ( - f"{e.message}\n" - f"Please load the component before passing it in as an argument to `from_single_file`.\n" - f"\n" - f"{name} = {class_name}.from_pretrained('...')\n" - f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n" - f"\n" - ) - ) - - init_kwargs[name] = loaded_sub_model - - missing_modules = set(expected_modules) - set(init_kwargs.keys()) - passed_modules = list(passed_class_obj.keys()) - optional_modules = pipeline_class._optional_components - - if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): - for module in missing_modules: - init_kwargs[module] = passed_class_obj.get(module, None) - elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs - raise ValueError( - f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." - ) - - # deprecated kwargs - load_safety_checker = kwargs.pop("load_safety_checker", None) - if load_safety_checker is not None: - deprecation_message = ( - "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`" - "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`" - ) - deprecate("load_safety_checker", "1.0.0", deprecation_message) - - safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype) - init_kwargs.update(safety_checker_components) - - pipe = pipeline_class(**init_kwargs) - return pipe +class FromSingleFileMixin(FromSingleFileMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead." + deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/single_file/__init__.py b/src/diffusers/loaders/single_file/__init__.py new file mode 100644 index 000000000000..827685218160 --- /dev/null +++ b/src/diffusers/loaders/single_file/__init__.py @@ -0,0 +1,8 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_torch_available(): + from .single_file_model import FromOriginalModelMixin + + if is_transformers_available(): + from .single_file import FromSingleFileMixin diff --git a/src/diffusers/loaders/single_file/single_file.py b/src/diffusers/loaders/single_file/single_file.py new file mode 100644 index 000000000000..26d94e7935ef --- /dev/null +++ b/src/diffusers/loaders/single_file/single_file.py @@ -0,0 +1,565 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import os + +import torch +from huggingface_hub import snapshot_download +from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args +from packaging import version +from typing_extensions import Self + +from ...utils import deprecate, is_transformers_available, logging +from .single_file_utils import ( + SingleFileComponentError, + _is_legacy_scheduler_kwargs, + _is_model_weights_in_cached_folder, + _legacy_load_clip_tokenizer, + _legacy_load_safety_checker, + _legacy_load_scheduler, + create_diffusers_clip_model_from_ldm, + create_diffusers_t5_model_from_checkpoint, + fetch_diffusers_config, + fetch_original_config, + is_clip_model_in_single_file, + is_t5_in_single_file, + load_single_file_checkpoint, +) + + +logger = logging.get_logger(__name__) + +# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided +SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel, PreTrainedTokenizer + + +def load_single_file_sub_model( + library_name, + class_name, + name, + checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config=None, + local_files_only=False, + torch_dtype=None, + is_legacy_loading=False, + disable_mmap=False, + **kwargs, +): + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + is_tokenizer = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin) + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin) + + if is_diffusers_single_file_model: + load_method = getattr(class_obj, "from_single_file") + + # We cannot provide two different config options to the `from_single_file` method + # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided + if original_config: + cached_model_config_path = None + + loaded_sub_model = load_method( + pretrained_model_link_or_path_or_dict=checkpoint, + original_config=original_config, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + disable_mmap=disable_mmap, + **kwargs, + ) + + elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint): + loaded_sub_model = create_diffusers_clip_model_from_ldm( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, + ) + + elif is_transformers_model and is_t5_in_single_file(checkpoint): + loaded_sub_model = create_diffusers_t5_model_from_checkpoint( + class_obj, + checkpoint=checkpoint, + config=cached_model_config_path, + subfolder=name, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + ) + + elif is_tokenizer and is_legacy_loading: + loaded_sub_model = _legacy_load_clip_tokenizer( + class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only + ) + + elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)): + loaded_sub_model = _legacy_load_scheduler( + class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs + ) + + else: + if not hasattr(class_obj, "from_pretrained"): + raise ValueError( + ( + f"The component {class_obj.__name__} cannot be loaded as it does not seem to have" + " a supported loading method." + ) + ) + + loading_kwargs = {} + loading_kwargs.update( + { + "pretrained_model_name_or_path": cached_model_config_path, + "subfolder": name, + "local_files_only": local_files_only, + } + ) + + # Schedulers and Tokenizers don't make use of torch_dtype + # Skip passing it to those objects + if issubclass(class_obj, torch.nn.Module): + loading_kwargs.update({"torch_dtype": torch_dtype}) + + if is_diffusers_model or is_transformers_model: + if not _is_model_weights_in_cached_folder(cached_model_config_path, name): + raise SingleFileComponentError( + f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + load_method = getattr(class_obj, "from_pretrained") + loaded_sub_model = load_method(**loading_kwargs) + + return loaded_sub_model + + +def _map_component_types_to_config_dict(component_types): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + config_dict = {} + component_types.pop("self", None) + + if is_transformers_available(): + transformers_version = version.parse(version.parse(transformers.__version__).base_version) + else: + transformers_version = "N/A" + + for component_name, component_value in component_types.items(): + is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin) + is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers" + is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin) + + is_transformers_model = ( + is_transformers_available() + and issubclass(component_value[0], PreTrainedModel) + and transformers_version >= version.parse("4.20.0") + ) + is_transformers_tokenizer = ( + is_transformers_available() + and issubclass(component_value[0], PreTrainedTokenizer) + and transformers_version >= version.parse("4.20.0") + ) + + if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif is_scheduler_enum or is_scheduler: + if is_scheduler_enum: + # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler + # if the type hint is a KarrassDiffusionSchedulers enum + config_dict[component_name] = ["diffusers", "DDIMScheduler"] + + elif is_scheduler: + config_dict[component_name] = ["diffusers", component_value[0].__name__] + + elif ( + is_transformers_model or is_transformers_tokenizer + ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: + config_dict[component_name] = ["transformers", component_value[0].__name__] + + else: + config_dict[component_name] = [None, None] + + return config_dict + + +def _infer_pipeline_config_dict(pipeline_class): + parameters = inspect.signature(pipeline_class.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + component_types = pipeline_class._get_signature_types() + + # Ignore parameters that are not required for the pipeline + component_types = {k: v for k, v in component_types.items() if k in required_parameters} + config_dict = _map_component_types_to_config_dict(component_types) + + return config_dict + + +def _download_diffusers_model_config_from_hub( + pretrained_model_name_or_path, + cache_dir, + revision, + proxies, + force_download=None, + local_files_only=None, + token=None, +): + allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"] + cached_model_path = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=local_files_only, + token=token, + allow_patterns=allow_patterns, + ) + + return cached_model_path + + +class FromSingleFileMixin: + """ + Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: + r""" + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + original_config_file (`str`, *optional*): + The path to the original config file that was used to train the model. If not provided, the config file + will be inferred from the checkpoint file. + config (`str`, *optional*): + Can be either: + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + + Examples: + + ```py + >>> from diffusers import StableDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" + ... ) + + >>> # Download pipeline from local file + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt + >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt") + + >>> # Enable float16 and move to GPU + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + ... torch_dtype=torch.float16, + ... ) + >>> pipeline.to("cuda") + ``` + + """ + original_config_file = kwargs.pop("original_config_file", None) + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if original_config_file is not None: + deprecation_message = ( + "`original_config_file` argument is deprecated and will be removed in future versions." + "please use the `original_config` argument instead." + ) + deprecate("original_config_file", "1.0.0", deprecation_message) + original_config = original_config_file + + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + disable_mmap = kwargs.pop("disable_mmap", False) + + is_legacy_loading = False + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + + # We shouldn't allow configuring individual models components through a Pipeline creation method + # These model kwargs should be deprecated + scaling_factor = kwargs.get("scaling_factor", None) + if scaling_factor is not None: + deprecation_message = ( + "Passing the `scaling_factor` argument to `from_single_file is deprecated " + "and will be ignored in future versions." + ) + deprecate("scaling_factor", "1.0.0", deprecation_message) + + if original_config is not None: + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + from ..pipelines.pipeline_utils import _get_pipeline_class + + pipeline_class = _get_pipeline_class(cls, config=None) + + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + disable_mmap=disable_mmap, + ) + + if config is None: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + else: + default_pretrained_model_config_name = config + + if not os.path.isdir(default_pretrained_model_config_name): + # Provided config is a repo_id + if default_pretrained_model_config_name.count("/") > 1: + raise ValueError( + f'The provided config "{config}"' + " is neither a valid local path nor a valid repo id. Please check the parameter." + ) + try: + # Attempt to download the config files for the pipeline + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=local_files_only, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + except LocalEntryNotFoundError: + # `local_files_only=True` but a local diffusers format model config is not available in the cache + # If `original_config` is not provided, we need override `local_files_only` to False + # to fetch the config files from the hub so that we have a way + # to configure the pipeline components. + + if original_config is None: + logger.warning( + "`local_files_only` is True but no local configs were found for this checkpoint.\n" + "Attempting to download the necessary config files for this pipeline.\n" + ) + cached_model_config_path = _download_diffusers_model_config_from_hub( + default_pretrained_model_config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + local_files_only=False, + token=token, + ) + config_dict = pipeline_class.load_config(cached_model_config_path) + + else: + # For backwards compatibility + # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components + logger.warning( + "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" + "This may lead to errors if the model components are not correctly inferred. \n" + "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" + "e.g. `from_single_file(, config=) \n" + "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " + "the necessary config files.\n" + ) + is_legacy_loading = True + cached_model_config_path = None + + config_dict = _infer_pipeline_config_dict(pipeline_class) + config_dict["_class_name"] = pipeline_class.__name__ + + else: + # Provided config is a path to a local directory attempt to load directly. + cached_model_config_path = default_pretrained_model_config_name + config_dict = pipeline_class.load_config(cached_model_config_path) + + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + + expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + from diffusers import pipelines + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + if name in SINGLE_FILE_OPTIONAL_COMPONENTS: + return False + + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + for name, (library_name, class_name) in logging.tqdm( + sorted(init_dict.items()), desc="Loading pipeline components..." + ): + loaded_sub_model = None + is_pipeline_module = hasattr(pipelines, library_name) + + if name in passed_class_obj: + loaded_sub_model = passed_class_obj[name] + + else: + try: + loaded_sub_model = load_single_file_sub_model( + library_name=library_name, + class_name=class_name, + name=name, + checkpoint=checkpoint, + is_pipeline_module=is_pipeline_module, + cached_model_config_path=cached_model_config_path, + pipelines=pipelines, + torch_dtype=torch_dtype, + original_config=original_config, + local_files_only=local_files_only, + is_legacy_loading=is_legacy_loading, + disable_mmap=disable_mmap, + **kwargs, + ) + except SingleFileComponentError as e: + raise SingleFileComponentError( + ( + f"{e.message}\n" + f"Please load the component before passing it in as an argument to `from_single_file`.\n" + f"\n" + f"{name} = {class_name}.from_pretrained('...')\n" + f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n" + f"\n" + ) + ) + + init_kwargs[name] = loaded_sub_model + + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # deprecated kwargs + load_safety_checker = kwargs.pop("load_safety_checker", None) + if load_safety_checker is not None: + deprecation_message = ( + "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`" + "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`" + ) + deprecate("load_safety_checker", "1.0.0", deprecation_message) + + safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype) + init_kwargs.update(safety_checker_components) + + pipe = pipeline_class(**init_kwargs) + + return pipe diff --git a/src/diffusers/loaders/single_file/single_file_model.py b/src/diffusers/loaders/single_file/single_file_model.py new file mode 100644 index 000000000000..0b97d7fbfafd --- /dev/null +++ b/src/diffusers/loaders/single_file/single_file_model.py @@ -0,0 +1,440 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import re +from contextlib import nullcontext +from typing import Optional + +import torch +from huggingface_hub.utils import validate_hf_hub_args +from typing_extensions import Self + +from ... import __version__ +from ...quantizers import DiffusersAutoQuantizer +from ...utils import deprecate, is_accelerate_available, logging +from .single_file_utils import ( + SingleFileComponentError, + convert_animatediff_checkpoint_to_diffusers, + convert_auraflow_transformer_checkpoint_to_diffusers, + convert_autoencoder_dc_checkpoint_to_diffusers, + convert_controlnet_checkpoint, + convert_flux_transformer_checkpoint_to_diffusers, + convert_hunyuan_video_transformer_to_diffusers, + convert_ldm_unet_checkpoint, + convert_ldm_vae_checkpoint, + convert_ltx_transformer_checkpoint_to_diffusers, + convert_ltx_vae_checkpoint_to_diffusers, + convert_lumina2_to_diffusers, + convert_mochi_transformer_checkpoint_to_diffusers, + convert_sana_transformer_to_diffusers, + convert_sd3_transformer_checkpoint_to_diffusers, + convert_stable_cascade_unet_single_file_to_diffusers, + convert_wan_transformer_to_diffusers, + convert_wan_vae_to_diffusers, + create_controlnet_diffusers_config_from_ldm, + create_unet_diffusers_config_from_ldm, + create_vae_diffusers_config_from_ldm, + fetch_diffusers_config, + fetch_original_config, + load_single_file_checkpoint, +) + + +logger = logging.get_logger(__name__) + + +if is_accelerate_available(): + from accelerate import dispatch_model, init_empty_weights + + from ...models.modeling_utils import load_model_dict_into_meta + + +SINGLE_FILE_LOADABLE_CLASSES = { + "StableCascadeUNet": { + "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, + }, + "UNet2DConditionModel": { + "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, + "config_mapping_fn": create_unet_diffusers_config_from_ldm, + "default_subfolder": "unet", + "legacy_kwargs": { + "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args + }, + }, + "AutoencoderKL": { + "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, + "config_mapping_fn": create_vae_diffusers_config_from_ldm, + "default_subfolder": "vae", + }, + "ControlNetModel": { + "checkpoint_mapping_fn": convert_controlnet_checkpoint, + "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, + }, + "SD3Transformer2DModel": { + "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "MotionAdapter": { + "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, + }, + "SparseControlNetModel": { + "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, + }, + "FluxTransformer2DModel": { + "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "LTXVideoTransformer3DModel": { + "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLLTXVideo": { + "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "default_subfolder": "vae", + }, + "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, + "MochiTransformer3DModel": { + "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "HunyuanVideoTransformer3DModel": { + "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, + "default_subfolder": "transformer", + }, + "AuraFlowTransformer2DModel": { + "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "Lumina2Transformer2DModel": { + "checkpoint_mapping_fn": convert_lumina2_to_diffusers, + "default_subfolder": "transformer", + }, + "SanaTransformer2DModel": { + "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, + "default_subfolder": "transformer", + }, + "WanTransformer3DModel": { + "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLWan": { + "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, + "default_subfolder": "vae", + }, +} + + +def _get_single_file_loadable_mapping_class(cls): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: + loadable_class = getattr(diffusers_module, loadable_class_str) + + if issubclass(cls, loadable_class): + return loadable_class_str + + return None + + +def _get_mapping_function_kwargs(mapping_fn, **kwargs): + parameters = inspect.signature(mapping_fn).parameters + + mapping_kwargs = {} + for parameter in parameters: + if parameter in kwargs: + mapping_kwargs[parameter] = kwargs[parameter] + + return mapping_kwargs + + +class FromOriginalModelMixin: + """ + Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self: + r""" + Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model + is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path_or_dict (`str`, *optional*): + Can be either: + - A link to the `.safetensors` or `.ckpt` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a local *file* containing the weights of the component model. + - A state dict containing the component model weights. + config (`str`, *optional*): + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted + on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component + configs in Diffusers format. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + original_config (`str`, *optional*): + Dict or path to a yaml file containing the configuration for the model in its original format. + If a dict is provided, it will be used to initialize the model configuration. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + ```py + >>> from diffusers import StableCascadeUNet + + >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors" + >>> model = StableCascadeUNet.from_single_file(ckpt_path) + ``` + """ + + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: + raise ValueError( + f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" + ) + + pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) + if pretrained_model_link_or_path is not None: + deprecation_message = ( + "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" + ) + deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) + pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path + + config = kwargs.pop("config", None) + original_config = kwargs.pop("original_config", None) + + if config is not None and original_config is not None: + raise ValueError( + "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments" + ) + + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + subfolder = kwargs.pop("subfolder", None) + revision = kwargs.pop("revision", None) + config_revision = kwargs.pop("config_revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + device = kwargs.pop("device", None) + disable_mmap = kwargs.pop("disable_mmap", False) + + user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + if quantization_config is not None: + user_agent["quant"] = quantization_config.quant_method.value + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + checkpoint = pretrained_model_link_or_path_or_dict + else: + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path_or_dict, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + disable_mmap=disable_mmap, + user_agent=user_agent, + ) + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + else: + hf_quantizer = None + + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] + + checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] + if original_config is not None: + if "config_mapping_fn" in mapping_functions: + config_mapping_fn = mapping_functions["config_mapping_fn"] + else: + config_mapping_fn = None + + if config_mapping_fn is None: + raise ValueError( + ( + f"`original_config` has been provided for {mapping_class_name} but no mapping function" + "was found to convert the original config to a Diffusers config in" + "`diffusers.loaders.single_file_utils`" + ) + ) + + if isinstance(original_config, str): + # If original_config is a URL or filepath fetch the original_config dict + original_config = fetch_original_config(original_config, local_files_only=local_files_only) + + config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) + diffusers_model_config = config_mapping_fn( + original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs + ) + else: + if config is not None: + if isinstance(config, str): + default_pretrained_model_config_name = config + else: + raise ValueError( + ( + "Invalid `config` argument. Please provide a string representing a repo id" + "or path to a local Diffusers model repo." + ) + ) + + else: + config = fetch_diffusers_config(checkpoint) + default_pretrained_model_config_name = config["pretrained_model_name_or_path"] + + if "default_subfolder" in mapping_functions: + subfolder = mapping_functions["default_subfolder"] + + subfolder = subfolder or config.pop( + "subfolder", None + ) # some configs contain a subfolder key, e.g. StableCascadeUNet + + diffusers_model_config = cls.load_config( + pretrained_model_name_or_path=default_pretrained_model_config_name, + subfolder=subfolder, + local_files_only=local_files_only, + token=token, + revision=config_revision, + ) + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + + # Map legacy kwargs to new kwargs + if "legacy_kwargs" in mapping_functions: + legacy_kwargs = mapping_functions["legacy_kwargs"] + for legacy_key, new_key in legacy_kwargs.items(): + if legacy_key in kwargs: + kwargs[new_key] = kwargs.pop(legacy_key) + + model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} + diffusers_model_config.update(model_kwargs) + + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + if not diffusers_format_checkpoint: + raise SingleFileComponentError( + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(diffusers_model_config) + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, + device_map=None, + state_dict=diffusers_format_checkpoint, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + device_map = None + if is_accelerate_available(): + param_device = torch.device(device) if device else torch.device("cpu") + empty_state_dict = model.state_dict() + unexpected_keys = [ + param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict + ] + device_map = {"": param_device} + load_model_dict_into_meta( + model, + diffusers_format_checkpoint, + dtype=torch_dtype, + device_map=device_map, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: + model.to(torch_dtype) + + model.eval() + + if device_map is not None: + device_map_kwargs = {"device_map": device_map} + dispatch_model(model, **device_map_kwargs) + + return model diff --git a/src/diffusers/loaders/single_file/single_file_utils.py b/src/diffusers/loaders/single_file/single_file_utils.py new file mode 100644 index 000000000000..a37a6c22d432 --- /dev/null +++ b/src/diffusers/loaders/single_file/single_file_utils.py @@ -0,0 +1,3295 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + +import copy +import os +import re +from contextlib import nullcontext +from io import BytesIO +from urllib.parse import urlparse + +import requests +import torch +import yaml + +from ...models.modeling_utils import load_state_dict +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EDMDPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + deprecate, + is_accelerate_available, + is_transformers_available, + logging, +) +from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT +from ...utils.hub_utils import _get_model_file + + +if is_transformers_available(): + from transformers import AutoImageProcessor + +if is_accelerate_available(): + from accelerate import init_empty_weights + + from ...models.modeling_utils import load_model_dict_into_meta + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CHECKPOINT_KEY_NAMES = { + "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", + "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", + "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", + "controlnet": [ + "control_model.time_embed.0.weight", + "controlnet_cond_embedding.conv_in.weight", + ], + # TODO: find non-Diffusers keys for controlnet_xl + "controlnet_xl": "add_embedding.linear_1.weight", + "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", + "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight", + "playground-v2-5": "edm_mean", + "inpainting": "model.diffusion_model.input_blocks.0.0.weight", + "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", + "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", + "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight", + "open_clip": "cond_stage_model.model.token_embedding.weight", + "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", + "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", + "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", + "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", + "stable_cascade_stage_c": "clip_txt_mapper.weight", + "sd3": [ + "joint_blocks.0.context_block.adaLN_modulation.1.bias", + "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + ], + "sd35_large": [ + "joint_blocks.37.x_block.mlp.fc1.weight", + "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", + ], + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", + "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", + "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", + "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", + "animatediff_rgb": "controlnet_cond_embedding.weight", + "auraflow": [ + "double_layers.0.attn.w2q.weight", + "double_layers.0.attn.w1q.weight", + "cond_seq_linear.weight", + "t_embedder.mlp.0.weight", + ], + "flux": [ + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + ], + "ltx-video": [ + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + "patchify_proj.weight", + "transformer_blocks.27.scale_shift_table", + "vae.per_channel_statistics.mean-of-means", + ], + "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", + "autoencoder-dc-sana": "encoder.project_in.conv.bias", + "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], + "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", + "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", + "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], + "sana": [ + "blocks.0.cross_attn.q_linear.weight", + "blocks.0.cross_attn.q_linear.bias", + "blocks.0.cross_attn.kv_linear.weight", + "blocks.0.cross_attn.kv_linear.bias", + ], + "wan": ["model.diffusion_model.head.modulation", "head.modulation"], + "wan_vae": "decoder.middle.0.residual.0.gamma", +} + +DIFFUSERS_DEFAULT_PIPELINE_PATHS = { + "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"}, + "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"}, + "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, + "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, + "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, + "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"}, + "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, + "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, + "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"}, + "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"}, + "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"}, + "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, + "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"}, + "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, + "stable_cascade_stage_b_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade", + "subfolder": "decoder_lite", + }, + "stable_cascade_stage_c": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior", + }, + "stable_cascade_stage_c_lite": { + "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", + "subfolder": "prior_lite", + }, + "sd3": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", + }, + "sd35_large": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large", + }, + "sd35_medium": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium", + }, + "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, + "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, + "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, + "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, + "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, + "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, + "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, + "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, + "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, + "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, + "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, + "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, + "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, + "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, + "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, + "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, + "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, + "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, + "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, + "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, + "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, + "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, + "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, + "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, +} + +# Use to configure model sample size when original config is provided +DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = { + "xl_base": 1024, + "xl_refiner": 1024, + "xl_inpaint": 1024, + "playground-v2-5": 1024, + "upscale": 512, + "inpainting": 512, + "inpainting_v2": 512, + "controlnet": 512, + "instruct-pix2pix": 512, + "v2": 768, + "v1": 512, +} + + +DIFFUSERS_TO_LDM_MAPPING = { + "unet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "conv_norm_out.weight": "out.0.weight", + "conv_norm_out.bias": "out.0.bias", + "conv_out.weight": "out.2.weight", + "conv_out.bias": "out.2.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "controlnet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", + "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", + "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "vae": { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.norm_out.bias", + "quant_conv.weight": "quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + }, + "openclip": { + "layers": { + "text_model.embeddings.position_embedding.weight": "positional_embedding", + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.final_layer_norm.weight": "ln_final.weight", + "text_model.final_layer_norm.bias": "ln_final.bias", + "text_projection.weight": "text_projection", + }, + "transformer": { + "text_model.encoder.layers.": "resblocks.", + "layer_norm1": "ln_1", + "layer_norm2": "ln_2", + ".fc1.": ".c_fc.", + ".fc2.": ".c_proj.", + ".self_attn": ".attn", + "transformer.text_model.final_layer_norm.": "ln_final.", + "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", + }, + }, +} + +SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", + "cond_stage_model.model.text_projection", +] + +# To support legacy scheduler_type argument +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + +LDM_VAE_KEYS = ["first_stage_model.", "vae."] +LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 +PLAYGROUND_VAE_SCALING_FACTOR = 0.5 +LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." +LDM_CLIP_PREFIX_TO_REMOVE = [ + "cond_stage_model.transformer.", + "conditioner.embedders.0.transformer.", +] +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 +SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"] + +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + + +class SingleFileComponentError(Exception): + def __init__(self, message=None): + self.message = message + super().__init__(self.message) + + +def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + return False + + +def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + if not is_valid_url(pretrained_model_name_or_path): + raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") + + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" + weights_name = None + repo_id = (None,) + for prefix in VALID_URL_PREFIXES: + pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") + match = re.match(pattern, pretrained_model_name_or_path) + if not match: + logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") + return repo_id, weights_name + + repo_id = f"{match.group(1)}/{match.group(2)}" + weights_name = match.group(3) + + return repo_id, weights_name + + +def _is_model_weights_in_cached_folder(cached_folder, name): + pretrained_model_name_or_path = os.path.join(cached_folder, name) + weights_exist = False + + for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]: + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + weights_exist = True + + return weights_exist + + +def _is_legacy_scheduler_kwargs(kwargs): + return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys()) + + +def load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, + disable_mmap=False, + user_agent=None, +): + if user_agent is None: + user_agent = {"file_type": "single_file", "framework": "pytorch"} + + if os.path.isfile(pretrained_model_link_or_path): + pretrained_model_link_or_path = pretrained_model_link_or_path + + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + pretrained_model_link_or_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + user_agent=user_agent, + ) + + checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + return checkpoint + + +def fetch_original_config(original_config_file, local_files_only=False): + if os.path.isfile(original_config_file): + with open(original_config_file, "r") as fp: + original_config_file = fp.read() + + elif is_valid_url(original_config_file): + if local_files_only: + raise ValueError( + "`local_files_only` is set to True, but a URL was provided as `original_config_file`. " + "Please provide a valid local file path." + ) + + original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content) + + else: + raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") + + original_config = yaml.safe_load(original_config_file) + + return original_config + + +def is_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip"] in checkpoint: + return True + + return False + + +def is_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint: + return True + + return False + + +def is_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint: + return True + + return False + + +def is_open_clip_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: + return True + + return False + + +def is_open_clip_sdxl_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint: + return True + + return False + + +def is_open_clip_sd3_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: + return True + + return False + + +def is_open_clip_sdxl_refiner_model(checkpoint): + if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: + return True + + return False + + +def is_clip_model_in_single_file(class_obj, checkpoint): + is_clip_in_checkpoint = any( + [ + is_clip_model(checkpoint), + is_clip_sd3_model(checkpoint), + is_open_clip_model(checkpoint), + is_open_clip_sdxl_model(checkpoint), + is_open_clip_sdxl_refiner_model(checkpoint), + is_open_clip_sd3_model(checkpoint), + ] + ) + if ( + class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection" + ) and is_clip_in_checkpoint: + return True + + return False + + +def infer_diffusers_model_type(checkpoint): + if ( + CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9 + ): + if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "inpainting_v2" + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_inpaint" + else: + model_type = "inpainting" + + elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + model_type = "v2" + + elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint: + model_type = "playground-v2-5" + + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + model_type = "xl_base" + + elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: + model_type = "xl_refiner" + + elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: + model_type = "upscale" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]): + if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint: + if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint: + model_type = "controlnet_xl_large" + elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint: + model_type = "controlnet_xl_mid" + else: + model_type = "controlnet_xl_small" + else: + model_type = "controlnet" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536 + ): + model_type = "stable_cascade_stage_c_lite" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048 + ): + model_type = "stable_cascade_stage_c" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576 + ): + model_type = "stable_cascade_stage_b_lite" + + elif ( + CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640 + ): + model_type = "stable_cascade_stage_b" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any( + checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"] + ): + if "model.diffusion_model.pos_embed" in checkpoint: + key = "model.diffusion_model.pos_embed" + else: + key = "pos_embed" + + if checkpoint[key].shape[1] == 36864: + model_type = "sd3" + elif checkpoint[key].shape[1] == 147456: + model_type = "sd35_medium" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]): + model_type = "sd35_large" + + elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: + if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: + model_type = "animatediff_scribble" + + elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: + model_type = "animatediff_rgb" + + elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: + model_type = "animatediff_v2" + + elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: + model_type = "animatediff_sdxl_beta" + + elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24: + model_type = "animatediff_v1" + + else: + model_type = "animatediff_v3" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): + if any( + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] + ): + if "model.diffusion_model.img_in.weight" in checkpoint: + key = "model.diffusion_model.img_in.weight" + else: + key = "img_in.weight" + + if checkpoint[key].shape[1] == 384: + model_type = "flux-fill" + elif checkpoint[key].shape[1] == 128: + model_type = "flux-depth" + else: + model_type = "flux-dev" + else: + model_type = "flux-schnell" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): + if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: + model_type = "ltx-video-0.9.5" + elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: + model_type = "ltx-video-0.9.1" + else: + model_type = "ltx-video" + + elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: + encoder_key = "encoder.project_in.conv.conv.bias" + decoder_key = "decoder.project_in.main.conv.weight" + + if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint: + model_type = "autoencoder-dc-f32c32-sana" + + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32: + model_type = "autoencoder-dc-f32c32" + + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128: + model_type = "autoencoder-dc-f64c128" + + else: + model_type = "autoencoder-dc-f128c512" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): + model_type = "mochi-1-preview" + + elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: + model_type = "hunyuan-video" + + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): + model_type = "auraflow" + + elif ( + CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint + and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8 + ): + model_type = "instruct-pix2pix" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): + model_type = "lumina2" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): + model_type = "sana" + + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): + if "model.diffusion_model.patch_embedding.weight" in checkpoint: + target_key = "model.diffusion_model.patch_embedding.weight" + else: + target_key = "patch_embedding.weight" + + if checkpoint[target_key].shape[0] == 1536: + model_type = "wan-t2v-1.3B" + elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: + model_type = "wan-t2v-14B" + else: + model_type = "wan-i2v-14B" + elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: + # All Wan models use the same VAE so we can use the same default model repo to fetch the config + model_type = "wan-t2v-14B" + else: + model_type = "v1" + + return model_type + + +def fetch_diffusers_config(checkpoint): + model_type = infer_diffusers_model_type(checkpoint) + model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] + model_path = copy.deepcopy(model_path) + + return model_path + + +def set_image_size(checkpoint, image_size=None): + if image_size: + return image_size + + model_type = infer_diffusers_model_type(checkpoint) + image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type] + + return image_size + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None +): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if image_size is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + if num_in_channels is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + in_channels = num_in_channels + else: + in_channels = unet_params["in_channels"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": in_channels, + "down_block_types": down_block_types, + "block_out_channels": block_out_channels, + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if upcast_attention is not None: + deprecation_message = ( + "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + config["upcast_attention"] = upcast_attention + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = up_block_types + + return config + + +def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): + if image_size is not None: + deprecation_message = ( + "Configuring ControlNetModel with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size) + + controlnet_config = { + "conditioning_channels": unet_params["hint_channels"], + "in_channels": diffusers_unet_config["in_channels"], + "down_block_types": diffusers_unet_config["down_block_types"], + "block_out_channels": diffusers_unet_config["block_out_channels"], + "layers_per_block": diffusers_unet_config["layers_per_block"], + "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], + "attention_head_dim": diffusers_unet_config["attention_head_dim"], + "use_linear_projection": diffusers_unet_config["use_linear_projection"], + "class_embed_type": diffusers_unet_config["class_embed_type"], + "addition_embed_type": diffusers_unet_config["addition_embed_type"], + "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], + "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], + "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], + } + + return controlnet_config + + +def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if image_size is not None: + deprecation_message = ( + "Configuring AutoencoderKL with the `image_size` argument" + "is deprecated and will be ignored in future versions." + ) + deprecate("image_size", "1.0.0", deprecation_message) + + image_size = set_image_size(checkpoint, image_size=image_size) + + if "edm_mean" in checkpoint and "edm_std" in checkpoint: + latents_mean = checkpoint["edm_mean"] + latents_std = checkpoint["edm_std"] + else: + latents_mean = None + latents_std = None + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None): + scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR + + elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]): + scaling_factor = original_config["model"]["params"]["scale_factor"] + + elif scaling_factor is None: + scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": down_block_types, + "up_block_types": up_block_types, + "block_out_channels": block_out_channels, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": scaling_factor, + } + if latents_mean is not None and latents_std is not None: + config.update({"latents_mean": latents_mean, "latents_std": latents_std}) + + return config + + +def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): + for ldm_key in ldm_keys: + diffusers_key = ( + ldm_key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + if mapping: + diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): + for ldm_key in ldm_keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] + elif len(shape) == 4: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + + +def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): + is_stage_c = "clip_txt_mapper.weight" in checkpoint + + if is_stage_c: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = checkpoint[key] + else: + state_dict = {} + for key in checkpoint.keys(): + if key.endswith("in_proj_weight"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = checkpoint[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = checkpoint[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = checkpoint[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = checkpoint[key] + + return state_dict + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + unet_key = LDM_UNET_KEY + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning("Checkpoint has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] + for diffusers_key, ldm_key in ldm_unet_keys.items(): + if ldm_key not in unet_state_dict: + continue + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): + class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] + for diffusers_key, ldm_key in class_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): + addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] + for diffusers_key, ldm_key in addition_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + unet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # Up Blocks + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + + resnets = [ + key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key + ] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.bias" + ] + if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.bias" + ] + + return new_checkpoint + + +def convert_controlnet_checkpoint( + checkpoint, + config, + **kwargs, +): + # Return checkpoint if it's already been converted + if "time_embedding.linear_1.weight" in checkpoint: + return checkpoint + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + update_unet_resnet_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, + ) + else: + update_unet_attention_ldm_to_diffusers( + middle_blocks[key], + new_checkpoint, + controlnet_state_dict, + mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "" + for ldm_vae_key in LDM_VAE_KEYS: + if any(k.startswith(ldm_vae_key) for k in keys): + vae_key = ldm_vae_key + + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] + for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): + if ldm_key not in vae_state_dict: + continue + new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len(config["down_block_types"]) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, + ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( + f"encoder.down.{i}.downsample.conv.bias" + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len(config["up_block_types"]) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, + ) + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = [] + remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) + if remove_prefix: + remove_prefixes.append(remove_prefix) + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def convert_open_clip_checkpoint( + text_model, + checkpoint, + prefix="cond_stage_model.model.", +): + text_model_dict = {} + text_proj_key = prefix + "text_projection" + + if text_proj_key in checkpoint: + text_proj_dim = int(checkpoint[text_proj_key].shape[0]) + elif hasattr(text_model.config, "hidden_size"): + text_proj_dim = text_model.config.hidden_size + else: + text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM + + keys = list(checkpoint.keys()) + keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE + + openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] + for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): + ldm_key = prefix + ldm_key + if ldm_key not in checkpoint: + continue + if ldm_key in keys_to_ignore: + continue + if ldm_key.endswith("text_projection"): + text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() + else: + text_model_dict[diffusers_key] = checkpoint[ldm_key] + + for key in keys: + if key in keys_to_ignore: + continue + + if not key.startswith(prefix + "transformer."): + continue + + diffusers_key = key.replace(prefix + "transformer.", "") + transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] + for new_key, old_key in transformer_diffusers_to_ldm_map.items(): + diffusers_key = ( + diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") + ) + + if key.endswith(".in_proj_weight"): + weight_value = checkpoint.get(key) + + text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() + text_model_dict[diffusers_key + ".k_proj.weight"] = ( + weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() + + elif key.endswith(".in_proj_bias"): + weight_value = checkpoint.get(key) + text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() + text_model_dict[diffusers_key + ".k_proj.bias"] = ( + weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() + else: + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_clip_model_from_ldm( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, + is_legacy_loading=False, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + # For backwards compatibility + # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo + # in the cache_dir, rather than in a subfolder of the Diffusers model + if is_legacy_loading: + logger.warning( + ( + "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " + "the local cache directory with the necessary CLIP model config files. " + "Attempting to load CLIP model from legacy cache directory." + ) + ) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "text_encoder" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] + + if is_clip_model(checkpoint): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") + diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) + + elif is_open_clip_model(checkpoint): + prefix = "cond_stage_model.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim + ): + prefix = "conditioner.embedders.1.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif is_open_clip_sdxl_refiner_model(checkpoint): + prefix = "conditioner.embedders.0.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") + + else: + raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") + + if is_accelerate_available(): + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + else: + model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if torch_dtype is not None: + model.to(torch_dtype) + + model.eval() + + return model + + +def _legacy_load_scheduler( + cls, + checkpoint, + component_name, + original_config=None, + **kwargs, +): + scheduler_type = kwargs.get("scheduler_type", None) + prediction_type = kwargs.get("prediction_type", None) + + if scheduler_type is not None: + deprecation_message = ( + "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + "scheduler = DDIMScheduler()\n" + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" + ) + deprecate("scheduler_type", "1.0.0", deprecation_message) + + if prediction_type is not None: + deprecation_message = ( + "Please configure an instance of a Scheduler with the appropriate `prediction_type` and " + "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n' + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" + ) + deprecate("prediction_type", "1.0.0", deprecation_message) + + scheduler_config = SCHEDULER_DEFAULT_CONFIG + model_type = infer_diffusers_model_type(checkpoint=checkpoint) + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + + if original_config: + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000) + else: + num_train_timesteps = 1000 + + scheduler_config["num_train_timesteps"] = num_train_timesteps + + if model_type == "v2": + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + + else: + prediction_type = prediction_type or "epsilon" + + scheduler_config["prediction_type"] = prediction_type + + if model_type in ["xl_base", "xl_refiner"]: + scheduler_type = "euler" + elif model_type == "playground": + scheduler_type = "edm_dpm_solver_multistep" + else: + if original_config: + beta_start = original_config["model"]["params"].get("linear_start") + beta_end = original_config["model"]["params"].get("linear_end") + + else: + beta_start = 0.02 + beta_end = 0.085 + + scheduler_config["beta_start"] = beta_start + scheduler_config["beta_end"] = beta_end + scheduler_config["beta_schedule"] = "scaled_linear" + scheduler_config["clip_sample"] = False + scheduler_config["set_alpha_to_one"] = False + + # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers + if component_name == "low_res_scheduler": + return cls.from_config( + { + "beta_end": 0.02, + "beta_schedule": "scaled_linear", + "beta_start": 0.0001, + "clip_sample": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "trained_betas": None, + "variance_type": "fixed_small", + } + ) + + if scheduler_type is None: + return cls.from_config(scheduler_config) + + elif scheduler_type == "pndm": + scheduler_config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(scheduler_config) + + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + + elif scheduler_type == "ddim": + scheduler = DDIMScheduler.from_config(scheduler_config) + + elif scheduler_type == "edm_dpm_solver_multistep": + scheduler_config = { + "algorithm_type": "dpmsolver++", + "dynamic_thresholding_ratio": 0.995, + "euler_at_final": False, + "final_sigmas_type": "zero", + "lower_order_final": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "rho": 7.0, + "sample_max_value": 1.0, + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + "solver_order": 2, + "solver_type": "midpoint", + "thresholding": False, + } + scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config) + + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + return scheduler + + +def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "tokenizer" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + + return tokenizer + + +def _legacy_load_safety_checker(local_files_only, torch_dtype): + # Support for loading safety checker components using the deprecated + # `load_safety_checker` argument. + + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + feature_extractor = AutoImageProcessor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype + ) + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype + ) + + return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + + return tuple(sorted(set(attn2_layers))) + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + +def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 + dual_attention_layers = get_attn2_layers(checkpoint) + + caption_projection_dim = get_caption_projection_dim(checkpoint) + has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) + + # Positional and patch embeddings. + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + + # Context projections. + converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") + + # Pooled context projection. + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") + + # Transformer blocks 🎸. + for i in range(num_layers): + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 + ) + + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) + + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.bias" + ) + + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + + # norms. + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" + ) + else: + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), + dim=caption_projection_dim, + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), + dim=caption_projection_dim, + ) + + # ffs. + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.bias" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim + ) + + return converted_state_dict + + +def is_t5_in_single_file(checkpoint): + if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: + return True + + return False + + +def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = ["text_encoders.t5xxl.transformer."] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_t5_model_from_checkpoint( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) + + if is_accelerate_available(): + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + else: + model.load_state_dict(diffusers_format_checkpoint) + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + return model + + +def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + for k, v in checkpoint.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 + mlp_ratio = 4.0 + inner_dim = 3072 + + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") + + # guidance + has_guidance = any("guidance" in k for k in checkpoint) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( + "guidance_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") + + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict + + +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} + + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + + VAE_095_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + } + + if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remap_qkv_(key: str, state_dict): + qkv = state_dict.pop(key) + q, k, v = torch.chunk(qkv, 3, dim=0) + parent_module, _, _ = key.rpartition(".qkv.conv.weight") + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() + + def remap_proj_conv_(key: str, state_dict): + parent_module, _, _ = key.rpartition(".proj.conv.weight") + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() + + AE_KEYS_RENAME_DICT = { + # common + "main.": "", + "op_list.": "", + "context_module": "attn", + "local_module": "conv_out", + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 + # If there were more scales, there would be more layers, so a loop would be better to handle this + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", + "depth_conv.conv": "conv_depth", + "inverted_conv.conv": "conv_inverted", + "point_conv.conv": "conv_point", + "point_conv.norm": "norm", + "conv.conv.": "conv.", + "conv1.conv": "conv1", + "conv2.conv": "conv2", + "conv2.norm": "norm", + "proj.norm": "norm_out", + # encoder + "encoder.project_in.conv": "encoder.conv_in", + "encoder.project_out.0.conv": "encoder.conv_out", + "encoder.stages": "encoder.down_blocks", + # decoder + "decoder.project_in.conv": "decoder.conv_in", + "decoder.project_out.0": "decoder.norm_out", + "decoder.project_out.2.conv": "decoder.conv_out", + "decoder.stages": "decoder.up_blocks", + } + + AE_F32C32_F64C128_F128C512_KEYS = { + "encoder.project_in.conv": "encoder.conv_in.conv", + "decoder.project_out.2.conv": "decoder.conv_out.conv", + } + + AE_SPECIAL_KEYS_REMAP = { + "qkv.conv.weight": remap_qkv_, + "proj.conv.weight": remap_proj_conv_, + } + if "encoder.project_in.conv.bias" not in converted_state_dict: + AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) + + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + # Convert patch_embed + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Convert time_embed + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) + else: + converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) + + # Visual attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + converted_state_dict[block_prefix + "attn1.to_q.weight"] = q + converted_state_dict[block_prefix + "attn1.to_k.weight"] = k + converted_state_dict[block_prefix + "attn1.to_v.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop( + old_prefix + "attn.proj_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_y.weight" + ) + converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + old_prefix + "attn.proj_y.weight" + ) + converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_x.w1.weight") + ) + converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_y.w1.weight") + ) + converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + + return converted_state_dict + + +def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + def update_state_dict_(state_dict, old_key, new_key): + state_dict[new_key] = state_dict.pop(old_key) + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(checkpoint, key, new_key) + + for key in list(checkpoint.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, checkpoint) + + return checkpoint + + +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + state_dict_keys = list(checkpoint.keys()) + + # Handle register tokens and positional embeddings + converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) + + # Handle time step projection + converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) + converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) + converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) + converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) + + # Handle context embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) + + # Calculate the number of layers + def calculate_layers(keys, key_prefix): + layers = set() + for k in keys: + if key_prefix in k: + layer_num = int(k.split(".")[1]) # get the layer number + layers.add(layer_num) + return len(layers) + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks + for i in range(mmdit_layers): + # Feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.{k}.weight", None + ) + + # Norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.attn.{k}.weight", None + ) + + # Single-DiT blocks + for i in range(single_dit_layers): + # Feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.mlp.{k}.weight", None + ) + + # Norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"single_layers.{i}.modCX.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.attn.{k}.weight", None + ) + # Final blocks + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) + + # Handle the final norm layer + norm_weight = checkpoint.pop("modF.1.weight", None) + if norm_weight is not None: + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) + else: + converted_state_dict["norm_out.linear.weight"] = None + + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") + + return converted_state_dict + + +def convert_lumina2_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Original Lumina-Image-2 has an extra norm paramter that is unused + # We just remove it here + checkpoint.pop("norm_final.weight", None) + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + LUMINA_KEY_MAP = { + "cap_embedder": "time_caption_embed.caption_embedder", + "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", + "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", + "attention": "attn", + ".out.": ".to_out.0.", + "k_norm": "norm_k", + "q_norm": "norm_q", + "w1": "linear_1", + "w2": "linear_2", + "w3": "linear_3", + "adaLN_modulation.1": "norm1.linear", + } + ATTENTION_NORM_MAP = { + "attention_norm1": "norm1.norm", + "attention_norm2": "norm2", + } + CONTEXT_REFINER_MAP = { + "context_refiner.0.attention_norm1": "context_refiner.0.norm1", + "context_refiner.0.attention_norm2": "context_refiner.0.norm2", + "context_refiner.1.attention_norm1": "context_refiner.1.norm1", + "context_refiner.1.attention_norm2": "context_refiner.1.norm2", + } + FINAL_LAYER_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear_1", + "final_layer.linear": "norm_out.linear_2", + } + + def convert_lumina_attn_to_diffusers(tensor, diffusers_key): + q_dim = 2304 + k_dim = v_dim = 768 + + to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) + + return { + diffusers_key.replace("qkv", "to_q"): to_q, + diffusers_key.replace("qkv", "to_k"): to_k, + diffusers_key.replace("qkv", "to_v"): to_v, + } + + for key in keys: + diffusers_key = key + for k, v in CONTEXT_REFINER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in FINAL_LAYER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in ATTENTION_NORM_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in LUMINA_KEY_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + + if "qkv" in diffusers_key: + converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) + else: + converted_state_dict[diffusers_key] = checkpoint.pop(key) + + return converted_state_dict + + +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + + # Positional and patch embeddings. + checkpoint.pop("pos_embed") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") + + # Caption Projection. + checkpoint.pop("y_embedder.y_embedding") + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") + + for i in range(num_layers): + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( + f"blocks.{i}.scale_shift_table" + ) + + # Self-Attention + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.attn.proj.bias" + ) + + # Cross-Attention + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.bias" + ) + + linear_sample_k, linear_sample_v = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 + ) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.bias" + ) + + # MLP + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.point_conv.conv.weight" + ) + + # Final layer + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") + + return converted_state_dict + + +def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "cross_attn": "attn2", + "self_attn": "attn1", + ".o.": ".to_out.0.", + ".q.": ".to_q.", + ".k.": ".to_k.", + ".v.": ".to_v.", + ".k_img.": ".add_k_proj.", + ".v_img.": ".add_v_proj.", + ".norm_k_img.": ".norm_added_k.", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + } + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + converted_state_dict[new_key] = checkpoint.pop(key) + + return converted_state_dict + + +def convert_wan_vae_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in checkpoint.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + converted_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + converted_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + converted_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + converted_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + converted_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + converted_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + converted_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + converted_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + converted_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + converted_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + converted_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + converted_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + converted_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + converted_state_dict[new_key] = value + else: + # Keep other keys unchanged + converted_state_dict[key] = value + + return converted_state_dict diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index a2f27b765a1b..4eeab1679e66 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -11,430 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import re -from contextlib import nullcontext -from typing import Optional -import torch -from huggingface_hub.utils import validate_hf_hub_args -from typing_extensions import Self -from .. import __version__ -from ..quantizers import DiffusersAutoQuantizer -from ..utils import deprecate, is_accelerate_available, logging -from .single_file_utils import ( - SingleFileComponentError, - convert_animatediff_checkpoint_to_diffusers, - convert_auraflow_transformer_checkpoint_to_diffusers, - convert_autoencoder_dc_checkpoint_to_diffusers, - convert_controlnet_checkpoint, - convert_flux_transformer_checkpoint_to_diffusers, - convert_hunyuan_video_transformer_to_diffusers, - convert_ldm_unet_checkpoint, - convert_ldm_vae_checkpoint, - convert_ltx_transformer_checkpoint_to_diffusers, - convert_ltx_vae_checkpoint_to_diffusers, - convert_lumina2_to_diffusers, - convert_mochi_transformer_checkpoint_to_diffusers, - convert_sana_transformer_to_diffusers, - convert_sd3_transformer_checkpoint_to_diffusers, - convert_stable_cascade_unet_single_file_to_diffusers, - convert_wan_transformer_to_diffusers, - convert_wan_vae_to_diffusers, - create_controlnet_diffusers_config_from_ldm, - create_unet_diffusers_config_from_ldm, - create_vae_diffusers_config_from_ldm, - fetch_diffusers_config, - fetch_original_config, - load_single_file_checkpoint, +from ..utils import deprecate +from .single_file.single_file_model import ( + SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401 + FromOriginalModelMixin, ) -logger = logging.get_logger(__name__) - - -if is_accelerate_available(): - from accelerate import dispatch_model, init_empty_weights - - from ..models.modeling_utils import load_model_dict_into_meta - - -SINGLE_FILE_LOADABLE_CLASSES = { - "StableCascadeUNet": { - "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, - }, - "UNet2DConditionModel": { - "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, - "config_mapping_fn": create_unet_diffusers_config_from_ldm, - "default_subfolder": "unet", - "legacy_kwargs": { - "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args - }, - }, - "AutoencoderKL": { - "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, - "config_mapping_fn": create_vae_diffusers_config_from_ldm, - "default_subfolder": "vae", - }, - "ControlNetModel": { - "checkpoint_mapping_fn": convert_controlnet_checkpoint, - "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, - }, - "SD3Transformer2DModel": { - "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, - "MotionAdapter": { - "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, - }, - "SparseControlNetModel": { - "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, - }, - "FluxTransformer2DModel": { - "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, - "LTXVideoTransformer3DModel": { - "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, - "AutoencoderKLLTXVideo": { - "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, - "default_subfolder": "vae", - }, - "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, - "MochiTransformer3DModel": { - "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, - "HunyuanVideoTransformer3DModel": { - "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, - "default_subfolder": "transformer", - }, - "AuraFlowTransformer2DModel": { - "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, - "default_subfolder": "transformer", - }, - "Lumina2Transformer2DModel": { - "checkpoint_mapping_fn": convert_lumina2_to_diffusers, - "default_subfolder": "transformer", - }, - "SanaTransformer2DModel": { - "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, - "default_subfolder": "transformer", - }, - "WanTransformer3DModel": { - "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, - "default_subfolder": "transformer", - }, - "AutoencoderKLWan": { - "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, - "default_subfolder": "vae", - }, -} - - -def _get_single_file_loadable_mapping_class(cls): - diffusers_module = importlib.import_module(__name__.split(".")[0]) - for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: - loadable_class = getattr(diffusers_module, loadable_class_str) - - if issubclass(cls, loadable_class): - return loadable_class_str - - return None - - -def _get_mapping_function_kwargs(mapping_fn, **kwargs): - parameters = inspect.signature(mapping_fn).parameters - - mapping_kwargs = {} - for parameter in parameters: - if parameter in kwargs: - mapping_kwargs[parameter] = kwargs[parameter] - - return mapping_kwargs - - -class FromOriginalModelMixin: - """ - Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self: - r""" - Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model - is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path_or_dict (`str`, *optional*): - Can be either: - - A link to the `.safetensors` or `.ckpt` file (for example - `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. - - A path to a local *file* containing the weights of the component model. - - A state dict containing the component model weights. - config (`str`, *optional*): - - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted - on the Hub. - - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component - configs in Diffusers format. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - original_config (`str`, *optional*): - Dict or path to a yaml file containing the configuration for the model in its original format. - If a dict is provided, it will be used to initialize the model configuration. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - disable_mmap ('bool', *optional*, defaults to 'False'): - Whether to disable mmap when loading a Safetensors model. This option can perform better when the model - is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - ```py - >>> from diffusers import StableCascadeUNet - - >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors" - >>> model = StableCascadeUNet.from_single_file(ckpt_path) - ``` - """ - - mapping_class_name = _get_single_file_loadable_mapping_class(cls) - # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: - if mapping_class_name is None: - raise ValueError( - f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" - ) - - pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) - if pretrained_model_link_or_path is not None: - deprecation_message = ( - "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" - ) - deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) - pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path - - config = kwargs.pop("config", None) - original_config = kwargs.pop("original_config", None) - - if config is not None and original_config is not None: - raise ValueError( - "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments" - ) - - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - cache_dir = kwargs.pop("cache_dir", None) - local_files_only = kwargs.pop("local_files_only", None) - subfolder = kwargs.pop("subfolder", None) - revision = kwargs.pop("revision", None) - config_revision = kwargs.pop("config_revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - quantization_config = kwargs.pop("quantization_config", None) - device = kwargs.pop("device", None) - disable_mmap = kwargs.pop("disable_mmap", False) - - user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} - # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` - if quantization_config is not None: - user_agent["quant"] = quantization_config.quant_method.value - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - torch_dtype = torch.float32 - logger.warning( - f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." - ) - - if isinstance(pretrained_model_link_or_path_or_dict, dict): - checkpoint = pretrained_model_link_or_path_or_dict - else: - checkpoint = load_single_file_checkpoint( - pretrained_model_link_or_path_or_dict, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - disable_mmap=disable_mmap, - user_agent=user_agent, - ) - if quantization_config is not None: - hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) - hf_quantizer.validate_environment() - torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) - - else: - hf_quantizer = None - - mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] - - checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] - if original_config is not None: - if "config_mapping_fn" in mapping_functions: - config_mapping_fn = mapping_functions["config_mapping_fn"] - else: - config_mapping_fn = None - - if config_mapping_fn is None: - raise ValueError( - ( - f"`original_config` has been provided for {mapping_class_name} but no mapping function" - "was found to convert the original config to a Diffusers config in" - "`diffusers.loaders.single_file_utils`" - ) - ) - - if isinstance(original_config, str): - # If original_config is a URL or filepath fetch the original_config dict - original_config = fetch_original_config(original_config, local_files_only=local_files_only) - - config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) - diffusers_model_config = config_mapping_fn( - original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs - ) - else: - if config is not None: - if isinstance(config, str): - default_pretrained_model_config_name = config - else: - raise ValueError( - ( - "Invalid `config` argument. Please provide a string representing a repo id" - "or path to a local Diffusers model repo." - ) - ) - - else: - config = fetch_diffusers_config(checkpoint) - default_pretrained_model_config_name = config["pretrained_model_name_or_path"] - - if "default_subfolder" in mapping_functions: - subfolder = mapping_functions["default_subfolder"] - - subfolder = subfolder or config.pop( - "subfolder", None - ) # some configs contain a subfolder key, e.g. StableCascadeUNet - - diffusers_model_config = cls.load_config( - pretrained_model_name_or_path=default_pretrained_model_config_name, - subfolder=subfolder, - local_files_only=local_files_only, - token=token, - revision=config_revision, - ) - expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) - - # Map legacy kwargs to new kwargs - if "legacy_kwargs" in mapping_functions: - legacy_kwargs = mapping_functions["legacy_kwargs"] - for legacy_key, new_key in legacy_kwargs.items(): - if legacy_key in kwargs: - kwargs[new_key] = kwargs.pop(legacy_key) - - model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} - diffusers_model_config.update(model_kwargs) - - checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - diffusers_format_checkpoint = checkpoint_mapping_fn( - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs - ) - if not diffusers_format_checkpoint: - raise SingleFileComponentError( - f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." - ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls.from_config(diffusers_model_config) - - # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( - (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") - ) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = cls._keep_in_fp32_modules - if not isinstance(keep_in_fp32_modules, list): - keep_in_fp32_modules = [keep_in_fp32_modules] - - else: - keep_in_fp32_modules = [] - - if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, - device_map=None, - state_dict=diffusers_format_checkpoint, - keep_in_fp32_modules=keep_in_fp32_modules, - ) - - device_map = None - if is_accelerate_available(): - param_device = torch.device(device) if device else torch.device("cpu") - empty_state_dict = model.state_dict() - unexpected_keys = [ - param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict - ] - device_map = {"": param_device} - load_model_dict_into_meta( - model, - diffusers_format_checkpoint, - dtype=torch_dtype, - device_map=device_map, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - ) - else: - _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - - if hf_quantizer is not None: - hf_quantizer.postprocess_model(model) - model.hf_quantizer = hf_quantizer - - if torch_dtype is not None and hf_quantizer is None: - model.to(torch_dtype) - - model.eval() - - if device_map is not None: - device_map_kwargs = {"device_map": device_map} - dispatch_model(model, **device_map_kwargs) - - return model +class FromOriginalModelMixin(FromOriginalModelMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead." + deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index b55b1b55206e..13bb67bfaad0 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -12,388 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Conversion script for the Stable Diffusion checkpoints.""" - -import copy -import os -import re -from contextlib import nullcontext -from io import BytesIO -from urllib.parse import urlparse - -import requests -import torch -import yaml - -from ..models.modeling_utils import load_state_dict -from ..schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EDMDPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) -from ..utils import ( - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, - deprecate, - is_accelerate_available, - is_transformers_available, - logging, -) -from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT -from ..utils.hub_utils import _get_model_file - - -if is_transformers_available(): - from transformers import AutoImageProcessor - -if is_accelerate_available(): - from accelerate import init_empty_weights - - from ..models.modeling_utils import load_model_dict_into_meta - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -CHECKPOINT_KEY_NAMES = { - "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", - "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", - "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", - "controlnet": [ - "control_model.time_embed.0.weight", - "controlnet_cond_embedding.conv_in.weight", - ], - # TODO: find non-Diffusers keys for controlnet_xl - "controlnet_xl": "add_embedding.linear_1.weight", - "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", - "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight", - "playground-v2-5": "edm_mean", - "inpainting": "model.diffusion_model.input_blocks.0.0.weight", - "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", - "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", - "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight", - "open_clip": "cond_stage_model.model.token_embedding.weight", - "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", - "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", - "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", - "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", - "stable_cascade_stage_c": "clip_txt_mapper.weight", - "sd3": [ - "joint_blocks.0.context_block.adaLN_modulation.1.bias", - "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", - ], - "sd35_large": [ - "joint_blocks.37.x_block.mlp.fc1.weight", - "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", - ], - "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", - "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", - "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", - "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", - "animatediff_rgb": "controlnet_cond_embedding.weight", - "auraflow": [ - "double_layers.0.attn.w2q.weight", - "double_layers.0.attn.w1q.weight", - "cond_seq_linear.weight", - "t_embedder.mlp.0.weight", - ], - "flux": [ - "double_blocks.0.img_attn.norm.key_norm.scale", - "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", - ], - "ltx-video": [ - "model.diffusion_model.patchify_proj.weight", - "model.diffusion_model.transformer_blocks.27.scale_shift_table", - "patchify_proj.weight", - "transformer_blocks.27.scale_shift_table", - "vae.per_channel_statistics.mean-of-means", - ], - "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", - "autoencoder-dc-sana": "encoder.project_in.conv.bias", - "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], - "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", - "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", - "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], - "sana": [ - "blocks.0.cross_attn.q_linear.weight", - "blocks.0.cross_attn.q_linear.bias", - "blocks.0.cross_attn.kv_linear.weight", - "blocks.0.cross_attn.kv_linear.bias", - ], - "wan": ["model.diffusion_model.head.modulation", "head.modulation"], - "wan_vae": "decoder.middle.0.residual.0.gamma", -} - -DIFFUSERS_DEFAULT_PIPELINE_PATHS = { - "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"}, - "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"}, - "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, - "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, - "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, - "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"}, - "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, - "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, - "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"}, - "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"}, - "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"}, - "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, - "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"}, - "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, - "stable_cascade_stage_b_lite": { - "pretrained_model_name_or_path": "stabilityai/stable-cascade", - "subfolder": "decoder_lite", - }, - "stable_cascade_stage_c": { - "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", - "subfolder": "prior", - }, - "stable_cascade_stage_c_lite": { - "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", - "subfolder": "prior_lite", - }, - "sd3": { - "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", - }, - "sd35_large": { - "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large", - }, - "sd35_medium": { - "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium", - }, - "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, - "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, - "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, - "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, - "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, - "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, - "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, - "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, - "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, - "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, - "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, - "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, - "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, - "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, - "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, - "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, - "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, - "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, - "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, - "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, - "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, - "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, - "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, - "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, - "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, - "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, -} - -# Use to configure model sample size when original config is provided -DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = { - "xl_base": 1024, - "xl_refiner": 1024, - "xl_inpaint": 1024, - "playground-v2-5": 1024, - "upscale": 512, - "inpainting": 512, - "inpainting_v2": 512, - "controlnet": 512, - "instruct-pix2pix": 512, - "v2": 768, - "v1": 512, -} - - -DIFFUSERS_TO_LDM_MAPPING = { - "unet": { - "layers": { - "time_embedding.linear_1.weight": "time_embed.0.weight", - "time_embedding.linear_1.bias": "time_embed.0.bias", - "time_embedding.linear_2.weight": "time_embed.2.weight", - "time_embedding.linear_2.bias": "time_embed.2.bias", - "conv_in.weight": "input_blocks.0.0.weight", - "conv_in.bias": "input_blocks.0.0.bias", - "conv_norm_out.weight": "out.0.weight", - "conv_norm_out.bias": "out.0.bias", - "conv_out.weight": "out.2.weight", - "conv_out.bias": "out.2.bias", - }, - "class_embed_type": { - "class_embedding.linear_1.weight": "label_emb.0.0.weight", - "class_embedding.linear_1.bias": "label_emb.0.0.bias", - "class_embedding.linear_2.weight": "label_emb.0.2.weight", - "class_embedding.linear_2.bias": "label_emb.0.2.bias", - }, - "addition_embed_type": { - "add_embedding.linear_1.weight": "label_emb.0.0.weight", - "add_embedding.linear_1.bias": "label_emb.0.0.bias", - "add_embedding.linear_2.weight": "label_emb.0.2.weight", - "add_embedding.linear_2.bias": "label_emb.0.2.bias", - }, - }, - "controlnet": { - "layers": { - "time_embedding.linear_1.weight": "time_embed.0.weight", - "time_embedding.linear_1.bias": "time_embed.0.bias", - "time_embedding.linear_2.weight": "time_embed.2.weight", - "time_embedding.linear_2.bias": "time_embed.2.bias", - "conv_in.weight": "input_blocks.0.0.weight", - "conv_in.bias": "input_blocks.0.0.bias", - "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", - "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", - "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", - "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", - }, - "class_embed_type": { - "class_embedding.linear_1.weight": "label_emb.0.0.weight", - "class_embedding.linear_1.bias": "label_emb.0.0.bias", - "class_embedding.linear_2.weight": "label_emb.0.2.weight", - "class_embedding.linear_2.bias": "label_emb.0.2.bias", - }, - "addition_embed_type": { - "add_embedding.linear_1.weight": "label_emb.0.0.weight", - "add_embedding.linear_1.bias": "label_emb.0.0.bias", - "add_embedding.linear_2.weight": "label_emb.0.2.weight", - "add_embedding.linear_2.bias": "label_emb.0.2.bias", - }, - }, - "vae": { - "encoder.conv_in.weight": "encoder.conv_in.weight", - "encoder.conv_in.bias": "encoder.conv_in.bias", - "encoder.conv_out.weight": "encoder.conv_out.weight", - "encoder.conv_out.bias": "encoder.conv_out.bias", - "encoder.conv_norm_out.weight": "encoder.norm_out.weight", - "encoder.conv_norm_out.bias": "encoder.norm_out.bias", - "decoder.conv_in.weight": "decoder.conv_in.weight", - "decoder.conv_in.bias": "decoder.conv_in.bias", - "decoder.conv_out.weight": "decoder.conv_out.weight", - "decoder.conv_out.bias": "decoder.conv_out.bias", - "decoder.conv_norm_out.weight": "decoder.norm_out.weight", - "decoder.conv_norm_out.bias": "decoder.norm_out.bias", - "quant_conv.weight": "quant_conv.weight", - "quant_conv.bias": "quant_conv.bias", - "post_quant_conv.weight": "post_quant_conv.weight", - "post_quant_conv.bias": "post_quant_conv.bias", - }, - "openclip": { - "layers": { - "text_model.embeddings.position_embedding.weight": "positional_embedding", - "text_model.embeddings.token_embedding.weight": "token_embedding.weight", - "text_model.final_layer_norm.weight": "ln_final.weight", - "text_model.final_layer_norm.bias": "ln_final.bias", - "text_projection.weight": "text_projection", - }, - "transformer": { - "text_model.encoder.layers.": "resblocks.", - "layer_norm1": "ln_1", - "layer_norm2": "ln_2", - ".fc1.": ".c_fc.", - ".fc2.": ".c_proj.", - ".self_attn": ".attn", - "transformer.text_model.final_layer_norm.": "ln_final.", - "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", - "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", - }, - }, -} - -SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ - "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", - "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", - "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", - "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", - "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", - "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", - "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", - "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", - "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", - "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", - "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", - "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", - "cond_stage_model.model.text_projection", -] - -# To support legacy scheduler_type argument -SCHEDULER_DEFAULT_CONFIG = { - "beta_schedule": "scaled_linear", - "beta_start": 0.00085, - "beta_end": 0.012, - "interpolation_type": "linear", - "num_train_timesteps": 1000, - "prediction_type": "epsilon", - "sample_max_value": 1.0, - "set_alpha_to_one": False, - "skip_prk_steps": True, - "steps_offset": 1, - "timestep_spacing": "leading", -} - -LDM_VAE_KEYS = ["first_stage_model.", "vae."] -LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 -PLAYGROUND_VAE_SCALING_FACTOR = 0.5 -LDM_UNET_KEY = "model.diffusion_model." -LDM_CONTROLNET_KEY = "control_model." -LDM_CLIP_PREFIX_TO_REMOVE = [ - "cond_stage_model.transformer.", - "conditioner.embedders.0.transformer.", -] -LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 -SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"] - -VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] - - -class SingleFileComponentError(Exception): - def __init__(self, message=None): - self.message = message - super().__init__(self.message) -def is_valid_url(url): - result = urlparse(url) - if result.scheme and result.netloc: - return True - - return False - - -def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): - if not is_valid_url(pretrained_model_name_or_path): - raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") - - pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" - weights_name = None - repo_id = (None,) - for prefix in VALID_URL_PREFIXES: - pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") - match = re.match(pattern, pretrained_model_name_or_path) - if not match: - logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") - return repo_id, weights_name - - repo_id = f"{match.group(1)}/{match.group(2)}" - weights_name = match.group(3) - - return repo_id, weights_name +from ..utils import deprecate +from .single_file.single_file_utils import SingleFileComponentError -def _is_model_weights_in_cached_folder(cached_folder, name): - pretrained_model_name_or_path = os.path.join(cached_folder, name) - weights_exist = False +class SingleFileComponentError(SingleFileComponentError): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SingleFileComponentError` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_files_utils import SingleFileComponentError` instead." + deprecate("diffusers.loaders.single_file_utils. ", "0.36", deprecation_message) + super().__init__(*args, **kwargs) - for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]: - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - weights_exist = True - return weights_exist +def is_valid_url(url): + from .single_file.single_file_utils import is_valid_url + deprecation_message = "Importing `is_valid_url()` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_valid_url` instead." + deprecate("diffusers.loaders.single_file_utils.is_valid_url", "0.36", deprecation_message) -def _is_legacy_scheduler_kwargs(kwargs): - return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys()) + return is_valid_url(url) def load_single_file_checkpoint( @@ -407,825 +45,228 @@ def load_single_file_checkpoint( disable_mmap=False, user_agent=None, ): - if user_agent is None: - user_agent = {"file_type": "single_file", "framework": "pytorch"} - - if os.path.isfile(pretrained_model_link_or_path): - pretrained_model_link_or_path = pretrained_model_link_or_path - - else: - repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) - pretrained_model_link_or_path = _get_model_file( - repo_id, - weights_name=weights_name, - force_download=force_download, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - user_agent=user_agent, - ) - - checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) - - # some checkpoints contain the model state dict under a "state_dict" key - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - return checkpoint + from .single_file.single_file_utils import load_single_file_checkpoint + + deprecation_message = "Importing `load_single_file_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import load_single_file_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.load_single_file_checkpoint", "0.36", deprecation_message) + + return load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download, + proxies, + token, + cache_dir, + local_files_only, + revision, + disable_mmap, + user_agent, + ) def fetch_original_config(original_config_file, local_files_only=False): - if os.path.isfile(original_config_file): - with open(original_config_file, "r") as fp: - original_config_file = fp.read() + from .single_file.single_file_utils import fetch_original_config - elif is_valid_url(original_config_file): - if local_files_only: - raise ValueError( - "`local_files_only` is set to True, but a URL was provided as `original_config_file`. " - "Please provide a valid local file path." - ) + deprecation_message = "Importing `fetch_original_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_original_config` instead." + deprecate("diffusers.loaders.single_file_utils.fetch_original_config", "0.36", deprecation_message) - original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content) - - else: - raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") - - original_config = yaml.safe_load(original_config_file) - - return original_config + return fetch_original_config(original_config_file, local_files_only) def is_clip_model(checkpoint): - if CHECKPOINT_KEY_NAMES["clip"] in checkpoint: - return True + from .single_file.single_file_utils import is_clip_model + + deprecation_message = "Importing `is_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_model", "0.36", deprecation_message) - return False + return is_clip_model(checkpoint) def is_clip_sdxl_model(checkpoint): - if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint: - return True + from .single_file.single_file_utils import is_clip_sdxl_model - return False + deprecation_message = "Importing `is_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sdxl_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_sdxl_model", "0.36", deprecation_message) + + return is_clip_sdxl_model(checkpoint) def is_clip_sd3_model(checkpoint): - if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint: - return True + from .single_file.single_file_utils import is_clip_sd3_model + + deprecation_message = "Importing `is_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sd3_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_sd3_model", "0.36", deprecation_message) - return False + return is_clip_sd3_model(checkpoint) def is_open_clip_model(checkpoint): - if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: - return True + deprecation_message = "Importing `is_open_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_model", "0.36", deprecation_message) - return False + return is_open_clip_model(checkpoint) def is_open_clip_sdxl_model(checkpoint): - if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint: - return True + from .single_file.single_file_utils import is_open_clip_sdxl_model + + deprecation_message = "Importing `is_open_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_model", "0.36", deprecation_message) - return False + return is_open_clip_sdxl_model(checkpoint) def is_open_clip_sd3_model(checkpoint): - if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: - return True + from .single_file.single_file_utils import is_open_clip_sd3_model - return False + deprecation_message = "Importing `is_open_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sd3_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sd3_model", "0.36", deprecation_message) + + return is_open_clip_sd3_model(checkpoint) def is_open_clip_sdxl_refiner_model(checkpoint): - if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: - return True + from .single_file.single_file_utils import is_open_clip_sdxl_refiner_model + + deprecation_message = "Importing `is_open_clip_sdxl_refiner_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_refiner_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_refiner_model", "0.36", deprecation_message) - return False + return is_open_clip_sdxl_refiner_model(checkpoint) def is_clip_model_in_single_file(class_obj, checkpoint): - is_clip_in_checkpoint = any( - [ - is_clip_model(checkpoint), - is_clip_sd3_model(checkpoint), - is_open_clip_model(checkpoint), - is_open_clip_sdxl_model(checkpoint), - is_open_clip_sdxl_refiner_model(checkpoint), - is_open_clip_sd3_model(checkpoint), - ] - ) - if ( - class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection" - ) and is_clip_in_checkpoint: - return True + from .single_file.single_file_utils import is_clip_model_in_single_file + + deprecation_message = "Importing `is_clip_model_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model_in_single_file` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_model_in_single_file", "0.36", deprecation_message) - return False + return is_clip_model_in_single_file(class_obj, checkpoint) def infer_diffusers_model_type(checkpoint): - if ( - CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9 - ): - if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: - model_type = "inpainting_v2" - elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: - model_type = "xl_inpaint" - else: - model_type = "inpainting" - - elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: - model_type = "v2" - - elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint: - model_type = "playground-v2-5" - - elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: - model_type = "xl_base" - - elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: - model_type = "xl_refiner" - - elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: - model_type = "upscale" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]): - if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint: - if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint: - model_type = "controlnet_xl_large" - elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint: - model_type = "controlnet_xl_mid" - else: - model_type = "controlnet_xl_small" - else: - model_type = "controlnet" - - elif ( - CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536 - ): - model_type = "stable_cascade_stage_c_lite" - - elif ( - CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048 - ): - model_type = "stable_cascade_stage_c" - - elif ( - CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576 - ): - model_type = "stable_cascade_stage_b_lite" - - elif ( - CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640 - ): - model_type = "stable_cascade_stage_b" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any( - checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"] - ): - if "model.diffusion_model.pos_embed" in checkpoint: - key = "model.diffusion_model.pos_embed" - else: - key = "pos_embed" - - if checkpoint[key].shape[1] == 36864: - model_type = "sd3" - elif checkpoint[key].shape[1] == 147456: - model_type = "sd35_medium" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]): - model_type = "sd35_large" - - elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: - if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: - model_type = "animatediff_scribble" - - elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: - model_type = "animatediff_rgb" - - elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: - model_type = "animatediff_v2" - - elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: - model_type = "animatediff_sdxl_beta" - - elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24: - model_type = "animatediff_v1" - - else: - model_type = "animatediff_v3" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): - if any( - g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] - ): - if "model.diffusion_model.img_in.weight" in checkpoint: - key = "model.diffusion_model.img_in.weight" - else: - key = "img_in.weight" - - if checkpoint[key].shape[1] == 384: - model_type = "flux-fill" - elif checkpoint[key].shape[1] == 128: - model_type = "flux-depth" - else: - model_type = "flux-dev" - else: - model_type = "flux-schnell" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): - if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: - model_type = "ltx-video-0.9.5" - elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: - model_type = "ltx-video-0.9.1" - else: - model_type = "ltx-video" - - elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: - encoder_key = "encoder.project_in.conv.conv.bias" - decoder_key = "decoder.project_in.main.conv.weight" - - if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint: - model_type = "autoencoder-dc-f32c32-sana" - - elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32: - model_type = "autoencoder-dc-f32c32" - - elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128: - model_type = "autoencoder-dc-f64c128" - - else: - model_type = "autoencoder-dc-f128c512" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): - model_type = "mochi-1-preview" - - elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: - model_type = "hunyuan-video" - - elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): - model_type = "auraflow" - - elif ( - CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint - and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8 - ): - model_type = "instruct-pix2pix" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): - model_type = "lumina2" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): - model_type = "sana" - - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): - if "model.diffusion_model.patch_embedding.weight" in checkpoint: - target_key = "model.diffusion_model.patch_embedding.weight" - else: - target_key = "patch_embedding.weight" - - if checkpoint[target_key].shape[0] == 1536: - model_type = "wan-t2v-1.3B" - elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: - model_type = "wan-t2v-14B" - else: - model_type = "wan-i2v-14B" - elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: - # All Wan models use the same VAE so we can use the same default model repo to fetch the config - model_type = "wan-t2v-14B" - else: - model_type = "v1" - - return model_type + from .single_file.single_file_utils import infer_diffusers_model_type + + deprecation_message = "Importing `infer_diffusers_model_type()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import infer_diffusers_model_type` instead." + deprecate("diffusers.loaders.single_file_utils.infer_diffusers_model_type", "0.36", deprecation_message) + + return infer_diffusers_model_type(checkpoint) def fetch_diffusers_config(checkpoint): - model_type = infer_diffusers_model_type(checkpoint) - model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] - model_path = copy.deepcopy(model_path) + from .single_file.single_file_utils import fetch_diffusers_config + + deprecation_message = "Importing `fetch_diffusers_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_diffusers_config` instead." + deprecate("diffusers.loaders.single_file_utils.fetch_diffusers_config", "0.36", deprecation_message) - return model_path + return fetch_diffusers_config(checkpoint) def set_image_size(checkpoint, image_size=None): - if image_size: - return image_size + from .single_file.single_file_utils import set_image_size - model_type = infer_diffusers_model_type(checkpoint) - image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type] + deprecation_message = "Importing `set_image_size()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import set_image_size` instead." + deprecate("diffusers.loaders.single_file_utils.set_image_size", "0.36", deprecation_message) - return image_size + return set_image_size(checkpoint, image_size) -# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + from .single_file.single_file_utils import conv_attn_to_linear + + deprecation_message = "Importing `conv_attn_to_linear()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import conv_attn_to_linear` instead." + deprecate("diffusers.loaders.single_file_utils.conv_attn_to_linear", "0.36", deprecation_message) + + return conv_attn_to_linear(checkpoint) def create_unet_diffusers_config_from_ldm( original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None ): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if image_size is not None: - deprecation_message = ( - "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`" - "is deprecated and will be ignored in future versions." - ) - deprecate("image_size", "1.0.0", deprecation_message) - - image_size = set_image_size(checkpoint, image_size=image_size) - - if ( - "unet_config" in original_config["model"]["params"] - and original_config["model"]["params"]["unet_config"] is not None - ): - unet_params = original_config["model"]["params"]["unet_config"]["params"] - else: - unet_params = original_config["model"]["params"]["network_config"]["params"] - - if num_in_channels is not None: - deprecation_message = ( - "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`" - "is deprecated and will be ignored in future versions." - ) - deprecate("image_size", "1.0.0", deprecation_message) - in_channels = num_in_channels - else: - in_channels = unet_params["in_channels"] - - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - if unet_params["transformer_depth"] is not None: - transformer_layers_per_block = ( - unet_params["transformer_depth"] - if isinstance(unet_params["transformer_depth"], int) - else list(unet_params["transformer_depth"]) - ) - else: - transformer_layers_per_block = 1 - - vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - - head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + from .single_file.single_file_utils import create_unet_diffusers_config_from_ldm + + deprecation_message = "Importing `create_unet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_unet_diffusers_config_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_unet_diffusers_config_from_ldm", "0.36", deprecation_message) + + return create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size, upcast_attention, num_in_channels ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] - head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] - - class_embed_type = None - addition_embed_type = None - addition_time_embed_dim = None - projection_class_embeddings_input_dim = None - context_dim = None - - if unet_params["context_dim"] is not None: - context_dim = ( - unet_params["context_dim"] - if isinstance(unet_params["context_dim"], int) - else unet_params["context_dim"][0] - ) - - if "num_classes" in unet_params: - if unet_params["num_classes"] == "sequential": - if context_dim in [2048, 1280]: - # SDXL - addition_embed_type = "text_time" - addition_time_embed_dim = 256 - else: - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params["adm_in_channels"] - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": in_channels, - "down_block_types": down_block_types, - "block_out_channels": block_out_channels, - "layers_per_block": unet_params["num_res_blocks"], - "cross_attention_dim": context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "addition_embed_type": addition_embed_type, - "addition_time_embed_dim": addition_time_embed_dim, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "transformer_layers_per_block": transformer_layers_per_block, - } - - if upcast_attention is not None: - deprecation_message = ( - "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`" - "is deprecated and will be ignored in future versions." - ) - deprecate("image_size", "1.0.0", deprecation_message) - config["upcast_attention"] = upcast_attention - - if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params["disable_self_attentions"] - - if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): - config["num_class_embeds"] = unet_params["num_classes"] - - config["out_channels"] = unet_params["out_channels"] - config["up_block_types"] = up_block_types - - return config def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): - if image_size is not None: - deprecation_message = ( - "Configuring ControlNetModel with the `image_size` argument" - "is deprecated and will be ignored in future versions." - ) - deprecate("image_size", "1.0.0", deprecation_message) - - image_size = set_image_size(checkpoint, image_size=image_size) - - unet_params = original_config["model"]["params"]["control_stage_config"]["params"] - diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size) - - controlnet_config = { - "conditioning_channels": unet_params["hint_channels"], - "in_channels": diffusers_unet_config["in_channels"], - "down_block_types": diffusers_unet_config["down_block_types"], - "block_out_channels": diffusers_unet_config["block_out_channels"], - "layers_per_block": diffusers_unet_config["layers_per_block"], - "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], - "attention_head_dim": diffusers_unet_config["attention_head_dim"], - "use_linear_projection": diffusers_unet_config["use_linear_projection"], - "class_embed_type": diffusers_unet_config["class_embed_type"], - "addition_embed_type": diffusers_unet_config["addition_embed_type"], - "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], - "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], - "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], - } - - return controlnet_config + from .single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm + + deprecation_message = "Importing `create_controlnet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm` instead." + deprecate( + "diffusers.loaders.single_file_utils.create_controlnet_diffusers_config_from_ldm", "0.36", deprecation_message + ) + return create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size, **kwargs) def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if image_size is not None: - deprecation_message = ( - "Configuring AutoencoderKL with the `image_size` argument" - "is deprecated and will be ignored in future versions." - ) - deprecate("image_size", "1.0.0", deprecation_message) - - image_size = set_image_size(checkpoint, image_size=image_size) - - if "edm_mean" in checkpoint and "edm_std" in checkpoint: - latents_mean = checkpoint["edm_mean"] - latents_std = checkpoint["edm_std"] - else: - latents_mean = None - latents_std = None - - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None): - scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR - - elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]): - scaling_factor = original_config["model"]["params"]["scale_factor"] - - elif scaling_factor is None: - scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR - - block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params["in_channels"], - "out_channels": vae_params["out_ch"], - "down_block_types": down_block_types, - "up_block_types": up_block_types, - "block_out_channels": block_out_channels, - "latent_channels": vae_params["z_channels"], - "layers_per_block": vae_params["num_res_blocks"], - "scaling_factor": scaling_factor, - } - if latents_mean is not None and latents_std is not None: - config.update({"latents_mean": latents_mean, "latents_std": latents_std}) - - return config + from .single_file.single_file_utils import create_vae_diffusers_config_from_ldm + + deprecation_message = "Importing `create_vae_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_vae_diffusers_config_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_vae_diffusers_config_from_ldm", "0.36", deprecation_message) + return create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size, scaling_factor) def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): - for ldm_key in ldm_keys: - diffusers_key = ( - ldm_key.replace("in_layers.0", "norm1") - .replace("in_layers.2", "conv1") - .replace("out_layers.0", "norm2") - .replace("out_layers.3", "conv2") - .replace("emb_layers.1", "time_emb_proj") - .replace("skip_connection", "conv_shortcut") - ) - if mapping: - diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) - new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + from .single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers + + deprecation_message = "Importing `update_unet_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_unet_resnet_ldm_to_diffusers", "0.36", deprecation_message) + + return update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping) def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): - for ldm_key in ldm_keys: - diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) - new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + from .single_file.single_file_utils import update_unet_attention_ldm_to_diffusers + + deprecation_message = "Importing `update_unet_attention_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_attention_ldm_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.update_unet_attention_ldm_to_diffusers", "0.36", deprecation_message + ) + + return update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping) def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): - for ldm_key in keys: - diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") - new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) + from .single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers + + deprecation_message = "Importing `update_vae_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_vae_resnet_ldm_to_diffusers", "0.36", deprecation_message) + + return update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping) def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): - for ldm_key in keys: - diffusers_key = ( - ldm_key.replace(mapping["old"], mapping["new"]) - .replace("norm.weight", "group_norm.weight") - .replace("norm.bias", "group_norm.bias") - .replace("q.weight", "to_q.weight") - .replace("q.bias", "to_q.bias") - .replace("k.weight", "to_k.weight") - .replace("k.bias", "to_k.bias") - .replace("v.weight", "to_v.weight") - .replace("v.bias", "to_v.bias") - .replace("proj_out.weight", "to_out.0.weight") - .replace("proj_out.bias", "to_out.0.bias") - ) - new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) - - # proj_attn.weight has to be converted from conv 1D to linear - shape = new_checkpoint[diffusers_key].shape - - if len(shape) == 3: - new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] - elif len(shape) == 4: - new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + from .single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers + + deprecation_message = "Importing `update_vae_attentions_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.update_vae_attentions_ldm_to_diffusers", "0.36", deprecation_message + ) + + return update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping) def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): - is_stage_c = "clip_txt_mapper.weight" in checkpoint - - if is_stage_c: - state_dict = {} - for key in checkpoint.keys(): - if key.endswith("in_proj_weight"): - weights = checkpoint[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] - state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] - state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] - elif key.endswith("in_proj_bias"): - weights = checkpoint[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] - state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] - state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] - elif key.endswith("out_proj.weight"): - weights = checkpoint[key] - state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights - elif key.endswith("out_proj.bias"): - weights = checkpoint[key] - state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights - else: - state_dict[key] = checkpoint[key] - else: - state_dict = {} - for key in checkpoint.keys(): - if key.endswith("in_proj_weight"): - weights = checkpoint[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] - state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] - state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] - elif key.endswith("in_proj_bias"): - weights = checkpoint[key].chunk(3, 0) - state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] - state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] - state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] - elif key.endswith("out_proj.weight"): - weights = checkpoint[key] - state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights - elif key.endswith("out_proj.bias"): - weights = checkpoint[key] - state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights - # rename clip_mapper to clip_txt_pooled_mapper - elif key.endswith("clip_mapper.weight"): - weights = checkpoint[key] - state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights - elif key.endswith("clip_mapper.bias"): - weights = checkpoint[key] - state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights - else: - state_dict[key] = checkpoint[key] - - return state_dict + from .single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers + + deprecation_message = "Importing `convert_stable_cascade_unet_single_file_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_stable_cascade_unet_single_file_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs) def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - unet_key = LDM_UNET_KEY - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning("Checkpoint has both EMA and non-EMA weights.") - logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] - for diffusers_key, ldm_key in ldm_unet_keys.items(): - if ldm_key not in unet_state_dict: - continue - new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] - - if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): - class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] - for diffusers_key, ldm_key in class_embed_keys.items(): - new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] - - if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): - addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] - for diffusers_key, ldm_key in addition_embed_keys.items(): - new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] - - # Relevant to StableDiffusionUpscalePipeline - if "num_class_embeds" in config: - if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): - new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - # Down blocks - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - update_unet_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - unet_state_dict, - {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, - ) - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get( - f"input_blocks.{i}.0.op.bias" - ) - - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if attentions: - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - unet_state_dict, - {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, - ) - - # Mid blocks - for key in middle_blocks.keys(): - diffusers_key = max(key - 1, 0) - if key % 2 == 0: - update_unet_resnet_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - unet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, - ) - else: - update_unet_attention_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - unet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, - ) - - # Up Blocks - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - - resnets = [ - key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key - ] - update_unet_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - unet_state_dict, - {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, - ) - - attentions = [ - key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key - ] - if attentions: - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - unet_state_dict, - {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, - ) - - if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.1.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.1.conv.bias" - ] - if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.2.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.2.conv.bias" - ] - - return new_checkpoint + from .single_file.single_file_utils import convert_ldm_unet_checkpoint + + deprecation_message = "Importing `convert_ldm_unet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_unet_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_unet_checkpoint", "0.36", deprecation_message) + return convert_ldm_unet_checkpoint(checkpoint, config, extract_ema, **kwargs) def convert_controlnet_checkpoint( @@ -1233,248 +274,27 @@ def convert_controlnet_checkpoint( config, **kwargs, ): - # Return checkpoint if it's already been converted - if "time_embedding.linear_1.weight" in checkpoint: - return checkpoint - # Some controlnet ckpt files are distributed independently from the rest of the - # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ - if "time_embed.0.weight" in checkpoint: - controlnet_state_dict = checkpoint - - else: - controlnet_state_dict = {} - keys = list(checkpoint.keys()) - controlnet_key = LDM_CONTROLNET_KEY - for key in keys: - if key.startswith(controlnet_key): - controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] - for diffusers_key, ldm_key in ldm_controlnet_keys.items(): - if ldm_key not in controlnet_state_dict: - continue - new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} - ) - input_blocks = { - layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Down blocks - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - update_unet_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, - ) - - if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( - f"input_blocks.{i}.0.op.bias" - ) - - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - if attentions: - update_unet_attention_ldm_to_diffusers( - attentions, - new_checkpoint, - controlnet_state_dict, - {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, - ) - - # controlnet down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} - ) - middle_blocks = { - layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Mid blocks - for key in middle_blocks.keys(): - diffusers_key = max(key - 1, 0) - if key % 2 == 0: - update_unet_resnet_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, - ) - else: - update_unet_attention_ldm_to_diffusers( - middle_blocks[key], - new_checkpoint, - controlnet_state_dict, - mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, - ) - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") - - # controlnet cond embedding blocks - cond_embedding_blocks = { - ".".join(layer.split(".")[:2]) - for layer in controlnet_state_dict - if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) - } - num_cond_embedding_blocks = len(cond_embedding_blocks) - - for idx in range(1, num_cond_embedding_blocks + 1): - diffusers_idx = idx - 1 - cond_block_id = 2 * idx - - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( - f"input_hint_block.{cond_block_id}.bias" - ) - - return new_checkpoint + from .single_file.single_file_utils import convert_controlnet_checkpoint + deprecation_message = "Importing `convert_controlnet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_controlnet_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_controlnet_checkpoint", "0.36", deprecation_message) + return convert_controlnet_checkpoint(checkpoint, config, **kwargs) -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys - vae_state_dict = {} - keys = list(checkpoint.keys()) - vae_key = "" - for ldm_vae_key in LDM_VAE_KEYS: - if any(k.startswith(ldm_vae_key) for k in keys): - vae_key = ldm_vae_key - - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] - for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): - if ldm_key not in vae_state_dict: - continue - new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len(config["down_block_types"]) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - update_vae_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - vae_state_dict, - mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, - ) - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( - f"encoder.down.{i}.downsample.conv.bias" - ) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - update_vae_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - vae_state_dict, - mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, - ) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - update_vae_attentions_ldm_to_diffusers( - mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} - ) - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len(config["up_block_types"]) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - update_vae_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - vae_state_dict, - mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, - ) - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - update_vae_resnet_ldm_to_diffusers( - resnets, - new_checkpoint, - vae_state_dict, - mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, - ) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - update_vae_attentions_ldm_to_diffusers( - mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} - ) - conv_attn_to_linear(new_checkpoint) +def convert_ldm_vae_checkpoint(checkpoint, config): + from .single_file.single_file_utils import convert_ldm_vae_checkpoint - return new_checkpoint + deprecation_message = "Importing `convert_ldm_vae_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_vae_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_vae_checkpoint", "0.36", deprecation_message) + return convert_ldm_vae_checkpoint(checkpoint, config, config) def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): - keys = list(checkpoint.keys()) - text_model_dict = {} - - remove_prefixes = [] - remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) - if remove_prefix: - remove_prefixes.append(remove_prefix) + from .single_file.single_file_utils import convert_ldm_clip_checkpoint - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - diffusers_key = key.replace(prefix, "") - text_model_dict[diffusers_key] = checkpoint.get(key) - - return text_model_dict + deprecation_message = "Importing `convert_ldm_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_clip_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_clip_checkpoint", "0.36", deprecation_message) + return convert_ldm_clip_checkpoint(checkpoint, remove_prefix) def convert_open_clip_checkpoint( @@ -1482,65 +302,11 @@ def convert_open_clip_checkpoint( checkpoint, prefix="cond_stage_model.model.", ): - text_model_dict = {} - text_proj_key = prefix + "text_projection" - - if text_proj_key in checkpoint: - text_proj_dim = int(checkpoint[text_proj_key].shape[0]) - elif hasattr(text_model.config, "hidden_size"): - text_proj_dim = text_model.config.hidden_size - else: - text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM - - keys = list(checkpoint.keys()) - keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE - - openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] - for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): - ldm_key = prefix + ldm_key - if ldm_key not in checkpoint: - continue - if ldm_key in keys_to_ignore: - continue - if ldm_key.endswith("text_projection"): - text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() - else: - text_model_dict[diffusers_key] = checkpoint[ldm_key] - - for key in keys: - if key in keys_to_ignore: - continue - - if not key.startswith(prefix + "transformer."): - continue - - diffusers_key = key.replace(prefix + "transformer.", "") - transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] - for new_key, old_key in transformer_diffusers_to_ldm_map.items(): - diffusers_key = ( - diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") - ) - - if key.endswith(".in_proj_weight"): - weight_value = checkpoint.get(key) - - text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() - text_model_dict[diffusers_key + ".k_proj.weight"] = ( - weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() - ) - text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() - - elif key.endswith(".in_proj_bias"): - weight_value = checkpoint.get(key) - text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() - text_model_dict[diffusers_key + ".k_proj.bias"] = ( - weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() - ) - text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() - else: - text_model_dict[diffusers_key] = checkpoint.get(key) - - return text_model_dict + from .single_file.single_file_utils import convert_open_clip_checkpoint + + deprecation_message = "Importing `convert_open_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_open_clip_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_open_clip_checkpoint", "0.36", deprecation_message) + return convert_open_clip_checkpoint(text_model, checkpoint, prefix) def create_diffusers_clip_model_from_ldm( @@ -1552,522 +318,77 @@ def create_diffusers_clip_model_from_ldm( local_files_only=None, is_legacy_loading=False, ): - if config: - config = {"pretrained_model_name_or_path": config} - else: - config = fetch_diffusers_config(checkpoint) - - # For backwards compatibility - # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo - # in the cache_dir, rather than in a subfolder of the Diffusers model - if is_legacy_loading: - logger.warning( - ( - "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " - "the local cache directory with the necessary CLIP model config files. " - "Attempting to load CLIP model from legacy cache directory." - ) - ) - - if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): - clip_config = "openai/clip-vit-large-patch14" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - elif is_open_clip_model(checkpoint): - clip_config = "stabilityai/stable-diffusion-2" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "text_encoder" - - else: - clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls(model_config) - - position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] - - if is_clip_model(checkpoint): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - - elif ( - is_clip_sdxl_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - - elif ( - is_clip_sd3_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") - diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) - - elif is_open_clip_model(checkpoint): - prefix = "cond_stage_model.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif ( - is_open_clip_sdxl_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim - ): - prefix = "conditioner.embedders.1.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif is_open_clip_sdxl_refiner_model(checkpoint): - prefix = "conditioner.embedders.0.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif ( - is_open_clip_sd3_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") - - else: - raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") - - if is_accelerate_available(): - load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - else: - model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if torch_dtype is not None: - model.to(torch_dtype) - - model.eval() - - return model - - -def _legacy_load_scheduler( - cls, - checkpoint, - component_name, - original_config=None, - **kwargs, -): - scheduler_type = kwargs.get("scheduler_type", None) - prediction_type = kwargs.get("prediction_type", None) - - if scheduler_type is not None: - deprecation_message = ( - "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n" - "Example:\n\n" - "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" - "scheduler = DDIMScheduler()\n" - "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" - ) - deprecate("scheduler_type", "1.0.0", deprecation_message) - - if prediction_type is not None: - deprecation_message = ( - "Please configure an instance of a Scheduler with the appropriate `prediction_type` and " - "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n" - "Example:\n\n" - "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" - 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n' - "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" - ) - deprecate("prediction_type", "1.0.0", deprecation_message) - - scheduler_config = SCHEDULER_DEFAULT_CONFIG - model_type = infer_diffusers_model_type(checkpoint=checkpoint) - - global_step = checkpoint["global_step"] if "global_step" in checkpoint else None - - if original_config: - num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000) - else: - num_train_timesteps = 1000 - - scheduler_config["num_train_timesteps"] = num_train_timesteps - - if model_type == "v2": - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - - else: - prediction_type = prediction_type or "epsilon" - - scheduler_config["prediction_type"] = prediction_type - - if model_type in ["xl_base", "xl_refiner"]: - scheduler_type = "euler" - elif model_type == "playground": - scheduler_type = "edm_dpm_solver_multistep" - else: - if original_config: - beta_start = original_config["model"]["params"].get("linear_start") - beta_end = original_config["model"]["params"].get("linear_end") - - else: - beta_start = 0.02 - beta_end = 0.085 - - scheduler_config["beta_start"] = beta_start - scheduler_config["beta_end"] = beta_end - scheduler_config["beta_schedule"] = "scaled_linear" - scheduler_config["clip_sample"] = False - scheduler_config["set_alpha_to_one"] = False - - # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers - if component_name == "low_res_scheduler": - return cls.from_config( - { - "beta_end": 0.02, - "beta_schedule": "scaled_linear", - "beta_start": 0.0001, - "clip_sample": True, - "num_train_timesteps": 1000, - "prediction_type": "epsilon", - "trained_betas": None, - "variance_type": "fixed_small", - } - ) - - if scheduler_type is None: - return cls.from_config(scheduler_config) - - elif scheduler_type == "pndm": - scheduler_config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(scheduler_config) - - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler_config) - - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler_config) - - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler_config) - - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) - - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) - - elif scheduler_type == "ddim": - scheduler = DDIMScheduler.from_config(scheduler_config) - - elif scheduler_type == "edm_dpm_solver_multistep": - scheduler_config = { - "algorithm_type": "dpmsolver++", - "dynamic_thresholding_ratio": 0.995, - "euler_at_final": False, - "final_sigmas_type": "zero", - "lower_order_final": True, - "num_train_timesteps": 1000, - "prediction_type": "epsilon", - "rho": 7.0, - "sample_max_value": 1.0, - "sigma_data": 0.5, - "sigma_max": 80.0, - "sigma_min": 0.002, - "solver_order": 2, - "solver_type": "midpoint", - "thresholding": False, - } - scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config) - - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - return scheduler - - -def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False): - if config: - config = {"pretrained_model_name_or_path": config} - else: - config = fetch_diffusers_config(checkpoint) - - if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): - clip_config = "openai/clip-vit-large-patch14" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - elif is_open_clip_model(checkpoint): - clip_config = "stabilityai/stable-diffusion-2" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "tokenizer" - - else: - clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) - - return tokenizer - - -def _legacy_load_safety_checker(local_files_only, torch_dtype): - # Support for loading safety checker components using the deprecated - # `load_safety_checker` argument. - - from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker - - feature_extractor = AutoImageProcessor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype - ) - safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype - ) + from .single_file.single_file_utils import create_diffusers_clip_model_from_ldm - return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} + deprecation_message = "Importing `create_diffusers_clip_model_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_clip_model_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_diffusers_clip_model_from_ldm", "0.36", deprecation_message) + return create_diffusers_clip_model_from_ldm( + cls, checkpoint, subfolder, config, torch_dtype, local_files_only, is_legacy_loading + ) # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation def swap_scale_shift(weight, dim): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight + from .single_file.single_file_utils import swap_scale_shift + + deprecation_message = "Importing `swap_scale_shift()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_scale_shift` instead." + deprecate("diffusers.loaders.single_file_utils.swap_scale_shift", "0.36", deprecation_message) + return swap_scale_shift(weight, dim) def swap_proj_gate(weight): - proj, gate = weight.chunk(2, dim=0) - new_weight = torch.cat([gate, proj], dim=0) - return new_weight + from .single_file.single_file_utils import swap_proj_gate + + deprecation_message = "Importing `swap_proj_gate()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_proj_gate` instead." + deprecate("diffusers.loaders.single_file_utils.swap_proj_gate", "0.36", deprecation_message) + return swap_proj_gate(weight) def get_attn2_layers(state_dict): - attn2_layers = [] - for key in state_dict.keys(): - if "attn2." in key: - # Extract the layer number from the key - layer_num = int(key.split(".")[1]) - attn2_layers.append(layer_num) + from .single_file.single_file_utils import get_attn2_layers - return tuple(sorted(set(attn2_layers))) + deprecation_message = "Importing `get_attn2_layers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_attn2_layers` instead." + deprecate("diffusers.loaders.single_file_utils.get_attn2_layers", "0.36", deprecation_message) + return get_attn2_layers(state_dict) def get_caption_projection_dim(state_dict): - caption_projection_dim = state_dict["context_embedder.weight"].shape[0] - return caption_projection_dim + from .single_file.single_file_utils import get_caption_projection_dim + + deprecation_message = "Importing `get_caption_projection_dim()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_caption_projection_dim` instead." + deprecate("diffusers.loaders.single_file_utils.get_caption_projection_dim", "0.36", deprecation_message) + return get_caption_projection_dim(state_dict) def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 - dual_attention_layers = get_attn2_layers(checkpoint) - - caption_projection_dim = get_caption_projection_dim(checkpoint) - has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) - - # Positional and patch embeddings. - converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") - converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Timestep embeddings. - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - - # Context projections. - converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") - - # Pooled context projection. - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") - - # Transformer blocks 🎸. - for i in range(num_layers): - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 - ) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 - ) - - converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) - - converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - - # qk norm - if has_qk_norm: - converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.ln_k.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.ln_k.weight" - ) - - # output projections. - converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.proj.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.proj.bias" - ) - - if i in dual_attention_layers: - # Q, K, V - sample_q2, sample_k2, sample_v2 = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 - ) - sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) - - # qk norm - if has_qk_norm: - converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.ln_k.weight" - ) - - # output projections. - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.proj.bias" - ) - - # norms. - converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" - ) - else: - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( - checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), - dim=caption_projection_dim, - ) - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( - checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), - dim=caption_projection_dim, - ) - - # ffs. - converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc1.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc2.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc2.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc1.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc2.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc2.bias" - ) - - # Final blocks. - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim - ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim - ) + from .single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers - return converted_state_dict + deprecation_message = "Importing `convert_sd3_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_sd3_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def is_t5_in_single_file(checkpoint): - if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: - return True + from .single_file.single_file_utils import is_t5_in_single_file - return False + deprecation_message = "Importing `is_t5_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_t5_in_single_file` instead." + deprecate("diffusers.loaders.single_file_utils.is_t5_in_single_file", "0.36", deprecation_message) + return is_t5_in_single_file(checkpoint) def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - - remove_prefixes = ["text_encoders.t5xxl.transformer."] + from .single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - diffusers_key = key.replace(prefix, "") - text_model_dict[diffusers_key] = checkpoint.get(key) - - return text_model_dict + deprecation_message = "Importing `convert_sd3_t5_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_sd3_t5_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_sd3_t5_checkpoint_to_diffusers(checkpoint) def create_diffusers_t5_model_from_checkpoint( @@ -2078,1218 +399,134 @@ def create_diffusers_t5_model_from_checkpoint( torch_dtype=None, local_files_only=None, ): - if config: - config = {"pretrained_model_name_or_path": config} - else: - config = fetch_diffusers_config(checkpoint) - - model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls(model_config) + from .single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint - diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) - - if is_accelerate_available(): - load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - else: - model.load_state_dict(diffusers_format_checkpoint) - - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = model._keep_in_fp32_modules - else: - keep_in_fp32_modules = [] - - if keep_in_fp32_modules is not None: - for name, param in model.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): - # param = param.to(torch.float32) does not work here as only in the local scope. - param.data = param.data.to(torch.float32) - - return model + deprecation_message = "Importing `create_diffusers_t5_model_from_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint` instead." + deprecate( + "diffusers.loaders.single_file_utils.create_diffusers_t5_model_from_checkpoint", "0.36", deprecation_message + ) + return create_diffusers_t5_model_from_checkpoint(cls, checkpoint, subfolder, config, torch_dtype, local_files_only) def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - for k, v in checkpoint.items(): - if "pos_encoder" in k: - continue - - else: - converted_state_dict[ - k.replace(".norms.0", ".norm1") - .replace(".norms.1", ".norm2") - .replace(".ff_norm", ".norm3") - .replace(".attention_blocks.0", ".attn1") - .replace(".attention_blocks.1", ".attn2") - .replace(".temporal_transformer", "") - ] = v + from .single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers - return converted_state_dict + deprecation_message = "Importing `convert_animatediff_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_animatediff_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 - num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 - mlp_ratio = 4.0 - inner_dim = 3072 - - # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; - # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation - def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + from .single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers - ## time_text_embed.text_embedder <- vector_in - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( - "vector_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") - - # guidance - has_guidance = any("guidance" in k for k in checkpoint) - if has_guidance: - converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( - "guidance_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( - "guidance_in.in_layer.bias" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( - "guidance_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( - "guidance_in.out_layer.bias" - ) - - # context_embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") - - # x_embedder - converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") - converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - # norms. - ## norm1 - converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.bias" - ) - ## norm1_context - converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.bias" - ) - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") - converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") - converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.bias" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.bias" - ) - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.bias" - ) - - # single transfomer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.bias" - ) - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) - q_bias, k_bias, v_bias, mlp_bias = torch.split( - checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) - converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - # output projections. - converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") - converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight") + deprecation_message = "Importing `convert_flux_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_flux_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias") - ) - - return converted_state_dict + return convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} - - TRANSFORMER_KEYS_RENAME_DICT = { - "model.diffusion_model.": "", - "patchify_proj": "proj_in", - "adaln_single": "time_embed", - "q_norm": "norm_q", - "k_norm": "norm_k", - } + from .single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers - TRANSFORMER_SPECIAL_KEYS_REMAP = {} - - for key in list(converted_state_dict.keys()): - new_key = key - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - return converted_state_dict + deprecation_message = "Importing `convert_ltx_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_ltx_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} - - def remove_keys_(key: str, state_dict): - state_dict.pop(key) - - VAE_KEYS_RENAME_DICT = { - # common - "vae.": "", - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0", - "up_blocks.2": "up_blocks.1.upsamplers.0", - "up_blocks.3": "up_blocks.1", - "up_blocks.4": "up_blocks.2.conv_in", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.conv_in", - "up_blocks.8": "up_blocks.3.upsamplers.0", - "up_blocks.9": "up_blocks.3", - # encoder - "down_blocks.0": "down_blocks.0", - "down_blocks.1": "down_blocks.0.downsamplers.0", - "down_blocks.2": "down_blocks.0.conv_out", - "down_blocks.3": "down_blocks.1", - "down_blocks.4": "down_blocks.1.downsamplers.0", - "down_blocks.5": "down_blocks.1.conv_out", - "down_blocks.6": "down_blocks.2", - "down_blocks.7": "down_blocks.2.downsamplers.0", - "down_blocks.8": "down_blocks.3", - "down_blocks.9": "mid_block", - # common - "conv_shortcut": "conv_shortcut.conv", - "res_blocks": "resnets", - "norm3.norm": "norm3", - "per_channel_statistics.mean-of-means": "latents_mean", - "per_channel_statistics.std-of-means": "latents_std", - } - - VAE_091_RENAME_DICT = { - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0.upsamplers.0", - "up_blocks.2": "up_blocks.0", - "up_blocks.3": "up_blocks.1.upsamplers.0", - "up_blocks.4": "up_blocks.1", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.upsamplers.0", - "up_blocks.8": "up_blocks.3", - # common - "last_time_embedder": "time_embedder", - "last_scale_shift_table": "scale_shift_table", - } - - VAE_095_RENAME_DICT = { - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0.upsamplers.0", - "up_blocks.2": "up_blocks.0", - "up_blocks.3": "up_blocks.1.upsamplers.0", - "up_blocks.4": "up_blocks.1", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.upsamplers.0", - "up_blocks.8": "up_blocks.3", - # encoder - "down_blocks.0": "down_blocks.0", - "down_blocks.1": "down_blocks.0.downsamplers.0", - "down_blocks.2": "down_blocks.1", - "down_blocks.3": "down_blocks.1.downsamplers.0", - "down_blocks.4": "down_blocks.2", - "down_blocks.5": "down_blocks.2.downsamplers.0", - "down_blocks.6": "down_blocks.3", - "down_blocks.7": "down_blocks.3.downsamplers.0", - "down_blocks.8": "mid_block", - # common - "last_time_embedder": "time_embedder", - "last_scale_shift_table": "scale_shift_table", - } - - VAE_SPECIAL_KEYS_REMAP = { - "per_channel_statistics.channel": remove_keys_, - "per_channel_statistics.mean-of-means": remove_keys_, - "per_channel_statistics.mean-of-stds": remove_keys_, - } - - if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: - VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) - elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: - VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - - for key in list(converted_state_dict.keys()): - new_key = key - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - return converted_state_dict + from .single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_ltx_vae_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_ltx_vae_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} - - def remap_qkv_(key: str, state_dict): - qkv = state_dict.pop(key) - q, k, v = torch.chunk(qkv, 3, dim=0) - parent_module, _, _ = key.rpartition(".qkv.conv.weight") - state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() - state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() - state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() - - def remap_proj_conv_(key: str, state_dict): - parent_module, _, _ = key.rpartition(".proj.conv.weight") - state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() - - AE_KEYS_RENAME_DICT = { - # common - "main.": "", - "op_list.": "", - "context_module": "attn", - "local_module": "conv_out", - # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 - # If there were more scales, there would be more layers, so a loop would be better to handle this - "aggreg.0.0": "to_qkv_multiscale.0.proj_in", - "aggreg.0.1": "to_qkv_multiscale.0.proj_out", - "depth_conv.conv": "conv_depth", - "inverted_conv.conv": "conv_inverted", - "point_conv.conv": "conv_point", - "point_conv.norm": "norm", - "conv.conv.": "conv.", - "conv1.conv": "conv1", - "conv2.conv": "conv2", - "conv2.norm": "norm", - "proj.norm": "norm_out", - # encoder - "encoder.project_in.conv": "encoder.conv_in", - "encoder.project_out.0.conv": "encoder.conv_out", - "encoder.stages": "encoder.down_blocks", - # decoder - "decoder.project_in.conv": "decoder.conv_in", - "decoder.project_out.0": "decoder.norm_out", - "decoder.project_out.2.conv": "decoder.conv_out", - "decoder.stages": "decoder.up_blocks", - } - - AE_F32C32_F64C128_F128C512_KEYS = { - "encoder.project_in.conv": "encoder.conv_in.conv", - "decoder.project_out.2.conv": "decoder.conv_out.conv", - } - - AE_SPECIAL_KEYS_REMAP = { - "qkv.conv.weight": remap_qkv_, - "proj.conv.weight": remap_proj_conv_, - } - if "encoder.project_in.conv.bias" not in converted_state_dict: - AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) - - for key in list(converted_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - return converted_state_dict + from .single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_autoencoder_dc_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_autoencoder_dc_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Comfy checkpoints add this prefix - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - # Convert patch_embed - converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Convert time_embed - converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") - converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") - converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") - converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") - converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") - converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") - converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") - converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") - converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") - converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") - - # Convert transformer blocks - num_layers = 48 - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - old_prefix = f"blocks.{i}." - - # norm1 - converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") - converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") - if i < num_layers - 1: - converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop( - old_prefix + "mod_y.weight" - ) - converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop( - old_prefix + "mod_y.bias" - ) - else: - converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( - old_prefix + "mod_y.weight" - ) - converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop( - old_prefix + "mod_y.bias" - ) - - # Visual attention - qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") - q, k, v = qkv_weight.chunk(3, dim=0) - - converted_state_dict[block_prefix + "attn1.to_q.weight"] = q - converted_state_dict[block_prefix + "attn1.to_k.weight"] = k - converted_state_dict[block_prefix + "attn1.to_v.weight"] = v - converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop( - old_prefix + "attn.q_norm_x.weight" - ) - converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop( - old_prefix + "attn.k_norm_x.weight" - ) - converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop( - old_prefix + "attn.proj_x.weight" - ) - converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") - - # Context attention - qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") - q, k, v = qkv_weight.chunk(3, dim=0) - - converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q - converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k - converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v - converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( - old_prefix + "attn.q_norm_y.weight" - ) - converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( - old_prefix + "attn.k_norm_y.weight" - ) - if i < num_layers - 1: - converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( - old_prefix + "attn.proj_y.weight" - ) - converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop( - old_prefix + "attn.proj_y.bias" - ) - - # MLP - converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( - checkpoint.pop(old_prefix + "mlp_x.w1.weight") - ) - converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") - if i < num_layers - 1: - converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( - checkpoint.pop(old_prefix + "mlp_y.w1.weight") - ) - converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop( - old_prefix + "mlp_y.w2.weight" - ) - - # Output layers - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - - converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") - - return converted_state_dict + from .single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_mochi_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_mochi_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): - def remap_norm_scale_shift_(key, state_dict): - weight = state_dict.pop(key) - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight - - def remap_txt_in_(key, state_dict): - def rename_key(key): - new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") - new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") - new_key = new_key.replace("txt_in", "context_embedder") - new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") - new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") - new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") - new_key = new_key.replace("mlp", "ff") - return new_key - - if "self_attn_qkv" in key: - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v - else: - state_dict[rename_key(key)] = state_dict.pop(key) - - def remap_img_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v - - def remap_txt_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v - - def remap_single_transformer_blocks_(key, state_dict): - hidden_size = 3072 - - if "linear1.weight" in key: - linear1_weight = state_dict.pop(key) - split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) - q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") - state_dict[f"{new_key}.attn.to_q.weight"] = q - state_dict[f"{new_key}.attn.to_k.weight"] = k - state_dict[f"{new_key}.attn.to_v.weight"] = v - state_dict[f"{new_key}.proj_mlp.weight"] = mlp - - elif "linear1.bias" in key: - linear1_bias = state_dict.pop(key) - split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) - q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") - state_dict[f"{new_key}.attn.to_q.bias"] = q_bias - state_dict[f"{new_key}.attn.to_k.bias"] = k_bias - state_dict[f"{new_key}.attn.to_v.bias"] = v_bias - state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias - - else: - new_key = key.replace("single_blocks", "single_transformer_blocks") - new_key = new_key.replace("linear2", "proj_out") - new_key = new_key.replace("q_norm", "attn.norm_q") - new_key = new_key.replace("k_norm", "attn.norm_k") - state_dict[new_key] = state_dict.pop(key) - - TRANSFORMER_KEYS_RENAME_DICT = { - "img_in": "x_embedder", - "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", - "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", - "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", - "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", - "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "double_blocks": "transformer_blocks", - "img_attn_q_norm": "attn.norm_q", - "img_attn_k_norm": "attn.norm_k", - "img_attn_proj": "attn.to_out.0", - "txt_attn_q_norm": "attn.norm_added_q", - "txt_attn_k_norm": "attn.norm_added_k", - "txt_attn_proj": "attn.to_add_out", - "img_mod.linear": "norm1.linear", - "img_norm1": "norm1.norm", - "img_norm2": "norm2", - "img_mlp": "ff", - "txt_mod.linear": "norm1_context.linear", - "txt_norm1": "norm1.norm", - "txt_norm2": "norm2_context", - "txt_mlp": "ff_context", - "self_attn_proj": "attn.to_out.0", - "modulation.linear": "norm.linear", - "pre_norm": "norm.norm", - "final_layer.norm_final": "norm_out.norm", - "final_layer.linear": "proj_out", - "fc1": "net.0.proj", - "fc2": "net.2", - "input_embedder": "proj_in", - } - - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "txt_in": remap_txt_in_, - "img_attn_qkv": remap_img_attn_qkv_, - "txt_attn_qkv": remap_txt_attn_qkv_, - "single_blocks": remap_single_transformer_blocks_, - "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, - } - - def update_state_dict_(state_dict, old_key, new_key): - state_dict[new_key] = state_dict.pop(old_key) - - for key in list(checkpoint.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_(checkpoint, key, new_key) - - for key in list(checkpoint.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, checkpoint) - - return checkpoint + from .single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers + + deprecation_message = "Importing `convert_hunyuan_video_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_hunyuan_video_transformer_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs) def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - state_dict_keys = list(checkpoint.keys()) - - # Handle register tokens and positional embeddings - converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) - - # Handle time step projection - converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) - converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) - converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) - converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) - - # Handle context embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) - - # Calculate the number of layers - def calculate_layers(keys, key_prefix): - layers = set() - for k in keys: - if key_prefix in k: - layer_num = int(k.split(".")[1]) # get the layer number - layers.add(layer_num) - return len(layers) - - mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") - single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") - - # MMDiT blocks - for i in range(mmdit_layers): - # Feed-forward - path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} - weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} - for orig_k, diffuser_k in path_mapping.items(): - for k, v in weight_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( - f"double_layers.{i}.{orig_k}.{k}.weight", None - ) - - # Norms - path_mapping = {"modX": "norm1", "modC": "norm1_context"} - for orig_k, diffuser_k in path_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( - f"double_layers.{i}.{orig_k}.1.weight", None - ) - - # Attentions - x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} - context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} - for attn_mapping in [x_attn_mapping, context_attn_mapping]: - for k, v in attn_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( - f"double_layers.{i}.attn.{k}.weight", None - ) - - # Single-DiT blocks - for i in range(single_dit_layers): - # Feed-forward - mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} - for k, v in mapping.items(): - converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( - f"single_layers.{i}.mlp.{k}.weight", None - ) - - # Norms - converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( - f"single_layers.{i}.modCX.1.weight", None - ) - - # Attentions - x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} - for k, v in x_attn_mapping.items(): - converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( - f"single_layers.{i}.attn.{k}.weight", None - ) - # Final blocks - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) - - # Handle the final norm layer - norm_weight = checkpoint.pop("modF.1.weight", None) - if norm_weight is not None: - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) - else: - converted_state_dict["norm_out.linear.weight"] = None - - converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") - converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") - converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") - - return converted_state_dict + from .single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_auraflow_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_auraflow_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_lumina2_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Original Lumina-Image-2 has an extra norm paramter that is unused - # We just remove it here - checkpoint.pop("norm_final.weight", None) - - # Comfy checkpoints add this prefix - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - LUMINA_KEY_MAP = { - "cap_embedder": "time_caption_embed.caption_embedder", - "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", - "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", - "attention": "attn", - ".out.": ".to_out.0.", - "k_norm": "norm_k", - "q_norm": "norm_q", - "w1": "linear_1", - "w2": "linear_2", - "w3": "linear_3", - "adaLN_modulation.1": "norm1.linear", - } - ATTENTION_NORM_MAP = { - "attention_norm1": "norm1.norm", - "attention_norm2": "norm2", - } - CONTEXT_REFINER_MAP = { - "context_refiner.0.attention_norm1": "context_refiner.0.norm1", - "context_refiner.0.attention_norm2": "context_refiner.0.norm2", - "context_refiner.1.attention_norm1": "context_refiner.1.norm1", - "context_refiner.1.attention_norm2": "context_refiner.1.norm2", - } - FINAL_LAYER_MAP = { - "final_layer.adaLN_modulation.1": "norm_out.linear_1", - "final_layer.linear": "norm_out.linear_2", - } - - def convert_lumina_attn_to_diffusers(tensor, diffusers_key): - q_dim = 2304 - k_dim = v_dim = 768 - - to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) - - return { - diffusers_key.replace("qkv", "to_q"): to_q, - diffusers_key.replace("qkv", "to_k"): to_k, - diffusers_key.replace("qkv", "to_v"): to_v, - } - - for key in keys: - diffusers_key = key - for k, v in CONTEXT_REFINER_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in FINAL_LAYER_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in ATTENTION_NORM_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in LUMINA_KEY_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - - if "qkv" in diffusers_key: - converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) - else: - converted_state_dict[diffusers_key] = checkpoint.pop(key) - - return converted_state_dict + from .single_file.single_file_utils import convert_lumina2_to_diffusers + + deprecation_message = "Importing `convert_lumina2_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_lumina2_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_lumina2_to_diffusers", "0.36", deprecation_message) + return convert_lumina2_to_diffusers(checkpoint, **kwargs) def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 - - # Positional and patch embeddings. - checkpoint.pop("pos_embed") - converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Timestep embeddings. - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") - converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") - - # Caption Projection. - checkpoint.pop("y_embedder.y_embedding") - converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") - converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") - converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") - converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") - converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") - - for i in range(num_layers): - converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( - f"blocks.{i}.scale_shift_table" - ) - - # Self-Attention - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.attn.proj.bias" - ) - - # Cross-Attention - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.bias" - ) - - linear_sample_k, linear_sample_v = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 - ) - linear_sample_k_bias, linear_sample_v_bias = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.bias" - ) - - # MLP - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.point_conv.conv.weight" - ) - - # Final layer - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") - - return converted_state_dict + from .single_file.single_file_utils import convert_sana_transformer_to_diffusers + + deprecation_message = "Importing `convert_sana_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sana_transformer_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_sana_transformer_to_diffusers", "0.36", deprecation_message) + return convert_sana_transformer_to_diffusers(checkpoint, **kwargs) def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "cross_attn": "attn2", - "self_attn": "attn1", - ".o.": ".to_out.0.", - ".q.": ".to_q.", - ".k.": ".to_k.", - ".v.": ".to_v.", - ".k_img.": ".add_k_proj.", - ".v_img.": ".add_v_proj.", - ".norm_k_img.": ".norm_added_k.", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - } - - for key in list(checkpoint.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - - converted_state_dict[new_key] = checkpoint.pop(key) - - return converted_state_dict + from .single_file.single_file_utils import convert_wan_transformer_to_diffusers + + deprecation_message = "Importing `convert_wan_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_transformer_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_wan_transformer_to_diffusers", "0.36", deprecation_message) + return convert_wan_transformer_to_diffusers(checkpoint, **kwargs) def convert_wan_vae_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Create mappings for specific components - middle_key_mapping = { - # Encoder middle block - "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", - "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", - "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", - "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", - "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", - "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", - "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", - "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", - "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", - "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", - "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", - "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", - # Decoder middle block - "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", - "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", - "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", - "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", - "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", - "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", - "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", - "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", - "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", - "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", - "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", - "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", - } - - # Create a mapping for attention blocks - attention_mapping = { - # Encoder middle attention - "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", - "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", - "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", - "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", - "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", - # Decoder middle attention - "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", - "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", - "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", - "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", - "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", - } - - # Create a mapping for the head components - head_mapping = { - # Encoder head - "encoder.head.0.gamma": "encoder.norm_out.gamma", - "encoder.head.2.bias": "encoder.conv_out.bias", - "encoder.head.2.weight": "encoder.conv_out.weight", - # Decoder head - "decoder.head.0.gamma": "decoder.norm_out.gamma", - "decoder.head.2.bias": "decoder.conv_out.bias", - "decoder.head.2.weight": "decoder.conv_out.weight", - } - - # Create a mapping for the quant components - quant_mapping = { - "conv1.weight": "quant_conv.weight", - "conv1.bias": "quant_conv.bias", - "conv2.weight": "post_quant_conv.weight", - "conv2.bias": "post_quant_conv.bias", - } - - # Process each key in the state dict - for key, value in checkpoint.items(): - # Handle middle block keys using the mapping - if key in middle_key_mapping: - new_key = middle_key_mapping[key] - converted_state_dict[new_key] = value - # Handle attention blocks using the mapping - elif key in attention_mapping: - new_key = attention_mapping[key] - converted_state_dict[new_key] = value - # Handle head keys using the mapping - elif key in head_mapping: - new_key = head_mapping[key] - converted_state_dict[new_key] = value - # Handle quant keys using the mapping - elif key in quant_mapping: - new_key = quant_mapping[key] - converted_state_dict[new_key] = value - # Handle encoder conv1 - elif key == "encoder.conv1.weight": - converted_state_dict["encoder.conv_in.weight"] = value - elif key == "encoder.conv1.bias": - converted_state_dict["encoder.conv_in.bias"] = value - # Handle decoder conv1 - elif key == "decoder.conv1.weight": - converted_state_dict["decoder.conv_in.weight"] = value - elif key == "decoder.conv1.bias": - converted_state_dict["decoder.conv_in.bias"] = value - # Handle encoder downsamples - elif key.startswith("encoder.downsamples."): - # Convert to down_blocks - new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - - # Convert residual block naming but keep the original structure - if ".residual.0.gamma" in new_key: - new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") - elif ".residual.2.bias" in new_key: - new_key = new_key.replace(".residual.2.bias", ".conv1.bias") - elif ".residual.2.weight" in new_key: - new_key = new_key.replace(".residual.2.weight", ".conv1.weight") - elif ".residual.3.gamma" in new_key: - new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") - elif ".residual.6.bias" in new_key: - new_key = new_key.replace(".residual.6.bias", ".conv2.bias") - elif ".residual.6.weight" in new_key: - new_key = new_key.replace(".residual.6.weight", ".conv2.weight") - elif ".shortcut.bias" in new_key: - new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - elif ".shortcut.weight" in new_key: - new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") - - converted_state_dict[new_key] = value - - # Handle decoder upsamples - elif key.startswith("decoder.upsamples."): - # Convert to up_blocks - parts = key.split(".") - block_idx = int(parts[2]) - - # Group residual blocks - if "residual" in key: - if block_idx in [0, 1, 2]: - new_block_idx = 0 - resnet_idx = block_idx - elif block_idx in [4, 5, 6]: - new_block_idx = 1 - resnet_idx = block_idx - 4 - elif block_idx in [8, 9, 10]: - new_block_idx = 2 - resnet_idx = block_idx - 8 - elif block_idx in [12, 13, 14]: - new_block_idx = 3 - resnet_idx = block_idx - 12 - else: - # Keep as is for other blocks - converted_state_dict[key] = value - continue - - # Convert residual block naming - if ".residual.0.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" - elif ".residual.2.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" - elif ".residual.2.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" - elif ".residual.3.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" - elif ".residual.6.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" - elif ".residual.6.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" - else: - new_key = key - - converted_state_dict[new_key] = value - - # Handle shortcut connections - elif ".shortcut." in key: - if block_idx == 4: - new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") - new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_key = new_key.replace(".shortcut.", ".conv_shortcut.") - - converted_state_dict[new_key] = value - - # Handle upsamplers - elif ".resample." in key or ".time_conv." in key: - if block_idx == 3: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") - elif block_idx == 7: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") - elif block_idx == 11: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - - converted_state_dict[new_key] = value - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - converted_state_dict[new_key] = value - else: - # Keep other keys unchanged - converted_state_dict[key] = value - - return converted_state_dict + from .single_file.single_file_utils import convert_wan_vae_to_diffusers + + deprecation_message = "Importing `convert_wan_vae_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_vae_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_wan_vae_to_diffusers", "0.36", deprecation_message) + return convert_wan_vae_to_diffusers(checkpoint, **kwargs) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 38a8a7ebe266..8487b0652feb 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -11,170 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from ..models.embeddings import ( - ImageProjection, - MultiIPAdapterImageProjection, -) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta -from ..utils import ( - is_accelerate_available, - is_torch_version, - logging, -) +from ..utils import deprecate +from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin -if is_accelerate_available(): - pass - -logger = logging.get_logger(__name__) - - -class FluxTransformer2DLoadersMixin: - """ - Load layers into a [`FluxTransformer2DModel`]. - """ - - def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - if low_cpu_mem_usage: - if is_accelerate_available(): - from accelerate import init_empty_weights - - else: - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - updated_state_dict = {} - image_projection = None - init_context = init_empty_weights if low_cpu_mem_usage else nullcontext - - if "proj.weight" in state_dict: - # IP-Adapter - num_image_text_embeds = 4 - if state_dict["proj.weight"].shape[0] == 65536: - num_image_text_embeds = 16 - clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds - - with init_context(): - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) - - for key, value in state_dict.items(): - diffusers_name = key.replace("proj", "image_embeds") - updated_state_dict[diffusers_name] = value - - if not low_cpu_mem_usage: - image_projection.load_state_dict(updated_state_dict, strict=True) - else: - device_map = {"": self.device} - load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) - - return image_projection - - def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( - FluxIPAdapterJointAttnProcessor2_0, - ) - - if low_cpu_mem_usage: - if is_accelerate_available(): - from accelerate import init_empty_weights - - else: - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - # set ip-adapter cross-attention processors & load state_dict - attn_procs = {} - key_id = 0 - init_context = init_empty_weights if low_cpu_mem_usage else nullcontext - for name in self.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - attn_processor_class = self.attn_processors[name].__class__ - attn_procs[name] = attn_processor_class() - else: - cross_attention_dim = self.config.joint_attention_dim - hidden_size = self.inner_dim - attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 - num_image_text_embeds = [] - for state_dict in state_dicts: - if "proj.weight" in state_dict["image_proj"]: - num_image_text_embed = 4 - if state_dict["image_proj"]["proj.weight"].shape[0] == 65536: - num_image_text_embed = 16 - # IP-Adapter - num_image_text_embeds += [num_image_text_embed] - - with init_context(): - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=num_image_text_embeds, - dtype=self.dtype, - device=self.device, - ) - - value_dict = {} - for i, state_dict in enumerate(state_dicts): - value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) - value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) - value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]}) - value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]}) - - if not low_cpu_mem_usage: - attn_procs[name].load_state_dict(value_dict) - else: - device_map = {"": self.device} - dtype = self.dtype - load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype) - - key_id += 1 - - return attn_procs - - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - if not isinstance(state_dicts, list): - state_dicts = [state_dicts] - - self.encoder_hid_proj = None - - attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - self.set_attn_processor(attn_procs) - - image_projection_layers = [] - for state_dict in state_dicts: - image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( - state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage - ) - image_projection_layers.append(image_projection_layer) - - self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - self.config.encoder_hid_dim_type = "ip_image_proj" +class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead." + deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index ece17e6728fa..e2e4c9fb67a3 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -11,160 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import Dict +from ..utils import deprecate +from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin -from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 -from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta -from ..utils import is_accelerate_available, is_torch_version, logging - -logger = logging.get_logger(__name__) - - -class SD3Transformer2DLoadersMixin: - """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" - - def _convert_ip_adapter_attn_to_diffusers( - self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT - ) -> Dict: - if low_cpu_mem_usage: - if is_accelerate_available(): - from accelerate import init_empty_weights - - else: - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - # IP-Adapter cross attention parameters - hidden_size = self.config.attention_head_dim * self.config.num_attention_heads - ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads - timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1] - - # Dict where key is transformer layer index, value is attention processor's state dict - # ip_adapter state dict keys example: "0.norm_ip.linear.weight" - layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} - for key, weights in state_dict.items(): - idx, name = key.split(".", maxsplit=1) - layer_state_dict[int(idx)][name] = weights - - # Create IP-Adapter attention processor & load state_dict - attn_procs = {} - init_context = init_empty_weights if low_cpu_mem_usage else nullcontext - for idx, name in enumerate(self.attn_processors.keys()): - with init_context(): - attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, - ip_hidden_states_dim=ip_hidden_states_dim, - head_dim=self.config.attention_head_dim, - timesteps_emb_dim=timesteps_emb_dim, - ) - - if not low_cpu_mem_usage: - attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) - else: - device_map = {"": self.device} - load_model_dict_into_meta( - attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype - ) - - return attn_procs - - def _convert_ip_adapter_image_proj_to_diffusers( - self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT - ) -> IPAdapterTimeImageProjection: - if low_cpu_mem_usage: - if is_accelerate_available(): - from accelerate import init_empty_weights - - else: - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - init_context = init_empty_weights if low_cpu_mem_usage else nullcontext - - # Convert to diffusers - updated_state_dict = {} - for key, value in state_dict.items(): - # InstantX/SD3.5-Large-IP-Adapter - if key.startswith("layers."): - idx = key.split(".")[1] - key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") - key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") - key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") - key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv") - key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0") - key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm") - key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj") - key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2") - key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") - updated_state_dict[key] = value - - # Image projetion parameters - embed_dim = updated_state_dict["proj_in.weight"].shape[1] - output_dim = updated_state_dict["proj_out.weight"].shape[0] - hidden_dim = updated_state_dict["proj_in.weight"].shape[0] - heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 - num_queries = updated_state_dict["latents"].shape[1] - timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1] - - # Image projection - with init_context(): - image_proj = IPAdapterTimeImageProjection( - embed_dim=embed_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - heads=heads, - num_queries=num_queries, - timestep_in_dim=timestep_in_dim, - ) - - if not low_cpu_mem_usage: - image_proj.load_state_dict(updated_state_dict, strict=True) - else: - device_map = {"": self.device} - load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) - - return image_proj - - def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: - """Sets IP-Adapter attention processors, image projection, and loads state_dict. - - Args: - state_dict (`Dict`): - State dict with keys "ip_adapter", which contains parameters for attention processors, and - "image_proj", which contains parameters for image projection net. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - - attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage) - self.set_attn_processor(attn_procs) - - self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) +class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead." + deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/unet/__init__.py b/src/diffusers/loaders/unet/__init__.py new file mode 100644 index 000000000000..700ecbe53642 --- /dev/null +++ b/src/diffusers/loaders/unet/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_torch_available + + +if is_torch_available(): + from .unet import UNet2DConditionLoadersMixin diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet/unet.py similarity index 98% rename from src/diffusers/loaders/unet.py rename to src/diffusers/loaders/unet/unet.py index 1d8aba900c85..a632ceac210f 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet/unet.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args -from ..models.embeddings import ( +from ...models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterFaceIDPlusImageProjection, @@ -30,8 +30,8 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict -from ..utils import ( +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict +from ...utils import ( USE_PEFT_BACKEND, _get_model_file, convert_unet_state_dict_to_peft, @@ -43,9 +43,9 @@ is_torch_version, logging, ) -from .lora_base import _func_optionally_disable_offloading -from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME -from .utils import AttnProcsLayers +from ..lora.lora_base import _func_optionally_disable_offloading +from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME +from ..utils import AttnProcsLayers logger = logging.get_logger(__name__) @@ -247,7 +247,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Unsafe code /> def _process_custom_diffusion(self, state_dict): - from ..models.attention_processor import CustomDiffusionAttnProcessor + from ...models.attention_processor import CustomDiffusionAttnProcessor attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict) @@ -395,7 +395,7 @@ def _process_lora( return is_model_cpu_offload, is_sequential_cpu_offload @classmethod - # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. @@ -451,7 +451,7 @@ def save_attn_procs( pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") ``` """ - from ..models.attention_processor import ( + from ...models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, @@ -513,7 +513,7 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {save_path}") def _get_custom_diffusion_state_dict(self): - from ..models.attention_processor import ( + from ...models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, @@ -759,7 +759,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( + from ...models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet/unet_loader_utils.py similarity index 98% rename from src/diffusers/loaders/unet_loader_utils.py rename to src/diffusers/loaders/unet/unet_loader_utils.py index 8f202ed4d44b..bfa935d13be8 100644 --- a/src/diffusers/loaders/unet_loader_utils.py +++ b/src/diffusers/loaders/unet/unet_loader_utils.py @@ -14,12 +14,12 @@ import copy from typing import TYPE_CHECKING, Dict, List, Union -from ..utils import logging +from ...utils import logging if TYPE_CHECKING: # import here to avoid circular imports - from ..models import UNet2DConditionModel + from ...models import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 357df0c31087..119409d4c02c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -17,8 +17,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import deprecate from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index e2b26396899f..0f4be88c92d3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 7a6ca886caed..d1bf033459f0 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 26cb86718a21..0716934a7d8e 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -17,7 +17,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import logging from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a873a6ec9444..25420394bdcd 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -20,8 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward from ..attention_processor import Attention diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6532f080d72..964c44ae4db4 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -19,8 +19,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..a35aa1a5b32c 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -19,8 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..activations import get_activation from ..attention_processor import ( diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 860aa6511689..7124fcb41200 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -352,7 +352,7 @@ def test_with_norm_in_state_dict(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -403,7 +403,7 @@ def test_lora_parameter_expanded_shapes(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. @@ -486,7 +486,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -541,7 +541,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -590,7 +590,7 @@ def test_fuse_expanded_lora_with_regular_lora(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -653,7 +653,7 @@ def test_load_regular_lora(self): "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") @@ -668,7 +668,7 @@ def test_load_regular_lora(self): def test_lora_unload_with_parameter_expanded_shapes(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. @@ -734,7 +734,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 87a8fddfa583..fe132944a835 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1017,7 +1017,7 @@ def test_multiple_wrong_adapter_name_raises_error(self): ) scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} - logger = logging.get_logger("diffusers.loaders.lora_base") + logger = logging.get_logger("diffusers.loaders.lora.lora_base") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) @@ -1824,7 +1824,7 @@ def test_logs_info_when_no_lora_keys_found(self): elif lora_module == "text_encoder_2": prefix = "text_encoder_2" - logger = logging.get_logger("diffusers.loaders.lora_base") + logger = logging.get_logger("diffusers.loaders.lora.lora_base") logger.setLevel(logging.WARNING) with CaptureLogger(logger) as cap_logger: @@ -1925,7 +1925,7 @@ def test_lora_B_bias(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 4e1713c9ceb1..10c229ba5e64 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -5,7 +5,7 @@ import torch from huggingface_hub import hf_hub_download, snapshot_download -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py index bba1726ae380..350a433cf68b 100644 --- a/tests/single_file/test_model_vae_single_file.py +++ b/tests/single_file/test_model_vae_single_file.py @@ -18,9 +18,7 @@ import torch -from diffusers import ( - AutoencoderKL, -) +from diffusers import AutoencoderKL from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py index 7f0e1c1a4b0b..83acbdc56504 100644 --- a/tests/single_file/test_model_wan_autoencoder_single_file.py +++ b/tests/single_file/test_model_wan_autoencoder_single_file.py @@ -16,9 +16,7 @@ import gc import unittest -from diffusers import ( - AutoencoderKLWan, -) +from diffusers import AutoencoderKLWan from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py index 36f0919cacb5..8a20135a8aab 100644 --- a/tests/single_file/test_model_wan_transformer3d_single_file.py +++ b/tests/single_file/test_model_wan_transformer3d_single_file.py @@ -18,9 +18,7 @@ import torch -from diffusers import ( - WanTransformer3DModel, -) +from diffusers import WanTransformer3DModel from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index 802ca37abfc3..061a5698d33a 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -3,9 +3,7 @@ import torch -from diffusers import ( - SanaTransformer2DModel, -) +from diffusers import SanaTransformer2DModel from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 7589b48028c2..29972241151d 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 1555831db6db..166365a9a064 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index 2c1e414e5e36..35ce5daac250 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 78baeb94929c..338842e8d08e 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index fb5f8725b86e..06a4f0bf5b3a 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -8,7 +8,7 @@ StableDiffusionXLAdapterPipeline, T2IAdapter, ) -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index 6d8c4369e1e1..9ea2fcc29710 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/utils/check_support_list.py b/utils/check_support_list.py index 5d06f0cb92a3..6f548796893c 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -98,7 +98,7 @@ def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_condit }, "LoRA Mixins": { "doc_path": "docs/source/en/api/loaders/lora.md", - "src_path": "src/diffusers/loaders/lora_pipeline.py", + "src_path": "src/diffusers/loaders/lora/lora_pipeline.py", "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", "src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]", },