diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md
index 946a8b1af875..260762f36d50 100644
--- a/docs/source/en/api/loaders/ip_adapter.md
+++ b/docs/source/en/api/loaders/ip_adapter.md
@@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
## IPAdapterMixin
-[[autodoc]] loaders.ip_adapter.IPAdapterMixin
+[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin
## SD3IPAdapterMixin
-[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
+[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin
- all
- is_ip_adapter_active
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 1c716f6d5e85..6b14c0b25869 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## StableDiffusionLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin
## StableDiffusionXLLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin
## SD3LoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin
## FluxLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin
## CogVideoXLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin
## Mochi1LoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin
## LTXVideoLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin
## SanaLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin
## HunyuanVideoLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin
## Lumina2LoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin
+
+## WanLoraLoaderMixin
+
+[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
+
+## CogView4LoraLoaderMixin
+
+[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
## CogView4LoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
## WanLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
## AmusedLoraLoaderMixin
-[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
+[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin
## HiDreamImageLoraLoaderMixin
@@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## LoraBaseMixin
-[[autodoc]] loaders.lora_base.LoraBaseMixin
\ No newline at end of file
+[[autodoc]] loaders.lora.lora_base.LoraBaseMixin
\ No newline at end of file
diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md
index 4fc9603054b4..cb4f9bdf37b6 100644
--- a/docs/source/en/api/loaders/transformer_sd3.md
+++ b/docs/source/en/api/loaders/transformer_sd3.md
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# SD3Transformer2D
-This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
+This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
@@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## SD3Transformer2DLoadersMixin
-[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
+[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin
- all
- _load_ip_adapter_weights
\ No newline at end of file
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 84c6d9f32c66..9b0dd6ea984b 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -54,14 +54,14 @@ def text_encoder_attn_modules(text_encoder):
_import_structure = {}
if is_torch_available():
- _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
- _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
- _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
- _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
+ _import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
+ _import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
+ _import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
+ _import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
- _import_structure["single_file"] = ["FromSingleFileMixin"]
- _import_structure["lora_pipeline"] = [
+ _import_structure["single_file.single_file"] = ["FromSingleFileMixin"]
+ _import_structure["lora.lora_pipeline"] = [
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
@@ -80,7 +80,7 @@ def text_encoder_attn_modules(text_encoder):
"HiDreamImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
- _import_structure["ip_adapter"] = [
+ _import_structure["ip_adapter.ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
@@ -91,19 +91,14 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
- from .single_file_model import FromOriginalModelMixin
- from .transformer_flux import FluxTransformer2DLoadersMixin
- from .transformer_sd3 import SD3Transformer2DLoadersMixin
+ from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
+ from .single_file import FromOriginalModelMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
- from .ip_adapter import (
- FluxIPAdapterMixin,
- IPAdapterMixin,
- SD3IPAdapterMixin,
- )
- from .lora_pipeline import (
+ from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
+ from .lora import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
@@ -111,6 +106,7 @@ def text_encoder_attn_modules(text_encoder):
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
+ LoraBaseMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index f4c48f254c44..07043114187e 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -12,868 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from pathlib import Path
-from typing import Dict, List, Optional, Union
-import torch
-import torch.nn.functional as F
-from huggingface_hub.utils import validate_hf_hub_args
-from safetensors import safe_open
+from ..utils import deprecate
+from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
-from ..utils import (
- USE_PEFT_BACKEND,
- _get_detailed_type,
- _get_model_file,
- _is_valid_type,
- is_accelerate_available,
- is_torch_version,
- is_transformers_available,
- logging,
-)
-from .unet_loader_utils import _maybe_expand_lora_scales
+class IPAdapterMixin(IPAdapterMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import IPAdapterMixin` instead."
+ deprecate("diffusers.loaders.ip_adapter.IPAdapterMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
-if is_transformers_available():
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
-from ..models.attention_processor import (
- AttnProcessor,
- AttnProcessor2_0,
- FluxAttnProcessor2_0,
- FluxIPAdapterJointAttnProcessor2_0,
- IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
- IPAdapterXFormersAttnProcessor,
- JointAttnProcessor2_0,
- SD3IPAdapterJointAttnProcessor2_0,
-)
+class FluxIPAdapterMixin(FluxIPAdapterMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxIPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import FluxIPAdapterMixin` instead."
+ deprecate("diffusers.loaders.ip_adapter.FluxIPAdapterMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
-logger = logging.get_logger(__name__)
-
-
-class IPAdapterMixin:
- """Mixin for handling IP Adapters."""
-
- @validate_hf_hub_args
- def load_ip_adapter(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
- subfolder: Union[str, List[str]],
- weight_name: Union[str, List[str]],
- image_encoder_folder: Optional[str] = "image_encoder",
- **kwargs,
- ):
- """
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
- subfolder (`str` or `List[str]`):
- The subfolder location of a model file within a larger model repository on the Hub or locally. If a
- list is passed, it should have the same length as `weight_name`.
- weight_name (`str` or `List[str]`):
- The name of the weight file to load. If a list is passed, it should have the same length as
- `subfolder`.
- image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
- The subfolder location of the image encoder within a larger model repository on the Hub or locally.
- Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
- `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
- `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
- `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
- `image_encoder_folder="different_subfolder/image_encoder"`.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
- argument to `True` will raise an error.
- """
-
- # handle the list inputs for multiple IP Adapters
- if not isinstance(weight_name, list):
- weight_name = [weight_name]
-
- if not isinstance(pretrained_model_name_or_path_or_dict, list):
- pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
- if len(pretrained_model_name_or_path_or_dict) == 1:
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
-
- if not isinstance(subfolder, list):
- subfolder = [subfolder]
- if len(subfolder) == 1:
- subfolder = subfolder * len(weight_name)
-
- if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
- raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
-
- if len(weight_name) != len(subfolder):
- raise ValueError("`weight_name` and `subfolder` must have the same length.")
-
- # Load the main state dict first.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
-
- if low_cpu_mem_usage and not is_accelerate_available():
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
- state_dicts = []
- for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
- pretrained_model_name_or_path_or_dict, weight_name, subfolder
- ):
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- if weight_name.endswith(".safetensors"):
- state_dict = {"image_proj": {}, "ip_adapter": {}}
- with safe_open(model_file, framework="pt", device="cpu") as f:
- for key in f.keys():
- if key.startswith("image_proj."):
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
- elif key.startswith("ip_adapter."):
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
- else:
- state_dict = load_state_dict(model_file)
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- keys = list(state_dict.keys())
- if "image_proj" not in keys and "ip_adapter" not in keys:
- raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
-
- state_dicts.append(state_dict)
-
- # load CLIP image encoder here if it has not been registered to the pipeline yet
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
- if image_encoder_folder is not None:
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
- if image_encoder_folder.count("/") == 0:
- image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
- else:
- image_encoder_subfolder = Path(image_encoder_folder).as_posix()
-
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- pretrained_model_name_or_path_or_dict,
- subfolder=image_encoder_subfolder,
- low_cpu_mem_usage=low_cpu_mem_usage,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- torch_dtype=self.dtype,
- ).to(self.device)
- self.register_modules(image_encoder=image_encoder)
- else:
- raise ValueError(
- "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
- )
- else:
- logger.warning(
- "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
- "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
- )
-
- # create feature extractor if it has not been registered to the pipeline yet
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
- # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
- default_clip_size = 224
- clip_image_size = (
- self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
- )
- feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
- self.register_modules(feature_extractor=feature_extractor)
-
- # load ip-adapter into unet
- unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
- unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
-
- extra_loras = unet._load_ip_adapter_loras(state_dicts)
- if extra_loras != {}:
- if not USE_PEFT_BACKEND:
- logger.warning("PEFT backend is required to load these weights.")
- else:
- # apply the IP Adapter Face ID LoRA weights
- peft_config = getattr(unet, "peft_config", {})
- for k, lora in extra_loras.items():
- if f"faceid_{k}" not in peft_config:
- self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
- self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
-
- def set_ip_adapter_scale(self, scale):
- """
- Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
- granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
-
- Example:
-
- ```py
- # To use original IP-Adapter
- scale = 1.0
- pipeline.set_ip_adapter_scale(scale)
-
- # To use style block only
- scale = {
- "up": {"block_0": [0.0, 1.0, 0.0]},
- }
- pipeline.set_ip_adapter_scale(scale)
-
- # To use style+layout blocks
- scale = {
- "down": {"block_2": [0.0, 1.0]},
- "up": {"block_0": [0.0, 1.0, 0.0]},
- }
- pipeline.set_ip_adapter_scale(scale)
-
- # To use style and layout from 2 reference images
- scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
- pipeline.set_ip_adapter_scale(scales)
- ```
- """
- unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
- if not isinstance(scale, list):
- scale = [scale]
- scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
-
- for attn_name, attn_processor in unet.attn_processors.items():
- if isinstance(
- attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
- ):
- if len(scale_configs) != len(attn_processor.scale):
- raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
- )
- elif len(scale_configs) == 1:
- scale_configs = scale_configs * len(attn_processor.scale)
- for i, scale_config in enumerate(scale_configs):
- if isinstance(scale_config, dict):
- for k, s in scale_config.items():
- if attn_name.startswith(k):
- attn_processor.scale[i] = s
- else:
- attn_processor.scale[i] = scale_config
-
- def unload_ip_adapter(self):
- """
- Unloads the IP Adapter weights
-
- Examples:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
- >>> pipeline.unload_ip_adapter()
- >>> ...
- ```
- """
- # remove CLIP image encoder
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
- self.image_encoder = None
- self.register_to_config(image_encoder=[None, None])
-
- # remove feature extractor only when safety_checker is None as safety_checker uses
- # the feature_extractor later
- if not hasattr(self, "safety_checker"):
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
- self.feature_extractor = None
- self.register_to_config(feature_extractor=[None, None])
-
- # remove hidden encoder
- self.unet.encoder_hid_proj = None
- self.unet.config.encoder_hid_dim_type = None
-
- # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
- if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
- self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
- self.unet.text_encoder_hid_proj = None
- self.unet.config.encoder_hid_dim_type = "text_proj"
-
- # restore original Unet attention processors layers
- attn_procs = {}
- for name, value in self.unet.attn_processors.items():
- attn_processor_class = (
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
- )
- attn_procs[name] = (
- attn_processor_class
- if isinstance(
- value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
- )
- else value.__class__()
- )
- self.unet.set_attn_processor(attn_procs)
-
-
-class FluxIPAdapterMixin:
- """Mixin for handling Flux IP Adapters."""
-
- @validate_hf_hub_args
- def load_ip_adapter(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
- weight_name: Union[str, List[str]],
- subfolder: Optional[Union[str, List[str]]] = "",
- image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
- image_encoder_subfolder: Optional[str] = "",
- image_encoder_dtype: torch.dtype = torch.float16,
- **kwargs,
- ):
- """
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
- subfolder (`str` or `List[str]`):
- The subfolder location of a model file within a larger model repository on the Hub or locally. If a
- list is passed, it should have the same length as `weight_name`.
- weight_name (`str` or `List[str]`):
- The name of the weight file to load. If a list is passed, it should have the same length as
- `weight_name`.
- image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
- Can be either:
-
- - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
- hosted on the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
- argument to `True` will raise an error.
- """
-
- # handle the list inputs for multiple IP Adapters
- if not isinstance(weight_name, list):
- weight_name = [weight_name]
-
- if not isinstance(pretrained_model_name_or_path_or_dict, list):
- pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
- if len(pretrained_model_name_or_path_or_dict) == 1:
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
-
- if not isinstance(subfolder, list):
- subfolder = [subfolder]
- if len(subfolder) == 1:
- subfolder = subfolder * len(weight_name)
-
- if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
- raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
-
- if len(weight_name) != len(subfolder):
- raise ValueError("`weight_name` and `subfolder` must have the same length.")
-
- # Load the main state dict first.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
-
- if low_cpu_mem_usage and not is_accelerate_available():
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
- state_dicts = []
- for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
- pretrained_model_name_or_path_or_dict, weight_name, subfolder
- ):
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- if weight_name.endswith(".safetensors"):
- state_dict = {"image_proj": {}, "ip_adapter": {}}
- with safe_open(model_file, framework="pt", device="cpu") as f:
- image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
- ip_adapter_keys = ["double_blocks.", "ip_adapter."]
- for key in f.keys():
- if any(key.startswith(prefix) for prefix in image_proj_keys):
- diffusers_name = ".".join(key.split(".")[1:])
- state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
- elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
- diffusers_name = (
- ".".join(key.split(".")[1:])
- .replace("ip_adapter_double_stream_k_proj", "to_k_ip")
- .replace("ip_adapter_double_stream_v_proj", "to_v_ip")
- .replace("processor.", "")
- )
- state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
- else:
- state_dict = load_state_dict(model_file)
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- keys = list(state_dict.keys())
- if keys != ["image_proj", "ip_adapter"]:
- raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
-
- state_dicts.append(state_dict)
-
- # load CLIP image encoder here if it has not been registered to the pipeline yet
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
- if image_encoder_pretrained_model_name_or_path is not None:
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
- image_encoder = (
- CLIPVisionModelWithProjection.from_pretrained(
- image_encoder_pretrained_model_name_or_path,
- subfolder=image_encoder_subfolder,
- low_cpu_mem_usage=low_cpu_mem_usage,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- torch_dtype=image_encoder_dtype,
- )
- .to(self.device)
- .eval()
- )
- self.register_modules(image_encoder=image_encoder)
- else:
- raise ValueError(
- "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
- )
- else:
- logger.warning(
- "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
- "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
- )
-
- # create feature extractor if it has not been registered to the pipeline yet
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
- # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
- default_clip_size = 224
- clip_image_size = (
- self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
- )
- feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
- self.register_modules(feature_extractor=feature_extractor)
-
- # load ip-adapter into transformer
- self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
-
- def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
- """
- Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
- granular control over each IP-Adapter behavior. A config can be a float or a list.
-
- `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
- length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
- number of IP adapters and each must match the number of blocks.
-
- Example:
-
- ```py
- # To use original IP-Adapter
- scale = 1.0
- pipeline.set_ip_adapter_scale(scale)
-
-
- def LinearStrengthModel(start, finish, size):
- return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
-
-
- ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
- pipeline.set_ip_adapter_scale(ip_strengths)
- ```
- """
-
- scale_type = Union[int, float]
- num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
- num_layers = self.transformer.config.num_layers
-
- # Single value for all layers of all IP-Adapters
- if isinstance(scale, scale_type):
- scale = [scale for _ in range(num_ip_adapters)]
- # List of per-layer scales for a single IP-Adapter
- elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
- scale = [scale]
- # Invalid scale type
- elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
- raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
-
- if len(scale) != num_ip_adapters:
- raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
-
- if any(len(s) != num_layers for s in scale if isinstance(s, list)):
- invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
- raise ValueError(
- f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
- )
-
- # Scalars are transformed to lists with length num_layers
- scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
-
- # Set scales. zip over scale_configs prevents going into single transformer layers
- for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
- attn_processor.scale = scale
-
- def unload_ip_adapter(self):
- """
- Unloads the IP Adapter weights
-
- Examples:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
- >>> pipeline.unload_ip_adapter()
- >>> ...
- ```
- """
- # remove CLIP image encoder
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
- self.image_encoder = None
- self.register_to_config(image_encoder=[None, None])
-
- # remove feature extractor only when safety_checker is None as safety_checker uses
- # the feature_extractor later
- if not hasattr(self, "safety_checker"):
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
- self.feature_extractor = None
- self.register_to_config(feature_extractor=[None, None])
-
- # remove hidden encoder
- self.transformer.encoder_hid_proj = None
- self.transformer.config.encoder_hid_dim_type = None
-
- # restore original Transformer attention processors layers
- attn_procs = {}
- for name, value in self.transformer.attn_processors.items():
- attn_processor_class = FluxAttnProcessor2_0()
- attn_procs[name] = (
- attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
- )
- self.transformer.set_attn_processor(attn_procs)
-
-
-class SD3IPAdapterMixin:
- """Mixin for handling StableDiffusion 3 IP Adapters."""
-
- @property
- def is_ip_adapter_active(self) -> bool:
- """Checks if IP-Adapter is loaded and scale > 0.
-
- IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
- the image context is irrelevant.
-
- Returns:
- `bool`: True when IP-Adapter is loaded and any layer has scale > 0.
- """
- scales = [
- attn_proc.scale
- for attn_proc in self.transformer.attn_processors.values()
- if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
- ]
-
- return len(scales) > 0 and any(scale > 0 for scale in scales)
-
- @validate_hf_hub_args
- def load_ip_adapter(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- weight_name: str = "ip-adapter.safetensors",
- subfolder: Optional[str] = None,
- image_encoder_folder: Optional[str] = "image_encoder",
- **kwargs,
- ) -> None:
- """
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
- weight_name (`str`, defaults to "ip-adapter.safetensors"):
- The name of the weight file to load. If a list is passed, it should have the same length as
- `subfolder`.
- subfolder (`str`, *optional*):
- The subfolder location of a model file within a larger model repository on the Hub or locally. If a
- list is passed, it should have the same length as `weight_name`.
- image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
- The subfolder location of the image encoder within a larger model repository on the Hub or locally.
- Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
- `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
- `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
- `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
- `image_encoder_folder="different_subfolder/image_encoder"`.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
- argument to `True` will raise an error.
- """
- # Load the main state dict first
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
-
- if low_cpu_mem_usage and not is_accelerate_available():
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- if weight_name.endswith(".safetensors"):
- state_dict = {"image_proj": {}, "ip_adapter": {}}
- with safe_open(model_file, framework="pt", device="cpu") as f:
- for key in f.keys():
- if key.startswith("image_proj."):
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
- elif key.startswith("ip_adapter."):
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
- else:
- state_dict = load_state_dict(model_file)
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- keys = list(state_dict.keys())
- if "image_proj" not in keys and "ip_adapter" not in keys:
- raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
-
- # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
- if image_encoder_folder is not None:
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
- if image_encoder_folder.count("/") == 0:
- image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
- else:
- image_encoder_subfolder = Path(image_encoder_folder).as_posix()
-
- # Commons args for loading image encoder and image processor
- kwargs = {
- "low_cpu_mem_usage": low_cpu_mem_usage,
- "cache_dir": cache_dir,
- "local_files_only": local_files_only,
- }
-
- self.register_modules(
- feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
- image_encoder=SiglipVisionModel.from_pretrained(
- image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
- ).to(self.device),
- )
- else:
- raise ValueError(
- "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
- )
- else:
- logger.warning(
- "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
- "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
- )
-
- # Load IP-Adapter into transformer
- self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
-
- def set_ip_adapter_scale(self, scale: float) -> None:
- """
- Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
- conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
- the model to produce more diverse images, but they may not be as aligned with the image prompt.
-
- Example:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
- >>> pipeline.set_ip_adapter_scale(0.6)
- >>> ...
- ```
-
- Args:
- scale (float):
- IP-Adapter scale to be set.
-
- """
- for attn_processor in self.transformer.attn_processors.values():
- if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
- attn_processor.scale = scale
-
- def unload_ip_adapter(self) -> None:
- """
- Unloads the IP Adapter weights.
-
- Example:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
- >>> pipeline.unload_ip_adapter()
- >>> ...
- ```
- """
- # Remove image encoder
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
- self.image_encoder = None
- self.register_to_config(image_encoder=None)
-
- # Remove feature extractor
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
- self.feature_extractor = None
- self.register_to_config(feature_extractor=None)
-
- # Remove image projection
- self.transformer.image_proj = None
-
- # Restore original attention processors layers
- attn_procs = {
- name: (
- JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
- )
- for name, value in self.transformer.attn_processors.items()
- }
- self.transformer.set_attn_processor(attn_procs)
+class SD3IPAdapterMixin(SD3IPAdapterMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import SD3IPAdapterMixin` instead."
+ deprecate("diffusers.loaders.ip_adapter.SD3IPAdapterMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/ip_adapter/__init__.py b/src/diffusers/loaders/ip_adapter/__init__.py
new file mode 100644
index 000000000000..42b1d6e1b509
--- /dev/null
+++ b/src/diffusers/loaders/ip_adapter/__init__.py
@@ -0,0 +1,9 @@
+from ...utils.import_utils import is_torch_available, is_transformers_available
+
+
+if is_torch_available():
+ from .transformer_flux import FluxTransformer2DLoadersMixin
+ from .transformer_sd3 import SD3Transformer2DLoadersMixin
+
+ if is_transformers_available():
+ from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
diff --git a/src/diffusers/loaders/ip_adapter/ip_adapter.py b/src/diffusers/loaders/ip_adapter/ip_adapter.py
new file mode 100644
index 000000000000..782391d84a88
--- /dev/null
+++ b/src/diffusers/loaders/ip_adapter/ip_adapter.py
@@ -0,0 +1,879 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from huggingface_hub.utils import validate_hf_hub_args
+from safetensors import safe_open
+
+from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
+from ...utils import (
+ USE_PEFT_BACKEND,
+ _get_detailed_type,
+ _get_model_file,
+ _is_valid_type,
+ is_accelerate_available,
+ is_torch_version,
+ is_transformers_available,
+ logging,
+)
+from ..unet.unet_loader_utils import _maybe_expand_lora_scales
+
+
+if is_transformers_available():
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
+
+from ...models.attention_processor import (
+ AttnProcessor,
+ AttnProcessor2_0,
+ FluxAttnProcessor2_0,
+ FluxIPAdapterJointAttnProcessor2_0,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+ IPAdapterXFormersAttnProcessor,
+ JointAttnProcessor2_0,
+ SD3IPAdapterJointAttnProcessor2_0,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class IPAdapterMixin:
+ """Mixin for handling IP Adapters."""
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ subfolder: Union[str, List[str]],
+ weight_name: Union[str, List[str]],
+ image_encoder_folder: Optional[str] = "image_encoder",
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ subfolder (`str` or `List[str]`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ weight_name (`str` or `List[str]`):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
+ `image_encoder_folder="different_subfolder/image_encoder"`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ state_dicts.append(state_dict)
+
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_folder is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+ if image_encoder_folder.count("/") == 0:
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+ else:
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ pretrained_model_name_or_path_or_dict,
+ subfolder=image_encoder_subfolder,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ torch_dtype=self.dtype,
+ ).to(self.device)
+ self.register_modules(image_encoder=image_encoder)
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # create feature extractor if it has not been registered to the pipeline yet
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
+ default_clip_size = 224
+ clip_image_size = (
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
+ )
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
+ self.register_modules(feature_extractor=feature_extractor)
+
+ # load ip-adapter into unet
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
+ if extra_loras != {}:
+ if not USE_PEFT_BACKEND:
+ logger.warning("PEFT backend is required to load these weights.")
+ else:
+ # apply the IP Adapter Face ID LoRA weights
+ peft_config = getattr(unet, "peft_config", {})
+ for k, lora in extra_loras.items():
+ if f"faceid_{k}" not in peft_config:
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
+
+ def set_ip_adapter_scale(self, scale):
+ """
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
+
+ Example:
+
+ ```py
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style block only
+ scale = {
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style+layout blocks
+ scale = {
+ "down": {"block_2": [0.0, 1.0]},
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style and layout from 2 reference images
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
+ pipeline.set_ip_adapter_scale(scales)
+ ```
+ """
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ if not isinstance(scale, list):
+ scale = [scale]
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
+
+ for attn_name, attn_processor in unet.attn_processors.items():
+ if isinstance(
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ ):
+ if len(scale_configs) != len(attn_processor.scale):
+ raise ValueError(
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
+ )
+ elif len(scale_configs) == 1:
+ scale_configs = scale_configs * len(attn_processor.scale)
+ for i, scale_config in enumerate(scale_configs):
+ if isinstance(scale_config, dict):
+ for k, s in scale_config.items():
+ if attn_name.startswith(k):
+ attn_processor.scale[i] = s
+ else:
+ attn_processor.scale[i] = scale_config
+
+ def unload_ip_adapter(self):
+ """
+ Unloads the IP Adapter weights
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # remove CLIP image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=[None, None])
+
+ # remove feature extractor only when safety_checker is None as safety_checker uses
+ # the feature_extractor later
+ if not hasattr(self, "safety_checker"):
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=[None, None])
+
+ # remove hidden encoder
+ self.unet.encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = None
+
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
+ self.unet.text_encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = "text_proj"
+
+ # restore original Unet attention processors layers
+ attn_procs = {}
+ for name, value in self.unet.attn_processors.items():
+ attn_processor_class = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+ )
+ attn_procs[name] = (
+ attn_processor_class
+ if isinstance(
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ )
+ else value.__class__()
+ )
+ self.unet.set_attn_processor(attn_procs)
+
+
+class FluxIPAdapterMixin:
+ """Mixin for handling Flux IP Adapters."""
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ weight_name: Union[str, List[str]],
+ subfolder: Optional[Union[str, List[str]]] = "",
+ image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
+ image_encoder_subfolder: Optional[str] = "",
+ image_encoder_dtype: torch.dtype = torch.float16,
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ subfolder (`str` or `List[str]`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ weight_name (`str` or `List[str]`):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `weight_name`.
+ image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
+ Can be either:
+
+ - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
+ ip_adapter_keys = ["double_blocks.", "ip_adapter."]
+ for key in f.keys():
+ if any(key.startswith(prefix) for prefix in image_proj_keys):
+ diffusers_name = ".".join(key.split(".")[1:])
+ state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
+ elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
+ diffusers_name = (
+ ".".join(key.split(".")[1:])
+ .replace("ip_adapter_double_stream_k_proj", "to_k_ip")
+ .replace("ip_adapter_double_stream_v_proj", "to_v_ip")
+ .replace("processor.", "")
+ )
+ state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if keys != ["image_proj", "ip_adapter"]:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ state_dicts.append(state_dict)
+
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_pretrained_model_name_or_path is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
+ image_encoder = (
+ CLIPVisionModelWithProjection.from_pretrained(
+ image_encoder_pretrained_model_name_or_path,
+ subfolder=image_encoder_subfolder,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ dtype=image_encoder_dtype,
+ )
+ .to(self.device)
+ .eval()
+ )
+ self.register_modules(image_encoder=image_encoder)
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # create feature extractor if it has not been registered to the pipeline yet
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
+ default_clip_size = 224
+ clip_image_size = (
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
+ )
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
+ self.register_modules(feature_extractor=feature_extractor)
+
+ # load ip-adapter into transformer
+ self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
+ """
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a list.
+
+ `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
+ length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
+ number of IP adapters and each must match the number of blocks.
+
+ Example:
+
+ ```py
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+
+ def LinearStrengthModel(start, finish, size):
+ return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
+
+
+ ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
+ pipeline.set_ip_adapter_scale(ip_strengths)
+ ```
+ """
+
+ scale_type = Union[int, float]
+ num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
+ num_layers = self.transformer.config.num_layers
+
+ # Single value for all layers of all IP-Adapters
+ if isinstance(scale, scale_type):
+ scale = [scale for _ in range(num_ip_adapters)]
+ # List of per-layer scales for a single IP-Adapter
+ elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
+ scale = [scale]
+ # Invalid scale type
+ elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
+ raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
+
+ if len(scale) != num_ip_adapters:
+ raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
+
+ if any(len(s) != num_layers for s in scale if isinstance(s, list)):
+ invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
+ raise ValueError(
+ f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
+ )
+
+ # Scalars are transformed to lists with length num_layers
+ scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
+
+ # Set scales. zip over scale_configs prevents going into single transformer layers
+ for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
+ attn_processor.scale = scale
+
+ def unload_ip_adapter(self):
+ """
+ Unloads the IP Adapter weights
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # remove CLIP image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=[None, None])
+
+ # remove feature extractor only when safety_checker is None as safety_checker uses
+ # the feature_extractor later
+ if not hasattr(self, "safety_checker"):
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=[None, None])
+
+ # remove hidden encoder
+ self.transformer.encoder_hid_proj = None
+ self.transformer.config.encoder_hid_dim_type = None
+
+ # restore original Transformer attention processors layers
+ attn_procs = {}
+ for name, value in self.transformer.attn_processors.items():
+ attn_processor_class = FluxAttnProcessor2_0()
+ attn_procs[name] = (
+ attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
+ )
+ self.transformer.set_attn_processor(attn_procs)
+
+
+class SD3IPAdapterMixin:
+ """Mixin for handling StableDiffusion 3 IP Adapters."""
+
+ @property
+ def is_ip_adapter_active(self) -> bool:
+ """Checks if IP-Adapter is loaded and scale > 0.
+
+ IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
+ the image context is irrelevant.
+
+ Returns:
+ `bool`: True when IP-Adapter is loaded and any layer has scale > 0.
+ """
+ scales = [
+ attn_proc.scale
+ for attn_proc in self.transformer.attn_processors.values()
+ if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
+ ]
+
+ return len(scales) > 0 and any(scale > 0 for scale in scales)
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ weight_name: str = "ip-adapter.safetensors",
+ subfolder: Optional[str] = None,
+ image_encoder_folder: Optional[str] = "image_encoder",
+ **kwargs,
+ ) -> None:
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ weight_name (`str`, defaults to "ip-adapter.safetensors"):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ subfolder (`str`, *optional*):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
+ `image_encoder_folder="different_subfolder/image_encoder"`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+ # Load the main state dict first
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_folder is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+ if image_encoder_folder.count("/") == 0:
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+ else:
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+ # Commons args for loading image encoder and image processor
+ kwargs = {
+ "low_cpu_mem_usage": low_cpu_mem_usage,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ }
+
+ self.register_modules(
+ feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
+ image_encoder=SiglipVisionModel.from_pretrained(
+ image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
+ ).to(self.device),
+ )
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # Load IP-Adapter into transformer
+ self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def set_ip_adapter_scale(self, scale: float) -> None:
+ """
+ Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
+ conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
+ the model to produce more diverse images, but they may not be as aligned with the image prompt.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.set_ip_adapter_scale(0.6)
+ >>> ...
+ ```
+
+ Args:
+ scale (float):
+ IP-Adapter scale to be set.
+
+ """
+ for attn_processor in self.transformer.attn_processors.values():
+ if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
+ attn_processor.scale = scale
+
+ def unload_ip_adapter(self) -> None:
+ """
+ Unloads the IP Adapter weights.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # Remove image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=None)
+
+ # Remove feature extractor
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=None)
+
+ # Remove image projection
+ self.transformer.image_proj = None
+
+ # Restore original attention processors layers
+ attn_procs = {
+ name: (
+ JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
+ )
+ for name, value in self.transformer.attn_processors.items()
+ }
+ self.transformer.set_attn_processor(attn_procs)
diff --git a/src/diffusers/loaders/ip_adapter/transformer_flux.py b/src/diffusers/loaders/ip_adapter/transformer_flux.py
new file mode 100644
index 000000000000..5d6e7c809b12
--- /dev/null
+++ b/src/diffusers/loaders/ip_adapter/transformer_flux.py
@@ -0,0 +1,168 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from contextlib import nullcontext
+
+from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
+from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ...utils import is_accelerate_available, is_torch_version, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class FluxTransformer2DLoadersMixin:
+ """
+ Load layers into a [`FluxTransformer2DModel`].
+ """
+
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ updated_state_dict = {}
+ image_projection = None
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+
+ if "proj.weight" in state_dict:
+ # IP-Adapter
+ num_image_text_embeds = 4
+ if state_dict["proj.weight"].shape[0] == 65536:
+ num_image_text_embeds = 16
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
+
+ with init_context():
+ image_projection = ImageProjection(
+ cross_attention_dim=cross_attention_dim,
+ image_embed_dim=clip_embeddings_dim,
+ num_image_text_embeds=num_image_text_embeds,
+ )
+
+ for key, value in state_dict.items():
+ diffusers_name = key.replace("proj", "image_embeds")
+ updated_state_dict[diffusers_name] = value
+
+ if not low_cpu_mem_usage:
+ image_projection.load_state_dict(updated_state_dict, strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+
+ return image_projection
+
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ # set ip-adapter cross-attention processors & load state_dict
+ attn_procs = {}
+ key_id = 0
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+ for name in self.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ attn_processor_class = self.attn_processors[name].__class__
+ attn_procs[name] = attn_processor_class()
+ else:
+ cross_attention_dim = self.config.joint_attention_dim
+ hidden_size = self.inner_dim
+ attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
+ num_image_text_embeds = []
+ for state_dict in state_dicts:
+ if "proj.weight" in state_dict["image_proj"]:
+ num_image_text_embed = 4
+ if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
+ num_image_text_embed = 16
+ # IP-Adapter
+ num_image_text_embeds += [num_image_text_embed]
+
+ with init_context():
+ attn_procs[name] = attn_processor_class(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=num_image_text_embeds,
+ dtype=self.dtype,
+ device=self.device,
+ )
+
+ value_dict = {}
+ for i, state_dict in enumerate(state_dicts):
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+ value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
+ value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
+
+ if not low_cpu_mem_usage:
+ attn_procs[name].load_state_dict(value_dict)
+ else:
+ device_map = {"": self.device}
+ dtype = self.dtype
+ load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
+
+ key_id += 1
+
+ return attn_procs
+
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ if not isinstance(state_dicts, list):
+ state_dicts = [state_dicts]
+
+ self.encoder_hid_proj = None
+
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+ self.set_attn_processor(attn_procs)
+
+ image_projection_layers = []
+ for state_dict in state_dicts:
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
+ )
+ image_projection_layers.append(image_projection_layer)
+
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+ self.config.encoder_hid_dim_type = "ip_image_proj"
diff --git a/src/diffusers/loaders/ip_adapter/transformer_sd3.py b/src/diffusers/loaders/ip_adapter/transformer_sd3.py
new file mode 100644
index 000000000000..5911ec5903d5
--- /dev/null
+++ b/src/diffusers/loaders/ip_adapter/transformer_sd3.py
@@ -0,0 +1,170 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from contextlib import nullcontext
+from typing import Dict
+
+from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
+from ...models.embeddings import IPAdapterTimeImageProjection
+from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ...utils import is_accelerate_available, is_torch_version, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SD3Transformer2DLoadersMixin:
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
+
+ def _convert_ip_adapter_attn_to_diffusers(
+ self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
+ ) -> Dict:
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ # IP-Adapter cross attention parameters
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
+ timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
+
+ # Dict where key is transformer layer index, value is attention processor's state dict
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
+ for key, weights in state_dict.items():
+ idx, name = key.split(".", maxsplit=1)
+ layer_state_dict[int(idx)][name] = weights
+
+ # Create IP-Adapter attention processor & load state_dict
+ attn_procs = {}
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+ for idx, name in enumerate(self.attn_processors.keys()):
+ with init_context():
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size,
+ ip_hidden_states_dim=ip_hidden_states_dim,
+ head_dim=self.config.attention_head_dim,
+ timesteps_emb_dim=timesteps_emb_dim,
+ )
+
+ if not low_cpu_mem_usage:
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(
+ attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
+ )
+
+ return attn_procs
+
+ def _convert_ip_adapter_image_proj_to_diffusers(
+ self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
+ ) -> IPAdapterTimeImageProjection:
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+
+ # Convert to diffusers
+ updated_state_dict = {}
+ for key, value in state_dict.items():
+ # InstantX/SD3.5-Large-IP-Adapter
+ if key.startswith("layers."):
+ idx = key.split(".")[1]
+ key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
+ key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
+ key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
+ key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
+ key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
+ key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
+ key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
+ key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
+ key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
+ updated_state_dict[key] = value
+
+ # Image projetion parameters
+ embed_dim = updated_state_dict["proj_in.weight"].shape[1]
+ output_dim = updated_state_dict["proj_out.weight"].shape[0]
+ hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
+ heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
+ num_queries = updated_state_dict["latents"].shape[1]
+ timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
+
+ # Image projection
+ with init_context():
+ image_proj = IPAdapterTimeImageProjection(
+ embed_dim=embed_dim,
+ output_dim=output_dim,
+ hidden_dim=hidden_dim,
+ heads=heads,
+ num_queries=num_queries,
+ timestep_in_dim=timestep_in_dim,
+ )
+
+ if not low_cpu_mem_usage:
+ image_proj.load_state_dict(updated_state_dict, strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
+
+ return image_proj
+
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
+
+ Args:
+ state_dict (`Dict`):
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
+ "image_proj", which contains parameters for image projection net.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
+ self.set_attn_processor(attn_procs)
+
+ self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
diff --git a/src/diffusers/loaders/lora/__init__.py b/src/diffusers/loaders/lora/__init__.py
new file mode 100644
index 000000000000..a0f2111cfcab
--- /dev/null
+++ b/src/diffusers/loaders/lora/__init__.py
@@ -0,0 +1,25 @@
+from ...utils import is_peft_available, is_torch_available, is_transformers_available
+
+
+if is_torch_available():
+ from .lora_base import LoraBaseMixin
+
+ if is_transformers_available():
+ from .lora_pipeline import (
+ AmusedLoraLoaderMixin,
+ AuraFlowLoraLoaderMixin,
+ CogVideoXLoraLoaderMixin,
+ CogView4LoraLoaderMixin,
+ FluxLoraLoaderMixin,
+ HiDreamImageLoraLoaderMixin,
+ HunyuanVideoLoraLoaderMixin,
+ LoraLoaderMixin,
+ LTXVideoLoraLoaderMixin,
+ Lumina2LoraLoaderMixin,
+ Mochi1LoraLoaderMixin,
+ SanaLoraLoaderMixin,
+ SD3LoraLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ WanLoraLoaderMixin,
+ )
diff --git a/src/diffusers/loaders/lora/lora_base.py b/src/diffusers/loaders/lora/lora_base.py
new file mode 100644
index 000000000000..80445106015b
--- /dev/null
+++ b/src/diffusers/loaders/lora/lora_base.py
@@ -0,0 +1,935 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+import os
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import safetensors
+import torch
+import torch.nn as nn
+from huggingface_hub import model_info
+from huggingface_hub.constants import HF_HUB_OFFLINE
+
+from ...models.modeling_utils import ModelMixin, load_state_dict
+from ...utils import (
+ USE_PEFT_BACKEND,
+ _get_model_file,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_peft,
+ delete_adapter_layers,
+ deprecate,
+ get_adapter_name,
+ get_peft_kwargs,
+ is_accelerate_available,
+ is_peft_available,
+ is_peft_version,
+ is_transformers_available,
+ is_transformers_version,
+ logging,
+ recurse_remove_peft_layers,
+ scale_lora_layers,
+ set_adapter_layers,
+ set_weights_and_activate_adapters,
+)
+
+
+if is_transformers_available():
+ from transformers import PreTrainedModel
+
+ from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
+
+if is_peft_available():
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+if is_accelerate_available():
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
+
+def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
+ """
+ Fuses LoRAs for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ """
+ merge_kwargs = {"safe_merge": safe_fusing}
+
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ if lora_scale != 1.0:
+ module.scale_layer(lora_scale)
+
+ # For BC with previous PEFT versions, we need to check the signature
+ # of the `merge` method to see if it supports the `adapter_names` argument.
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
+ if "adapter_names" in supported_merge_kwargs:
+ merge_kwargs["adapter_names"] = adapter_names
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
+ raise ValueError(
+ "The `adapter_names` argument is not supported with your PEFT version. "
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
+ )
+
+ module.merge(**merge_kwargs)
+
+
+def unfuse_text_encoder_lora(text_encoder):
+ """
+ Unfuses LoRAs for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+
+def set_adapters_for_text_encoder(
+ adapter_names: Union[List[str], str],
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
+):
+ """
+ Sets the adapter layers for the text encoder.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ text_encoder_weights (`List[float]`, *optional*):
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
+ """
+ if text_encoder is None:
+ raise ValueError(
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
+ )
+
+ def process_weights(adapter_names, weights):
+ # Expand weights into a list, one entry per adapter
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
+ if not isinstance(weights, list):
+ weights = [weights] * len(adapter_names)
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
+ )
+
+ # Set None values to default of 1.0
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
+ weights = [w if w is not None else 1.0 for w in weights]
+
+ return weights
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
+
+
+def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
+ """
+ Disables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(text_encoder, enabled=False)
+
+
+def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
+ """
+ Enables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(text_encoder, enabled=True)
+
+
+def _remove_text_encoder_monkey_patch(text_encoder):
+ recurse_remove_peft_layers(text_encoder)
+ if getattr(text_encoder, "peft_config", None) is not None:
+ del text_encoder.peft_config
+ text_encoder._hf_peft_config_loaded = None
+
+
+def _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ weight_name,
+ use_safetensors,
+ local_files_only,
+ cache_dir,
+ force_download,
+ proxies,
+ token,
+ revision,
+ subfolder,
+ user_agent,
+ allow_pickle,
+):
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ # Here we're relaxing the loading check to enable more Inference API
+ # friendliness where sometimes, it's not at all possible to automatically
+ # determine `weight_name`.
+ if weight_name is None:
+ weight_name = _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict,
+ file_extension=".safetensors",
+ local_files_only=local_files_only,
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except (IOError, safetensors.SafetensorError) as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ model_file = None
+ pass
+
+ if model_file is None:
+ if weight_name is None:
+ weight_name = _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ return state_dict
+
+
+def _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
+):
+ if local_files_only or HF_HUB_OFFLINE:
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
+
+ targeted_files = []
+
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
+ return
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
+ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
+ else:
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
+ if len(targeted_files) == 0:
+ return
+
+ # "scheduler" does not correspond to a LoRA checkpoint.
+ # "optimizer" does not correspond to a LoRA checkpoint
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
+ targeted_files = list(
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
+ )
+
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
+
+ if len(targeted_files) > 1:
+ raise ValueError(
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
+ )
+ weight_name = targeted_files[0]
+ return weight_name
+
+
+def _load_lora_into_text_encoder(
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ text_encoder_name="text_encoder",
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ peft_kwargs = {}
+ if low_cpu_mem_usage:
+ if not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+ if not is_transformers_version(">", "4.45.2"):
+ # Note from sayakpaul: It's not in `transformers` stable yet.
+ # https://github.com/huggingface/transformers/pull/33725/
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
+ )
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
+ # their prefixes.
+ prefix = text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
+ raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
+
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ if prefix is not None:
+ state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+
+ if len(state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ state_dict = convert_state_dict_to_diffusers(state_dict)
+
+ # convert state dict
+ state_dict = convert_state_dict_to_peft(state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in state_dict:
+ continue
+ rank[rank_key] = state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in state_dict:
+ continue
+ rank[rank_key] = state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
+
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+
+ if "lora_bias" in lora_config_kwargs:
+ if lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError(
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<=", "0.13.2"):
+ lora_config_kwargs.pop("lora_bias")
+
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=state_dict,
+ peft_config=lora_config,
+ **peft_kwargs,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ if prefix is not None and not state_dict:
+ logger.warning(
+ f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
+ "This is safe to ignore if LoRA state dict didn't originally have any "
+ f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
+ "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
+ "https://github.com/huggingface/diffusers/issues/new"
+ )
+
+
+def _func_optionally_disable_offloading(_pipeline):
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+
+ if _pipeline is not None and _pipeline.hf_device_map is None:
+ for _, component in _pipeline.components.items():
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
+ if not is_model_cpu_offload:
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
+ if not is_sequential_cpu_offload:
+ is_sequential_cpu_offload = (
+ isinstance(component._hf_hook, AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
+
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
+
+
+class LoraBaseMixin:
+ """Utility class for handling LoRAs."""
+
+ _lora_loadable_modules = []
+ num_fused_loras = 0
+
+ def load_lora_weights(self, **kwargs):
+ raise NotImplementedError("`load_lora_weights()` is not implemented.")
+
+ @classmethod
+ def save_lora_weights(cls, **kwargs):
+ raise NotImplementedError("`save_lora_weights()` not implemented.")
+
+ @classmethod
+ def lora_state_dict(cls, **kwargs):
+ raise NotImplementedError("`lora_state_dict()` is not implemented.")
+
+ @classmethod
+ def _optionally_disable_offloading(cls, _pipeline):
+ """
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+
+ Args:
+ _pipeline (`DiffusionPipeline`):
+ The pipeline to disable offloading for.
+
+ Returns:
+ tuple:
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
+ """
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
+
+ @classmethod
+ def _fetch_state_dict(cls, *args, **kwargs):
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
+ return _fetch_state_dict(*args, **kwargs)
+
+ @classmethod
+ def _best_guess_weight_name(cls, *args, **kwargs):
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
+ return _best_guess_weight_name(*args, **kwargs)
+
+ def unload_lora_weights(self):
+ """
+ Unloads the LoRA parameters.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.unload_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ _remove_text_encoder_monkey_patch(model)
+
+ def fuse_lora(
+ self,
+ components: List[str] = [],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ if "fuse_unet" in kwargs:
+ depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
+ deprecate(
+ "fuse_unet",
+ "1.0.0",
+ depr_message,
+ )
+ if "fuse_transformer" in kwargs:
+ depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
+ deprecate(
+ "fuse_transformer",
+ "1.0.0",
+ depr_message,
+ )
+ if "fuse_text_encoder" in kwargs:
+ depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
+ deprecate(
+ "fuse_text_encoder",
+ "1.0.0",
+ depr_message,
+ )
+
+ if len(components) == 0:
+ raise ValueError("`components` cannot be an empty list.")
+
+ for fuse_component in components:
+ if fuse_component not in self._lora_loadable_modules:
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
+
+ model = getattr(self, fuse_component, None)
+ if model is not None:
+ # check if diffusers model
+ if issubclass(model.__class__, ModelMixin):
+ model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
+ # handle transformers models.
+ if issubclass(model.__class__, PreTrainedModel):
+ fuse_text_encoder_lora(
+ model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ self.num_fused_loras += 1
+
+ def unfuse_lora(self, components: List[str] = [], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ if "unfuse_unet" in kwargs:
+ depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
+ deprecate(
+ "unfuse_unet",
+ "1.0.0",
+ depr_message,
+ )
+ if "unfuse_transformer" in kwargs:
+ depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
+ deprecate(
+ "unfuse_transformer",
+ "1.0.0",
+ depr_message,
+ )
+ if "unfuse_text_encoder" in kwargs:
+ depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
+ deprecate(
+ "unfuse_text_encoder",
+ "1.0.0",
+ depr_message,
+ )
+
+ if len(components) == 0:
+ raise ValueError("`components` cannot be an empty list.")
+
+ for fuse_component in components:
+ if fuse_component not in self._lora_loadable_modules:
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
+
+ model = getattr(self, fuse_component, None)
+ if model is not None:
+ if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ self.num_fused_loras -= 1
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
+ ):
+ if isinstance(adapter_weights, dict):
+ components_passed = set(adapter_weights.keys())
+ lora_components = set(self._lora_loadable_modules)
+
+ invalid_components = sorted(components_passed - lora_components)
+ if invalid_components:
+ logger.warning(
+ f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
+ f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
+ "to the invalid components will be removed and ignored."
+ )
+ adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+ adapter_weights = copy.deepcopy(adapter_weights)
+
+ # Expand weights into a list, one entry per adapter
+ if not isinstance(adapter_weights, list):
+ adapter_weights = [adapter_weights] * len(adapter_names)
+
+ if len(adapter_names) != len(adapter_weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
+ )
+
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
+ # eg ["adapter1", "adapter2"]
+ all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
+ missing_adapters = set(adapter_names) - all_adapters
+ if len(missing_adapters) > 0:
+ raise ValueError(
+ f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
+ )
+
+ # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
+ invert_list_adapters = {
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
+ for adapter in all_adapters
+ }
+
+ # Decompose weights into weights for denoiser and text encoders.
+ _component_adapter_weights = {}
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component)
+
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
+ if isinstance(weights, dict):
+ component_adapter_weights = weights.pop(component, None)
+ if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
+ logger.warning(
+ (
+ f"Lora weight dict for adapter '{adapter_name}' contains {component},"
+ f"but this will be ignored because {adapter_name} does not contain weights for {component}."
+ f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
+ )
+ )
+
+ else:
+ component_adapter_weights = weights
+
+ _component_adapter_weights.setdefault(component, [])
+ _component_adapter_weights[component].append(component_adapter_weights)
+
+ if issubclass(model.__class__, ModelMixin):
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
+ elif issubclass(model.__class__, PreTrainedModel):
+ set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
+
+ def disable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.disable_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ disable_lora_for_text_encoder(model)
+
+ def enable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.enable_lora()
+ elif issubclass(model.__class__, PreTrainedModel):
+ enable_lora_for_text_encoder(model)
+
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
+ """
+ Args:
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
+ adapter_names (`Union[List[str], str]`):
+ The names of the adapter to delete. Can be a single string or a list of strings
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ if issubclass(model.__class__, ModelMixin):
+ model.delete_adapters(adapter_names)
+ elif issubclass(model.__class__, PreTrainedModel):
+ for adapter_name in adapter_names:
+ delete_adapter_layers(model, adapter_name)
+
+ def get_active_adapters(self) -> List[str]:
+ """
+ Gets the list of the current active adapters.
+
+ Example:
+
+ ```python
+ from diffusers import DiffusionPipeline
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ ).to("cuda")
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+ pipeline.get_active_adapters()
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ active_adapters = []
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None and issubclass(model.__class__, ModelMixin):
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ active_adapters = module.active_adapters
+ break
+
+ return active_adapters
+
+ def get_list_adapters(self) -> Dict[str, List[str]]:
+ """
+ Gets the current list of all available adapters in the pipeline.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ set_adapters = {}
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if (
+ model is not None
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
+ and hasattr(model, "peft_config")
+ ):
+ set_adapters[component] = list(model.peft_config.keys())
+
+ return set_adapters
+
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
+ """
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
+ you want to load multiple adapters and free some GPU memory.
+
+ Args:
+ adapter_names (`List[str]`):
+ List of adapters to send device to.
+ device (`Union[torch.device, str, int]`):
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ for component in self._lora_loadable_modules:
+ model = getattr(self, component, None)
+ if model is not None:
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ for adapter_name in adapter_names:
+ module.lora_A[adapter_name].to(device)
+ module.lora_B[adapter_name].to(device)
+ # this is a param, not a module, so device placement is not in-place -> re-assign
+ if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
+ if adapter_name in module.lora_magnitude_vector:
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
+ adapter_name
+ ].to(device)
+
+ @staticmethod
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ @staticmethod
+ def write_lora_layers(
+ state_dict: Dict[str, torch.Tensor],
+ save_directory: str,
+ is_main_process: bool,
+ weight_name: str,
+ save_function: Callable,
+ safe_serialization: bool,
+ ):
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = LORA_WEIGHT_NAME
+
+ save_path = Path(save_directory, weight_name).as_posix()
+ save_function(state_dict, save_path)
+ logger.info(f"Model weights saved in {save_path}")
+
+ @property
+ def lora_scale(self) -> float:
+ # property function that returns the lora scale which can be set at run time by the pipeline.
+ # if _lora_scale has not been set, return 1
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+
+ def enable_lora_hotswap(self, **kwargs) -> None:
+ """Enables the possibility to hotswap LoRA adapters.
+
+ Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
+ the loaded adapters differ.
+
+ Args:
+ target_rank (`int`):
+ The highest rank among all the adapters that will be loaded.
+ check_compiled (`str`, *optional*, defaults to `"error"`):
+ How to handle the case when the model is already compiled, which should generally be avoided. The
+ options are:
+ - "error" (default): raise an error
+ - "warn": issue a warning
+ - "ignore": do nothing
+ """
+ for key, component in self.components.items():
+ if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
+ component.enable_lora_hotswap(**kwargs)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora/lora_conversion_utils.py
similarity index 99%
rename from src/diffusers/loaders/lora_conversion_utils.py
rename to src/diffusers/loaders/lora/lora_conversion_utils.py
index d0c9611735ce..3c5afdcf6523 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora/lora_conversion_utils.py
@@ -17,7 +17,7 @@
import torch
-from ..utils import is_peft_version, logging, state_dict_all_zero
+from ...utils import is_peft_version, logging, state_dict_all_zero
logger = logging.get_logger(__name__)
diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py
new file mode 100644
index 000000000000..27705041408d
--- /dev/null
+++ b/src/diffusers/loaders/lora/lora_pipeline.py
@@ -0,0 +1,5686 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ get_submodule_by_name,
+ is_bitsandbytes_available,
+ is_gguf_available,
+ is_peft_available,
+ is_peft_version,
+ is_torch_version,
+ is_transformers_available,
+ is_transformers_version,
+ logging,
+)
+from .lora_base import ( # noqa
+ LORA_WEIGHT_NAME,
+ LORA_WEIGHT_NAME_SAFE,
+ LoraBaseMixin,
+ _fetch_state_dict,
+ _load_lora_into_text_encoder,
+)
+from .lora_conversion_utils import (
+ _convert_bfl_flux_control_lora_to_diffusers,
+ _convert_hunyuan_video_lora_to_diffusers,
+ _convert_kohya_flux_lora_to_diffusers,
+ _convert_musubi_wan_lora_to_diffusers,
+ _convert_non_diffusers_lora_to_diffusers,
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
+ _convert_non_diffusers_wan_lora_to_diffusers,
+ _convert_xlabs_flux_lora_to_diffusers,
+ _maybe_map_sgm_blocks_to_diffusers,
+)
+
+
+_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
+if is_torch_version(">=", "1.9.0"):
+ if (
+ is_peft_available()
+ and is_peft_version(">=", "0.13.1")
+ and is_transformers_available()
+ and is_transformers_version(">", "4.45.2")
+ ):
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
+
+
+logger = logging.get_logger(__name__)
+
+TEXT_ENCODER_NAME = "text_encoder"
+UNET_NAME = "unet"
+TRANSFORMER_NAME = "transformer"
+
+_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
+
+
+def _maybe_dequantize_weight_for_expanded_lora(model, module):
+ if is_bitsandbytes_available():
+ from ...quantizers.bitsandbytes import dequantize_bnb_weight
+
+ if is_gguf_available():
+ from ...quantizers.gguf.utils import dequantize_gguf_tensor
+
+ is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
+ is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
+
+ if is_bnb_4bit_quantized and not is_bitsandbytes_available():
+ raise ValueError(
+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
+ )
+ if is_gguf_quantized and not is_gguf_available():
+ raise ValueError(
+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
+ )
+
+ weight_on_cpu = False
+ if not module.weight.is_cuda:
+ weight_on_cpu = True
+
+ if is_bnb_4bit_quantized:
+ module_weight = dequantize_bnb_weight(
+ module.weight.cuda() if weight_on_cpu else module.weight,
+ state=module.weight.quant_state,
+ dtype=model.dtype,
+ ).data
+ elif is_gguf_quantized:
+ module_weight = dequantize_gguf_tensor(
+ module.weight.cuda() if weight_on_cpu else module.weight,
+ )
+ module_weight = module_weight.to(model.dtype)
+ else:
+ module_weight = module.weight.data
+
+ if weight_on_cpu:
+ module_weight = module_weight.cpu()
+
+ return module_weight
+
+
+class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+ """
+
+ _lora_loadable_modules = ["unet", "text_encoder"]
+ unet_name = UNET_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
+ loaded into `self.unet`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
+ dict is loaded into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
+
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
+ to call an additional method before loading the adapter:
+
+ ```py
+ pipeline = ... # load diffusers pipeline
+ max_rank = ... # the highest rank among all LoRAs that you want to load
+ # call *before* compiling and loading the LoRA adapter
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
+ pipeline.load_lora_weights(file_name)
+ # optionally compile the model now
+ ```
+
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
+ limitations to this technique, which are documented here:
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_unet(
+ state_dict,
+ network_alphas=network_alphas,
+ unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=getattr(self, self.text_encoder_name)
+ if not hasattr(self, "text_encoder")
+ else self.text_encoder,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ unet_config = kwargs.pop("unet_config", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ network_alphas = None
+ # TODO: replace it with a method from `state_dict_utils`
+ if all(
+ (
+ k.startswith("lora_te_")
+ or k.startswith("lora_unet_")
+ or k.startswith("lora_te1_")
+ or k.startswith("lora_te2_")
+ )
+ for k in state_dict.keys()
+ ):
+ # Map SDXL blocks correctly.
+ if unet_config is not None:
+ # use unet config to remap block numbers
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alphas
+
+ @classmethod
+ def load_lora_into_unet(
+ cls,
+ state_dict,
+ network_alphas,
+ unet,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
+ # their prefixes.
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_lora_adapter(
+ state_dict,
+ prefix=cls.unet_name,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (unet_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
+
+ if unet_lora_layers:
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["unet", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
+ """
+
+ _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
+ unet_name = UNET_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
+ loaded into `self.unet`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
+ dict is loaded into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_unet(
+ state_dict,
+ network_alphas=network_alphas,
+ unet=self.unet,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix=f"{self.text_encoder_name}_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ unet_config = kwargs.pop("unet_config", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ network_alphas = None
+ # TODO: replace it with a method from `state_dict_utils`
+ if all(
+ (
+ k.startswith("lora_te_")
+ or k.startswith("lora_unet_")
+ or k.startswith("lora_te1_")
+ or k.startswith("lora_te2_")
+ )
+ for k in state_dict.keys()
+ ):
+ # Map SDXL blocks correctly.
+ if unet_config is not None:
+ # use unet config to remap block numbers
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alphas
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
+ def load_lora_into_unet(
+ cls,
+ state_dict,
+ network_alphas,
+ unet,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
+ # their prefixes.
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_lora_adapter(
+ state_dict,
+ prefix=cls.unet_name,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
+ )
+
+ if unet_lora_layers:
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+
+ if text_encoder_2_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class SD3LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SD3Transformer2DModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
+
+ Specific to [`StableDiffusion3Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name=None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder_2,
+ prefix=f"{self.text_encoder_name}_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SD3Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
+ )
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+
+ if text_encoder_2_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class AuraFlowLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`AuraFlowTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class FluxLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`FluxTransformer2DModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+
+ Specific to [`StableDiffusion3Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ return_alphas: bool = False,
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
+ is_kohya = any(".lora_down.weight" in k for k in state_dict)
+ if is_kohya:
+ state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
+ # Kohya already takes care of scaling the LoRA parameters with alpha.
+ return (state_dict, None) if return_alphas else state_dict
+
+ is_xlabs = any("processor" in k for k in state_dict)
+ if is_xlabs:
+ state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
+ # xlabs doesn't use `alpha`.
+ return (state_dict, None) if return_alphas else state_dict
+
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
+ if is_bfl_control:
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
+ return (state_dict, None) if return_alphas else state_dict
+
+ # For state dicts like
+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
+ keys = list(state_dict.keys())
+ network_alphas = {}
+ for k in keys:
+ if "alpha" in k:
+ alpha_value = state_dict.get(k)
+ if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
+ alpha_value, float
+ ):
+ network_alphas[k] = state_dict.pop(k)
+ else:
+ raise ValueError(
+ f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
+ )
+
+ if return_alphas:
+ return state_dict, network_alphas
+ else:
+ return state_dict
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
+ )
+
+ has_lora_keys = any("lora" in key for key in state_dict.keys())
+
+ # Flux Control LoRAs also have norm keys
+ has_norm_keys = any(
+ norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
+ )
+
+ if not (has_lora_keys or has_norm_keys):
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ transformer_lora_state_dict = {
+ k: state_dict.get(k)
+ for k in list(state_dict.keys())
+ if k.startswith(f"{self.transformer_name}.") and "lora" in k
+ }
+ transformer_norm_state_dict = {
+ k: state_dict.pop(k)
+ for k in list(state_dict.keys())
+ if k.startswith(f"{self.transformer_name}.")
+ and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
+ }
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ has_param_with_expanded_shape = False
+ if len(transformer_lora_state_dict) > 0:
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
+ )
+
+ if has_param_with_expanded_shape:
+ logger.info(
+ "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
+ "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
+ "To get a comprehensive list of parameter names that were modified, enable debug logging."
+ )
+ if len(transformer_lora_state_dict) > 0:
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
+ )
+ for k in transformer_lora_state_dict:
+ state_dict.update({k: transformer_lora_state_dict[k]})
+
+ self.load_lora_into_transformer(
+ state_dict,
+ network_alphas=network_alphas,
+ transformer=transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ if len(transformer_norm_state_dict) > 0:
+ transformer._transformer_norm_layers = self._load_norm_into_transformer(
+ transformer_norm_state_dict,
+ transformer=transformer,
+ discard_original_layers=False,
+ )
+
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ network_alphas,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ transformer (`FluxTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def _load_norm_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ prefix=None,
+ discard_original_layers=False,
+ ) -> Dict[str, torch.Tensor]:
+ # Remove prefix if present
+ prefix = prefix or cls.transformer_name
+ for key in list(state_dict.keys()):
+ if key.split(".")[0] == prefix:
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+
+ # Find invalid keys
+ transformer_state_dict = transformer.state_dict()
+ transformer_keys = set(transformer_state_dict.keys())
+ state_dict_keys = set(state_dict.keys())
+ extra_keys = list(state_dict_keys - transformer_keys)
+
+ if extra_keys:
+ logger.warning(
+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
+ )
+
+ for key in extra_keys:
+ state_dict.pop(key)
+
+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
+ overwritten_layers_state_dict = {}
+ if not discard_original_layers:
+ for key in state_dict.keys():
+ overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
+
+ logger.info(
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
+ 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
+ "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
+ )
+
+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
+ incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+
+ # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
+ if unexpected_keys:
+ if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
+ raise ValueError(
+ f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
+ )
+
+ return overwritten_layers_state_dict
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if (
+ hasattr(transformer, "_transformer_norm_layers")
+ and isinstance(transformer._transformer_norm_layers, dict)
+ and len(transformer._transformer_norm_layers.keys()) > 0
+ ):
+ logger.info(
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
+ "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
+ )
+
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
+
+ super().unfuse_lora(components=components, **kwargs)
+
+ # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
+ def unload_lora_weights(self, reset_to_overwritten_params=False):
+ """
+ Unloads the LoRA parameters.
+
+ Args:
+ reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
+ to their original params. Refer to the [Flux
+ documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ super().unload_lora_weights()
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
+ transformer._transformer_norm_layers = None
+
+ if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
+ overwritten_params = transformer._overwritten_params
+ module_names = set()
+
+ for param_name in overwritten_params:
+ if param_name.endswith(".weight"):
+ module_names.add(param_name.replace(".weight", ""))
+
+ for name, module in transformer.named_modules():
+ if isinstance(module, torch.nn.Linear) and name in module_names:
+ module_weight = module.weight.data
+ module_bias = module.bias.data if module.bias is not None else None
+ bias = module_bias is not None
+
+ parent_module_name, _, current_module_name = name.rpartition(".")
+ parent_module = transformer.get_submodule(parent_module_name)
+
+ current_param_weight = overwritten_params[f"{name}.weight"]
+ in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
+ with torch.device("meta"):
+ original_module = torch.nn.Linear(
+ in_features,
+ out_features,
+ bias=bias,
+ dtype=module_weight.dtype,
+ )
+
+ tmp_state_dict = {"weight": current_param_weight}
+ if module_bias is not None:
+ tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
+ original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
+ setattr(parent_module, current_module_name, original_module)
+
+ del tmp_state_dict
+
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
+ new_value = int(current_param_weight.shape[1])
+ old_value = getattr(transformer.config, attribute_name)
+ setattr(transformer.config, attribute_name, new_value)
+ logger.info(
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
+ )
+
+ @classmethod
+ def _maybe_expand_transformer_param_shape_or_error_(
+ cls,
+ transformer: torch.nn.Module,
+ lora_state_dict=None,
+ norm_state_dict=None,
+ prefix=None,
+ ) -> bool:
+ """
+ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
+ """
+ state_dict = {}
+ if lora_state_dict is not None:
+ state_dict.update(lora_state_dict)
+ if norm_state_dict is not None:
+ state_dict.update(norm_state_dict)
+
+ # Remove prefix if present
+ prefix = prefix or cls.transformer_name
+ for key in list(state_dict.keys()):
+ if key.split(".")[0] == prefix:
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+
+ # Expand transformer parameter shapes if they don't match lora
+ has_param_with_shape_update = False
+ overwritten_params = {}
+
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ is_quantized = hasattr(transformer, "hf_quantizer")
+ for name, module in transformer.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ module_weight = module.weight.data
+ module_bias = module.bias.data if module.bias is not None else None
+ bias = module_bias is not None
+
+ lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
+ lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
+ lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
+ if lora_A_weight_name not in state_dict:
+ continue
+
+ in_features = state_dict[lora_A_weight_name].shape[1]
+ out_features = state_dict[lora_B_weight_name].shape[0]
+
+ # Model maybe loaded with different quantization schemes which may flatten the params.
+ # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
+ # preserve weight shape.
+ module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
+
+ # This means there's no need for an expansion in the params, so we simply skip.
+ if tuple(module_weight_shape) == (out_features, in_features):
+ continue
+
+ module_out_features, module_in_features = module_weight_shape
+ debug_message = ""
+ if in_features > module_in_features:
+ debug_message += (
+ f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
+ f"expanded from {module_in_features} to {in_features}"
+ )
+ if out_features > module_out_features:
+ debug_message += (
+ ", and the number of output features will be "
+ f"expanded from {module_out_features} to {out_features}."
+ )
+ else:
+ debug_message += "."
+ if debug_message:
+ logger.debug(debug_message)
+
+ if out_features > module_out_features or in_features > module_in_features:
+ has_param_with_shape_update = True
+ parent_module_name, _, current_module_name = name.rpartition(".")
+ parent_module = transformer.get_submodule(parent_module_name)
+
+ if is_quantized:
+ module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
+
+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
+ with torch.device("meta"):
+ expanded_module = torch.nn.Linear(
+ in_features, out_features, bias=bias, dtype=module_weight.dtype
+ )
+ # Only weights are expanded and biases are not. This is because only the input dimensions
+ # are changed while the output dimensions remain the same. The shape of the weight tensor
+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
+ # explains the reason why only weights are expanded.
+ new_weight = torch.zeros_like(
+ expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
+ )
+ slices = tuple(slice(0, dim) for dim in module_weight_shape)
+ new_weight[slices] = module_weight
+ tmp_state_dict = {"weight": new_weight}
+ if module_bias is not None:
+ tmp_state_dict["bias"] = module_bias
+ expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
+
+ setattr(parent_module, current_module_name, expanded_module)
+
+ del tmp_state_dict
+
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
+ new_value = int(expanded_module.weight.data.shape[1])
+ old_value = getattr(transformer.config, attribute_name)
+ setattr(transformer.config, attribute_name, new_value)
+ logger.info(
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
+ )
+
+ # For `unload_lora_weights()`.
+ # TODO: this could lead to more memory overhead if the number of overwritten params
+ # are large. Should be revisited later and tackled through a `discard_original_layers` arg.
+ overwritten_params[f"{current_module_name}.weight"] = module_weight
+ if module_bias is not None:
+ overwritten_params[f"{current_module_name}.bias"] = module_bias
+
+ if len(overwritten_params) > 0:
+ transformer._overwritten_params = overwritten_params
+
+ return has_param_with_shape_update
+
+ @classmethod
+ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
+ expanded_module_names = set()
+ transformer_state_dict = transformer.state_dict()
+ prefix = f"{cls.transformer_name}."
+
+ lora_module_names = [
+ key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
+ ]
+ lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
+ lora_module_names = sorted(set(lora_module_names))
+ transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
+ unexpected_modules = set(lora_module_names) - set(transformer_module_names)
+ if unexpected_modules:
+ logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
+
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ for k in lora_module_names:
+ if k in unexpected_modules:
+ continue
+
+ base_param_name = (
+ f"{k.replace(prefix, '')}.base_layer.weight"
+ if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
+ else f"{k.replace(prefix, '')}.weight"
+ )
+ base_weight_param = transformer_state_dict[base_param_name]
+ lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
+
+ # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
+ base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
+
+ if base_module_shape[1] > lora_A_param.shape[1]:
+ shape = (lora_A_param.shape[0], base_weight_param.shape[1])
+ expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
+ expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
+ lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
+ expanded_module_names.add(k)
+ elif base_module_shape[1] < lora_A_param.shape[1]:
+ raise NotImplementedError(
+ f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
+ )
+
+ if expanded_module_names:
+ logger.info(
+ f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
+ )
+
+ return lora_state_dict
+
+ @staticmethod
+ def _calculate_module_shape(
+ model: "torch.nn.Module",
+ base_module: "torch.nn.Linear" = None,
+ base_weight_param_name: str = None,
+ ) -> "torch.Size":
+ def _get_weight_shape(weight: torch.Tensor):
+ if weight.__class__.__name__ == "Params4bit":
+ return weight.quant_state.shape
+ elif weight.__class__.__name__ == "GGUFParameter":
+ return weight.quant_shape
+ else:
+ return weight.shape
+
+ if base_module is not None:
+ return _get_weight_shape(base_module.weight)
+ elif base_weight_param_name is not None:
+ if not base_weight_param_name.endswith(".weight"):
+ raise ValueError(
+ f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
+ )
+ module_path = base_weight_param_name.rsplit(".weight", 1)[0]
+ submodule = get_submodule_by_name(model, module_path)
+ return _get_weight_shape(submodule.weight)
+
+ raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
+
+
+# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
+# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
+class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ network_alphas,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ transformer (`UVit2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+
+class CogVideoXLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`CogVideoXTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class Mochi1LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`MochiTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class LTXVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`LTXVideoTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class SanaLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SanaTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading original format HunyuanVideo LoRA checkpoints.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
+ if is_original_hunyuan_video:
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`HunyuanVideoTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class Lumina2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ # conversion.
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
+ if non_diffusers:
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`Lumina2Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class WanLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+ elif any(k.startswith("lora_unet_") for k in state_dict):
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ @classmethod
+ def _maybe_expand_t2v_lora_for_i2v(
+ cls,
+ transformer: torch.nn.Module,
+ state_dict,
+ ):
+ if transformer.config.image_dim is None:
+ return state_dict
+
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
+
+ if is_i2v_lora:
+ return state_dict
+
+ for i in range(num_blocks):
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
+ )
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ state_dict=state_dict,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`WanTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class CogView4LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`CogView4Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`HiDreamImageTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+ deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index 280a9fa6e73f..0970ef86f109 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -12,924 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-import inspect
-import os
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Union
-import safetensors
-import torch
-import torch.nn as nn
-from huggingface_hub import model_info
-from huggingface_hub.constants import HF_HUB_OFFLINE
-
-from ..models.modeling_utils import ModelMixin, load_state_dict
-from ..utils import (
- USE_PEFT_BACKEND,
- _get_model_file,
- convert_state_dict_to_diffusers,
- convert_state_dict_to_peft,
- delete_adapter_layers,
- deprecate,
- get_adapter_name,
- get_peft_kwargs,
- is_accelerate_available,
- is_peft_available,
- is_peft_version,
- is_transformers_available,
- is_transformers_version,
- logging,
- recurse_remove_peft_layers,
- scale_lora_layers,
- set_adapter_layers,
- set_weights_and_activate_adapters,
-)
-
-
-if is_transformers_available():
- from transformers import PreTrainedModel
-
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
-
-if is_peft_available():
- from peft.tuners.tuners_utils import BaseTunerLayer
-
-if is_accelerate_available():
- from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
-
-logger = logging.get_logger(__name__)
-
-LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
-LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+from ..utils import deprecate
+from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
- """
- Fuses LoRAs for the text encoder.
-
- Args:
- text_encoder (`torch.nn.Module`):
- The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
- attribute.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]` or `str`):
- The names of the adapters to use.
- """
- merge_kwargs = {"safe_merge": safe_fusing}
+ from .lora.lora_base import fuse_text_encoder_lora
- for module in text_encoder.modules():
- if isinstance(module, BaseTunerLayer):
- if lora_scale != 1.0:
- module.scale_layer(lora_scale)
+ deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead."
+ deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message)
- # For BC with previous PEFT versions, we need to check the signature
- # of the `merge` method to see if it supports the `adapter_names` argument.
- supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
- if "adapter_names" in supported_merge_kwargs:
- merge_kwargs["adapter_names"] = adapter_names
- elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
- raise ValueError(
- "The `adapter_names` argument is not supported with your PEFT version. "
- "Please upgrade to the latest version of PEFT. `pip install -U peft`"
- )
-
- module.merge(**merge_kwargs)
+ return fuse_text_encoder_lora(
+ text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
def unfuse_text_encoder_lora(text_encoder):
- """
- Unfuses LoRAs for the text encoder.
-
- Args:
- text_encoder (`torch.nn.Module`):
- The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
- attribute.
- """
- for module in text_encoder.modules():
- if isinstance(module, BaseTunerLayer):
- module.unmerge()
-
-
-def set_adapters_for_text_encoder(
- adapter_names: Union[List[str], str],
- text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
- text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
-):
- """
- Sets the adapter layers for the text encoder.
-
- Args:
- adapter_names (`List[str]` or `str`):
- The names of the adapters to use.
- text_encoder (`torch.nn.Module`, *optional*):
- The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
- attribute.
- text_encoder_weights (`List[float]`, *optional*):
- The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
- """
- if text_encoder is None:
- raise ValueError(
- "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
- )
-
- def process_weights(adapter_names, weights):
- # Expand weights into a list, one entry per adapter
- # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
- if not isinstance(weights, list):
- weights = [weights] * len(adapter_names)
+ from .lora.lora_base import unfuse_text_encoder_lora
- if len(adapter_names) != len(weights):
- raise ValueError(
- f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
- )
+ deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead."
+ deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message)
- # Set None values to default of 1.0
- # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
- weights = [w if w is not None else 1.0 for w in weights]
+ return unfuse_text_encoder_lora(text_encoder)
- return weights
- adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
- text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
- set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
-
-
-def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
- """
- Disables the LoRA layers for the text encoder.
-
- Args:
- text_encoder (`torch.nn.Module`, *optional*):
- The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
- attribute.
- """
- if text_encoder is None:
- raise ValueError("Text Encoder not found.")
- set_adapter_layers(text_encoder, enabled=False)
-
-
-def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
- """
- Enables the LoRA layers for the text encoder.
-
- Args:
- text_encoder (`torch.nn.Module`, *optional*):
- The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
- attribute.
- """
- if text_encoder is None:
- raise ValueError("Text Encoder not found.")
- set_adapter_layers(text_encoder, enabled=True)
-
-
-def _remove_text_encoder_monkey_patch(text_encoder):
- recurse_remove_peft_layers(text_encoder)
- if getattr(text_encoder, "peft_config", None) is not None:
- del text_encoder.peft_config
- text_encoder._hf_peft_config_loaded = None
-
-
-def _fetch_state_dict(
- pretrained_model_name_or_path_or_dict,
- weight_name,
- use_safetensors,
- local_files_only,
- cache_dir,
- force_download,
- proxies,
- token,
- revision,
- subfolder,
- user_agent,
- allow_pickle,
-):
- model_file = None
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- # Let's first try to load .safetensors weights
- if (use_safetensors and weight_name is None) or (
- weight_name is not None and weight_name.endswith(".safetensors")
- ):
- try:
- # Here we're relaxing the loading check to enable more Inference API
- # friendliness where sometimes, it's not at all possible to automatically
- # determine `weight_name`.
- if weight_name is None:
- weight_name = _best_guess_weight_name(
- pretrained_model_name_or_path_or_dict,
- file_extension=".safetensors",
- local_files_only=local_files_only,
- )
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
- except (IOError, safetensors.SafetensorError) as e:
- if not allow_pickle:
- raise e
- # try loading non-safetensors weights
- model_file = None
- pass
-
- if model_file is None:
- if weight_name is None:
- weight_name = _best_guess_weight_name(
- pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
- )
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = load_state_dict(model_file)
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- return state_dict
-
-
-def _best_guess_weight_name(
- pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
+def set_adapters_for_text_encoder(
+ adapter_names,
+ text_encoder=None,
+ text_encoder_weights=None,
):
- if local_files_only or HF_HUB_OFFLINE:
- raise ValueError("When using the offline mode, you must specify a `weight_name`.")
-
- targeted_files = []
+ from .lora.lora_base import set_adapters_for_text_encoder
- if os.path.isfile(pretrained_model_name_or_path_or_dict):
- return
- elif os.path.isdir(pretrained_model_name_or_path_or_dict):
- targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
- else:
- files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
- targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
- if len(targeted_files) == 0:
- return
+ deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead."
+ deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message)
- # "scheduler" does not correspond to a LoRA checkpoint.
- # "optimizer" does not correspond to a LoRA checkpoint
- # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
- unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
- targeted_files = list(
- filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
+ return set_adapters_for_text_encoder(
+ adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights
)
- if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
- elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
-
- if len(targeted_files) > 1:
- raise ValueError(
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
- )
- weight_name = targeted_files[0]
- return weight_name
-
-
-def _load_lora_into_text_encoder(
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- text_encoder_name="text_encoder",
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
-):
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
- # their prefixes.
- prefix = text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
- raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
-
- # Load the layers corresponding to text encoder and make necessary adjustments.
- if prefix is not None:
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
-
- if len(state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- state_dict = convert_state_dict_to_diffusers(state_dict)
-
- # convert state dict
- state_dict = convert_state_dict_to_peft(state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in state_dict:
- continue
- rank[rank_key] = state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in state_dict:
- continue
- rank[rank_key] = state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
-
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
-
- if "lora_bias" in lora_config_kwargs:
- if lora_config_kwargs["lora_bias"]:
- if is_peft_version("<=", "0.13.2"):
- raise ValueError(
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<=", "0.13.2"):
- lora_config_kwargs.pop("lora_bias")
-
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
-
- if prefix is not None and not state_dict:
- logger.warning(
- f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
- "This is safe to ignore if LoRA state dict didn't originally have any "
- f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
- "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
- "https://github.com/huggingface/diffusers/issues/new"
- )
-
-
-def _func_optionally_disable_offloading(_pipeline):
- is_model_cpu_offload = False
- is_sequential_cpu_offload = False
-
- if _pipeline is not None and _pipeline.hf_device_map is None:
- for _, component in _pipeline.components.items():
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
- if not is_model_cpu_offload:
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
- if not is_sequential_cpu_offload:
- is_sequential_cpu_offload = (
- isinstance(component._hf_hook, AlignDevicesHook)
- or hasattr(component._hf_hook, "hooks")
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
- )
-
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
-
- return (is_model_cpu_offload, is_sequential_cpu_offload)
-
-
-class LoraBaseMixin:
- """Utility class for handling LoRAs."""
-
- _lora_loadable_modules = []
- num_fused_loras = 0
-
- def load_lora_weights(self, **kwargs):
- raise NotImplementedError("`load_lora_weights()` is not implemented.")
-
- @classmethod
- def save_lora_weights(cls, **kwargs):
- raise NotImplementedError("`save_lora_weights()` not implemented.")
-
- @classmethod
- def lora_state_dict(cls, **kwargs):
- raise NotImplementedError("`lora_state_dict()` is not implemented.")
-
- @classmethod
- def _optionally_disable_offloading(cls, _pipeline):
- """
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
-
- Args:
- _pipeline (`DiffusionPipeline`):
- The pipeline to disable offloading for.
-
- Returns:
- tuple:
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
- """
- return _func_optionally_disable_offloading(_pipeline=_pipeline)
-
- @classmethod
- def _fetch_state_dict(cls, *args, **kwargs):
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
- return _fetch_state_dict(*args, **kwargs)
-
- @classmethod
- def _best_guess_weight_name(cls, *args, **kwargs):
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
- return _best_guess_weight_name(*args, **kwargs)
-
- def unload_lora_weights(self):
- """
- Unloads the LoRA parameters.
-
- Examples:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
- >>> pipeline.unload_lora_weights()
- >>> ...
- ```
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None:
- if issubclass(model.__class__, ModelMixin):
- model.unload_lora()
- elif issubclass(model.__class__, PreTrainedModel):
- _remove_text_encoder_monkey_patch(model)
-
- def fuse_lora(
- self,
- components: List[str] = [],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- if "fuse_unet" in kwargs:
- depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
- deprecate(
- "fuse_unet",
- "1.0.0",
- depr_message,
- )
- if "fuse_transformer" in kwargs:
- depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
- deprecate(
- "fuse_transformer",
- "1.0.0",
- depr_message,
- )
- if "fuse_text_encoder" in kwargs:
- depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
- deprecate(
- "fuse_text_encoder",
- "1.0.0",
- depr_message,
- )
-
- if len(components) == 0:
- raise ValueError("`components` cannot be an empty list.")
-
- for fuse_component in components:
- if fuse_component not in self._lora_loadable_modules:
- raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
-
- model = getattr(self, fuse_component, None)
- if model is not None:
- # check if diffusers model
- if issubclass(model.__class__, ModelMixin):
- model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
- # handle transformers models.
- if issubclass(model.__class__, PreTrainedModel):
- fuse_text_encoder_lora(
- model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
- )
-
- self.num_fused_loras += 1
-
- def unfuse_lora(self, components: List[str] = [], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
- """
- if "unfuse_unet" in kwargs:
- depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
- deprecate(
- "unfuse_unet",
- "1.0.0",
- depr_message,
- )
- if "unfuse_transformer" in kwargs:
- depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
- deprecate(
- "unfuse_transformer",
- "1.0.0",
- depr_message,
- )
- if "unfuse_text_encoder" in kwargs:
- depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
- deprecate(
- "unfuse_text_encoder",
- "1.0.0",
- depr_message,
- )
-
- if len(components) == 0:
- raise ValueError("`components` cannot be an empty list.")
-
- for fuse_component in components:
- if fuse_component not in self._lora_loadable_modules:
- raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
-
- model = getattr(self, fuse_component, None)
- if model is not None:
- if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- module.unmerge()
-
- self.num_fused_loras -= 1
-
- def set_adapters(
- self,
- adapter_names: Union[List[str], str],
- adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
- ):
- if isinstance(adapter_weights, dict):
- components_passed = set(adapter_weights.keys())
- lora_components = set(self._lora_loadable_modules)
-
- invalid_components = sorted(components_passed - lora_components)
- if invalid_components:
- logger.warning(
- f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
- f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
- "to the invalid components will be removed and ignored."
- )
- adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
-
- adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
- adapter_weights = copy.deepcopy(adapter_weights)
-
- # Expand weights into a list, one entry per adapter
- if not isinstance(adapter_weights, list):
- adapter_weights = [adapter_weights] * len(adapter_names)
-
- if len(adapter_names) != len(adapter_weights):
- raise ValueError(
- f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
- )
-
- list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
- # eg ["adapter1", "adapter2"]
- all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
- missing_adapters = set(adapter_names) - all_adapters
- if len(missing_adapters) > 0:
- raise ValueError(
- f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
- )
-
- # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
- invert_list_adapters = {
- adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
- for adapter in all_adapters
- }
-
- # Decompose weights into weights for denoiser and text encoders.
- _component_adapter_weights = {}
- for component in self._lora_loadable_modules:
- model = getattr(self, component)
-
- for adapter_name, weights in zip(adapter_names, adapter_weights):
- if isinstance(weights, dict):
- component_adapter_weights = weights.pop(component, None)
- if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
- logger.warning(
- (
- f"Lora weight dict for adapter '{adapter_name}' contains {component},"
- f"but this will be ignored because {adapter_name} does not contain weights for {component}."
- f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
- )
- )
-
- else:
- component_adapter_weights = weights
-
- _component_adapter_weights.setdefault(component, [])
- _component_adapter_weights[component].append(component_adapter_weights)
-
- if issubclass(model.__class__, ModelMixin):
- model.set_adapters(adapter_names, _component_adapter_weights[component])
- elif issubclass(model.__class__, PreTrainedModel):
- set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
-
- def disable_lora(self):
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None:
- if issubclass(model.__class__, ModelMixin):
- model.disable_lora()
- elif issubclass(model.__class__, PreTrainedModel):
- disable_lora_for_text_encoder(model)
-
- def enable_lora(self):
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None:
- if issubclass(model.__class__, ModelMixin):
- model.enable_lora()
- elif issubclass(model.__class__, PreTrainedModel):
- enable_lora_for_text_encoder(model)
-
- def delete_adapters(self, adapter_names: Union[List[str], str]):
- """
- Args:
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
- adapter_names (`Union[List[str], str]`):
- The names of the adapter to delete. Can be a single string or a list of strings
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- if isinstance(adapter_names, str):
- adapter_names = [adapter_names]
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None:
- if issubclass(model.__class__, ModelMixin):
- model.delete_adapters(adapter_names)
- elif issubclass(model.__class__, PreTrainedModel):
- for adapter_name in adapter_names:
- delete_adapter_layers(model, adapter_name)
-
- def get_active_adapters(self) -> List[str]:
- """
- Gets the list of the current active adapters.
-
- Example:
-
- ```python
- from diffusers import DiffusionPipeline
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- ).to("cuda")
- pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
- pipeline.get_active_adapters()
- ```
- """
- if not USE_PEFT_BACKEND:
- raise ValueError(
- "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
- )
-
- active_adapters = []
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None and issubclass(model.__class__, ModelMixin):
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- active_adapters = module.active_adapters
- break
-
- return active_adapters
-
- def get_list_adapters(self) -> Dict[str, List[str]]:
- """
- Gets the current list of all available adapters in the pipeline.
- """
- if not USE_PEFT_BACKEND:
- raise ValueError(
- "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
- )
-
- set_adapters = {}
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if (
- model is not None
- and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
- and hasattr(model, "peft_config")
- ):
- set_adapters[component] = list(model.peft_config.keys())
-
- return set_adapters
-
- def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
- """
- Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
- you want to load multiple adapters and free some GPU memory.
-
- Args:
- adapter_names (`List[str]`):
- List of adapters to send device to.
- device (`Union[torch.device, str, int]`):
- Device to send the adapters to. Can be either a torch device, a str or an integer.
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- for component in self._lora_loadable_modules:
- model = getattr(self, component, None)
- if model is not None:
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- for adapter_name in adapter_names:
- module.lora_A[adapter_name].to(device)
- module.lora_B[adapter_name].to(device)
- # this is a param, not a module, so device placement is not in-place -> re-assign
- if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
- if adapter_name in module.lora_magnitude_vector:
- module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
- adapter_name
- ].to(device)
-
- @staticmethod
- def pack_weights(layers, prefix):
- layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
- layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
- return layers_state_dict
-
- @staticmethod
- def write_lora_layers(
- state_dict: Dict[str, torch.Tensor],
- save_directory: str,
- is_main_process: bool,
- weight_name: str,
- save_function: Callable,
- safe_serialization: bool,
- ):
- if os.path.isfile(save_directory):
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
- return
-
- if save_function is None:
- if safe_serialization:
- def save_function(weights, filename):
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+def disable_lora_for_text_encoder(text_encoder=None):
+ from .lora.lora_base import disable_lora_for_text_encoder
- else:
- save_function = torch.save
+ deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead."
+ deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message)
- os.makedirs(save_directory, exist_ok=True)
+ return disable_lora_for_text_encoder(text_encoder=text_encoder)
- if weight_name is None:
- if safe_serialization:
- weight_name = LORA_WEIGHT_NAME_SAFE
- else:
- weight_name = LORA_WEIGHT_NAME
- save_path = Path(save_directory, weight_name).as_posix()
- save_function(state_dict, save_path)
- logger.info(f"Model weights saved in {save_path}")
+def enable_lora_for_text_encoder(text_encoder=None):
+ from .lora.lora_base import enable_lora_for_text_encoder
- @property
- def lora_scale(self) -> float:
- # property function that returns the lora scale which can be set at run time by the pipeline.
- # if _lora_scale has not been set, return 1
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+ deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead."
+ deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message)
- def enable_lora_hotswap(self, **kwargs) -> None:
- """Enables the possibility to hotswap LoRA adapters.
+ return enable_lora_for_text_encoder(text_encoder=text_encoder)
- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
- the loaded adapters differ.
- Args:
- target_rank (`int`):
- The highest rank among all the adapters that will be loaded.
- check_compiled (`str`, *optional*, defaults to `"error"`):
- How to handle the case when the model is already compiled, which should generally be avoided. The
- options are:
- - "error" (default): raise an error
- - "warn": issue a warning
- - "ignore": do nothing
- """
- for key, component in self.components.items():
- if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
- component.enable_lora_hotswap(**kwargs)
+class LoraBaseMixin(LoraBaseMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead."
+ deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 50a99cee1d23..1494e6e0b82a 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -12,5672 +12,133 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-from typing import Callable, Dict, List, Optional, Union
-import torch
-from huggingface_hub.utils import validate_hf_hub_args
-
-from ..utils import (
- USE_PEFT_BACKEND,
- deprecate,
- get_submodule_by_name,
- is_bitsandbytes_available,
- is_gguf_available,
- is_peft_available,
- is_peft_version,
- is_torch_version,
- is_transformers_available,
- is_transformers_version,
- logging,
-)
-from .lora_base import ( # noqa
- LORA_WEIGHT_NAME,
- LORA_WEIGHT_NAME_SAFE,
- LoraBaseMixin,
- _fetch_state_dict,
- _load_lora_into_text_encoder,
+from ..utils import deprecate
+from .lora.lora_pipeline import (
+ TEXT_ENCODER_NAME, # noqa: F401
+ TRANSFORMER_NAME, # noqa: F401
+ UNET_NAME, # noqa: F401
+ AmusedLoraLoaderMixin,
+ AuraFlowLoraLoaderMixin,
+ CogVideoXLoraLoaderMixin,
+ CogView4LoraLoaderMixin,
+ FluxLoraLoaderMixin,
+ HiDreamImageLoraLoaderMixin,
+ HunyuanVideoLoraLoaderMixin,
+ LTXVideoLoraLoaderMixin,
+ Lumina2LoraLoaderMixin,
+ Mochi1LoraLoaderMixin,
+ SanaLoraLoaderMixin,
+ SD3LoraLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ WanLoraLoaderMixin,
)
-from .lora_conversion_utils import (
- _convert_bfl_flux_control_lora_to_diffusers,
- _convert_hunyuan_video_lora_to_diffusers,
- _convert_kohya_flux_lora_to_diffusers,
- _convert_musubi_wan_lora_to_diffusers,
- _convert_non_diffusers_lora_to_diffusers,
- _convert_non_diffusers_lumina2_lora_to_diffusers,
- _convert_non_diffusers_wan_lora_to_diffusers,
- _convert_xlabs_flux_lora_to_diffusers,
- _maybe_map_sgm_blocks_to_diffusers,
-)
-
-
-_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
-if is_torch_version(">=", "1.9.0"):
- if (
- is_peft_available()
- and is_peft_version(">=", "0.13.1")
- and is_transformers_available()
- and is_transformers_version(">", "4.45.2")
- ):
- _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
-
-
-logger = logging.get_logger(__name__)
-
-TEXT_ENCODER_NAME = "text_encoder"
-UNET_NAME = "unet"
-TRANSFORMER_NAME = "transformer"
-
-_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
-
-
-def _maybe_dequantize_weight_for_expanded_lora(model, module):
- if is_bitsandbytes_available():
- from ..quantizers.bitsandbytes import dequantize_bnb_weight
-
- if is_gguf_available():
- from ..quantizers.gguf.utils import dequantize_gguf_tensor
-
- is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
- is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
-
- if is_bnb_4bit_quantized and not is_bitsandbytes_available():
- raise ValueError(
- "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
- )
- if is_gguf_quantized and not is_gguf_available():
- raise ValueError(
- "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
- )
-
- weight_on_cpu = False
- if module.weight.device.type == "cpu":
- weight_on_cpu = True
-
- device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
- if is_bnb_4bit_quantized:
- module_weight = dequantize_bnb_weight(
- module.weight.to(device) if weight_on_cpu else module.weight,
- state=module.weight.quant_state,
- dtype=model.dtype,
- ).data
- elif is_gguf_quantized:
- module_weight = dequantize_gguf_tensor(
- module.weight.to(device) if weight_on_cpu else module.weight,
- )
- module_weight = module_weight.to(model.dtype)
- else:
- module_weight = module.weight.data
-
- if weight_on_cpu:
- module_weight = module_weight.cpu()
-
- return module_weight
-
-
-class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
- """
-
- _lora_loadable_modules = ["unet", "text_encoder"]
- unet_name = UNET_NAME
- text_encoder_name = TEXT_ENCODER_NAME
-
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
- loaded into `self.unet`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
- dict is loaded into `self.text_encoder`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
- in-place. This means that, instead of loading an additional adapter, this will take the existing
- adapter weights and replace them with the weights of the new adapter. This can be faster and more
- memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
- torch.compile, loading the new adapter does not require recompilation of the model. When using
- hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
-
- If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
- to call an additional method before loading the adapter:
-
- ```py
- pipeline = ... # load diffusers pipeline
- max_rank = ... # the highest rank among all LoRAs that you want to load
- # call *before* compiling and loading the LoRA adapter
- pipeline.enable_lora_hotswap(target_rank=max_rank)
- pipeline.load_lora_weights(file_name)
- # optionally compile the model now
- ```
-
- Note that hotswapping adapters of the text encoder is not yet supported. There are some further
- limitations to this technique, which are documented here:
- https://huggingface.co/docs/peft/main/en/package_reference/hotswap
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_unet(
- state_dict,
- network_alphas=network_alphas,
- unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=network_alphas,
- text_encoder=getattr(self, self.text_encoder_name)
- if not hasattr(self, "text_encoder")
- else self.text_encoder,
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- weight_name (`str`, *optional*, defaults to None):
- Name of the serialized state dict file.
- """
- # Load the main state dict first which has the LoRA layers for either of
- # UNet and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- unet_config = kwargs.pop("unet_config", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- network_alphas = None
- # TODO: replace it with a method from `state_dict_utils`
- if all(
- (
- k.startswith("lora_te_")
- or k.startswith("lora_unet_")
- or k.startswith("lora_te1_")
- or k.startswith("lora_te2_")
- )
- for k in state_dict.keys()
- ):
- # Map SDXL blocks correctly.
- if unet_config is not None:
- # use unet config to remap block numbers
- state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
- state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
-
- return state_dict, network_alphas
-
- @classmethod
- def load_lora_into_unet(
- cls,
- state_dict,
- network_alphas,
- unet,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `unet`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- unet (`UNet2DConditionModel`):
- The UNet model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
- # their prefixes.
- logger.info(f"Loading {cls.unet_name}.")
- unet.load_lora_adapter(
- state_dict,
- prefix=cls.unet_name,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def load_lora_into_text_encoder(
- cls,
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between unet lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- text_encoder (`CLIPTextModel`):
- The text encoder model to load the LoRA layers into.
- prefix (`str`):
- Expected prefix of the `text_encoder` in the `state_dict`.
- lora_scale (`float`):
- How much to scale the output of the lora linear layer before it is added with the output of the regular
- lora layer.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- _load_lora_into_text_encoder(
- state_dict=state_dict,
- network_alphas=network_alphas,
- lora_scale=lora_scale,
- text_encoder=text_encoder,
- prefix=prefix,
- text_encoder_name=cls.text_encoder_name,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
-
- if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
-
- if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- def fuse_lora(
- self,
- components: List[str] = ["unet", "text_encoder"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
- [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
- """
-
- _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
- unet_name = UNET_NAME
- text_encoder_name = TEXT_ENCODER_NAME
-
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
- loaded into `self.unet`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
- dict is loaded into `self.text_encoder`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # We could have accessed the unet config from `lora_state_dict()` too. We pass
- # it here explicitly to be able to tell that it's coming from an SDXL
- # pipeline.
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(
- pretrained_model_name_or_path_or_dict,
- unet_config=self.unet.config,
- **kwargs,
- )
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_unet(
- state_dict,
- network_alphas=network_alphas,
- unet=self.unet,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder,
- prefix=self.text_encoder_name,
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder_2,
- prefix=f"{self.text_encoder_name}_2",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- weight_name (`str`, *optional*, defaults to None):
- Name of the serialized state dict file.
- """
- # Load the main state dict first which has the LoRA layers for either of
- # UNet and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- unet_config = kwargs.pop("unet_config", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- network_alphas = None
- # TODO: replace it with a method from `state_dict_utils`
- if all(
- (
- k.startswith("lora_te_")
- or k.startswith("lora_unet_")
- or k.startswith("lora_te1_")
- or k.startswith("lora_te2_")
- )
- for k in state_dict.keys()
- ):
- # Map SDXL blocks correctly.
- if unet_config is not None:
- # use unet config to remap block numbers
- state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
- state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
-
- return state_dict, network_alphas
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
- def load_lora_into_unet(
- cls,
- state_dict,
- network_alphas,
- unet,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `unet`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- unet (`UNet2DConditionModel`):
- The UNet model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
- # their prefixes.
- logger.info(f"Loading {cls.unet_name}.")
- unet.load_lora_adapter(
- state_dict,
- prefix=cls.unet_name,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
- def load_lora_into_text_encoder(
- cls,
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between unet lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- text_encoder (`CLIPTextModel`):
- The text encoder model to load the LoRA layers into.
- prefix (`str`):
- Expected prefix of the `text_encoder` in the `state_dict`.
- lora_scale (`float`):
- How much to scale the output of the lora linear layer before it is added with the output of the regular
- lora layer.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- _load_lora_into_text_encoder(
- state_dict=state_dict,
- network_alphas=network_alphas,
- lora_scale=lora_scale,
- text_encoder=text_encoder,
- prefix=prefix,
- text_encoder_name=cls.text_encoder_name,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
-
- if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
-
- if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
-
- if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
-
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- def fuse_lora(
- self,
- components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class SD3LoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`SD3Transformer2DModel`],
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
- [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
-
- Specific to [`StableDiffusion3Pipeline`].
- """
-
- _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
- transformer_name = TRANSFORMER_NAME
- text_encoder_name = TEXT_ENCODER_NAME
-
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name=None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=None,
- text_encoder=self.text_encoder,
- prefix=self.text_encoder_name,
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=None,
- text_encoder=self.text_encoder_2,
- prefix=f"{self.text_encoder_name}_2",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SD3Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
- def load_lora_into_text_encoder(
- cls,
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between unet lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- text_encoder (`CLIPTextModel`):
- The text encoder model to load the LoRA layers into.
- prefix (`str`):
- Expected prefix of the `text_encoder` in the `state_dict`.
- lora_scale (`float`):
- How much to scale the output of the lora linear layer before it is added with the output of the regular
- lora layer.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- _load_lora_into_text_encoder(
- state_dict=state_dict,
- network_alphas=network_alphas,
- lora_scale=lora_scale,
- text_encoder=text_encoder,
- prefix=prefix,
- text_encoder_name=cls.text_encoder_name,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
-
- if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
-
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
- def fuse_lora(
- self,
- components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class AuraFlowLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`AuraFlowTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
+class StableDiffusionLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `StableDiffusionLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
-class FluxLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`FluxTransformer2DModel`],
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
- Specific to [`StableDiffusion3Pipeline`].
- """
+class StableDiffusionXLLoraLoaderMixin(StableDiffusionXLLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `StableDiffusionXLLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionXLLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- _lora_loadable_modules = ["transformer", "text_encoder"]
- transformer_name = TRANSFORMER_NAME
- text_encoder_name = TEXT_ENCODER_NAME
- _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- return_alphas: bool = False,
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
+class SD3LoraLoaderMixin(SD3LoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SD3LoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+class AuraFlowLoraLoaderMixin(AuraFlowLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `AuraFlowLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AuraFlowLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.AuraFlowLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- This function is experimental and might change in the future.
-
+class FluxLoraLoaderMixin(FluxLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import FluxLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+class AmusedLoraLoaderMixin(AmusedLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `AmusedLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AmusedLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.AmusedLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
+class CogVideoXLoraLoaderMixin(CogVideoXLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `CogVideoXLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogVideoXLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
+class Mochi1LoraLoaderMixin(Mochi1LoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `Mochi1LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Mochi1LoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+class LTXVideoLoraLoaderMixin(LTXVideoLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `LTXVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import LTXVideoLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.LTXVideoLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
- is_kohya = any(".lora_down.weight" in k for k in state_dict)
- if is_kohya:
- state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
- # Kohya already takes care of scaling the LoRA parameters with alpha.
- return (state_dict, None) if return_alphas else state_dict
- is_xlabs = any("processor" in k for k in state_dict)
- if is_xlabs:
- state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
- # xlabs doesn't use `alpha`.
- return (state_dict, None) if return_alphas else state_dict
+class SanaLoraLoaderMixin(SanaLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SanaLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SanaLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- is_bfl_control = any("query_norm.scale" in k for k in state_dict)
- if is_bfl_control:
- state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
- return (state_dict, None) if return_alphas else state_dict
- # For state dicts like
- # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
- keys = list(state_dict.keys())
- network_alphas = {}
- for k in keys:
- if "alpha" in k:
- alpha_value = state_dict.get(k)
- if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
- alpha_value, float
- ):
- network_alphas[k] = state_dict.pop(k)
- else:
- raise ValueError(
- f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
- )
+class HunyuanVideoLoraLoaderMixin(HunyuanVideoLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `HunyuanVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import HunyuanVideoLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- if return_alphas:
- return state_dict, network_alphas
- else:
- return state_dict
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`.
+class Lumina2LoraLoaderMixin(Lumina2LoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `Lumina2LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Lumina2LoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.Lumina2LoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- All kwargs are forwarded to `self.lora_state_dict`.
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
+class WanLoraLoaderMixin(WanLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `WanLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import WanLoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.WanLoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
+class CogView4LoraLoaderMixin(CogView4LoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `CogView4LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogView4LoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.CogView4LoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(
- pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
- )
-
- has_lora_keys = any("lora" in key for key in state_dict.keys())
-
- # Flux Control LoRAs also have norm keys
- has_norm_keys = any(
- norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
- )
-
- if not (has_lora_keys or has_norm_keys):
- raise ValueError("Invalid LoRA checkpoint.")
-
- transformer_lora_state_dict = {
- k: state_dict.get(k)
- for k in list(state_dict.keys())
- if k.startswith(f"{self.transformer_name}.") and "lora" in k
- }
- transformer_norm_state_dict = {
- k: state_dict.pop(k)
- for k in list(state_dict.keys())
- if k.startswith(f"{self.transformer_name}.")
- and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
- }
-
- transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
- has_param_with_expanded_shape = False
- if len(transformer_lora_state_dict) > 0:
- has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
- transformer, transformer_lora_state_dict, transformer_norm_state_dict
- )
-
- if has_param_with_expanded_shape:
- logger.info(
- "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
- "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
- "To get a comprehensive list of parameter names that were modified, enable debug logging."
- )
- if len(transformer_lora_state_dict) > 0:
- transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
- transformer=transformer, lora_state_dict=transformer_lora_state_dict
- )
- for k in transformer_lora_state_dict:
- state_dict.update({k: transformer_lora_state_dict[k]})
-
- self.load_lora_into_transformer(
- state_dict,
- network_alphas=network_alphas,
- transformer=transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- if len(transformer_norm_state_dict) > 0:
- transformer._transformer_norm_layers = self._load_norm_into_transformer(
- transformer_norm_state_dict,
- transformer=transformer,
- discard_original_layers=False,
- )
-
- self.load_lora_into_text_encoder(
- state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder,
- prefix=self.text_encoder_name,
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def load_lora_into_transformer(
- cls,
- state_dict,
- network_alphas,
- transformer,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`FluxTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def _load_norm_into_transformer(
- cls,
- state_dict,
- transformer,
- prefix=None,
- discard_original_layers=False,
- ) -> Dict[str, torch.Tensor]:
- # Remove prefix if present
- prefix = prefix or cls.transformer_name
- for key in list(state_dict.keys()):
- if key.split(".")[0] == prefix:
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
-
- # Find invalid keys
- transformer_state_dict = transformer.state_dict()
- transformer_keys = set(transformer_state_dict.keys())
- state_dict_keys = set(state_dict.keys())
- extra_keys = list(state_dict_keys - transformer_keys)
-
- if extra_keys:
- logger.warning(
- f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
- )
-
- for key in extra_keys:
- state_dict.pop(key)
-
- # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
- overwritten_layers_state_dict = {}
- if not discard_original_layers:
- for key in state_dict.keys():
- overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
-
- logger.info(
- "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
- 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
- "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
- "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
- )
-
- # We can't load with strict=True because the current state_dict does not contain all the transformer keys
- incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
-
- # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
- if unexpected_keys:
- if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
- raise ValueError(
- f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
- )
-
- return overwritten_layers_state_dict
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
- def load_lora_into_text_encoder(
- cls,
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between unet lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- text_encoder (`CLIPTextModel`):
- The text encoder model to load the LoRA layers into.
- prefix (`str`):
- Expected prefix of the `text_encoder` in the `state_dict`.
- lora_scale (`float`):
- How much to scale the output of the lora linear layer before it is added with the output of the regular
- lora layer.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- _load_lora_into_text_encoder(
- state_dict=state_dict,
- network_alphas=network_alphas,
- lora_scale=lora_scale,
- text_encoder=text_encoder,
- prefix=prefix,
- text_encoder_name=cls.text_encoder_name,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
-
- transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
- if (
- hasattr(transformer, "_transformer_norm_layers")
- and isinstance(transformer._transformer_norm_layers, dict)
- and len(transformer._transformer_norm_layers.keys()) > 0
- ):
- logger.info(
- "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
- "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
- "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
- )
-
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- """
- transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
- if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
- transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
-
- super().unfuse_lora(components=components, **kwargs)
-
- # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
- def unload_lora_weights(self, reset_to_overwritten_params=False):
- """
- Unloads the LoRA parameters.
-
- Args:
- reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
- to their original params. Refer to the [Flux
- documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
-
- Examples:
-
- ```python
- >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
- >>> pipeline.unload_lora_weights()
- >>> ...
- ```
- """
- super().unload_lora_weights()
-
- transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
- if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
- transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
- transformer._transformer_norm_layers = None
-
- if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
- overwritten_params = transformer._overwritten_params
- module_names = set()
-
- for param_name in overwritten_params:
- if param_name.endswith(".weight"):
- module_names.add(param_name.replace(".weight", ""))
-
- for name, module in transformer.named_modules():
- if isinstance(module, torch.nn.Linear) and name in module_names:
- module_weight = module.weight.data
- module_bias = module.bias.data if module.bias is not None else None
- bias = module_bias is not None
-
- parent_module_name, _, current_module_name = name.rpartition(".")
- parent_module = transformer.get_submodule(parent_module_name)
-
- current_param_weight = overwritten_params[f"{name}.weight"]
- in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
- with torch.device("meta"):
- original_module = torch.nn.Linear(
- in_features,
- out_features,
- bias=bias,
- dtype=module_weight.dtype,
- )
-
- tmp_state_dict = {"weight": current_param_weight}
- if module_bias is not None:
- tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
- original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
- setattr(parent_module, current_module_name, original_module)
-
- del tmp_state_dict
-
- if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
- attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
- new_value = int(current_param_weight.shape[1])
- old_value = getattr(transformer.config, attribute_name)
- setattr(transformer.config, attribute_name, new_value)
- logger.info(
- f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
- )
-
- @classmethod
- def _maybe_expand_transformer_param_shape_or_error_(
- cls,
- transformer: torch.nn.Module,
- lora_state_dict=None,
- norm_state_dict=None,
- prefix=None,
- ) -> bool:
- """
- Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
- generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
- """
- state_dict = {}
- if lora_state_dict is not None:
- state_dict.update(lora_state_dict)
- if norm_state_dict is not None:
- state_dict.update(norm_state_dict)
-
- # Remove prefix if present
- prefix = prefix or cls.transformer_name
- for key in list(state_dict.keys()):
- if key.split(".")[0] == prefix:
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
-
- # Expand transformer parameter shapes if they don't match lora
- has_param_with_shape_update = False
- overwritten_params = {}
-
- is_peft_loaded = getattr(transformer, "peft_config", None) is not None
- is_quantized = hasattr(transformer, "hf_quantizer")
- for name, module in transformer.named_modules():
- if isinstance(module, torch.nn.Linear):
- module_weight = module.weight.data
- module_bias = module.bias.data if module.bias is not None else None
- bias = module_bias is not None
-
- lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
- lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
- lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
- if lora_A_weight_name not in state_dict:
- continue
-
- in_features = state_dict[lora_A_weight_name].shape[1]
- out_features = state_dict[lora_B_weight_name].shape[0]
-
- # Model maybe loaded with different quantization schemes which may flatten the params.
- # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
- # preserve weight shape.
- module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
-
- # This means there's no need for an expansion in the params, so we simply skip.
- if tuple(module_weight_shape) == (out_features, in_features):
- continue
-
- module_out_features, module_in_features = module_weight_shape
- debug_message = ""
- if in_features > module_in_features:
- debug_message += (
- f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
- f"checkpoint contains higher number of features than expected. The number of input_features will be "
- f"expanded from {module_in_features} to {in_features}"
- )
- if out_features > module_out_features:
- debug_message += (
- ", and the number of output features will be "
- f"expanded from {module_out_features} to {out_features}."
- )
- else:
- debug_message += "."
- if debug_message:
- logger.debug(debug_message)
-
- if out_features > module_out_features or in_features > module_in_features:
- has_param_with_shape_update = True
- parent_module_name, _, current_module_name = name.rpartition(".")
- parent_module = transformer.get_submodule(parent_module_name)
-
- if is_quantized:
- module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
-
- # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
- with torch.device("meta"):
- expanded_module = torch.nn.Linear(
- in_features, out_features, bias=bias, dtype=module_weight.dtype
- )
- # Only weights are expanded and biases are not. This is because only the input dimensions
- # are changed while the output dimensions remain the same. The shape of the weight tensor
- # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
- # explains the reason why only weights are expanded.
- new_weight = torch.zeros_like(
- expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
- )
- slices = tuple(slice(0, dim) for dim in module_weight_shape)
- new_weight[slices] = module_weight
- tmp_state_dict = {"weight": new_weight}
- if module_bias is not None:
- tmp_state_dict["bias"] = module_bias
- expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
-
- setattr(parent_module, current_module_name, expanded_module)
-
- del tmp_state_dict
-
- if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
- attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
- new_value = int(expanded_module.weight.data.shape[1])
- old_value = getattr(transformer.config, attribute_name)
- setattr(transformer.config, attribute_name, new_value)
- logger.info(
- f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
- )
-
- # For `unload_lora_weights()`.
- # TODO: this could lead to more memory overhead if the number of overwritten params
- # are large. Should be revisited later and tackled through a `discard_original_layers` arg.
- overwritten_params[f"{current_module_name}.weight"] = module_weight
- if module_bias is not None:
- overwritten_params[f"{current_module_name}.bias"] = module_bias
-
- if len(overwritten_params) > 0:
- transformer._overwritten_params = overwritten_params
-
- return has_param_with_shape_update
-
- @classmethod
- def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
- expanded_module_names = set()
- transformer_state_dict = transformer.state_dict()
- prefix = f"{cls.transformer_name}."
-
- lora_module_names = [
- key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
- ]
- lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
- lora_module_names = sorted(set(lora_module_names))
- transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
- unexpected_modules = set(lora_module_names) - set(transformer_module_names)
- if unexpected_modules:
- logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
-
- is_peft_loaded = getattr(transformer, "peft_config", None) is not None
- for k in lora_module_names:
- if k in unexpected_modules:
- continue
-
- base_param_name = (
- f"{k.replace(prefix, '')}.base_layer.weight"
- if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
- else f"{k.replace(prefix, '')}.weight"
- )
- base_weight_param = transformer_state_dict[base_param_name]
- lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
-
- # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
- base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
-
- if base_module_shape[1] > lora_A_param.shape[1]:
- shape = (lora_A_param.shape[0], base_weight_param.shape[1])
- expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
- expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
- lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
- expanded_module_names.add(k)
- elif base_module_shape[1] < lora_A_param.shape[1]:
- raise NotImplementedError(
- f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
- )
-
- if expanded_module_names:
- logger.info(
- f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
- )
-
- return lora_state_dict
-
- @staticmethod
- def _calculate_module_shape(
- model: "torch.nn.Module",
- base_module: "torch.nn.Linear" = None,
- base_weight_param_name: str = None,
- ) -> "torch.Size":
- def _get_weight_shape(weight: torch.Tensor):
- if weight.__class__.__name__ == "Params4bit":
- return weight.quant_state.shape
- elif weight.__class__.__name__ == "GGUFParameter":
- return weight.quant_shape
- else:
- return weight.shape
-
- if base_module is not None:
- return _get_weight_shape(base_module.weight)
- elif base_weight_param_name is not None:
- if not base_weight_param_name.endswith(".weight"):
- raise ValueError(
- f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
- )
- module_path = base_weight_param_name.rsplit(".weight", 1)[0]
- submodule = get_submodule_by_name(model, module_path)
- return _get_weight_shape(submodule.weight)
-
- raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
-
-
-# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
-# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
-class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
- _lora_loadable_modules = ["transformer", "text_encoder"]
- transformer_name = TRANSFORMER_NAME
- text_encoder_name = TEXT_ENCODER_NAME
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
- def load_lora_into_transformer(
- cls,
- state_dict,
- network_alphas,
- transformer,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`UVit2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
- def load_lora_into_text_encoder(
- cls,
- state_dict,
- network_alphas,
- text_encoder,
- prefix=None,
- lora_scale=1.0,
- adapter_name=None,
- _pipeline=None,
- low_cpu_mem_usage=False,
- hotswap: bool = False,
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `text_encoder`
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The key should be prefixed with an
- additional `text_encoder` to distinguish between unet lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- text_encoder (`CLIPTextModel`):
- The text encoder model to load the LoRA layers into.
- prefix (`str`):
- Expected prefix of the `text_encoder` in the `state_dict`.
- lora_scale (`float`):
- How much to scale the output of the lora linear layer before it is added with the output of the regular
- lora layer.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- _load_lora_into_text_encoder(
- state_dict=state_dict,
- network_alphas=network_alphas,
- lora_scale=lora_scale,
- text_encoder=text_encoder,
- prefix=prefix,
- text_encoder_name=cls.text_encoder_name,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
- transformer_lora_layers: Dict[str, torch.nn.Module] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
-
-class CogVideoXLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogVideoXTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class Mochi1LoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`MochiTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class LTXVideoLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`LTXVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class SanaLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SanaTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading original format HunyuanVideo LoRA checkpoints.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
- if is_original_hunyuan_video:
- state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HunyuanVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class Lumina2LoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- # conversion.
- non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
- if non_diffusers:
- state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`Lumina2Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class WanLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
- if any(k.startswith("diffusion_model.") for k in state_dict):
- state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
- elif any(k.startswith("lora_unet_") for k in state_dict):
- state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- @classmethod
- def _maybe_expand_t2v_lora_for_i2v(
- cls,
- transformer: torch.nn.Module,
- state_dict,
- ):
- if transformer.config.image_dim is None:
- return state_dict
-
- if any(k.startswith("transformer.blocks.") for k in state_dict):
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
- is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
-
- if is_i2v_lora:
- return state_dict
-
- for i in range(num_blocks):
- for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
- )
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
- )
-
- return state_dict
-
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
- # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
- state_dict = self._maybe_expand_t2v_lora_for_i2v(
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- state_dict=state_dict,
- )
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`WanTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class CogView4LoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogView4Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
-
-
-class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
- r"""
- Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
- """
-
- _lora_loadable_modules = ["transformer"]
- transformer_name = TRANSFORMER_NAME
-
- @classmethod
- @validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
- def lora_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- **kwargs,
- ):
- r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
- def load_lora_weights(
- self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- adapter_name: Optional[str] = None,
- hotswap: bool = False,
- **kwargs,
- ):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- """
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # if a dict is passed, copy it instead of modifying it inplace
- if isinstance(pretrained_model_name_or_path_or_dict, dict):
- pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
-
- # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
-
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
- raise ValueError("Invalid LoRA checkpoint.")
-
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
- def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
- ):
- """
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HiDreamImageTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- """
- if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
-
- # Load the layers corresponding to transformer.
- logger.info(f"Loading {cls.transformer_name}.")
- transformer.load_lora_adapter(
- state_dict,
- network_alphas=None,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
-
- @classmethod
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
- def save_lora_weights(
- cls,
- save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
- is_main_process: bool = True,
- weight_name: str = None,
- save_function: Callable = None,
- safe_serialization: bool = True,
- ):
- r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
-
- if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
- save_directory=save_directory,
- is_main_process=is_main_process,
- weight_name=weight_name,
- save_function=save_function,
- safe_serialization=safe_serialization,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
- def fuse_lora(
- self,
- components: List[str] = ["transformer"],
- lora_scale: float = 1.0,
- safe_fusing: bool = False,
- adapter_names: Optional[List[str]] = None,
- **kwargs,
- ):
- r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
- """
- super().fuse_lora(
- components=components,
- lora_scale=lora_scale,
- safe_fusing=safe_fusing,
- adapter_names=adapter_names,
- **kwargs,
- )
-
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
- def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
- r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- """
- super().unfuse_lora(components=components, **kwargs)
+class HiDreamImageLoraLoaderMixin(HiDreamImageLoraLoaderMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `CogView4LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogView4LoraLoaderMixin` instead."
+ deprecate("diffusers.loaders.lora_pipeline.CogView4LoraLoaderMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 50450ab7d880..708c118ff521 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -35,8 +35,8 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
-from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
-from .unet_loader_utils import _maybe_expand_lora_scales
+from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading
+from .unet.unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
@@ -99,7 +99,7 @@ class PeftAdapterMixin:
_prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
- # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
+ # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index c2843fc7406a..e6acefb43976 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -11,42 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib
-import inspect
-import os
-
-import torch
-from huggingface_hub import snapshot_download
-from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
-from packaging import version
-from typing_extensions import Self
-
-from ..utils import deprecate, is_transformers_available, logging
-from .single_file_utils import (
- SingleFileComponentError,
- _is_legacy_scheduler_kwargs,
- _is_model_weights_in_cached_folder,
- _legacy_load_clip_tokenizer,
- _legacy_load_safety_checker,
- _legacy_load_scheduler,
- create_diffusers_clip_model_from_ldm,
- create_diffusers_t5_model_from_checkpoint,
- fetch_diffusers_config,
- fetch_original_config,
- is_clip_model_in_single_file,
- is_t5_in_single_file,
- load_single_file_checkpoint,
-)
-
-
-logger = logging.get_logger(__name__)
-
-# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
-SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
-
-if is_transformers_available():
- import transformers
- from transformers import PreTrainedModel, PreTrainedTokenizer
+from ..utils import deprecate
+from .single_file.single_file import FromSingleFileMixin
def load_single_file_sub_model(
@@ -64,502 +30,30 @@ def load_single_file_sub_model(
disable_mmap=False,
**kwargs,
):
- if is_pipeline_module:
- pipeline_module = getattr(pipelines, library_name)
- class_obj = getattr(pipeline_module, class_name)
- else:
- # else we just import it from the library.
- library = importlib.import_module(library_name)
- class_obj = getattr(library, class_name)
-
- if is_transformers_available():
- transformers_version = version.parse(version.parse(transformers.__version__).base_version)
- else:
- transformers_version = "N/A"
-
- is_transformers_model = (
- is_transformers_available()
- and issubclass(class_obj, PreTrainedModel)
- and transformers_version >= version.parse("4.20.0")
- )
- is_tokenizer = (
- is_transformers_available()
- and issubclass(class_obj, PreTrainedTokenizer)
- and transformers_version >= version.parse("4.20.0")
+ from .single_file.single_file import load_single_file_sub_model
+
+ deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead."
+ deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message)
+
+ return load_single_file_sub_model(
+ library_name,
+ class_name,
+ name,
+ checkpoint,
+ pipelines,
+ is_pipeline_module,
+ cached_model_config_path,
+ original_config,
+ local_files_only,
+ torch_dtype,
+ is_legacy_loading,
+ disable_mmap,
+ **kwargs,
)
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
- is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
- is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
-
- if is_diffusers_single_file_model:
- load_method = getattr(class_obj, "from_single_file")
-
- # We cannot provide two different config options to the `from_single_file` method
- # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
- if original_config:
- cached_model_config_path = None
-
- loaded_sub_model = load_method(
- pretrained_model_link_or_path_or_dict=checkpoint,
- original_config=original_config,
- config=cached_model_config_path,
- subfolder=name,
- torch_dtype=torch_dtype,
- local_files_only=local_files_only,
- disable_mmap=disable_mmap,
- **kwargs,
- )
-
- elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
- loaded_sub_model = create_diffusers_clip_model_from_ldm(
- class_obj,
- checkpoint=checkpoint,
- config=cached_model_config_path,
- subfolder=name,
- torch_dtype=torch_dtype,
- local_files_only=local_files_only,
- is_legacy_loading=is_legacy_loading,
- )
-
- elif is_transformers_model and is_t5_in_single_file(checkpoint):
- loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
- class_obj,
- checkpoint=checkpoint,
- config=cached_model_config_path,
- subfolder=name,
- torch_dtype=torch_dtype,
- local_files_only=local_files_only,
- )
-
- elif is_tokenizer and is_legacy_loading:
- loaded_sub_model = _legacy_load_clip_tokenizer(
- class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
- )
-
- elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
- loaded_sub_model = _legacy_load_scheduler(
- class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
- )
-
- else:
- if not hasattr(class_obj, "from_pretrained"):
- raise ValueError(
- (
- f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
- " a supported loading method."
- )
- )
-
- loading_kwargs = {}
- loading_kwargs.update(
- {
- "pretrained_model_name_or_path": cached_model_config_path,
- "subfolder": name,
- "local_files_only": local_files_only,
- }
- )
-
- # Schedulers and Tokenizers don't make use of torch_dtype
- # Skip passing it to those objects
- if issubclass(class_obj, torch.nn.Module):
- loading_kwargs.update({"torch_dtype": torch_dtype})
-
- if is_diffusers_model or is_transformers_model:
- if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
- raise SingleFileComponentError(
- f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
- )
-
- load_method = getattr(class_obj, "from_pretrained")
- loaded_sub_model = load_method(**loading_kwargs)
-
- return loaded_sub_model
-
-
-def _map_component_types_to_config_dict(component_types):
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- config_dict = {}
- component_types.pop("self", None)
-
- if is_transformers_available():
- transformers_version = version.parse(version.parse(transformers.__version__).base_version)
- else:
- transformers_version = "N/A"
-
- for component_name, component_value in component_types.items():
- is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
- is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
- is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
-
- is_transformers_model = (
- is_transformers_available()
- and issubclass(component_value[0], PreTrainedModel)
- and transformers_version >= version.parse("4.20.0")
- )
- is_transformers_tokenizer = (
- is_transformers_available()
- and issubclass(component_value[0], PreTrainedTokenizer)
- and transformers_version >= version.parse("4.20.0")
- )
-
- if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
- config_dict[component_name] = ["diffusers", component_value[0].__name__]
-
- elif is_scheduler_enum or is_scheduler:
- if is_scheduler_enum:
- # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
- # if the type hint is a KarrassDiffusionSchedulers enum
- config_dict[component_name] = ["diffusers", "DDIMScheduler"]
-
- elif is_scheduler:
- config_dict[component_name] = ["diffusers", component_value[0].__name__]
-
- elif (
- is_transformers_model or is_transformers_tokenizer
- ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
- config_dict[component_name] = ["transformers", component_value[0].__name__]
-
- else:
- config_dict[component_name] = [None, None]
-
- return config_dict
-
-
-def _infer_pipeline_config_dict(pipeline_class):
- parameters = inspect.signature(pipeline_class.__init__).parameters
- required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
- component_types = pipeline_class._get_signature_types()
-
- # Ignore parameters that are not required for the pipeline
- component_types = {k: v for k, v in component_types.items() if k in required_parameters}
- config_dict = _map_component_types_to_config_dict(component_types)
-
- return config_dict
-
-
-def _download_diffusers_model_config_from_hub(
- pretrained_model_name_or_path,
- cache_dir,
- revision,
- proxies,
- force_download=None,
- local_files_only=None,
- token=None,
-):
- allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
- cached_model_path = snapshot_download(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- revision=revision,
- proxies=proxies,
- force_download=force_download,
- local_files_only=local_files_only,
- token=token,
- allow_patterns=allow_patterns,
- )
-
- return cached_model_path
-
-
-class FromSingleFileMixin:
- """
- Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
- """
-
- @classmethod
- @validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
- r"""
- Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
- format. The pipeline is set in evaluation mode (`model.eval()`) by default.
-
- Parameters:
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
- Can be either:
- - A link to the `.ckpt` file (for example
- `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
- - A path to a *file* containing all pipeline weights.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- original_config_file (`str`, *optional*):
- The path to the original config file that was used to train the model. If not provided, the config file
- will be inferred from the checkpoint file.
- config (`str`, *optional*):
- Can be either:
- - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
- hosted on the Hub.
- - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
- component configs in Diffusers format.
- disable_mmap ('bool', *optional*, defaults to 'False'):
- Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
- is on a network mount or hard drive.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
- class). The overwritten components are passed directly to the pipelines `__init__` method. See example
- below for more information.
-
- Examples:
-
- ```py
- >>> from diffusers import StableDiffusionPipeline
-
- >>> # Download pipeline from huggingface.co and cache.
- >>> pipeline = StableDiffusionPipeline.from_single_file(
- ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
- ... )
-
- >>> # Download pipeline from local file
- >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
- >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
-
- >>> # Enable float16 and move to GPU
- >>> pipeline = StableDiffusionPipeline.from_single_file(
- ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
- ... torch_dtype=torch.float16,
- ... )
- >>> pipeline.to("cuda")
- ```
-
- """
- original_config_file = kwargs.pop("original_config_file", None)
- config = kwargs.pop("config", None)
- original_config = kwargs.pop("original_config", None)
-
- if original_config_file is not None:
- deprecation_message = (
- "`original_config_file` argument is deprecated and will be removed in future versions."
- "please use the `original_config` argument instead."
- )
- deprecate("original_config_file", "1.0.0", deprecation_message)
- original_config = original_config_file
-
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- cache_dir = kwargs.pop("cache_dir", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
- torch_dtype = kwargs.pop("torch_dtype", None)
- disable_mmap = kwargs.pop("disable_mmap", False)
-
- is_legacy_loading = False
-
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
- torch_dtype = torch.float32
- logger.warning(
- f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
- )
-
- # We shouldn't allow configuring individual models components through a Pipeline creation method
- # These model kwargs should be deprecated
- scaling_factor = kwargs.get("scaling_factor", None)
- if scaling_factor is not None:
- deprecation_message = (
- "Passing the `scaling_factor` argument to `from_single_file is deprecated "
- "and will be ignored in future versions."
- )
- deprecate("scaling_factor", "1.0.0", deprecation_message)
-
- if original_config is not None:
- original_config = fetch_original_config(original_config, local_files_only=local_files_only)
-
- from ..pipelines.pipeline_utils import _get_pipeline_class
-
- pipeline_class = _get_pipeline_class(cls, config=None)
-
- checkpoint = load_single_file_checkpoint(
- pretrained_model_link_or_path,
- force_download=force_download,
- proxies=proxies,
- token=token,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- revision=revision,
- disable_mmap=disable_mmap,
- )
-
- if config is None:
- config = fetch_diffusers_config(checkpoint)
- default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
- else:
- default_pretrained_model_config_name = config
-
- if not os.path.isdir(default_pretrained_model_config_name):
- # Provided config is a repo_id
- if default_pretrained_model_config_name.count("/") > 1:
- raise ValueError(
- f'The provided config "{config}"'
- " is neither a valid local path nor a valid repo id. Please check the parameter."
- )
- try:
- # Attempt to download the config files for the pipeline
- cached_model_config_path = _download_diffusers_model_config_from_hub(
- default_pretrained_model_config_name,
- cache_dir=cache_dir,
- revision=revision,
- proxies=proxies,
- force_download=force_download,
- local_files_only=local_files_only,
- token=token,
- )
- config_dict = pipeline_class.load_config(cached_model_config_path)
-
- except LocalEntryNotFoundError:
- # `local_files_only=True` but a local diffusers format model config is not available in the cache
- # If `original_config` is not provided, we need override `local_files_only` to False
- # to fetch the config files from the hub so that we have a way
- # to configure the pipeline components.
-
- if original_config is None:
- logger.warning(
- "`local_files_only` is True but no local configs were found for this checkpoint.\n"
- "Attempting to download the necessary config files for this pipeline.\n"
- )
- cached_model_config_path = _download_diffusers_model_config_from_hub(
- default_pretrained_model_config_name,
- cache_dir=cache_dir,
- revision=revision,
- proxies=proxies,
- force_download=force_download,
- local_files_only=False,
- token=token,
- )
- config_dict = pipeline_class.load_config(cached_model_config_path)
-
- else:
- # For backwards compatibility
- # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
- logger.warning(
- "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
- "This may lead to errors if the model components are not correctly inferred. \n"
- "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
- "e.g. `from_single_file(, config=) \n"
- "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
- "the necessary config files.\n"
- )
- is_legacy_loading = True
- cached_model_config_path = None
-
- config_dict = _infer_pipeline_config_dict(pipeline_class)
- config_dict["_class_name"] = pipeline_class.__name__
-
- else:
- # Provided config is a path to a local directory attempt to load directly.
- cached_model_config_path = default_pretrained_model_config_name
- config_dict = pipeline_class.load_config(cached_model_config_path)
-
- # pop out "_ignore_files" as it is only needed for download
- config_dict.pop("_ignore_files", None)
-
- expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
- passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
- passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
-
- init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
- init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
- init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
-
- from diffusers import pipelines
-
- # remove `null` components
- def load_module(name, value):
- if value[0] is None:
- return False
- if name in passed_class_obj and passed_class_obj[name] is None:
- return False
- if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
- return False
-
- return True
-
- init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
-
- for name, (library_name, class_name) in logging.tqdm(
- sorted(init_dict.items()), desc="Loading pipeline components..."
- ):
- loaded_sub_model = None
- is_pipeline_module = hasattr(pipelines, library_name)
-
- if name in passed_class_obj:
- loaded_sub_model = passed_class_obj[name]
-
- else:
- try:
- loaded_sub_model = load_single_file_sub_model(
- library_name=library_name,
- class_name=class_name,
- name=name,
- checkpoint=checkpoint,
- is_pipeline_module=is_pipeline_module,
- cached_model_config_path=cached_model_config_path,
- pipelines=pipelines,
- torch_dtype=torch_dtype,
- original_config=original_config,
- local_files_only=local_files_only,
- is_legacy_loading=is_legacy_loading,
- disable_mmap=disable_mmap,
- **kwargs,
- )
- except SingleFileComponentError as e:
- raise SingleFileComponentError(
- (
- f"{e.message}\n"
- f"Please load the component before passing it in as an argument to `from_single_file`.\n"
- f"\n"
- f"{name} = {class_name}.from_pretrained('...')\n"
- f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n"
- f"\n"
- )
- )
-
- init_kwargs[name] = loaded_sub_model
-
- missing_modules = set(expected_modules) - set(init_kwargs.keys())
- passed_modules = list(passed_class_obj.keys())
- optional_modules = pipeline_class._optional_components
-
- if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
- for module in missing_modules:
- init_kwargs[module] = passed_class_obj.get(module, None)
- elif len(missing_modules) > 0:
- passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
- raise ValueError(
- f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
- )
-
- # deprecated kwargs
- load_safety_checker = kwargs.pop("load_safety_checker", None)
- if load_safety_checker is not None:
- deprecation_message = (
- "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
- "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
- )
- deprecate("load_safety_checker", "1.0.0", deprecation_message)
-
- safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
- init_kwargs.update(safety_checker_components)
-
- pipe = pipeline_class(**init_kwargs)
- return pipe
+class FromSingleFileMixin(FromSingleFileMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead."
+ deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/single_file/__init__.py b/src/diffusers/loaders/single_file/__init__.py
new file mode 100644
index 000000000000..827685218160
--- /dev/null
+++ b/src/diffusers/loaders/single_file/__init__.py
@@ -0,0 +1,8 @@
+from ...utils import is_torch_available, is_transformers_available
+
+
+if is_torch_available():
+ from .single_file_model import FromOriginalModelMixin
+
+ if is_transformers_available():
+ from .single_file import FromSingleFileMixin
diff --git a/src/diffusers/loaders/single_file/single_file.py b/src/diffusers/loaders/single_file/single_file.py
new file mode 100644
index 000000000000..26d94e7935ef
--- /dev/null
+++ b/src/diffusers/loaders/single_file/single_file.py
@@ -0,0 +1,565 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import inspect
+import os
+
+import torch
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
+from packaging import version
+from typing_extensions import Self
+
+from ...utils import deprecate, is_transformers_available, logging
+from .single_file_utils import (
+ SingleFileComponentError,
+ _is_legacy_scheduler_kwargs,
+ _is_model_weights_in_cached_folder,
+ _legacy_load_clip_tokenizer,
+ _legacy_load_safety_checker,
+ _legacy_load_scheduler,
+ create_diffusers_clip_model_from_ldm,
+ create_diffusers_t5_model_from_checkpoint,
+ fetch_diffusers_config,
+ fetch_original_config,
+ is_clip_model_in_single_file,
+ is_t5_in_single_file,
+ load_single_file_checkpoint,
+)
+
+
+logger = logging.get_logger(__name__)
+
+# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
+SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
+
+if is_transformers_available():
+ import transformers
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+
+
+def load_single_file_sub_model(
+ library_name,
+ class_name,
+ name,
+ checkpoint,
+ pipelines,
+ is_pipeline_module,
+ cached_model_config_path,
+ original_config=None,
+ local_files_only=False,
+ torch_dtype=None,
+ is_legacy_loading=False,
+ disable_mmap=False,
+ **kwargs,
+):
+ if is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+
+ if is_transformers_available():
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
+ else:
+ transformers_version = "N/A"
+
+ is_transformers_model = (
+ is_transformers_available()
+ and issubclass(class_obj, PreTrainedModel)
+ and transformers_version >= version.parse("4.20.0")
+ )
+ is_tokenizer = (
+ is_transformers_available()
+ and issubclass(class_obj, PreTrainedTokenizer)
+ and transformers_version >= version.parse("4.20.0")
+ )
+
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
+ is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
+ is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
+
+ if is_diffusers_single_file_model:
+ load_method = getattr(class_obj, "from_single_file")
+
+ # We cannot provide two different config options to the `from_single_file` method
+ # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
+ if original_config:
+ cached_model_config_path = None
+
+ loaded_sub_model = load_method(
+ pretrained_model_link_or_path_or_dict=checkpoint,
+ original_config=original_config,
+ config=cached_model_config_path,
+ subfolder=name,
+ torch_dtype=torch_dtype,
+ local_files_only=local_files_only,
+ disable_mmap=disable_mmap,
+ **kwargs,
+ )
+
+ elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
+ loaded_sub_model = create_diffusers_clip_model_from_ldm(
+ class_obj,
+ checkpoint=checkpoint,
+ config=cached_model_config_path,
+ subfolder=name,
+ torch_dtype=torch_dtype,
+ local_files_only=local_files_only,
+ is_legacy_loading=is_legacy_loading,
+ )
+
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
+ class_obj,
+ checkpoint=checkpoint,
+ config=cached_model_config_path,
+ subfolder=name,
+ torch_dtype=torch_dtype,
+ local_files_only=local_files_only,
+ )
+
+ elif is_tokenizer and is_legacy_loading:
+ loaded_sub_model = _legacy_load_clip_tokenizer(
+ class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
+ )
+
+ elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
+ loaded_sub_model = _legacy_load_scheduler(
+ class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
+ )
+
+ else:
+ if not hasattr(class_obj, "from_pretrained"):
+ raise ValueError(
+ (
+ f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
+ " a supported loading method."
+ )
+ )
+
+ loading_kwargs = {}
+ loading_kwargs.update(
+ {
+ "pretrained_model_name_or_path": cached_model_config_path,
+ "subfolder": name,
+ "local_files_only": local_files_only,
+ }
+ )
+
+ # Schedulers and Tokenizers don't make use of torch_dtype
+ # Skip passing it to those objects
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs.update({"torch_dtype": torch_dtype})
+
+ if is_diffusers_model or is_transformers_model:
+ if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
+ raise SingleFileComponentError(
+ f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
+ )
+
+ load_method = getattr(class_obj, "from_pretrained")
+ loaded_sub_model = load_method(**loading_kwargs)
+
+ return loaded_sub_model
+
+
+def _map_component_types_to_config_dict(component_types):
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
+ config_dict = {}
+ component_types.pop("self", None)
+
+ if is_transformers_available():
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
+ else:
+ transformers_version = "N/A"
+
+ for component_name, component_value in component_types.items():
+ is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
+ is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
+ is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
+
+ is_transformers_model = (
+ is_transformers_available()
+ and issubclass(component_value[0], PreTrainedModel)
+ and transformers_version >= version.parse("4.20.0")
+ )
+ is_transformers_tokenizer = (
+ is_transformers_available()
+ and issubclass(component_value[0], PreTrainedTokenizer)
+ and transformers_version >= version.parse("4.20.0")
+ )
+
+ if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
+
+ elif is_scheduler_enum or is_scheduler:
+ if is_scheduler_enum:
+ # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
+ # if the type hint is a KarrassDiffusionSchedulers enum
+ config_dict[component_name] = ["diffusers", "DDIMScheduler"]
+
+ elif is_scheduler:
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
+
+ elif (
+ is_transformers_model or is_transformers_tokenizer
+ ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
+ config_dict[component_name] = ["transformers", component_value[0].__name__]
+
+ else:
+ config_dict[component_name] = [None, None]
+
+ return config_dict
+
+
+def _infer_pipeline_config_dict(pipeline_class):
+ parameters = inspect.signature(pipeline_class.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ component_types = pipeline_class._get_signature_types()
+
+ # Ignore parameters that are not required for the pipeline
+ component_types = {k: v for k, v in component_types.items() if k in required_parameters}
+ config_dict = _map_component_types_to_config_dict(component_types)
+
+ return config_dict
+
+
+def _download_diffusers_model_config_from_hub(
+ pretrained_model_name_or_path,
+ cache_dir,
+ revision,
+ proxies,
+ force_download=None,
+ local_files_only=None,
+ token=None,
+):
+ allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
+ cached_model_path = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ revision=revision,
+ proxies=proxies,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ allow_patterns=allow_patterns,
+ )
+
+ return cached_model_path
+
+
+class FromSingleFileMixin:
+ """
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
+ """
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
+ r"""
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ original_config_file (`str`, *optional*):
+ The path to the original config file that was used to train the model. If not provided, the config file
+ will be inferred from the checkpoint file.
+ config (`str`, *optional*):
+ Can be either:
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
+ component configs in Diffusers format.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+ ... )
+
+ >>> # Download pipeline from local file
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
+
+ >>> # Enable float16 and move to GPU
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipeline.to("cuda")
+ ```
+
+ """
+ original_config_file = kwargs.pop("original_config_file", None)
+ config = kwargs.pop("config", None)
+ original_config = kwargs.pop("original_config", None)
+
+ if original_config_file is not None:
+ deprecation_message = (
+ "`original_config_file` argument is deprecated and will be removed in future versions."
+ "please use the `original_config` argument instead."
+ )
+ deprecate("original_config_file", "1.0.0", deprecation_message)
+ original_config = original_config_file
+
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ cache_dir = kwargs.pop("cache_dir", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
+
+ is_legacy_loading = False
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
+
+ # We shouldn't allow configuring individual models components through a Pipeline creation method
+ # These model kwargs should be deprecated
+ scaling_factor = kwargs.get("scaling_factor", None)
+ if scaling_factor is not None:
+ deprecation_message = (
+ "Passing the `scaling_factor` argument to `from_single_file is deprecated "
+ "and will be ignored in future versions."
+ )
+ deprecate("scaling_factor", "1.0.0", deprecation_message)
+
+ if original_config is not None:
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
+
+ from ..pipelines.pipeline_utils import _get_pipeline_class
+
+ pipeline_class = _get_pipeline_class(cls, config=None)
+
+ checkpoint = load_single_file_checkpoint(
+ pretrained_model_link_or_path,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ disable_mmap=disable_mmap,
+ )
+
+ if config is None:
+ config = fetch_diffusers_config(checkpoint)
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
+ else:
+ default_pretrained_model_config_name = config
+
+ if not os.path.isdir(default_pretrained_model_config_name):
+ # Provided config is a repo_id
+ if default_pretrained_model_config_name.count("/") > 1:
+ raise ValueError(
+ f'The provided config "{config}"'
+ " is neither a valid local path nor a valid repo id. Please check the parameter."
+ )
+ try:
+ # Attempt to download the config files for the pipeline
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
+ default_pretrained_model_config_name,
+ cache_dir=cache_dir,
+ revision=revision,
+ proxies=proxies,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ )
+ config_dict = pipeline_class.load_config(cached_model_config_path)
+
+ except LocalEntryNotFoundError:
+ # `local_files_only=True` but a local diffusers format model config is not available in the cache
+ # If `original_config` is not provided, we need override `local_files_only` to False
+ # to fetch the config files from the hub so that we have a way
+ # to configure the pipeline components.
+
+ if original_config is None:
+ logger.warning(
+ "`local_files_only` is True but no local configs were found for this checkpoint.\n"
+ "Attempting to download the necessary config files for this pipeline.\n"
+ )
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
+ default_pretrained_model_config_name,
+ cache_dir=cache_dir,
+ revision=revision,
+ proxies=proxies,
+ force_download=force_download,
+ local_files_only=False,
+ token=token,
+ )
+ config_dict = pipeline_class.load_config(cached_model_config_path)
+
+ else:
+ # For backwards compatibility
+ # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
+ logger.warning(
+ "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
+ "This may lead to errors if the model components are not correctly inferred. \n"
+ "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
+ "e.g. `from_single_file(, config=) \n"
+ "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
+ "the necessary config files.\n"
+ )
+ is_legacy_loading = True
+ cached_model_config_path = None
+
+ config_dict = _infer_pipeline_config_dict(pipeline_class)
+ config_dict["_class_name"] = pipeline_class.__name__
+
+ else:
+ # Provided config is a path to a local directory attempt to load directly.
+ cached_model_config_path = default_pretrained_model_config_name
+ config_dict = pipeline_class.load_config(cached_model_config_path)
+
+ # pop out "_ignore_files" as it is only needed for download
+ config_dict.pop("_ignore_files", None)
+
+ expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
+
+ from diffusers import pipelines
+
+ # remove `null` components
+ def load_module(name, value):
+ if value[0] is None:
+ return False
+ if name in passed_class_obj and passed_class_obj[name] is None:
+ return False
+ if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
+ return False
+
+ return True
+
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
+
+ for name, (library_name, class_name) in logging.tqdm(
+ sorted(init_dict.items()), desc="Loading pipeline components..."
+ ):
+ loaded_sub_model = None
+ is_pipeline_module = hasattr(pipelines, library_name)
+
+ if name in passed_class_obj:
+ loaded_sub_model = passed_class_obj[name]
+
+ else:
+ try:
+ loaded_sub_model = load_single_file_sub_model(
+ library_name=library_name,
+ class_name=class_name,
+ name=name,
+ checkpoint=checkpoint,
+ is_pipeline_module=is_pipeline_module,
+ cached_model_config_path=cached_model_config_path,
+ pipelines=pipelines,
+ torch_dtype=torch_dtype,
+ original_config=original_config,
+ local_files_only=local_files_only,
+ is_legacy_loading=is_legacy_loading,
+ disable_mmap=disable_mmap,
+ **kwargs,
+ )
+ except SingleFileComponentError as e:
+ raise SingleFileComponentError(
+ (
+ f"{e.message}\n"
+ f"Please load the component before passing it in as an argument to `from_single_file`.\n"
+ f"\n"
+ f"{name} = {class_name}.from_pretrained('...')\n"
+ f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n"
+ f"\n"
+ )
+ )
+
+ init_kwargs[name] = loaded_sub_model
+
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ passed_modules = list(passed_class_obj.keys())
+ optional_modules = pipeline_class._optional_components
+
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj.get(module, None)
+ elif len(missing_modules) > 0:
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ # deprecated kwargs
+ load_safety_checker = kwargs.pop("load_safety_checker", None)
+ if load_safety_checker is not None:
+ deprecation_message = (
+ "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
+ "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
+ )
+ deprecate("load_safety_checker", "1.0.0", deprecation_message)
+
+ safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
+ init_kwargs.update(safety_checker_components)
+
+ pipe = pipeline_class(**init_kwargs)
+
+ return pipe
diff --git a/src/diffusers/loaders/single_file/single_file_model.py b/src/diffusers/loaders/single_file/single_file_model.py
new file mode 100644
index 000000000000..0b97d7fbfafd
--- /dev/null
+++ b/src/diffusers/loaders/single_file/single_file_model.py
@@ -0,0 +1,440 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import inspect
+import re
+from contextlib import nullcontext
+from typing import Optional
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+from typing_extensions import Self
+
+from ... import __version__
+from ...quantizers import DiffusersAutoQuantizer
+from ...utils import deprecate, is_accelerate_available, logging
+from .single_file_utils import (
+ SingleFileComponentError,
+ convert_animatediff_checkpoint_to_diffusers,
+ convert_auraflow_transformer_checkpoint_to_diffusers,
+ convert_autoencoder_dc_checkpoint_to_diffusers,
+ convert_controlnet_checkpoint,
+ convert_flux_transformer_checkpoint_to_diffusers,
+ convert_hunyuan_video_transformer_to_diffusers,
+ convert_ldm_unet_checkpoint,
+ convert_ldm_vae_checkpoint,
+ convert_ltx_transformer_checkpoint_to_diffusers,
+ convert_ltx_vae_checkpoint_to_diffusers,
+ convert_lumina2_to_diffusers,
+ convert_mochi_transformer_checkpoint_to_diffusers,
+ convert_sana_transformer_to_diffusers,
+ convert_sd3_transformer_checkpoint_to_diffusers,
+ convert_stable_cascade_unet_single_file_to_diffusers,
+ convert_wan_transformer_to_diffusers,
+ convert_wan_vae_to_diffusers,
+ create_controlnet_diffusers_config_from_ldm,
+ create_unet_diffusers_config_from_ldm,
+ create_vae_diffusers_config_from_ldm,
+ fetch_diffusers_config,
+ fetch_original_config,
+ load_single_file_checkpoint,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_accelerate_available():
+ from accelerate import dispatch_model, init_empty_weights
+
+ from ...models.modeling_utils import load_model_dict_into_meta
+
+
+SINGLE_FILE_LOADABLE_CLASSES = {
+ "StableCascadeUNet": {
+ "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
+ },
+ "UNet2DConditionModel": {
+ "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
+ "config_mapping_fn": create_unet_diffusers_config_from_ldm,
+ "default_subfolder": "unet",
+ "legacy_kwargs": {
+ "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
+ },
+ },
+ "AutoencoderKL": {
+ "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
+ "config_mapping_fn": create_vae_diffusers_config_from_ldm,
+ "default_subfolder": "vae",
+ },
+ "ControlNetModel": {
+ "checkpoint_mapping_fn": convert_controlnet_checkpoint,
+ "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
+ },
+ "SD3Transformer2DModel": {
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "MotionAdapter": {
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
+ },
+ "SparseControlNetModel": {
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
+ },
+ "FluxTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "LTXVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AutoencoderKLLTXVideo": {
+ "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
+ "default_subfolder": "vae",
+ },
+ "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
+ "MochiTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "HunyuanVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AuraFlowTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "Lumina2Transformer2DModel": {
+ "checkpoint_mapping_fn": convert_lumina2_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "SanaTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "WanTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AutoencoderKLWan": {
+ "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
+ "default_subfolder": "vae",
+ },
+}
+
+
+def _get_single_file_loadable_mapping_class(cls):
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
+ loadable_class = getattr(diffusers_module, loadable_class_str)
+
+ if issubclass(cls, loadable_class):
+ return loadable_class_str
+
+ return None
+
+
+def _get_mapping_function_kwargs(mapping_fn, **kwargs):
+ parameters = inspect.signature(mapping_fn).parameters
+
+ mapping_kwargs = {}
+ for parameter in parameters:
+ if parameter in kwargs:
+ mapping_kwargs[parameter] = kwargs[parameter]
+
+ return mapping_kwargs
+
+
+class FromOriginalModelMixin:
+ """
+ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
+ """
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
+ r"""
+ Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
+ is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path_or_dict (`str`, *optional*):
+ Can be either:
+ - A link to the `.safetensors` or `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A path to a local *file* containing the weights of the component model.
+ - A state dict containing the component model weights.
+ config (`str`, *optional*):
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
+ on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
+ configs in Diffusers format.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ original_config (`str`, *optional*):
+ Dict or path to a yaml file containing the configuration for the model in its original format.
+ If a dict is provided, it will be used to initialize the model configuration.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+ ```py
+ >>> from diffusers import StableCascadeUNet
+
+ >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
+ >>> model = StableCascadeUNet.from_single_file(ckpt_path)
+ ```
+ """
+
+ mapping_class_name = _get_single_file_loadable_mapping_class(cls)
+ # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
+ if mapping_class_name is None:
+ raise ValueError(
+ f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
+ )
+
+ pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
+ if pretrained_model_link_or_path is not None:
+ deprecation_message = (
+ "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
+ )
+ deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
+ pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
+
+ config = kwargs.pop("config", None)
+ original_config = kwargs.pop("original_config", None)
+
+ if config is not None and original_config is not None:
+ raise ValueError(
+ "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
+ )
+
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ cache_dir = kwargs.pop("cache_dir", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ subfolder = kwargs.pop("subfolder", None)
+ revision = kwargs.pop("revision", None)
+ config_revision = kwargs.pop("config_revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ quantization_config = kwargs.pop("quantization_config", None)
+ device = kwargs.pop("device", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
+
+ user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
+ if quantization_config is not None:
+ user_agent["quant"] = quantization_config.quant_method.value
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
+
+ if isinstance(pretrained_model_link_or_path_or_dict, dict):
+ checkpoint = pretrained_model_link_or_path_or_dict
+ else:
+ checkpoint = load_single_file_checkpoint(
+ pretrained_model_link_or_path_or_dict,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ disable_mmap=disable_mmap,
+ user_agent=user_agent,
+ )
+ if quantization_config is not None:
+ hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
+ hf_quantizer.validate_environment()
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+
+ else:
+ hf_quantizer = None
+
+ mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
+
+ checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
+ if original_config is not None:
+ if "config_mapping_fn" in mapping_functions:
+ config_mapping_fn = mapping_functions["config_mapping_fn"]
+ else:
+ config_mapping_fn = None
+
+ if config_mapping_fn is None:
+ raise ValueError(
+ (
+ f"`original_config` has been provided for {mapping_class_name} but no mapping function"
+ "was found to convert the original config to a Diffusers config in"
+ "`diffusers.loaders.single_file_utils`"
+ )
+ )
+
+ if isinstance(original_config, str):
+ # If original_config is a URL or filepath fetch the original_config dict
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
+
+ config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
+ diffusers_model_config = config_mapping_fn(
+ original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
+ )
+ else:
+ if config is not None:
+ if isinstance(config, str):
+ default_pretrained_model_config_name = config
+ else:
+ raise ValueError(
+ (
+ "Invalid `config` argument. Please provide a string representing a repo id"
+ "or path to a local Diffusers model repo."
+ )
+ )
+
+ else:
+ config = fetch_diffusers_config(checkpoint)
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
+
+ if "default_subfolder" in mapping_functions:
+ subfolder = mapping_functions["default_subfolder"]
+
+ subfolder = subfolder or config.pop(
+ "subfolder", None
+ ) # some configs contain a subfolder key, e.g. StableCascadeUNet
+
+ diffusers_model_config = cls.load_config(
+ pretrained_model_name_or_path=default_pretrained_model_config_name,
+ subfolder=subfolder,
+ local_files_only=local_files_only,
+ token=token,
+ revision=config_revision,
+ )
+ expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
+
+ # Map legacy kwargs to new kwargs
+ if "legacy_kwargs" in mapping_functions:
+ legacy_kwargs = mapping_functions["legacy_kwargs"]
+ for legacy_key, new_key in legacy_kwargs.items():
+ if legacy_key in kwargs:
+ kwargs[new_key] = kwargs.pop(legacy_key)
+
+ model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
+ diffusers_model_config.update(model_kwargs)
+
+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
+ )
+ if not diffusers_format_checkpoint:
+ raise SingleFileComponentError(
+ f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
+ )
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ model = cls.from_config(diffusers_model_config)
+
+ # Check if `_keep_in_fp32_modules` is not None
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+ )
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
+ if not isinstance(keep_in_fp32_modules, list):
+ keep_in_fp32_modules = [keep_in_fp32_modules]
+
+ else:
+ keep_in_fp32_modules = []
+
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model,
+ device_map=None,
+ state_dict=diffusers_format_checkpoint,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ )
+
+ device_map = None
+ if is_accelerate_available():
+ param_device = torch.device(device) if device else torch.device("cpu")
+ empty_state_dict = model.state_dict()
+ unexpected_keys = [
+ param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
+ ]
+ device_map = {"": param_device}
+ load_model_dict_into_meta(
+ model,
+ diffusers_format_checkpoint,
+ dtype=torch_dtype,
+ device_map=device_map,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ unexpected_keys=unexpected_keys,
+ )
+ else:
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
+
+ if model._keys_to_ignore_on_load_unexpected is not None:
+ for pat in model._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
+ )
+
+ if hf_quantizer is not None:
+ hf_quantizer.postprocess_model(model)
+ model.hf_quantizer = hf_quantizer
+
+ if torch_dtype is not None and hf_quantizer is None:
+ model.to(torch_dtype)
+
+ model.eval()
+
+ if device_map is not None:
+ device_map_kwargs = {"device_map": device_map}
+ dispatch_model(model, **device_map_kwargs)
+
+ return model
diff --git a/src/diffusers/loaders/single_file/single_file_utils.py b/src/diffusers/loaders/single_file/single_file_utils.py
new file mode 100644
index 000000000000..a37a6c22d432
--- /dev/null
+++ b/src/diffusers/loaders/single_file/single_file_utils.py
@@ -0,0 +1,3295 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for the Stable Diffusion checkpoints."""
+
+import copy
+import os
+import re
+from contextlib import nullcontext
+from io import BytesIO
+from urllib.parse import urlparse
+
+import requests
+import torch
+import yaml
+
+from ...models.modeling_utils import load_state_dict
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EDMDPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import (
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ deprecate,
+ is_accelerate_available,
+ is_transformers_available,
+ logging,
+)
+from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
+from ...utils.hub_utils import _get_model_file
+
+
+if is_transformers_available():
+ from transformers import AutoImageProcessor
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ from ...models.modeling_utils import load_model_dict_into_meta
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+CHECKPOINT_KEY_NAMES = {
+ "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
+ "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
+ "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
+ "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
+ "controlnet": [
+ "control_model.time_embed.0.weight",
+ "controlnet_cond_embedding.conv_in.weight",
+ ],
+ # TODO: find non-Diffusers keys for controlnet_xl
+ "controlnet_xl": "add_embedding.linear_1.weight",
+ "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
+ "playground-v2-5": "edm_mean",
+ "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
+ "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
+ "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
+ "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
+ "open_clip": "cond_stage_model.model.token_embedding.weight",
+ "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
+ "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
+ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
+ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
+ "stable_cascade_stage_c": "clip_txt_mapper.weight",
+ "sd3": [
+ "joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ ],
+ "sd35_large": [
+ "joint_blocks.37.x_block.mlp.fc1.weight",
+ "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
+ ],
+ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
+ "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
+ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
+ "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
+ "animatediff_rgb": "controlnet_cond_embedding.weight",
+ "auraflow": [
+ "double_layers.0.attn.w2q.weight",
+ "double_layers.0.attn.w1q.weight",
+ "cond_seq_linear.weight",
+ "t_embedder.mlp.0.weight",
+ ],
+ "flux": [
+ "double_blocks.0.img_attn.norm.key_norm.scale",
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
+ ],
+ "ltx-video": [
+ "model.diffusion_model.patchify_proj.weight",
+ "model.diffusion_model.transformer_blocks.27.scale_shift_table",
+ "patchify_proj.weight",
+ "transformer_blocks.27.scale_shift_table",
+ "vae.per_channel_statistics.mean-of-means",
+ ],
+ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
+ "autoencoder-dc-sana": "encoder.project_in.conv.bias",
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
+ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
+ "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
+ "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
+ "sana": [
+ "blocks.0.cross_attn.q_linear.weight",
+ "blocks.0.cross_attn.q_linear.bias",
+ "blocks.0.cross_attn.kv_linear.weight",
+ "blocks.0.cross_attn.kv_linear.bias",
+ ],
+ "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
+ "wan_vae": "decoder.middle.0.residual.0.gamma",
+}
+
+DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
+ "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
+ "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
+ "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
+ "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
+ "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
+ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
+ "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
+ "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
+ "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
+ "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
+ "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
+ "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
+ "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
+ "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
+ "stable_cascade_stage_b_lite": {
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade",
+ "subfolder": "decoder_lite",
+ },
+ "stable_cascade_stage_c": {
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
+ "subfolder": "prior",
+ },
+ "stable_cascade_stage_c_lite": {
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
+ "subfolder": "prior_lite",
+ },
+ "sd3": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
+ },
+ "sd35_large": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
+ },
+ "sd35_medium": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
+ },
+ "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
+ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
+ "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
+ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
+ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
+ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
+ "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
+ "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
+ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
+ "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
+ "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
+ "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
+ "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
+ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
+ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
+ "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
+ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
+ "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
+ "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
+ "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
+ "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
+ "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
+ "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
+}
+
+# Use to configure model sample size when original config is provided
+DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
+ "xl_base": 1024,
+ "xl_refiner": 1024,
+ "xl_inpaint": 1024,
+ "playground-v2-5": 1024,
+ "upscale": 512,
+ "inpainting": 512,
+ "inpainting_v2": 512,
+ "controlnet": 512,
+ "instruct-pix2pix": 512,
+ "v2": 768,
+ "v1": 512,
+}
+
+
+DIFFUSERS_TO_LDM_MAPPING = {
+ "unet": {
+ "layers": {
+ "time_embedding.linear_1.weight": "time_embed.0.weight",
+ "time_embedding.linear_1.bias": "time_embed.0.bias",
+ "time_embedding.linear_2.weight": "time_embed.2.weight",
+ "time_embedding.linear_2.bias": "time_embed.2.bias",
+ "conv_in.weight": "input_blocks.0.0.weight",
+ "conv_in.bias": "input_blocks.0.0.bias",
+ "conv_norm_out.weight": "out.0.weight",
+ "conv_norm_out.bias": "out.0.bias",
+ "conv_out.weight": "out.2.weight",
+ "conv_out.bias": "out.2.bias",
+ },
+ "class_embed_type": {
+ "class_embedding.linear_1.weight": "label_emb.0.0.weight",
+ "class_embedding.linear_1.bias": "label_emb.0.0.bias",
+ "class_embedding.linear_2.weight": "label_emb.0.2.weight",
+ "class_embedding.linear_2.bias": "label_emb.0.2.bias",
+ },
+ "addition_embed_type": {
+ "add_embedding.linear_1.weight": "label_emb.0.0.weight",
+ "add_embedding.linear_1.bias": "label_emb.0.0.bias",
+ "add_embedding.linear_2.weight": "label_emb.0.2.weight",
+ "add_embedding.linear_2.bias": "label_emb.0.2.bias",
+ },
+ },
+ "controlnet": {
+ "layers": {
+ "time_embedding.linear_1.weight": "time_embed.0.weight",
+ "time_embedding.linear_1.bias": "time_embed.0.bias",
+ "time_embedding.linear_2.weight": "time_embed.2.weight",
+ "time_embedding.linear_2.bias": "time_embed.2.bias",
+ "conv_in.weight": "input_blocks.0.0.weight",
+ "conv_in.bias": "input_blocks.0.0.bias",
+ "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
+ "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
+ "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
+ "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
+ },
+ "class_embed_type": {
+ "class_embedding.linear_1.weight": "label_emb.0.0.weight",
+ "class_embedding.linear_1.bias": "label_emb.0.0.bias",
+ "class_embedding.linear_2.weight": "label_emb.0.2.weight",
+ "class_embedding.linear_2.bias": "label_emb.0.2.bias",
+ },
+ "addition_embed_type": {
+ "add_embedding.linear_1.weight": "label_emb.0.0.weight",
+ "add_embedding.linear_1.bias": "label_emb.0.0.bias",
+ "add_embedding.linear_2.weight": "label_emb.0.2.weight",
+ "add_embedding.linear_2.bias": "label_emb.0.2.bias",
+ },
+ },
+ "vae": {
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
+ "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
+ "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
+ "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
+ "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
+ "quant_conv.weight": "quant_conv.weight",
+ "quant_conv.bias": "quant_conv.bias",
+ "post_quant_conv.weight": "post_quant_conv.weight",
+ "post_quant_conv.bias": "post_quant_conv.bias",
+ },
+ "openclip": {
+ "layers": {
+ "text_model.embeddings.position_embedding.weight": "positional_embedding",
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_model.final_layer_norm.weight": "ln_final.weight",
+ "text_model.final_layer_norm.bias": "ln_final.bias",
+ "text_projection.weight": "text_projection",
+ },
+ "transformer": {
+ "text_model.encoder.layers.": "resblocks.",
+ "layer_norm1": "ln_1",
+ "layer_norm2": "ln_2",
+ ".fc1.": ".c_fc.",
+ ".fc2.": ".c_proj.",
+ ".self_attn": ".attn",
+ "transformer.text_model.final_layer_norm.": "ln_final.",
+ "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding",
+ },
+ },
+}
+
+SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
+ "cond_stage_model.model.text_projection",
+]
+
+# To support legacy scheduler_type argument
+SCHEDULER_DEFAULT_CONFIG = {
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "beta_end": 0.012,
+ "interpolation_type": "linear",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "set_alpha_to_one": False,
+ "skip_prk_steps": True,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+}
+
+LDM_VAE_KEYS = ["first_stage_model.", "vae."]
+LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
+PLAYGROUND_VAE_SCALING_FACTOR = 0.5
+LDM_UNET_KEY = "model.diffusion_model."
+LDM_CONTROLNET_KEY = "control_model."
+LDM_CLIP_PREFIX_TO_REMOVE = [
+ "cond_stage_model.transformer.",
+ "conditioner.embedders.0.transformer.",
+]
+LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
+SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
+
+VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
+
+
+class SingleFileComponentError(Exception):
+ def __init__(self, message=None):
+ self.message = message
+ super().__init__(self.message)
+
+
+def is_valid_url(url):
+ result = urlparse(url)
+ if result.scheme and result.netloc:
+ return True
+
+ return False
+
+
+def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
+ if not is_valid_url(pretrained_model_name_or_path):
+ raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
+
+ pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
+ weights_name = None
+ repo_id = (None,)
+ for prefix in VALID_URL_PREFIXES:
+ pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
+ match = re.match(pattern, pretrained_model_name_or_path)
+ if not match:
+ logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
+ return repo_id, weights_name
+
+ repo_id = f"{match.group(1)}/{match.group(2)}"
+ weights_name = match.group(3)
+
+ return repo_id, weights_name
+
+
+def _is_model_weights_in_cached_folder(cached_folder, name):
+ pretrained_model_name_or_path = os.path.join(cached_folder, name)
+ weights_exist = False
+
+ for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+ weights_exist = True
+
+ return weights_exist
+
+
+def _is_legacy_scheduler_kwargs(kwargs):
+ return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
+
+
+def load_single_file_checkpoint(
+ pretrained_model_link_or_path,
+ force_download=False,
+ proxies=None,
+ token=None,
+ cache_dir=None,
+ local_files_only=None,
+ revision=None,
+ disable_mmap=False,
+ user_agent=None,
+):
+ if user_agent is None:
+ user_agent = {"file_type": "single_file", "framework": "pytorch"}
+
+ if os.path.isfile(pretrained_model_link_or_path):
+ pretrained_model_link_or_path = pretrained_model_link_or_path
+
+ else:
+ repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
+ pretrained_model_link_or_path = _get_model_file(
+ repo_id,
+ weights_name=weights_name,
+ force_download=force_download,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ user_agent=user_agent,
+ )
+
+ checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
+
+ # some checkpoints contain the model state dict under a "state_dict" key
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ return checkpoint
+
+
+def fetch_original_config(original_config_file, local_files_only=False):
+ if os.path.isfile(original_config_file):
+ with open(original_config_file, "r") as fp:
+ original_config_file = fp.read()
+
+ elif is_valid_url(original_config_file):
+ if local_files_only:
+ raise ValueError(
+ "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
+ "Please provide a valid local file path."
+ )
+
+ original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
+
+ else:
+ raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
+
+ original_config = yaml.safe_load(original_config_file)
+
+ return original_config
+
+
+def is_clip_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_clip_sdxl_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_clip_sd3_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_open_clip_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_open_clip_sdxl_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_open_clip_sd3_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_open_clip_sdxl_refiner_model(checkpoint):
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
+ return True
+
+ return False
+
+
+def is_clip_model_in_single_file(class_obj, checkpoint):
+ is_clip_in_checkpoint = any(
+ [
+ is_clip_model(checkpoint),
+ is_clip_sd3_model(checkpoint),
+ is_open_clip_model(checkpoint),
+ is_open_clip_sdxl_model(checkpoint),
+ is_open_clip_sdxl_refiner_model(checkpoint),
+ is_open_clip_sd3_model(checkpoint),
+ ]
+ )
+ if (
+ class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
+ ) and is_clip_in_checkpoint:
+ return True
+
+ return False
+
+
+def infer_diffusers_model_type(checkpoint):
+ if (
+ CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
+ ):
+ if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
+ model_type = "inpainting_v2"
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
+ model_type = "xl_inpaint"
+ else:
+ model_type = "inpainting"
+
+ elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
+ model_type = "v2"
+
+ elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
+ model_type = "playground-v2-5"
+
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
+ model_type = "xl_base"
+
+ elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
+ model_type = "xl_refiner"
+
+ elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
+ model_type = "upscale"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
+ if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
+ if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
+ model_type = "controlnet_xl_large"
+ elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
+ model_type = "controlnet_xl_mid"
+ else:
+ model_type = "controlnet_xl_small"
+ else:
+ model_type = "controlnet"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
+ ):
+ model_type = "stable_cascade_stage_c_lite"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
+ ):
+ model_type = "stable_cascade_stage_c"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
+ ):
+ model_type = "stable_cascade_stage_b_lite"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
+ ):
+ model_type = "stable_cascade_stage_b"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
+ checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
+ ):
+ if "model.diffusion_model.pos_embed" in checkpoint:
+ key = "model.diffusion_model.pos_embed"
+ else:
+ key = "pos_embed"
+
+ if checkpoint[key].shape[1] == 36864:
+ model_type = "sd3"
+ elif checkpoint[key].shape[1] == 147456:
+ model_type = "sd35_medium"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
+ model_type = "sd35_large"
+
+ elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
+ if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
+ model_type = "animatediff_scribble"
+
+ elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
+ model_type = "animatediff_rgb"
+
+ elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
+ model_type = "animatediff_v2"
+
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
+ model_type = "animatediff_sdxl_beta"
+
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
+ model_type = "animatediff_v1"
+
+ else:
+ model_type = "animatediff_v3"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
+ if any(
+ g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
+ ):
+ if "model.diffusion_model.img_in.weight" in checkpoint:
+ key = "model.diffusion_model.img_in.weight"
+ else:
+ key = "img_in.weight"
+
+ if checkpoint[key].shape[1] == 384:
+ model_type = "flux-fill"
+ elif checkpoint[key].shape[1] == 128:
+ model_type = "flux-depth"
+ else:
+ model_type = "flux-dev"
+ else:
+ model_type = "flux-schnell"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
+ if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
+ model_type = "ltx-video-0.9.5"
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
+ model_type = "ltx-video-0.9.1"
+ else:
+ model_type = "ltx-video"
+
+ elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
+ encoder_key = "encoder.project_in.conv.conv.bias"
+ decoder_key = "decoder.project_in.main.conv.weight"
+
+ if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
+ model_type = "autoencoder-dc-f32c32-sana"
+
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
+ model_type = "autoencoder-dc-f32c32"
+
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
+ model_type = "autoencoder-dc-f64c128"
+
+ else:
+ model_type = "autoencoder-dc-f128c512"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
+ model_type = "mochi-1-preview"
+
+ elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
+ model_type = "hunyuan-video"
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
+ model_type = "auraflow"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
+ ):
+ model_type = "instruct-pix2pix"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
+ model_type = "lumina2"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
+ model_type = "sana"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
+ if "model.diffusion_model.patch_embedding.weight" in checkpoint:
+ target_key = "model.diffusion_model.patch_embedding.weight"
+ else:
+ target_key = "patch_embedding.weight"
+
+ if checkpoint[target_key].shape[0] == 1536:
+ model_type = "wan-t2v-1.3B"
+ elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
+ model_type = "wan-t2v-14B"
+ else:
+ model_type = "wan-i2v-14B"
+ elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
+ # All Wan models use the same VAE so we can use the same default model repo to fetch the config
+ model_type = "wan-t2v-14B"
+ else:
+ model_type = "v1"
+
+ return model_type
+
+
+def fetch_diffusers_config(checkpoint):
+ model_type = infer_diffusers_model_type(checkpoint)
+ model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
+ model_path = copy.deepcopy(model_path)
+
+ return model_path
+
+
+def set_image_size(checkpoint, image_size=None):
+ if image_size:
+ return image_size
+
+ model_type = infer_diffusers_model_type(checkpoint)
+ image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
+
+ return image_size
+
+
+# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def create_unet_diffusers_config_from_ldm(
+ original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
+):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if image_size is not None:
+ deprecation_message = (
+ "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
+ "is deprecated and will be ignored in future versions."
+ )
+ deprecate("image_size", "1.0.0", deprecation_message)
+
+ image_size = set_image_size(checkpoint, image_size=image_size)
+
+ if (
+ "unet_config" in original_config["model"]["params"]
+ and original_config["model"]["params"]["unet_config"] is not None
+ ):
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
+ else:
+ unet_params = original_config["model"]["params"]["network_config"]["params"]
+
+ if num_in_channels is not None:
+ deprecation_message = (
+ "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
+ "is deprecated and will be ignored in future versions."
+ )
+ deprecate("image_size", "1.0.0", deprecation_message)
+ in_channels = num_in_channels
+ else:
+ in_channels = unet_params["in_channels"]
+
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ if unet_params["transformer_depth"] is not None:
+ transformer_layers_per_block = (
+ unet_params["transformer_depth"]
+ if isinstance(unet_params["transformer_depth"], int)
+ else list(unet_params["transformer_depth"])
+ )
+ else:
+ transformer_layers_per_block = 1
+
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
+
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
+ head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
+
+ class_embed_type = None
+ addition_embed_type = None
+ addition_time_embed_dim = None
+ projection_class_embeddings_input_dim = None
+ context_dim = None
+
+ if unet_params["context_dim"] is not None:
+ context_dim = (
+ unet_params["context_dim"]
+ if isinstance(unet_params["context_dim"], int)
+ else unet_params["context_dim"][0]
+ )
+
+ if "num_classes" in unet_params:
+ if unet_params["num_classes"] == "sequential":
+ if context_dim in [2048, 1280]:
+ # SDXL
+ addition_embed_type = "text_time"
+ addition_time_embed_dim = 256
+ else:
+ class_embed_type = "projection"
+ assert "adm_in_channels" in unet_params
+ projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": in_channels,
+ "down_block_types": down_block_types,
+ "block_out_channels": block_out_channels,
+ "layers_per_block": unet_params["num_res_blocks"],
+ "cross_attention_dim": context_dim,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "addition_embed_type": addition_embed_type,
+ "addition_time_embed_dim": addition_time_embed_dim,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "transformer_layers_per_block": transformer_layers_per_block,
+ }
+
+ if upcast_attention is not None:
+ deprecation_message = (
+ "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
+ "is deprecated and will be ignored in future versions."
+ )
+ deprecate("image_size", "1.0.0", deprecation_message)
+ config["upcast_attention"] = upcast_attention
+
+ if "disable_self_attentions" in unet_params:
+ config["only_cross_attention"] = unet_params["disable_self_attentions"]
+
+ if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
+ config["num_class_embeds"] = unet_params["num_classes"]
+
+ config["out_channels"] = unet_params["out_channels"]
+ config["up_block_types"] = up_block_types
+
+ return config
+
+
+def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
+ if image_size is not None:
+ deprecation_message = (
+ "Configuring ControlNetModel with the `image_size` argument"
+ "is deprecated and will be ignored in future versions."
+ )
+ deprecate("image_size", "1.0.0", deprecation_message)
+
+ image_size = set_image_size(checkpoint, image_size=image_size)
+
+ unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
+ diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
+
+ controlnet_config = {
+ "conditioning_channels": unet_params["hint_channels"],
+ "in_channels": diffusers_unet_config["in_channels"],
+ "down_block_types": diffusers_unet_config["down_block_types"],
+ "block_out_channels": diffusers_unet_config["block_out_channels"],
+ "layers_per_block": diffusers_unet_config["layers_per_block"],
+ "cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
+ "attention_head_dim": diffusers_unet_config["attention_head_dim"],
+ "use_linear_projection": diffusers_unet_config["use_linear_projection"],
+ "class_embed_type": diffusers_unet_config["class_embed_type"],
+ "addition_embed_type": diffusers_unet_config["addition_embed_type"],
+ "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
+ "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
+ "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
+ }
+
+ return controlnet_config
+
+
+def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if image_size is not None:
+ deprecation_message = (
+ "Configuring AutoencoderKL with the `image_size` argument"
+ "is deprecated and will be ignored in future versions."
+ )
+ deprecate("image_size", "1.0.0", deprecation_message)
+
+ image_size = set_image_size(checkpoint, image_size=image_size)
+
+ if "edm_mean" in checkpoint and "edm_std" in checkpoint:
+ latents_mean = checkpoint["edm_mean"]
+ latents_std = checkpoint["edm_std"]
+ else:
+ latents_mean = None
+ latents_std = None
+
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
+ if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
+ scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
+
+ elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
+ scaling_factor = original_config["model"]["params"]["scale_factor"]
+
+ elif scaling_factor is None:
+ scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
+
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params["in_channels"],
+ "out_channels": vae_params["out_ch"],
+ "down_block_types": down_block_types,
+ "up_block_types": up_block_types,
+ "block_out_channels": block_out_channels,
+ "latent_channels": vae_params["z_channels"],
+ "layers_per_block": vae_params["num_res_blocks"],
+ "scaling_factor": scaling_factor,
+ }
+ if latents_mean is not None and latents_std is not None:
+ config.update({"latents_mean": latents_mean, "latents_std": latents_std})
+
+ return config
+
+
+def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None):
+ for ldm_key in ldm_keys:
+ diffusers_key = (
+ ldm_key.replace("in_layers.0", "norm1")
+ .replace("in_layers.2", "conv1")
+ .replace("out_layers.0", "norm2")
+ .replace("out_layers.3", "conv2")
+ .replace("emb_layers.1", "time_emb_proj")
+ .replace("skip_connection", "conv_shortcut")
+ )
+ if mapping:
+ diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+
+def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in ldm_keys:
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+
+def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+
+def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = (
+ ldm_key.replace(mapping["old"], mapping["new"])
+ .replace("norm.weight", "group_norm.weight")
+ .replace("norm.bias", "group_norm.bias")
+ .replace("q.weight", "to_q.weight")
+ .replace("q.bias", "to_q.bias")
+ .replace("k.weight", "to_k.weight")
+ .replace("k.bias", "to_k.bias")
+ .replace("v.weight", "to_v.weight")
+ .replace("v.bias", "to_v.bias")
+ .replace("proj_out.weight", "to_out.0.weight")
+ .replace("proj_out.bias", "to_out.0.bias")
+ )
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ shape = new_checkpoint[diffusers_key].shape
+
+ if len(shape) == 3:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
+ elif len(shape) == 4:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
+
+
+def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
+ is_stage_c = "clip_txt_mapper.weight" in checkpoint
+
+ if is_stage_c:
+ state_dict = {}
+ for key in checkpoint.keys():
+ if key.endswith("in_proj_weight"):
+ weights = checkpoint[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
+ elif key.endswith("in_proj_bias"):
+ weights = checkpoint[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
+ elif key.endswith("out_proj.weight"):
+ weights = checkpoint[key]
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
+ elif key.endswith("out_proj.bias"):
+ weights = checkpoint[key]
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
+ else:
+ state_dict[key] = checkpoint[key]
+ else:
+ state_dict = {}
+ for key in checkpoint.keys():
+ if key.endswith("in_proj_weight"):
+ weights = checkpoint[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
+ elif key.endswith("in_proj_bias"):
+ weights = checkpoint[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
+ elif key.endswith("out_proj.weight"):
+ weights = checkpoint[key]
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
+ elif key.endswith("out_proj.bias"):
+ weights = checkpoint[key]
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
+ # rename clip_mapper to clip_txt_pooled_mapper
+ elif key.endswith("clip_mapper.weight"):
+ weights = checkpoint[key]
+ state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
+ elif key.endswith("clip_mapper.bias"):
+ weights = checkpoint[key]
+ state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
+ else:
+ state_dict[key] = checkpoint[key]
+
+ return state_dict
+
+
+def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+ unet_key = LDM_UNET_KEY
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ logger.warning("Checkpoint has both EMA and non-EMA weights.")
+ logger.warning(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ logger.warning(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+ ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
+ for diffusers_key, ldm_key in ldm_unet_keys.items():
+ if ldm_key not in unet_state_dict:
+ continue
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
+
+ if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
+ class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
+ for diffusers_key, ldm_key in class_embed_keys.items():
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
+
+ if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
+ addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
+ for diffusers_key, ldm_key in addition_embed_keys.items():
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
+
+ # Relevant to StableDiffusionUpscalePipeline
+ if "num_class_embeds" in config:
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ # Down blocks
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ update_unet_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ unet_state_dict,
+ {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
+ )
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+ if attentions:
+ update_unet_attention_ldm_to_diffusers(
+ attentions,
+ new_checkpoint,
+ unet_state_dict,
+ {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
+ )
+
+ # Mid blocks
+ for key in middle_blocks.keys():
+ diffusers_key = max(key - 1, 0)
+ if key % 2 == 0:
+ update_unet_resnet_ldm_to_diffusers(
+ middle_blocks[key],
+ new_checkpoint,
+ unet_state_dict,
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
+ )
+ else:
+ update_unet_attention_ldm_to_diffusers(
+ middle_blocks[key],
+ new_checkpoint,
+ unet_state_dict,
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
+ )
+
+ # Up Blocks
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key
+ ]
+ update_unet_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ unet_state_dict,
+ {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"},
+ )
+
+ attentions = [
+ key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key
+ ]
+ if attentions:
+ update_unet_attention_ldm_to_diffusers(
+ attentions,
+ new_checkpoint,
+ unet_state_dict,
+ {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"},
+ )
+
+ if f"output_blocks.{i}.1.conv.weight" in unet_state_dict:
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.1.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.1.conv.bias"
+ ]
+ if f"output_blocks.{i}.2.conv.weight" in unet_state_dict:
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.2.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.2.conv.bias"
+ ]
+
+ return new_checkpoint
+
+
+def convert_controlnet_checkpoint(
+ checkpoint,
+ config,
+ **kwargs,
+):
+ # Return checkpoint if it's already been converted
+ if "time_embedding.linear_1.weight" in checkpoint:
+ return checkpoint
+ # Some controlnet ckpt files are distributed independently from the rest of the
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
+ if "time_embed.0.weight" in checkpoint:
+ controlnet_state_dict = checkpoint
+
+ else:
+ controlnet_state_dict = {}
+ keys = list(checkpoint.keys())
+ controlnet_key = LDM_CONTROLNET_KEY
+ for key in keys:
+ if key.startswith(controlnet_key):
+ controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+ ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
+ for diffusers_key, ldm_key in ldm_controlnet_keys.items():
+ if ldm_key not in controlnet_state_dict:
+ continue
+ new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len(
+ {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
+ )
+ input_blocks = {
+ layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Down blocks
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ update_unet_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ controlnet_state_dict,
+ {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
+ )
+
+ if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+ if attentions:
+ update_unet_attention_ldm_to_diffusers(
+ attentions,
+ new_checkpoint,
+ controlnet_state_dict,
+ {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
+ )
+
+ # controlnet down blocks
+ for i in range(num_input_blocks):
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len(
+ {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
+ )
+ middle_blocks = {
+ layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Mid blocks
+ for key in middle_blocks.keys():
+ diffusers_key = max(key - 1, 0)
+ if key % 2 == 0:
+ update_unet_resnet_ldm_to_diffusers(
+ middle_blocks[key],
+ new_checkpoint,
+ controlnet_state_dict,
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
+ )
+ else:
+ update_unet_attention_ldm_to_diffusers(
+ middle_blocks[key],
+ new_checkpoint,
+ controlnet_state_dict,
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
+ )
+
+ # mid block
+ new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
+ new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
+
+ # controlnet cond embedding blocks
+ cond_embedding_blocks = {
+ ".".join(layer.split(".")[:2])
+ for layer in controlnet_state_dict
+ if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
+ }
+ num_cond_embedding_blocks = len(cond_embedding_blocks)
+
+ for idx in range(1, num_cond_embedding_blocks + 1):
+ diffusers_idx = idx - 1
+ cond_block_id = 2 * idx
+
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
+ f"input_hint_block.{cond_block_id}.weight"
+ )
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
+ f"input_hint_block.{cond_block_id}.bias"
+ )
+
+ return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
+ vae_state_dict = {}
+ keys = list(checkpoint.keys())
+ vae_key = ""
+ for ldm_vae_key in LDM_VAE_KEYS:
+ if any(k.startswith(ldm_vae_key) for k in keys):
+ vae_key = ldm_vae_key
+
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+ vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"]
+ for diffusers_key, ldm_key in vae_diffusers_ldm_map.items():
+ if ldm_key not in vae_state_dict:
+ continue
+ new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len(config["down_block_types"])
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
+ )
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len(config["up_block_types"])
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
+ )
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+ conv_attn_to_linear(new_checkpoint)
+
+ return new_checkpoint
+
+
+def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
+ keys = list(checkpoint.keys())
+ text_model_dict = {}
+
+ remove_prefixes = []
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
+ if remove_prefix:
+ remove_prefixes.append(remove_prefix)
+
+ for key in keys:
+ for prefix in remove_prefixes:
+ if key.startswith(prefix):
+ diffusers_key = key.replace(prefix, "")
+ text_model_dict[diffusers_key] = checkpoint.get(key)
+
+ return text_model_dict
+
+
+def convert_open_clip_checkpoint(
+ text_model,
+ checkpoint,
+ prefix="cond_stage_model.model.",
+):
+ text_model_dict = {}
+ text_proj_key = prefix + "text_projection"
+
+ if text_proj_key in checkpoint:
+ text_proj_dim = int(checkpoint[text_proj_key].shape[0])
+ elif hasattr(text_model.config, "hidden_size"):
+ text_proj_dim = text_model.config.hidden_size
+ else:
+ text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
+
+ keys = list(checkpoint.keys())
+ keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
+
+ openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
+ for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
+ ldm_key = prefix + ldm_key
+ if ldm_key not in checkpoint:
+ continue
+ if ldm_key in keys_to_ignore:
+ continue
+ if ldm_key.endswith("text_projection"):
+ text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
+ else:
+ text_model_dict[diffusers_key] = checkpoint[ldm_key]
+
+ for key in keys:
+ if key in keys_to_ignore:
+ continue
+
+ if not key.startswith(prefix + "transformer."):
+ continue
+
+ diffusers_key = key.replace(prefix + "transformer.", "")
+ transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
+ for new_key, old_key in transformer_diffusers_to_ldm_map.items():
+ diffusers_key = (
+ diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
+ )
+
+ if key.endswith(".in_proj_weight"):
+ weight_value = checkpoint.get(key)
+
+ text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
+ text_model_dict[diffusers_key + ".k_proj.weight"] = (
+ weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
+ )
+ text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
+
+ elif key.endswith(".in_proj_bias"):
+ weight_value = checkpoint.get(key)
+ text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
+ text_model_dict[diffusers_key + ".k_proj.bias"] = (
+ weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
+ )
+ text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
+ else:
+ text_model_dict[diffusers_key] = checkpoint.get(key)
+
+ return text_model_dict
+
+
+def create_diffusers_clip_model_from_ldm(
+ cls,
+ checkpoint,
+ subfolder="",
+ config=None,
+ torch_dtype=None,
+ local_files_only=None,
+ is_legacy_loading=False,
+):
+ if config:
+ config = {"pretrained_model_name_or_path": config}
+ else:
+ config = fetch_diffusers_config(checkpoint)
+
+ # For backwards compatibility
+ # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
+ # in the cache_dir, rather than in a subfolder of the Diffusers model
+ if is_legacy_loading:
+ logger.warning(
+ (
+ "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
+ "the local cache directory with the necessary CLIP model config files. "
+ "Attempting to load CLIP model from legacy cache directory."
+ )
+ )
+
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
+ clip_config = "openai/clip-vit-large-patch14"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = ""
+
+ elif is_open_clip_model(checkpoint):
+ clip_config = "stabilityai/stable-diffusion-2"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = "text_encoder"
+
+ else:
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = ""
+
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ model = cls(model_config)
+
+ position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
+
+ if is_clip_model(checkpoint):
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
+
+ elif (
+ is_clip_sdxl_model(checkpoint)
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
+ ):
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
+
+ elif (
+ is_clip_sd3_model(checkpoint)
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
+ ):
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
+
+ elif is_open_clip_model(checkpoint):
+ prefix = "cond_stage_model.model."
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
+
+ elif (
+ is_open_clip_sdxl_model(checkpoint)
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
+ ):
+ prefix = "conditioner.embedders.1.model."
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
+
+ elif is_open_clip_sdxl_refiner_model(checkpoint):
+ prefix = "conditioner.embedders.0.model."
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
+
+ elif (
+ is_open_clip_sd3_model(checkpoint)
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
+ ):
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
+
+ else:
+ raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
+
+ if is_accelerate_available():
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ else:
+ model.load_state_dict(diffusers_format_checkpoint, strict=False)
+
+ if torch_dtype is not None:
+ model.to(torch_dtype)
+
+ model.eval()
+
+ return model
+
+
+def _legacy_load_scheduler(
+ cls,
+ checkpoint,
+ component_name,
+ original_config=None,
+ **kwargs,
+):
+ scheduler_type = kwargs.get("scheduler_type", None)
+ prediction_type = kwargs.get("prediction_type", None)
+
+ if scheduler_type is not None:
+ deprecation_message = (
+ "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
+ "Example:\n\n"
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
+ "scheduler = DDIMScheduler()\n"
+ "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n"
+ )
+ deprecate("scheduler_type", "1.0.0", deprecation_message)
+
+ if prediction_type is not None:
+ deprecation_message = (
+ "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
+ "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
+ "Example:\n\n"
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
+ 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
+ "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n"
+ )
+ deprecate("prediction_type", "1.0.0", deprecation_message)
+
+ scheduler_config = SCHEDULER_DEFAULT_CONFIG
+ model_type = infer_diffusers_model_type(checkpoint=checkpoint)
+
+ global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
+
+ if original_config:
+ num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
+ else:
+ num_train_timesteps = 1000
+
+ scheduler_config["num_train_timesteps"] = num_train_timesteps
+
+ if model_type == "v2":
+ if prediction_type is None:
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
+
+ else:
+ prediction_type = prediction_type or "epsilon"
+
+ scheduler_config["prediction_type"] = prediction_type
+
+ if model_type in ["xl_base", "xl_refiner"]:
+ scheduler_type = "euler"
+ elif model_type == "playground":
+ scheduler_type = "edm_dpm_solver_multistep"
+ else:
+ if original_config:
+ beta_start = original_config["model"]["params"].get("linear_start")
+ beta_end = original_config["model"]["params"].get("linear_end")
+
+ else:
+ beta_start = 0.02
+ beta_end = 0.085
+
+ scheduler_config["beta_start"] = beta_start
+ scheduler_config["beta_end"] = beta_end
+ scheduler_config["beta_schedule"] = "scaled_linear"
+ scheduler_config["clip_sample"] = False
+ scheduler_config["set_alpha_to_one"] = False
+
+ # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
+ if component_name == "low_res_scheduler":
+ return cls.from_config(
+ {
+ "beta_end": 0.02,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.0001,
+ "clip_sample": True,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "trained_betas": None,
+ "variance_type": "fixed_small",
+ }
+ )
+
+ if scheduler_type is None:
+ return cls.from_config(scheduler_config)
+
+ elif scheduler_type == "pndm":
+ scheduler_config["skip_prk_steps"] = True
+ scheduler = PNDMScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "lms":
+ scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "heun":
+ scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "euler":
+ scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "euler-ancestral":
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "dpm":
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "ddim":
+ scheduler = DDIMScheduler.from_config(scheduler_config)
+
+ elif scheduler_type == "edm_dpm_solver_multistep":
+ scheduler_config = {
+ "algorithm_type": "dpmsolver++",
+ "dynamic_thresholding_ratio": 0.995,
+ "euler_at_final": False,
+ "final_sigmas_type": "zero",
+ "lower_order_final": True,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rho": 7.0,
+ "sample_max_value": 1.0,
+ "sigma_data": 0.5,
+ "sigma_max": 80.0,
+ "sigma_min": 0.002,
+ "solver_order": 2,
+ "solver_type": "midpoint",
+ "thresholding": False,
+ }
+ scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
+
+ else:
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
+
+ return scheduler
+
+
+def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
+ if config:
+ config = {"pretrained_model_name_or_path": config}
+ else:
+ config = fetch_diffusers_config(checkpoint)
+
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
+ clip_config = "openai/clip-vit-large-patch14"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = ""
+
+ elif is_open_clip_model(checkpoint):
+ clip_config = "stabilityai/stable-diffusion-2"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = "tokenizer"
+
+ else:
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+ config["pretrained_model_name_or_path"] = clip_config
+ subfolder = ""
+
+ tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
+
+ return tokenizer
+
+
+def _legacy_load_safety_checker(local_files_only, torch_dtype):
+ # Support for loading safety checker components using the deprecated
+ # `load_safety_checker` argument.
+
+ from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+ feature_extractor = AutoImageProcessor.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
+ )
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
+ )
+
+ return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def swap_proj_gate(weight):
+ proj, gate = weight.chunk(2, dim=0)
+ new_weight = torch.cat([gate, proj], dim=0)
+ return new_weight
+
+
+def get_attn2_layers(state_dict):
+ attn2_layers = []
+ for key in state_dict.keys():
+ if "attn2." in key:
+ # Extract the layer number from the key
+ layer_num = int(key.split(".")[1])
+ attn2_layers.append(layer_num)
+
+ return tuple(sorted(set(attn2_layers)))
+
+
+def get_caption_projection_dim(state_dict):
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
+ return caption_projection_dim
+
+
+def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
+ dual_attention_layers = get_attn2_layers(checkpoint)
+
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
+
+ # Positional and patch embeddings.
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Timestep embeddings.
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+
+ # Context projections.
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
+
+ # Pooled context projection.
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
+
+ # Transformer blocks 🎸.
+ for i in range(num_layers):
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
+ )
+ context_q, context_k, context_v = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
+ )
+
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
+
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
+ )
+
+ # output projections.
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
+ )
+ if not (i == num_layers - 1):
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
+ )
+
+ if i in dual_attention_layers:
+ # Q, K, V
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
+ )
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
+
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
+ )
+
+ # output projections.
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
+ )
+
+ # norms.
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
+ )
+ if not (i == num_layers - 1):
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
+ )
+ else:
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
+ dim=caption_projection_dim,
+ )
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
+ dim=caption_projection_dim,
+ )
+
+ # ffs.
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
+ )
+ if not (i == num_layers - 1):
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
+ )
+
+ # Final blocks.
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
+ )
+
+ return converted_state_dict
+
+
+def is_t5_in_single_file(checkpoint):
+ if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
+ return True
+
+ return False
+
+
+def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
+ keys = list(checkpoint.keys())
+ text_model_dict = {}
+
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
+
+ for key in keys:
+ for prefix in remove_prefixes:
+ if key.startswith(prefix):
+ diffusers_key = key.replace(prefix, "")
+ text_model_dict[diffusers_key] = checkpoint.get(key)
+
+ return text_model_dict
+
+
+def create_diffusers_t5_model_from_checkpoint(
+ cls,
+ checkpoint,
+ subfolder="",
+ config=None,
+ torch_dtype=None,
+ local_files_only=None,
+):
+ if config:
+ config = {"pretrained_model_name_or_path": config}
+ else:
+ config = fetch_diffusers_config(checkpoint)
+
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ model = cls(model_config)
+
+ diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
+
+ if is_accelerate_available():
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ else:
+ model.load_state_dict(diffusers_format_checkpoint)
+
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = model._keep_in_fp32_modules
+ else:
+ keep_in_fp32_modules = []
+
+ if keep_in_fp32_modules is not None:
+ for name, param in model.named_parameters():
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
+ # param = param.to(torch.float32) does not work here as only in the local scope.
+ param.data = param.data.to(torch.float32)
+
+ return model
+
+
+def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ for k, v in checkpoint.items():
+ if "pos_encoder" in k:
+ continue
+
+ else:
+ converted_state_dict[
+ k.replace(".norms.0", ".norm1")
+ .replace(".norms.1", ".norm2")
+ .replace(".ff_norm", ".norm3")
+ .replace(".attention_blocks.0", ".attn1")
+ .replace(".attention_blocks.1", ".attn2")
+ .replace(".temporal_transformer", "")
+ ] = v
+
+ return converted_state_dict
+
+
+def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
+ mlp_ratio = 4.0
+ inner_dim = 3072
+
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
+ "time_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
+ "time_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
+
+ ## time_text_embed.text_embedder <- vector_in
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
+ "vector_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
+
+ # guidance
+ has_guidance = any("guidance" in k for k in checkpoint)
+ if has_guidance:
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
+ "guidance_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
+ "guidance_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
+ "guidance_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
+ "guidance_in.out_layer.bias"
+ )
+
+ # context_embedder
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # norms.
+ ## norm1
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.img_mod.lin.bias"
+ )
+ ## norm1_context
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mod.lin.bias"
+ )
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
+ context_q, context_k, context_v = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.2.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transfomer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
+ f"single_blocks.{i}.modulation.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
+ f"single_blocks.{i}.modulation.lin.bias"
+ )
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
+
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias")
+ )
+
+ return converted_state_dict
+
+
+def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "model.diffusion_model.": "",
+ "patchify_proj": "proj_in",
+ "adaln_single": "time_embed",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
+
+ def remove_keys_(key: str, state_dict):
+ state_dict.pop(key)
+
+ VAE_KEYS_RENAME_DICT = {
+ # common
+ "vae.": "",
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0",
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
+ "up_blocks.3": "up_blocks.1",
+ "up_blocks.4": "up_blocks.2.conv_in",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.conv_in",
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
+ "up_blocks.9": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.0.conv_out",
+ "down_blocks.3": "down_blocks.1",
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
+ "down_blocks.5": "down_blocks.1.conv_out",
+ "down_blocks.6": "down_blocks.2",
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
+ "down_blocks.8": "down_blocks.3",
+ "down_blocks.9": "mid_block",
+ # common
+ "conv_shortcut": "conv_shortcut.conv",
+ "res_blocks": "resnets",
+ "norm3.norm": "norm3",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+ }
+
+ VAE_091_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+ }
+
+ VAE_095_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.1",
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
+ "down_blocks.4": "down_blocks.2",
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
+ "down_blocks.6": "down_blocks.3",
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
+ "down_blocks.8": "mid_block",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+ }
+
+ VAE_SPECIAL_KEYS_REMAP = {
+ "per_channel_statistics.channel": remove_keys_,
+ "per_channel_statistics.mean-of-means": remove_keys_,
+ "per_channel_statistics.mean-of-stds": remove_keys_,
+ }
+
+ if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ def remap_qkv_(key: str, state_dict):
+ qkv = state_dict.pop(key)
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
+
+ def remap_proj_conv_(key: str, state_dict):
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
+
+ AE_KEYS_RENAME_DICT = {
+ # common
+ "main.": "",
+ "op_list.": "",
+ "context_module": "attn",
+ "local_module": "conv_out",
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
+ "depth_conv.conv": "conv_depth",
+ "inverted_conv.conv": "conv_inverted",
+ "point_conv.conv": "conv_point",
+ "point_conv.norm": "norm",
+ "conv.conv.": "conv.",
+ "conv1.conv": "conv1",
+ "conv2.conv": "conv2",
+ "conv2.norm": "norm",
+ "proj.norm": "norm_out",
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in",
+ "encoder.project_out.0.conv": "encoder.conv_out",
+ "encoder.stages": "encoder.down_blocks",
+ # decoder
+ "decoder.project_in.conv": "decoder.conv_in",
+ "decoder.project_out.0": "decoder.norm_out",
+ "decoder.project_out.2.conv": "decoder.conv_out",
+ "decoder.stages": "decoder.up_blocks",
+ }
+
+ AE_F32C32_F64C128_F128C512_KEYS = {
+ "encoder.project_in.conv": "encoder.conv_in.conv",
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
+ }
+
+ AE_SPECIAL_KEYS_REMAP = {
+ "qkv.conv.weight": remap_qkv_,
+ "proj.conv.weight": remap_proj_conv_,
+ }
+ if "encoder.project_in.conv.bias" not in converted_state_dict:
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Comfy checkpoints add this prefix
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ # Convert patch_embed
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Convert time_embed
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+ converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
+ converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
+ converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
+ converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
+ converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
+ converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
+ converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
+ converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
+
+ # Convert transformer blocks
+ num_layers = 48
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"blocks.{i}."
+
+ # norm1
+ converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
+ converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
+ old_prefix + "mod_y.weight"
+ )
+ converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
+ old_prefix + "mod_y.bias"
+ )
+ else:
+ converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
+ old_prefix + "mod_y.weight"
+ )
+ converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
+ old_prefix + "mod_y.bias"
+ )
+
+ # Visual attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
+ old_prefix + "attn.q_norm_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
+ old_prefix + "attn.k_norm_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
+ old_prefix + "attn.proj_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
+
+ # Context attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
+ converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
+ converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
+ converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
+ old_prefix + "attn.q_norm_y.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
+ old_prefix + "attn.k_norm_y.weight"
+ )
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
+ old_prefix + "attn.proj_y.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
+ old_prefix + "attn.proj_y.bias"
+ )
+
+ # MLP
+ converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
+ )
+ converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
+ )
+ converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
+ old_prefix + "mlp_y.w2.weight"
+ )
+
+ # Output layers
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+
+ converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
+
+ return converted_state_dict
+
+
+def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
+ def remap_norm_scale_shift_(key, state_dict):
+ weight = state_dict.pop(key)
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
+
+ def remap_txt_in_(key, state_dict):
+ def rename_key(key):
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
+ new_key = new_key.replace("txt_in", "context_embedder")
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
+ new_key = new_key.replace("mlp", "ff")
+ return new_key
+
+ if "self_attn_qkv" in key:
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
+ else:
+ state_dict[rename_key(key)] = state_dict.pop(key)
+
+ def remap_img_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
+
+ def remap_txt_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
+
+ def remap_single_transformer_blocks_(key, state_dict):
+ hidden_size = 3072
+
+ if "linear1.weight" in key:
+ linear1_weight = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
+
+ elif "linear1.bias" in key:
+ linear1_bias = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
+
+ else:
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
+ new_key = new_key.replace("linear2", "proj_out")
+ new_key = new_key.replace("q_norm", "attn.norm_q")
+ new_key = new_key.replace("k_norm", "attn.norm_k")
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "img_in": "x_embedder",
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
+ "double_blocks": "transformer_blocks",
+ "img_attn_q_norm": "attn.norm_q",
+ "img_attn_k_norm": "attn.norm_k",
+ "img_attn_proj": "attn.to_out.0",
+ "txt_attn_q_norm": "attn.norm_added_q",
+ "txt_attn_k_norm": "attn.norm_added_k",
+ "txt_attn_proj": "attn.to_add_out",
+ "img_mod.linear": "norm1.linear",
+ "img_norm1": "norm1.norm",
+ "img_norm2": "norm2",
+ "img_mlp": "ff",
+ "txt_mod.linear": "norm1_context.linear",
+ "txt_norm1": "norm1.norm",
+ "txt_norm2": "norm2_context",
+ "txt_mlp": "ff_context",
+ "self_attn_proj": "attn.to_out.0",
+ "modulation.linear": "norm.linear",
+ "pre_norm": "norm.norm",
+ "final_layer.norm_final": "norm_out.norm",
+ "final_layer.linear": "proj_out",
+ "fc1": "net.0.proj",
+ "fc2": "net.2",
+ "input_embedder": "proj_in",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "txt_in": remap_txt_in_,
+ "img_attn_qkv": remap_img_attn_qkv_,
+ "txt_attn_qkv": remap_txt_attn_qkv_,
+ "single_blocks": remap_single_transformer_blocks_,
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
+ }
+
+ def update_state_dict_(state_dict, old_key, new_key):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ for key in list(checkpoint.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(checkpoint, key, new_key)
+
+ for key in list(checkpoint.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, checkpoint)
+
+ return checkpoint
+
+
+def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ state_dict_keys = list(checkpoint.keys())
+
+ # Handle register tokens and positional embeddings
+ converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
+
+ # Handle time step projection
+ converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
+ converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
+ converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
+ converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
+
+ # Handle context embedder
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
+
+ # Calculate the number of layers
+ def calculate_layers(keys, key_prefix):
+ layers = set()
+ for k in keys:
+ if key_prefix in k:
+ layer_num = int(k.split(".")[1]) # get the layer number
+ layers.add(layer_num)
+ return len(layers)
+
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ # MMDiT blocks
+ for i in range(mmdit_layers):
+ # Feed-forward
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for orig_k, diffuser_k in path_mapping.items():
+ for k, v in weight_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.{k}.weight", None
+ )
+
+ # Norms
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
+ for orig_k, diffuser_k in path_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
+ for k, v in attn_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.attn.{k}.weight", None
+ )
+
+ # Single-DiT blocks
+ for i in range(single_dit_layers):
+ # Feed-forward
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for k, v in mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.mlp.{k}.weight", None
+ )
+
+ # Norms
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
+ f"single_layers.{i}.modCX.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
+ for k, v in x_attn_mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.attn.{k}.weight", None
+ )
+ # Final blocks
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
+
+ # Handle the final norm layer
+ norm_weight = checkpoint.pop("modF.1.weight", None)
+ if norm_weight is not None:
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
+ else:
+ converted_state_dict["norm_out.linear.weight"] = None
+
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
+
+ return converted_state_dict
+
+
+def convert_lumina2_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Original Lumina-Image-2 has an extra norm paramter that is unused
+ # We just remove it here
+ checkpoint.pop("norm_final.weight", None)
+
+ # Comfy checkpoints add this prefix
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ LUMINA_KEY_MAP = {
+ "cap_embedder": "time_caption_embed.caption_embedder",
+ "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
+ "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
+ "attention": "attn",
+ ".out.": ".to_out.0.",
+ "k_norm": "norm_k",
+ "q_norm": "norm_q",
+ "w1": "linear_1",
+ "w2": "linear_2",
+ "w3": "linear_3",
+ "adaLN_modulation.1": "norm1.linear",
+ }
+ ATTENTION_NORM_MAP = {
+ "attention_norm1": "norm1.norm",
+ "attention_norm2": "norm2",
+ }
+ CONTEXT_REFINER_MAP = {
+ "context_refiner.0.attention_norm1": "context_refiner.0.norm1",
+ "context_refiner.0.attention_norm2": "context_refiner.0.norm2",
+ "context_refiner.1.attention_norm1": "context_refiner.1.norm1",
+ "context_refiner.1.attention_norm2": "context_refiner.1.norm2",
+ }
+ FINAL_LAYER_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.linear": "norm_out.linear_2",
+ }
+
+ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
+ q_dim = 2304
+ k_dim = v_dim = 768
+
+ to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
+
+ return {
+ diffusers_key.replace("qkv", "to_q"): to_q,
+ diffusers_key.replace("qkv", "to_k"): to_k,
+ diffusers_key.replace("qkv", "to_v"): to_v,
+ }
+
+ for key in keys:
+ diffusers_key = key
+ for k, v in CONTEXT_REFINER_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in FINAL_LAYER_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in ATTENTION_NORM_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in LUMINA_KEY_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+
+ if "qkv" in diffusers_key:
+ converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
+ else:
+ converted_state_dict[diffusers_key] = checkpoint.pop(key)
+
+ return converted_state_dict
+
+
+def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
+
+ # Positional and patch embeddings.
+ checkpoint.pop("pos_embed")
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Timestep embeddings.
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+ converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
+
+ # Caption Projection.
+ checkpoint.pop("y_embedder.y_embedding")
+ converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
+ converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
+
+ for i in range(num_layers):
+ converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
+ f"blocks.{i}.scale_shift_table"
+ )
+
+ # Self-Attention
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
+
+ # Output Projections
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
+ f"blocks.{i}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
+ f"blocks.{i}.attn.proj.bias"
+ )
+
+ # Cross-Attention
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.q_linear.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.q_linear.bias"
+ )
+
+ linear_sample_k, linear_sample_v = torch.chunk(
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
+ )
+ linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
+
+ # Output Projections
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.proj.bias"
+ )
+
+ # MLP
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
+ f"blocks.{i}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
+ f"blocks.{i}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.point_conv.conv.weight"
+ )
+
+ # Final layer
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
+
+ return converted_state_dict
+
+
+def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "cross_attn": "attn2",
+ "self_attn": "attn1",
+ ".o.": ".to_out.0.",
+ ".q.": ".to_q.",
+ ".k.": ".to_k.",
+ ".v.": ".to_v.",
+ ".k_img.": ".add_k_proj.",
+ ".v_img.": ".add_v_proj.",
+ ".norm_k_img.": ".norm_added_k.",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ }
+
+ for key in list(checkpoint.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ converted_state_dict[new_key] = checkpoint.pop(key)
+
+ return converted_state_dict
+
+
+def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in checkpoint.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ converted_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ converted_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ converted_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ converted_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ converted_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ converted_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ converted_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ converted_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ converted_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ converted_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ converted_state_dict[key] = value
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index a2f27b765a1b..4eeab1679e66 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -11,430 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib
-import inspect
-import re
-from contextlib import nullcontext
-from typing import Optional
-import torch
-from huggingface_hub.utils import validate_hf_hub_args
-from typing_extensions import Self
-from .. import __version__
-from ..quantizers import DiffusersAutoQuantizer
-from ..utils import deprecate, is_accelerate_available, logging
-from .single_file_utils import (
- SingleFileComponentError,
- convert_animatediff_checkpoint_to_diffusers,
- convert_auraflow_transformer_checkpoint_to_diffusers,
- convert_autoencoder_dc_checkpoint_to_diffusers,
- convert_controlnet_checkpoint,
- convert_flux_transformer_checkpoint_to_diffusers,
- convert_hunyuan_video_transformer_to_diffusers,
- convert_ldm_unet_checkpoint,
- convert_ldm_vae_checkpoint,
- convert_ltx_transformer_checkpoint_to_diffusers,
- convert_ltx_vae_checkpoint_to_diffusers,
- convert_lumina2_to_diffusers,
- convert_mochi_transformer_checkpoint_to_diffusers,
- convert_sana_transformer_to_diffusers,
- convert_sd3_transformer_checkpoint_to_diffusers,
- convert_stable_cascade_unet_single_file_to_diffusers,
- convert_wan_transformer_to_diffusers,
- convert_wan_vae_to_diffusers,
- create_controlnet_diffusers_config_from_ldm,
- create_unet_diffusers_config_from_ldm,
- create_vae_diffusers_config_from_ldm,
- fetch_diffusers_config,
- fetch_original_config,
- load_single_file_checkpoint,
+from ..utils import deprecate
+from .single_file.single_file_model import (
+ SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401
+ FromOriginalModelMixin,
)
-logger = logging.get_logger(__name__)
-
-
-if is_accelerate_available():
- from accelerate import dispatch_model, init_empty_weights
-
- from ..models.modeling_utils import load_model_dict_into_meta
-
-
-SINGLE_FILE_LOADABLE_CLASSES = {
- "StableCascadeUNet": {
- "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
- },
- "UNet2DConditionModel": {
- "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
- "config_mapping_fn": create_unet_diffusers_config_from_ldm,
- "default_subfolder": "unet",
- "legacy_kwargs": {
- "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
- },
- },
- "AutoencoderKL": {
- "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
- "config_mapping_fn": create_vae_diffusers_config_from_ldm,
- "default_subfolder": "vae",
- },
- "ControlNetModel": {
- "checkpoint_mapping_fn": convert_controlnet_checkpoint,
- "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
- },
- "SD3Transformer2DModel": {
- "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
- "default_subfolder": "transformer",
- },
- "MotionAdapter": {
- "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
- },
- "SparseControlNetModel": {
- "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
- },
- "FluxTransformer2DModel": {
- "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
- "default_subfolder": "transformer",
- },
- "LTXVideoTransformer3DModel": {
- "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
- "default_subfolder": "transformer",
- },
- "AutoencoderKLLTXVideo": {
- "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
- "default_subfolder": "vae",
- },
- "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
- "MochiTransformer3DModel": {
- "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
- "default_subfolder": "transformer",
- },
- "HunyuanVideoTransformer3DModel": {
- "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
- "default_subfolder": "transformer",
- },
- "AuraFlowTransformer2DModel": {
- "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
- "default_subfolder": "transformer",
- },
- "Lumina2Transformer2DModel": {
- "checkpoint_mapping_fn": convert_lumina2_to_diffusers,
- "default_subfolder": "transformer",
- },
- "SanaTransformer2DModel": {
- "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
- "default_subfolder": "transformer",
- },
- "WanTransformer3DModel": {
- "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
- "default_subfolder": "transformer",
- },
- "AutoencoderKLWan": {
- "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
- "default_subfolder": "vae",
- },
-}
-
-
-def _get_single_file_loadable_mapping_class(cls):
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
- loadable_class = getattr(diffusers_module, loadable_class_str)
-
- if issubclass(cls, loadable_class):
- return loadable_class_str
-
- return None
-
-
-def _get_mapping_function_kwargs(mapping_fn, **kwargs):
- parameters = inspect.signature(mapping_fn).parameters
-
- mapping_kwargs = {}
- for parameter in parameters:
- if parameter in kwargs:
- mapping_kwargs[parameter] = kwargs[parameter]
-
- return mapping_kwargs
-
-
-class FromOriginalModelMixin:
- """
- Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
- """
-
- @classmethod
- @validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
- r"""
- Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
- is set in evaluation mode (`model.eval()`) by default.
-
- Parameters:
- pretrained_model_link_or_path_or_dict (`str`, *optional*):
- Can be either:
- - A link to the `.safetensors` or `.ckpt` file (for example
- `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
- - A path to a local *file* containing the weights of the component model.
- - A state dict containing the component model weights.
- config (`str`, *optional*):
- - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
- on the Hub.
- - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
- configs in Diffusers format.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- original_config (`str`, *optional*):
- Dict or path to a yaml file containing the configuration for the model in its original format.
- If a dict is provided, it will be used to initialize the model configuration.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to True, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- disable_mmap ('bool', *optional*, defaults to 'False'):
- Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
- is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to overwrite load and saveable variables (for example the pipeline components of the
- specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
- method. See example below for more information.
-
- ```py
- >>> from diffusers import StableCascadeUNet
-
- >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
- >>> model = StableCascadeUNet.from_single_file(ckpt_path)
- ```
- """
-
- mapping_class_name = _get_single_file_loadable_mapping_class(cls)
- # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
- if mapping_class_name is None:
- raise ValueError(
- f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
- )
-
- pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
- if pretrained_model_link_or_path is not None:
- deprecation_message = (
- "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
- )
- deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
- pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
-
- config = kwargs.pop("config", None)
- original_config = kwargs.pop("original_config", None)
-
- if config is not None and original_config is not None:
- raise ValueError(
- "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
- )
-
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- cache_dir = kwargs.pop("cache_dir", None)
- local_files_only = kwargs.pop("local_files_only", None)
- subfolder = kwargs.pop("subfolder", None)
- revision = kwargs.pop("revision", None)
- config_revision = kwargs.pop("config_revision", None)
- torch_dtype = kwargs.pop("torch_dtype", None)
- quantization_config = kwargs.pop("quantization_config", None)
- device = kwargs.pop("device", None)
- disable_mmap = kwargs.pop("disable_mmap", False)
-
- user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
- # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
- if quantization_config is not None:
- user_agent["quant"] = quantization_config.quant_method.value
-
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
- torch_dtype = torch.float32
- logger.warning(
- f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
- )
-
- if isinstance(pretrained_model_link_or_path_or_dict, dict):
- checkpoint = pretrained_model_link_or_path_or_dict
- else:
- checkpoint = load_single_file_checkpoint(
- pretrained_model_link_or_path_or_dict,
- force_download=force_download,
- proxies=proxies,
- token=token,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- revision=revision,
- disable_mmap=disable_mmap,
- user_agent=user_agent,
- )
- if quantization_config is not None:
- hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
- hf_quantizer.validate_environment()
- torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
-
- else:
- hf_quantizer = None
-
- mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
-
- checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
- if original_config is not None:
- if "config_mapping_fn" in mapping_functions:
- config_mapping_fn = mapping_functions["config_mapping_fn"]
- else:
- config_mapping_fn = None
-
- if config_mapping_fn is None:
- raise ValueError(
- (
- f"`original_config` has been provided for {mapping_class_name} but no mapping function"
- "was found to convert the original config to a Diffusers config in"
- "`diffusers.loaders.single_file_utils`"
- )
- )
-
- if isinstance(original_config, str):
- # If original_config is a URL or filepath fetch the original_config dict
- original_config = fetch_original_config(original_config, local_files_only=local_files_only)
-
- config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
- diffusers_model_config = config_mapping_fn(
- original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
- )
- else:
- if config is not None:
- if isinstance(config, str):
- default_pretrained_model_config_name = config
- else:
- raise ValueError(
- (
- "Invalid `config` argument. Please provide a string representing a repo id"
- "or path to a local Diffusers model repo."
- )
- )
-
- else:
- config = fetch_diffusers_config(checkpoint)
- default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
-
- if "default_subfolder" in mapping_functions:
- subfolder = mapping_functions["default_subfolder"]
-
- subfolder = subfolder or config.pop(
- "subfolder", None
- ) # some configs contain a subfolder key, e.g. StableCascadeUNet
-
- diffusers_model_config = cls.load_config(
- pretrained_model_name_or_path=default_pretrained_model_config_name,
- subfolder=subfolder,
- local_files_only=local_files_only,
- token=token,
- revision=config_revision,
- )
- expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
-
- # Map legacy kwargs to new kwargs
- if "legacy_kwargs" in mapping_functions:
- legacy_kwargs = mapping_functions["legacy_kwargs"]
- for legacy_key, new_key in legacy_kwargs.items():
- if legacy_key in kwargs:
- kwargs[new_key] = kwargs.pop(legacy_key)
-
- model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
- diffusers_model_config.update(model_kwargs)
-
- checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
- diffusers_format_checkpoint = checkpoint_mapping_fn(
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
- )
- if not diffusers_format_checkpoint:
- raise SingleFileComponentError(
- f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
- )
-
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
- with ctx():
- model = cls.from_config(diffusers_model_config)
-
- # Check if `_keep_in_fp32_modules` is not None
- use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
- (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
- )
- if use_keep_in_fp32_modules:
- keep_in_fp32_modules = cls._keep_in_fp32_modules
- if not isinstance(keep_in_fp32_modules, list):
- keep_in_fp32_modules = [keep_in_fp32_modules]
-
- else:
- keep_in_fp32_modules = []
-
- if hf_quantizer is not None:
- hf_quantizer.preprocess_model(
- model=model,
- device_map=None,
- state_dict=diffusers_format_checkpoint,
- keep_in_fp32_modules=keep_in_fp32_modules,
- )
-
- device_map = None
- if is_accelerate_available():
- param_device = torch.device(device) if device else torch.device("cpu")
- empty_state_dict = model.state_dict()
- unexpected_keys = [
- param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
- ]
- device_map = {"": param_device}
- load_model_dict_into_meta(
- model,
- diffusers_format_checkpoint,
- dtype=torch_dtype,
- device_map=device_map,
- hf_quantizer=hf_quantizer,
- keep_in_fp32_modules=keep_in_fp32_modules,
- unexpected_keys=unexpected_keys,
- )
- else:
- _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
-
- if model._keys_to_ignore_on_load_unexpected is not None:
- for pat in model._keys_to_ignore_on_load_unexpected:
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
-
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
- )
-
- if hf_quantizer is not None:
- hf_quantizer.postprocess_model(model)
- model.hf_quantizer = hf_quantizer
-
- if torch_dtype is not None and hf_quantizer is None:
- model.to(torch_dtype)
-
- model.eval()
-
- if device_map is not None:
- device_map_kwargs = {"device_map": device_map}
- dispatch_model(model, **device_map_kwargs)
-
- return model
+class FromOriginalModelMixin(FromOriginalModelMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead."
+ deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index b55b1b55206e..13bb67bfaad0 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -12,388 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Conversion script for the Stable Diffusion checkpoints."""
-
-import copy
-import os
-import re
-from contextlib import nullcontext
-from io import BytesIO
-from urllib.parse import urlparse
-
-import requests
-import torch
-import yaml
-
-from ..models.modeling_utils import load_state_dict
-from ..schedulers import (
- DDIMScheduler,
- DPMSolverMultistepScheduler,
- EDMDPMSolverMultistepScheduler,
- EulerAncestralDiscreteScheduler,
- EulerDiscreteScheduler,
- HeunDiscreteScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
-)
-from ..utils import (
- SAFETENSORS_WEIGHTS_NAME,
- WEIGHTS_NAME,
- deprecate,
- is_accelerate_available,
- is_transformers_available,
- logging,
-)
-from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
-from ..utils.hub_utils import _get_model_file
-
-
-if is_transformers_available():
- from transformers import AutoImageProcessor
-
-if is_accelerate_available():
- from accelerate import init_empty_weights
-
- from ..models.modeling_utils import load_model_dict_into_meta
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-CHECKPOINT_KEY_NAMES = {
- "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
- "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
- "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
- "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
- "controlnet": [
- "control_model.time_embed.0.weight",
- "controlnet_cond_embedding.conv_in.weight",
- ],
- # TODO: find non-Diffusers keys for controlnet_xl
- "controlnet_xl": "add_embedding.linear_1.weight",
- "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
- "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
- "playground-v2-5": "edm_mean",
- "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
- "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
- "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
- "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
- "open_clip": "cond_stage_model.model.token_embedding.weight",
- "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
- "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
- "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
- "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
- "stable_cascade_stage_c": "clip_txt_mapper.weight",
- "sd3": [
- "joint_blocks.0.context_block.adaLN_modulation.1.bias",
- "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
- ],
- "sd35_large": [
- "joint_blocks.37.x_block.mlp.fc1.weight",
- "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
- ],
- "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
- "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
- "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
- "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
- "animatediff_rgb": "controlnet_cond_embedding.weight",
- "auraflow": [
- "double_layers.0.attn.w2q.weight",
- "double_layers.0.attn.w1q.weight",
- "cond_seq_linear.weight",
- "t_embedder.mlp.0.weight",
- ],
- "flux": [
- "double_blocks.0.img_attn.norm.key_norm.scale",
- "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
- ],
- "ltx-video": [
- "model.diffusion_model.patchify_proj.weight",
- "model.diffusion_model.transformer_blocks.27.scale_shift_table",
- "patchify_proj.weight",
- "transformer_blocks.27.scale_shift_table",
- "vae.per_channel_statistics.mean-of-means",
- ],
- "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
- "autoencoder-dc-sana": "encoder.project_in.conv.bias",
- "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
- "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
- "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
- "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
- "sana": [
- "blocks.0.cross_attn.q_linear.weight",
- "blocks.0.cross_attn.q_linear.bias",
- "blocks.0.cross_attn.kv_linear.weight",
- "blocks.0.cross_attn.kv_linear.bias",
- ],
- "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
- "wan_vae": "decoder.middle.0.residual.0.gamma",
-}
-
-DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
- "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
- "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
- "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
- "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
- "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
- "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
- "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
- "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
- "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
- "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
- "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
- "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
- "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
- "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
- "stable_cascade_stage_b_lite": {
- "pretrained_model_name_or_path": "stabilityai/stable-cascade",
- "subfolder": "decoder_lite",
- },
- "stable_cascade_stage_c": {
- "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
- "subfolder": "prior",
- },
- "stable_cascade_stage_c_lite": {
- "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
- "subfolder": "prior_lite",
- },
- "sd3": {
- "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
- },
- "sd35_large": {
- "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
- },
- "sd35_medium": {
- "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
- },
- "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
- "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
- "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
- "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
- "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
- "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
- "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
- "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
- "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
- "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
- "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
- "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
- "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
- "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
- "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
- "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
- "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
- "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
- "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
- "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
- "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
- "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
- "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
- "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
- "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
- "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
-}
-
-# Use to configure model sample size when original config is provided
-DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
- "xl_base": 1024,
- "xl_refiner": 1024,
- "xl_inpaint": 1024,
- "playground-v2-5": 1024,
- "upscale": 512,
- "inpainting": 512,
- "inpainting_v2": 512,
- "controlnet": 512,
- "instruct-pix2pix": 512,
- "v2": 768,
- "v1": 512,
-}
-
-
-DIFFUSERS_TO_LDM_MAPPING = {
- "unet": {
- "layers": {
- "time_embedding.linear_1.weight": "time_embed.0.weight",
- "time_embedding.linear_1.bias": "time_embed.0.bias",
- "time_embedding.linear_2.weight": "time_embed.2.weight",
- "time_embedding.linear_2.bias": "time_embed.2.bias",
- "conv_in.weight": "input_blocks.0.0.weight",
- "conv_in.bias": "input_blocks.0.0.bias",
- "conv_norm_out.weight": "out.0.weight",
- "conv_norm_out.bias": "out.0.bias",
- "conv_out.weight": "out.2.weight",
- "conv_out.bias": "out.2.bias",
- },
- "class_embed_type": {
- "class_embedding.linear_1.weight": "label_emb.0.0.weight",
- "class_embedding.linear_1.bias": "label_emb.0.0.bias",
- "class_embedding.linear_2.weight": "label_emb.0.2.weight",
- "class_embedding.linear_2.bias": "label_emb.0.2.bias",
- },
- "addition_embed_type": {
- "add_embedding.linear_1.weight": "label_emb.0.0.weight",
- "add_embedding.linear_1.bias": "label_emb.0.0.bias",
- "add_embedding.linear_2.weight": "label_emb.0.2.weight",
- "add_embedding.linear_2.bias": "label_emb.0.2.bias",
- },
- },
- "controlnet": {
- "layers": {
- "time_embedding.linear_1.weight": "time_embed.0.weight",
- "time_embedding.linear_1.bias": "time_embed.0.bias",
- "time_embedding.linear_2.weight": "time_embed.2.weight",
- "time_embedding.linear_2.bias": "time_embed.2.bias",
- "conv_in.weight": "input_blocks.0.0.weight",
- "conv_in.bias": "input_blocks.0.0.bias",
- "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
- "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
- "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
- "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
- },
- "class_embed_type": {
- "class_embedding.linear_1.weight": "label_emb.0.0.weight",
- "class_embedding.linear_1.bias": "label_emb.0.0.bias",
- "class_embedding.linear_2.weight": "label_emb.0.2.weight",
- "class_embedding.linear_2.bias": "label_emb.0.2.bias",
- },
- "addition_embed_type": {
- "add_embedding.linear_1.weight": "label_emb.0.0.weight",
- "add_embedding.linear_1.bias": "label_emb.0.0.bias",
- "add_embedding.linear_2.weight": "label_emb.0.2.weight",
- "add_embedding.linear_2.bias": "label_emb.0.2.bias",
- },
- },
- "vae": {
- "encoder.conv_in.weight": "encoder.conv_in.weight",
- "encoder.conv_in.bias": "encoder.conv_in.bias",
- "encoder.conv_out.weight": "encoder.conv_out.weight",
- "encoder.conv_out.bias": "encoder.conv_out.bias",
- "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
- "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
- "decoder.conv_in.weight": "decoder.conv_in.weight",
- "decoder.conv_in.bias": "decoder.conv_in.bias",
- "decoder.conv_out.weight": "decoder.conv_out.weight",
- "decoder.conv_out.bias": "decoder.conv_out.bias",
- "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
- "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
- "quant_conv.weight": "quant_conv.weight",
- "quant_conv.bias": "quant_conv.bias",
- "post_quant_conv.weight": "post_quant_conv.weight",
- "post_quant_conv.bias": "post_quant_conv.bias",
- },
- "openclip": {
- "layers": {
- "text_model.embeddings.position_embedding.weight": "positional_embedding",
- "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
- "text_model.final_layer_norm.weight": "ln_final.weight",
- "text_model.final_layer_norm.bias": "ln_final.bias",
- "text_projection.weight": "text_projection",
- },
- "transformer": {
- "text_model.encoder.layers.": "resblocks.",
- "layer_norm1": "ln_1",
- "layer_norm2": "ln_2",
- ".fc1.": ".c_fc.",
- ".fc2.": ".c_proj.",
- ".self_attn": ".attn",
- "transformer.text_model.final_layer_norm.": "ln_final.",
- "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
- "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding",
- },
- },
-}
-
-SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
- "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
- "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
- "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
- "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
- "cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
- "cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
- "cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
- "cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
- "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
- "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
- "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
- "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
- "cond_stage_model.model.text_projection",
-]
-
-# To support legacy scheduler_type argument
-SCHEDULER_DEFAULT_CONFIG = {
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "beta_end": 0.012,
- "interpolation_type": "linear",
- "num_train_timesteps": 1000,
- "prediction_type": "epsilon",
- "sample_max_value": 1.0,
- "set_alpha_to_one": False,
- "skip_prk_steps": True,
- "steps_offset": 1,
- "timestep_spacing": "leading",
-}
-
-LDM_VAE_KEYS = ["first_stage_model.", "vae."]
-LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
-PLAYGROUND_VAE_SCALING_FACTOR = 0.5
-LDM_UNET_KEY = "model.diffusion_model."
-LDM_CONTROLNET_KEY = "control_model."
-LDM_CLIP_PREFIX_TO_REMOVE = [
- "cond_stage_model.transformer.",
- "conditioner.embedders.0.transformer.",
-]
-LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
-SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
-
-VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
-
-
-class SingleFileComponentError(Exception):
- def __init__(self, message=None):
- self.message = message
- super().__init__(self.message)
-def is_valid_url(url):
- result = urlparse(url)
- if result.scheme and result.netloc:
- return True
-
- return False
-
-
-def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
- if not is_valid_url(pretrained_model_name_or_path):
- raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
-
- pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
- weights_name = None
- repo_id = (None,)
- for prefix in VALID_URL_PREFIXES:
- pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
- match = re.match(pattern, pretrained_model_name_or_path)
- if not match:
- logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
- return repo_id, weights_name
-
- repo_id = f"{match.group(1)}/{match.group(2)}"
- weights_name = match.group(3)
-
- return repo_id, weights_name
+from ..utils import deprecate
+from .single_file.single_file_utils import SingleFileComponentError
-def _is_model_weights_in_cached_folder(cached_folder, name):
- pretrained_model_name_or_path = os.path.join(cached_folder, name)
- weights_exist = False
+class SingleFileComponentError(SingleFileComponentError):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SingleFileComponentError` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_files_utils import SingleFileComponentError` instead."
+ deprecate("diffusers.loaders.single_file_utils. ", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
- for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
- weights_exist = True
- return weights_exist
+def is_valid_url(url):
+ from .single_file.single_file_utils import is_valid_url
+ deprecation_message = "Importing `is_valid_url()` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_valid_url` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_valid_url", "0.36", deprecation_message)
-def _is_legacy_scheduler_kwargs(kwargs):
- return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
+ return is_valid_url(url)
def load_single_file_checkpoint(
@@ -407,825 +45,228 @@ def load_single_file_checkpoint(
disable_mmap=False,
user_agent=None,
):
- if user_agent is None:
- user_agent = {"file_type": "single_file", "framework": "pytorch"}
-
- if os.path.isfile(pretrained_model_link_or_path):
- pretrained_model_link_or_path = pretrained_model_link_or_path
-
- else:
- repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
- pretrained_model_link_or_path = _get_model_file(
- repo_id,
- weights_name=weights_name,
- force_download=force_download,
- cache_dir=cache_dir,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- user_agent=user_agent,
- )
-
- checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
-
- # some checkpoints contain the model state dict under a "state_dict" key
- while "state_dict" in checkpoint:
- checkpoint = checkpoint["state_dict"]
-
- return checkpoint
+ from .single_file.single_file_utils import load_single_file_checkpoint
+
+ deprecation_message = "Importing `load_single_file_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import load_single_file_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.load_single_file_checkpoint", "0.36", deprecation_message)
+
+ return load_single_file_checkpoint(
+ pretrained_model_link_or_path,
+ force_download,
+ proxies,
+ token,
+ cache_dir,
+ local_files_only,
+ revision,
+ disable_mmap,
+ user_agent,
+ )
def fetch_original_config(original_config_file, local_files_only=False):
- if os.path.isfile(original_config_file):
- with open(original_config_file, "r") as fp:
- original_config_file = fp.read()
+ from .single_file.single_file_utils import fetch_original_config
- elif is_valid_url(original_config_file):
- if local_files_only:
- raise ValueError(
- "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
- "Please provide a valid local file path."
- )
+ deprecation_message = "Importing `fetch_original_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_original_config` instead."
+ deprecate("diffusers.loaders.single_file_utils.fetch_original_config", "0.36", deprecation_message)
- original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
-
- else:
- raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
-
- original_config = yaml.safe_load(original_config_file)
-
- return original_config
+ return fetch_original_config(original_config_file, local_files_only)
def is_clip_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_clip_model
+
+ deprecation_message = "Importing `is_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_clip_model", "0.36", deprecation_message)
- return False
+ return is_clip_model(checkpoint)
def is_clip_sdxl_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_clip_sdxl_model
- return False
+ deprecation_message = "Importing `is_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sdxl_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_clip_sdxl_model", "0.36", deprecation_message)
+
+ return is_clip_sdxl_model(checkpoint)
def is_clip_sd3_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_clip_sd3_model
+
+ deprecation_message = "Importing `is_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sd3_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_clip_sd3_model", "0.36", deprecation_message)
- return False
+ return is_clip_sd3_model(checkpoint)
def is_open_clip_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
- return True
+ deprecation_message = "Importing `is_open_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_open_clip_model", "0.36", deprecation_message)
- return False
+ return is_open_clip_model(checkpoint)
def is_open_clip_sdxl_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_open_clip_sdxl_model
+
+ deprecation_message = "Importing `is_open_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_model", "0.36", deprecation_message)
- return False
+ return is_open_clip_sdxl_model(checkpoint)
def is_open_clip_sd3_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_open_clip_sd3_model
- return False
+ deprecation_message = "Importing `is_open_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sd3_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_open_clip_sd3_model", "0.36", deprecation_message)
+
+ return is_open_clip_sd3_model(checkpoint)
def is_open_clip_sdxl_refiner_model(checkpoint):
- if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
- return True
+ from .single_file.single_file_utils import is_open_clip_sdxl_refiner_model
+
+ deprecation_message = "Importing `is_open_clip_sdxl_refiner_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_refiner_model` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_refiner_model", "0.36", deprecation_message)
- return False
+ return is_open_clip_sdxl_refiner_model(checkpoint)
def is_clip_model_in_single_file(class_obj, checkpoint):
- is_clip_in_checkpoint = any(
- [
- is_clip_model(checkpoint),
- is_clip_sd3_model(checkpoint),
- is_open_clip_model(checkpoint),
- is_open_clip_sdxl_model(checkpoint),
- is_open_clip_sdxl_refiner_model(checkpoint),
- is_open_clip_sd3_model(checkpoint),
- ]
- )
- if (
- class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
- ) and is_clip_in_checkpoint:
- return True
+ from .single_file.single_file_utils import is_clip_model_in_single_file
+
+ deprecation_message = "Importing `is_clip_model_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model_in_single_file` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_clip_model_in_single_file", "0.36", deprecation_message)
- return False
+ return is_clip_model_in_single_file(class_obj, checkpoint)
def infer_diffusers_model_type(checkpoint):
- if (
- CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
- ):
- if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
- model_type = "inpainting_v2"
- elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
- model_type = "xl_inpaint"
- else:
- model_type = "inpainting"
-
- elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
- model_type = "v2"
-
- elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
- model_type = "playground-v2-5"
-
- elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
- model_type = "xl_base"
-
- elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
- model_type = "xl_refiner"
-
- elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
- model_type = "upscale"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
- if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
- if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
- model_type = "controlnet_xl_large"
- elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
- model_type = "controlnet_xl_mid"
- else:
- model_type = "controlnet_xl_small"
- else:
- model_type = "controlnet"
-
- elif (
- CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
- ):
- model_type = "stable_cascade_stage_c_lite"
-
- elif (
- CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
- ):
- model_type = "stable_cascade_stage_c"
-
- elif (
- CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
- ):
- model_type = "stable_cascade_stage_b_lite"
-
- elif (
- CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
- ):
- model_type = "stable_cascade_stage_b"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
- checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
- ):
- if "model.diffusion_model.pos_embed" in checkpoint:
- key = "model.diffusion_model.pos_embed"
- else:
- key = "pos_embed"
-
- if checkpoint[key].shape[1] == 36864:
- model_type = "sd3"
- elif checkpoint[key].shape[1] == 147456:
- model_type = "sd35_medium"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
- model_type = "sd35_large"
-
- elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
- if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
- model_type = "animatediff_scribble"
-
- elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
- model_type = "animatediff_rgb"
-
- elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
- model_type = "animatediff_v2"
-
- elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
- model_type = "animatediff_sdxl_beta"
-
- elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
- model_type = "animatediff_v1"
-
- else:
- model_type = "animatediff_v3"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
- if any(
- g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
- ):
- if "model.diffusion_model.img_in.weight" in checkpoint:
- key = "model.diffusion_model.img_in.weight"
- else:
- key = "img_in.weight"
-
- if checkpoint[key].shape[1] == 384:
- model_type = "flux-fill"
- elif checkpoint[key].shape[1] == 128:
- model_type = "flux-depth"
- else:
- model_type = "flux-dev"
- else:
- model_type = "flux-schnell"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
- if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
- model_type = "ltx-video-0.9.5"
- elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
- model_type = "ltx-video-0.9.1"
- else:
- model_type = "ltx-video"
-
- elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
- encoder_key = "encoder.project_in.conv.conv.bias"
- decoder_key = "decoder.project_in.main.conv.weight"
-
- if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
- model_type = "autoencoder-dc-f32c32-sana"
-
- elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
- model_type = "autoencoder-dc-f32c32"
-
- elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
- model_type = "autoencoder-dc-f64c128"
-
- else:
- model_type = "autoencoder-dc-f128c512"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
- model_type = "mochi-1-preview"
-
- elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
- model_type = "hunyuan-video"
-
- elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
- model_type = "auraflow"
-
- elif (
- CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
- and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
- ):
- model_type = "instruct-pix2pix"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
- model_type = "lumina2"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
- model_type = "sana"
-
- elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
- if "model.diffusion_model.patch_embedding.weight" in checkpoint:
- target_key = "model.diffusion_model.patch_embedding.weight"
- else:
- target_key = "patch_embedding.weight"
-
- if checkpoint[target_key].shape[0] == 1536:
- model_type = "wan-t2v-1.3B"
- elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
- model_type = "wan-t2v-14B"
- else:
- model_type = "wan-i2v-14B"
- elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
- # All Wan models use the same VAE so we can use the same default model repo to fetch the config
- model_type = "wan-t2v-14B"
- else:
- model_type = "v1"
-
- return model_type
+ from .single_file.single_file_utils import infer_diffusers_model_type
+
+ deprecation_message = "Importing `infer_diffusers_model_type()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import infer_diffusers_model_type` instead."
+ deprecate("diffusers.loaders.single_file_utils.infer_diffusers_model_type", "0.36", deprecation_message)
+
+ return infer_diffusers_model_type(checkpoint)
def fetch_diffusers_config(checkpoint):
- model_type = infer_diffusers_model_type(checkpoint)
- model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
- model_path = copy.deepcopy(model_path)
+ from .single_file.single_file_utils import fetch_diffusers_config
+
+ deprecation_message = "Importing `fetch_diffusers_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_diffusers_config` instead."
+ deprecate("diffusers.loaders.single_file_utils.fetch_diffusers_config", "0.36", deprecation_message)
- return model_path
+ return fetch_diffusers_config(checkpoint)
def set_image_size(checkpoint, image_size=None):
- if image_size:
- return image_size
+ from .single_file.single_file_utils import set_image_size
- model_type = infer_diffusers_model_type(checkpoint)
- image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
+ deprecation_message = "Importing `set_image_size()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import set_image_size` instead."
+ deprecate("diffusers.loaders.single_file_utils.set_image_size", "0.36", deprecation_message)
- return image_size
+ return set_image_size(checkpoint, image_size)
-# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
+ from .single_file.single_file_utils import conv_attn_to_linear
+
+ deprecation_message = "Importing `conv_attn_to_linear()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import conv_attn_to_linear` instead."
+ deprecate("diffusers.loaders.single_file_utils.conv_attn_to_linear", "0.36", deprecation_message)
+
+ return conv_attn_to_linear(checkpoint)
def create_unet_diffusers_config_from_ldm(
original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- if image_size is not None:
- deprecation_message = (
- "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
- "is deprecated and will be ignored in future versions."
- )
- deprecate("image_size", "1.0.0", deprecation_message)
-
- image_size = set_image_size(checkpoint, image_size=image_size)
-
- if (
- "unet_config" in original_config["model"]["params"]
- and original_config["model"]["params"]["unet_config"] is not None
- ):
- unet_params = original_config["model"]["params"]["unet_config"]["params"]
- else:
- unet_params = original_config["model"]["params"]["network_config"]["params"]
-
- if num_in_channels is not None:
- deprecation_message = (
- "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
- "is deprecated and will be ignored in future versions."
- )
- deprecate("image_size", "1.0.0", deprecation_message)
- in_channels = num_in_channels
- else:
- in_channels = unet_params["in_channels"]
-
- vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
- block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- if unet_params["transformer_depth"] is not None:
- transformer_layers_per_block = (
- unet_params["transformer_depth"]
- if isinstance(unet_params["transformer_depth"], int)
- else list(unet_params["transformer_depth"])
- )
- else:
- transformer_layers_per_block = 1
-
- vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
-
- head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
- use_linear_projection = (
- unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
+ from .single_file.single_file_utils import create_unet_diffusers_config_from_ldm
+
+ deprecation_message = "Importing `create_unet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_unet_diffusers_config_from_ldm` instead."
+ deprecate("diffusers.loaders.single_file_utils.create_unet_diffusers_config_from_ldm", "0.36", deprecation_message)
+
+ return create_unet_diffusers_config_from_ldm(
+ original_config, checkpoint, image_size, upcast_attention, num_in_channels
)
- if use_linear_projection:
- # stable diffusion 2-base-512 and 2-768
- if head_dim is None:
- head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
- head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
-
- class_embed_type = None
- addition_embed_type = None
- addition_time_embed_dim = None
- projection_class_embeddings_input_dim = None
- context_dim = None
-
- if unet_params["context_dim"] is not None:
- context_dim = (
- unet_params["context_dim"]
- if isinstance(unet_params["context_dim"], int)
- else unet_params["context_dim"][0]
- )
-
- if "num_classes" in unet_params:
- if unet_params["num_classes"] == "sequential":
- if context_dim in [2048, 1280]:
- # SDXL
- addition_embed_type = "text_time"
- addition_time_embed_dim = 256
- else:
- class_embed_type = "projection"
- assert "adm_in_channels" in unet_params
- projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
-
- config = {
- "sample_size": image_size // vae_scale_factor,
- "in_channels": in_channels,
- "down_block_types": down_block_types,
- "block_out_channels": block_out_channels,
- "layers_per_block": unet_params["num_res_blocks"],
- "cross_attention_dim": context_dim,
- "attention_head_dim": head_dim,
- "use_linear_projection": use_linear_projection,
- "class_embed_type": class_embed_type,
- "addition_embed_type": addition_embed_type,
- "addition_time_embed_dim": addition_time_embed_dim,
- "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
- "transformer_layers_per_block": transformer_layers_per_block,
- }
-
- if upcast_attention is not None:
- deprecation_message = (
- "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
- "is deprecated and will be ignored in future versions."
- )
- deprecate("image_size", "1.0.0", deprecation_message)
- config["upcast_attention"] = upcast_attention
-
- if "disable_self_attentions" in unet_params:
- config["only_cross_attention"] = unet_params["disable_self_attentions"]
-
- if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
- config["num_class_embeds"] = unet_params["num_classes"]
-
- config["out_channels"] = unet_params["out_channels"]
- config["up_block_types"] = up_block_types
-
- return config
def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
- if image_size is not None:
- deprecation_message = (
- "Configuring ControlNetModel with the `image_size` argument"
- "is deprecated and will be ignored in future versions."
- )
- deprecate("image_size", "1.0.0", deprecation_message)
-
- image_size = set_image_size(checkpoint, image_size=image_size)
-
- unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
- diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
-
- controlnet_config = {
- "conditioning_channels": unet_params["hint_channels"],
- "in_channels": diffusers_unet_config["in_channels"],
- "down_block_types": diffusers_unet_config["down_block_types"],
- "block_out_channels": diffusers_unet_config["block_out_channels"],
- "layers_per_block": diffusers_unet_config["layers_per_block"],
- "cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
- "attention_head_dim": diffusers_unet_config["attention_head_dim"],
- "use_linear_projection": diffusers_unet_config["use_linear_projection"],
- "class_embed_type": diffusers_unet_config["class_embed_type"],
- "addition_embed_type": diffusers_unet_config["addition_embed_type"],
- "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
- "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
- "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
- }
-
- return controlnet_config
+ from .single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm
+
+ deprecation_message = "Importing `create_controlnet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.create_controlnet_diffusers_config_from_ldm", "0.36", deprecation_message
+ )
+ return create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size, **kwargs)
def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- if image_size is not None:
- deprecation_message = (
- "Configuring AutoencoderKL with the `image_size` argument"
- "is deprecated and will be ignored in future versions."
- )
- deprecate("image_size", "1.0.0", deprecation_message)
-
- image_size = set_image_size(checkpoint, image_size=image_size)
-
- if "edm_mean" in checkpoint and "edm_std" in checkpoint:
- latents_mean = checkpoint["edm_mean"]
- latents_std = checkpoint["edm_std"]
- else:
- latents_mean = None
- latents_std = None
-
- vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
- if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
- scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
-
- elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
- scaling_factor = original_config["model"]["params"]["scale_factor"]
-
- elif scaling_factor is None:
- scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
-
- block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = {
- "sample_size": image_size,
- "in_channels": vae_params["in_channels"],
- "out_channels": vae_params["out_ch"],
- "down_block_types": down_block_types,
- "up_block_types": up_block_types,
- "block_out_channels": block_out_channels,
- "latent_channels": vae_params["z_channels"],
- "layers_per_block": vae_params["num_res_blocks"],
- "scaling_factor": scaling_factor,
- }
- if latents_mean is not None and latents_std is not None:
- config.update({"latents_mean": latents_mean, "latents_std": latents_std})
-
- return config
+ from .single_file.single_file_utils import create_vae_diffusers_config_from_ldm
+
+ deprecation_message = "Importing `create_vae_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_vae_diffusers_config_from_ldm` instead."
+ deprecate("diffusers.loaders.single_file_utils.create_vae_diffusers_config_from_ldm", "0.36", deprecation_message)
+ return create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size, scaling_factor)
def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None):
- for ldm_key in ldm_keys:
- diffusers_key = (
- ldm_key.replace("in_layers.0", "norm1")
- .replace("in_layers.2", "conv1")
- .replace("out_layers.0", "norm2")
- .replace("out_layers.3", "conv2")
- .replace("emb_layers.1", "time_emb_proj")
- .replace("skip_connection", "conv_shortcut")
- )
- if mapping:
- diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
- new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+ from .single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers
+
+ deprecation_message = "Importing `update_unet_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.update_unet_resnet_ldm_to_diffusers", "0.36", deprecation_message)
+
+ return update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping)
def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
- for ldm_key in ldm_keys:
- diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
- new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+ from .single_file.single_file_utils import update_unet_attention_ldm_to_diffusers
+
+ deprecation_message = "Importing `update_unet_attention_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_attention_ldm_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.update_unet_attention_ldm_to_diffusers", "0.36", deprecation_message
+ )
+
+ return update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping)
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
- for ldm_key in keys:
- diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
- new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+ from .single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers
+
+ deprecation_message = "Importing `update_vae_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.update_vae_resnet_ldm_to_diffusers", "0.36", deprecation_message)
+
+ return update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
- for ldm_key in keys:
- diffusers_key = (
- ldm_key.replace(mapping["old"], mapping["new"])
- .replace("norm.weight", "group_norm.weight")
- .replace("norm.bias", "group_norm.bias")
- .replace("q.weight", "to_q.weight")
- .replace("q.bias", "to_q.bias")
- .replace("k.weight", "to_k.weight")
- .replace("k.bias", "to_k.bias")
- .replace("v.weight", "to_v.weight")
- .replace("v.bias", "to_v.bias")
- .replace("proj_out.weight", "to_out.0.weight")
- .replace("proj_out.bias", "to_out.0.bias")
- )
- new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
-
- # proj_attn.weight has to be converted from conv 1D to linear
- shape = new_checkpoint[diffusers_key].shape
-
- if len(shape) == 3:
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
- elif len(shape) == 4:
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
+ from .single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers
+
+ deprecation_message = "Importing `update_vae_attentions_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.update_vae_attentions_ldm_to_diffusers", "0.36", deprecation_message
+ )
+
+ return update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping)
def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
- is_stage_c = "clip_txt_mapper.weight" in checkpoint
-
- if is_stage_c:
- state_dict = {}
- for key in checkpoint.keys():
- if key.endswith("in_proj_weight"):
- weights = checkpoint[key].chunk(3, 0)
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
- elif key.endswith("in_proj_bias"):
- weights = checkpoint[key].chunk(3, 0)
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
- elif key.endswith("out_proj.weight"):
- weights = checkpoint[key]
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
- elif key.endswith("out_proj.bias"):
- weights = checkpoint[key]
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
- else:
- state_dict[key] = checkpoint[key]
- else:
- state_dict = {}
- for key in checkpoint.keys():
- if key.endswith("in_proj_weight"):
- weights = checkpoint[key].chunk(3, 0)
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
- elif key.endswith("in_proj_bias"):
- weights = checkpoint[key].chunk(3, 0)
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
- elif key.endswith("out_proj.weight"):
- weights = checkpoint[key]
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
- elif key.endswith("out_proj.bias"):
- weights = checkpoint[key]
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
- # rename clip_mapper to clip_txt_pooled_mapper
- elif key.endswith("clip_mapper.weight"):
- weights = checkpoint[key]
- state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
- elif key.endswith("clip_mapper.bias"):
- weights = checkpoint[key]
- state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
- else:
- state_dict[key] = checkpoint[key]
-
- return state_dict
+ from .single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
+
+ deprecation_message = "Importing `convert_stable_cascade_unet_single_file_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_stable_cascade_unet_single_file_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs)
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
- # extract state_dict for UNet
- unet_state_dict = {}
- keys = list(checkpoint.keys())
- unet_key = LDM_UNET_KEY
-
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
- if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
- logger.warning("Checkpoint has both EMA and non-EMA weights.")
- logger.warning(
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
- )
- for key in keys:
- if key.startswith("model.diffusion_model"):
- flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
- else:
- if sum(k.startswith("model_ema") for k in keys) > 100:
- logger.warning(
- "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
- " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
- )
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
- ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
- for diffusers_key, ldm_key in ldm_unet_keys.items():
- if ldm_key not in unet_state_dict:
- continue
- new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
-
- if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
- class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
- for diffusers_key, ldm_key in class_embed_keys.items():
- new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
-
- if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
- addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
- for diffusers_key, ldm_key in addition_embed_keys.items():
- new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
-
- # Relevant to StableDiffusionUpscalePipeline
- if "num_class_embeds" in config:
- if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
- new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- # Down blocks
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- update_unet_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- unet_state_dict,
- {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
- )
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
- f"input_blocks.{i}.0.op.bias"
- )
-
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
- if attentions:
- update_unet_attention_ldm_to_diffusers(
- attentions,
- new_checkpoint,
- unet_state_dict,
- {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
- )
-
- # Mid blocks
- for key in middle_blocks.keys():
- diffusers_key = max(key - 1, 0)
- if key % 2 == 0:
- update_unet_resnet_ldm_to_diffusers(
- middle_blocks[key],
- new_checkpoint,
- unet_state_dict,
- mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
- )
- else:
- update_unet_attention_ldm_to_diffusers(
- middle_blocks[key],
- new_checkpoint,
- unet_state_dict,
- mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
- )
-
- # Up Blocks
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key
- ]
- update_unet_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- unet_state_dict,
- {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"},
- )
-
- attentions = [
- key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key
- ]
- if attentions:
- update_unet_attention_ldm_to_diffusers(
- attentions,
- new_checkpoint,
- unet_state_dict,
- {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"},
- )
-
- if f"output_blocks.{i}.1.conv.weight" in unet_state_dict:
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.1.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.1.conv.bias"
- ]
- if f"output_blocks.{i}.2.conv.weight" in unet_state_dict:
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.2.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.2.conv.bias"
- ]
-
- return new_checkpoint
+ from .single_file.single_file_utils import convert_ldm_unet_checkpoint
+
+ deprecation_message = "Importing `convert_ldm_unet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_unet_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_ldm_unet_checkpoint", "0.36", deprecation_message)
+ return convert_ldm_unet_checkpoint(checkpoint, config, extract_ema, **kwargs)
def convert_controlnet_checkpoint(
@@ -1233,248 +274,27 @@ def convert_controlnet_checkpoint(
config,
**kwargs,
):
- # Return checkpoint if it's already been converted
- if "time_embedding.linear_1.weight" in checkpoint:
- return checkpoint
- # Some controlnet ckpt files are distributed independently from the rest of the
- # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
- if "time_embed.0.weight" in checkpoint:
- controlnet_state_dict = checkpoint
-
- else:
- controlnet_state_dict = {}
- keys = list(checkpoint.keys())
- controlnet_key = LDM_CONTROLNET_KEY
- for key in keys:
- if key.startswith(controlnet_key):
- controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
- ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
- for diffusers_key, ldm_key in ldm_controlnet_keys.items():
- if ldm_key not in controlnet_state_dict:
- continue
- new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len(
- {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
- )
- input_blocks = {
- layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Down blocks
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- update_unet_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- controlnet_state_dict,
- {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
- )
-
- if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
- f"input_blocks.{i}.0.op.bias"
- )
-
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
- if attentions:
- update_unet_attention_ldm_to_diffusers(
- attentions,
- new_checkpoint,
- controlnet_state_dict,
- {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
- )
-
- # controlnet down blocks
- for i in range(num_input_blocks):
- new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
- new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len(
- {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
- )
- middle_blocks = {
- layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Mid blocks
- for key in middle_blocks.keys():
- diffusers_key = max(key - 1, 0)
- if key % 2 == 0:
- update_unet_resnet_ldm_to_diffusers(
- middle_blocks[key],
- new_checkpoint,
- controlnet_state_dict,
- mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
- )
- else:
- update_unet_attention_ldm_to_diffusers(
- middle_blocks[key],
- new_checkpoint,
- controlnet_state_dict,
- mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
- )
-
- # mid block
- new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
- new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
-
- # controlnet cond embedding blocks
- cond_embedding_blocks = {
- ".".join(layer.split(".")[:2])
- for layer in controlnet_state_dict
- if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
- }
- num_cond_embedding_blocks = len(cond_embedding_blocks)
-
- for idx in range(1, num_cond_embedding_blocks + 1):
- diffusers_idx = idx - 1
- cond_block_id = 2 * idx
-
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
- f"input_hint_block.{cond_block_id}.weight"
- )
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
- f"input_hint_block.{cond_block_id}.bias"
- )
-
- return new_checkpoint
+ from .single_file.single_file_utils import convert_controlnet_checkpoint
+ deprecation_message = "Importing `convert_controlnet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_controlnet_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_controlnet_checkpoint", "0.36", deprecation_message)
+ return convert_controlnet_checkpoint(checkpoint, config, **kwargs)
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
- vae_state_dict = {}
- keys = list(checkpoint.keys())
- vae_key = ""
- for ldm_vae_key in LDM_VAE_KEYS:
- if any(k.startswith(ldm_vae_key) for k in keys):
- vae_key = ldm_vae_key
-
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
- vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"]
- for diffusers_key, ldm_key in vae_diffusers_ldm_map.items():
- if ldm_key not in vae_state_dict:
- continue
- new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len(config["down_block_types"])
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
- update_vae_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- vae_state_dict,
- mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
- )
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
- update_vae_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- vae_state_dict,
- mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
- )
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- update_vae_attentions_ldm_to_diffusers(
- mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- )
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len(config["up_block_types"])
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
- update_vae_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- vae_state_dict,
- mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
- )
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
- update_vae_resnet_ldm_to_diffusers(
- resnets,
- new_checkpoint,
- vae_state_dict,
- mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
- )
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- update_vae_attentions_ldm_to_diffusers(
- mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- )
- conv_attn_to_linear(new_checkpoint)
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ from .single_file.single_file_utils import convert_ldm_vae_checkpoint
- return new_checkpoint
+ deprecation_message = "Importing `convert_ldm_vae_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_vae_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_ldm_vae_checkpoint", "0.36", deprecation_message)
+ return convert_ldm_vae_checkpoint(checkpoint, config, config)
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
- keys = list(checkpoint.keys())
- text_model_dict = {}
-
- remove_prefixes = []
- remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
- if remove_prefix:
- remove_prefixes.append(remove_prefix)
+ from .single_file.single_file_utils import convert_ldm_clip_checkpoint
- for key in keys:
- for prefix in remove_prefixes:
- if key.startswith(prefix):
- diffusers_key = key.replace(prefix, "")
- text_model_dict[diffusers_key] = checkpoint.get(key)
-
- return text_model_dict
+ deprecation_message = "Importing `convert_ldm_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_clip_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_ldm_clip_checkpoint", "0.36", deprecation_message)
+ return convert_ldm_clip_checkpoint(checkpoint, remove_prefix)
def convert_open_clip_checkpoint(
@@ -1482,65 +302,11 @@ def convert_open_clip_checkpoint(
checkpoint,
prefix="cond_stage_model.model.",
):
- text_model_dict = {}
- text_proj_key = prefix + "text_projection"
-
- if text_proj_key in checkpoint:
- text_proj_dim = int(checkpoint[text_proj_key].shape[0])
- elif hasattr(text_model.config, "hidden_size"):
- text_proj_dim = text_model.config.hidden_size
- else:
- text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
-
- keys = list(checkpoint.keys())
- keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
-
- openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
- for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
- ldm_key = prefix + ldm_key
- if ldm_key not in checkpoint:
- continue
- if ldm_key in keys_to_ignore:
- continue
- if ldm_key.endswith("text_projection"):
- text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
- else:
- text_model_dict[diffusers_key] = checkpoint[ldm_key]
-
- for key in keys:
- if key in keys_to_ignore:
- continue
-
- if not key.startswith(prefix + "transformer."):
- continue
-
- diffusers_key = key.replace(prefix + "transformer.", "")
- transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
- for new_key, old_key in transformer_diffusers_to_ldm_map.items():
- diffusers_key = (
- diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
- )
-
- if key.endswith(".in_proj_weight"):
- weight_value = checkpoint.get(key)
-
- text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
- text_model_dict[diffusers_key + ".k_proj.weight"] = (
- weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
- )
- text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
-
- elif key.endswith(".in_proj_bias"):
- weight_value = checkpoint.get(key)
- text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
- text_model_dict[diffusers_key + ".k_proj.bias"] = (
- weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
- )
- text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
- else:
- text_model_dict[diffusers_key] = checkpoint.get(key)
-
- return text_model_dict
+ from .single_file.single_file_utils import convert_open_clip_checkpoint
+
+ deprecation_message = "Importing `convert_open_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_open_clip_checkpoint` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_open_clip_checkpoint", "0.36", deprecation_message)
+ return convert_open_clip_checkpoint(text_model, checkpoint, prefix)
def create_diffusers_clip_model_from_ldm(
@@ -1552,522 +318,77 @@ def create_diffusers_clip_model_from_ldm(
local_files_only=None,
is_legacy_loading=False,
):
- if config:
- config = {"pretrained_model_name_or_path": config}
- else:
- config = fetch_diffusers_config(checkpoint)
-
- # For backwards compatibility
- # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
- # in the cache_dir, rather than in a subfolder of the Diffusers model
- if is_legacy_loading:
- logger.warning(
- (
- "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
- "the local cache directory with the necessary CLIP model config files. "
- "Attempting to load CLIP model from legacy cache directory."
- )
- )
-
- if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
- clip_config = "openai/clip-vit-large-patch14"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = ""
-
- elif is_open_clip_model(checkpoint):
- clip_config = "stabilityai/stable-diffusion-2"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = "text_encoder"
-
- else:
- clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = ""
-
- model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
- with ctx():
- model = cls(model_config)
-
- position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
-
- if is_clip_model(checkpoint):
- diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
-
- elif (
- is_clip_sdxl_model(checkpoint)
- and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
- ):
- diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
-
- elif (
- is_clip_sd3_model(checkpoint)
- and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
- ):
- diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
- diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
-
- elif is_open_clip_model(checkpoint):
- prefix = "cond_stage_model.model."
- diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
-
- elif (
- is_open_clip_sdxl_model(checkpoint)
- and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
- ):
- prefix = "conditioner.embedders.1.model."
- diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
-
- elif is_open_clip_sdxl_refiner_model(checkpoint):
- prefix = "conditioner.embedders.0.model."
- diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
-
- elif (
- is_open_clip_sd3_model(checkpoint)
- and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
- ):
- diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
-
- else:
- raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
-
- if is_accelerate_available():
- load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
- else:
- model.load_state_dict(diffusers_format_checkpoint, strict=False)
-
- if torch_dtype is not None:
- model.to(torch_dtype)
-
- model.eval()
-
- return model
-
-
-def _legacy_load_scheduler(
- cls,
- checkpoint,
- component_name,
- original_config=None,
- **kwargs,
-):
- scheduler_type = kwargs.get("scheduler_type", None)
- prediction_type = kwargs.get("prediction_type", None)
-
- if scheduler_type is not None:
- deprecation_message = (
- "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
- "Example:\n\n"
- "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
- "scheduler = DDIMScheduler()\n"
- "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n"
- )
- deprecate("scheduler_type", "1.0.0", deprecation_message)
-
- if prediction_type is not None:
- deprecation_message = (
- "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
- "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
- "Example:\n\n"
- "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
- 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
- "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n"
- )
- deprecate("prediction_type", "1.0.0", deprecation_message)
-
- scheduler_config = SCHEDULER_DEFAULT_CONFIG
- model_type = infer_diffusers_model_type(checkpoint=checkpoint)
-
- global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
-
- if original_config:
- num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
- else:
- num_train_timesteps = 1000
-
- scheduler_config["num_train_timesteps"] = num_train_timesteps
-
- if model_type == "v2":
- if prediction_type is None:
- # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
- prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
-
- else:
- prediction_type = prediction_type or "epsilon"
-
- scheduler_config["prediction_type"] = prediction_type
-
- if model_type in ["xl_base", "xl_refiner"]:
- scheduler_type = "euler"
- elif model_type == "playground":
- scheduler_type = "edm_dpm_solver_multistep"
- else:
- if original_config:
- beta_start = original_config["model"]["params"].get("linear_start")
- beta_end = original_config["model"]["params"].get("linear_end")
-
- else:
- beta_start = 0.02
- beta_end = 0.085
-
- scheduler_config["beta_start"] = beta_start
- scheduler_config["beta_end"] = beta_end
- scheduler_config["beta_schedule"] = "scaled_linear"
- scheduler_config["clip_sample"] = False
- scheduler_config["set_alpha_to_one"] = False
-
- # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
- if component_name == "low_res_scheduler":
- return cls.from_config(
- {
- "beta_end": 0.02,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.0001,
- "clip_sample": True,
- "num_train_timesteps": 1000,
- "prediction_type": "epsilon",
- "trained_betas": None,
- "variance_type": "fixed_small",
- }
- )
-
- if scheduler_type is None:
- return cls.from_config(scheduler_config)
-
- elif scheduler_type == "pndm":
- scheduler_config["skip_prk_steps"] = True
- scheduler = PNDMScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "lms":
- scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "heun":
- scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "euler":
- scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "euler-ancestral":
- scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "dpm":
- scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "ddim":
- scheduler = DDIMScheduler.from_config(scheduler_config)
-
- elif scheduler_type == "edm_dpm_solver_multistep":
- scheduler_config = {
- "algorithm_type": "dpmsolver++",
- "dynamic_thresholding_ratio": 0.995,
- "euler_at_final": False,
- "final_sigmas_type": "zero",
- "lower_order_final": True,
- "num_train_timesteps": 1000,
- "prediction_type": "epsilon",
- "rho": 7.0,
- "sample_max_value": 1.0,
- "sigma_data": 0.5,
- "sigma_max": 80.0,
- "sigma_min": 0.002,
- "solver_order": 2,
- "solver_type": "midpoint",
- "thresholding": False,
- }
- scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
-
- else:
- raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
-
- return scheduler
-
-
-def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
- if config:
- config = {"pretrained_model_name_or_path": config}
- else:
- config = fetch_diffusers_config(checkpoint)
-
- if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
- clip_config = "openai/clip-vit-large-patch14"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = ""
-
- elif is_open_clip_model(checkpoint):
- clip_config = "stabilityai/stable-diffusion-2"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = "tokenizer"
-
- else:
- clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
- config["pretrained_model_name_or_path"] = clip_config
- subfolder = ""
-
- tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
-
- return tokenizer
-
-
-def _legacy_load_safety_checker(local_files_only, torch_dtype):
- # Support for loading safety checker components using the deprecated
- # `load_safety_checker` argument.
-
- from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-
- feature_extractor = AutoImageProcessor.from_pretrained(
- "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
- )
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
- "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
- )
+ from .single_file.single_file_utils import create_diffusers_clip_model_from_ldm
- return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
+ deprecation_message = "Importing `create_diffusers_clip_model_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_clip_model_from_ldm` instead."
+ deprecate("diffusers.loaders.single_file_utils.create_diffusers_clip_model_from_ldm", "0.36", deprecation_message)
+ return create_diffusers_clip_model_from_ldm(
+ cls, checkpoint, subfolder, config, torch_dtype, local_files_only, is_legacy_loading
+ )
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight, dim):
- shift, scale = weight.chunk(2, dim=0)
- new_weight = torch.cat([scale, shift], dim=0)
- return new_weight
+ from .single_file.single_file_utils import swap_scale_shift
+
+ deprecation_message = "Importing `swap_scale_shift()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_scale_shift` instead."
+ deprecate("diffusers.loaders.single_file_utils.swap_scale_shift", "0.36", deprecation_message)
+ return swap_scale_shift(weight, dim)
def swap_proj_gate(weight):
- proj, gate = weight.chunk(2, dim=0)
- new_weight = torch.cat([gate, proj], dim=0)
- return new_weight
+ from .single_file.single_file_utils import swap_proj_gate
+
+ deprecation_message = "Importing `swap_proj_gate()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_proj_gate` instead."
+ deprecate("diffusers.loaders.single_file_utils.swap_proj_gate", "0.36", deprecation_message)
+ return swap_proj_gate(weight)
def get_attn2_layers(state_dict):
- attn2_layers = []
- for key in state_dict.keys():
- if "attn2." in key:
- # Extract the layer number from the key
- layer_num = int(key.split(".")[1])
- attn2_layers.append(layer_num)
+ from .single_file.single_file_utils import get_attn2_layers
- return tuple(sorted(set(attn2_layers)))
+ deprecation_message = "Importing `get_attn2_layers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_attn2_layers` instead."
+ deprecate("diffusers.loaders.single_file_utils.get_attn2_layers", "0.36", deprecation_message)
+ return get_attn2_layers(state_dict)
def get_caption_projection_dim(state_dict):
- caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
- return caption_projection_dim
+ from .single_file.single_file_utils import get_caption_projection_dim
+
+ deprecation_message = "Importing `get_caption_projection_dim()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_caption_projection_dim` instead."
+ deprecate("diffusers.loaders.single_file_utils.get_caption_projection_dim", "0.36", deprecation_message)
+ return get_caption_projection_dim(state_dict)
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
- keys = list(checkpoint.keys())
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
- dual_attention_layers = get_attn2_layers(checkpoint)
-
- caption_projection_dim = get_caption_projection_dim(checkpoint)
- has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
-
- # Positional and patch embeddings.
- converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
- converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
- converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
-
- # Timestep embeddings.
- converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
- "t_embedder.mlp.0.weight"
- )
- converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
- converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
- "t_embedder.mlp.2.weight"
- )
- converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
-
- # Context projections.
- converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
- converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
-
- # Pooled context projection.
- converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
- converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
- converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
- converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
-
- # Transformer blocks 🎸.
- for i in range(num_layers):
- # Q, K, V
- sample_q, sample_k, sample_v = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
- )
- context_q, context_k, context_v = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
- )
- sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
- )
- context_q_bias, context_k_bias, context_v_bias = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
- )
-
- converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
- converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
- converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
- converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
-
- converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
- converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
- converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
- converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
-
- # qk norm
- if has_qk_norm:
- converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn.ln_q.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn.ln_k.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.attn.ln_q.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.attn.ln_k.weight"
- )
-
- # output projections.
- converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn.proj.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn.proj.bias"
- )
- if not (i == num_layers - 1):
- converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.attn.proj.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.attn.proj.bias"
- )
-
- if i in dual_attention_layers:
- # Q, K, V
- sample_q2, sample_k2, sample_v2 = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
- )
- sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
- checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
-
- # qk norm
- if has_qk_norm:
- converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
- )
-
- # output projections.
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn2.proj.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.attn2.proj.bias"
- )
-
- # norms.
- converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
- )
- if not (i == num_layers - 1):
- converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
- )
- else:
- converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
- checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
- dim=caption_projection_dim,
- )
- converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
- checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
- dim=caption_projection_dim,
- )
-
- # ffs.
- converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.mlp.fc1.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.mlp.fc1.bias"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.mlp.fc2.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.x_block.mlp.fc2.bias"
- )
- if not (i == num_layers - 1):
- converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.mlp.fc1.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.mlp.fc1.bias"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.mlp.fc2.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
- f"joint_blocks.{i}.context_block.mlp.fc2.bias"
- )
-
- # Final blocks.
- converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
- converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
- converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
- checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
- )
- converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
- checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
- )
+ from .single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers
- return converted_state_dict
+ deprecation_message = "Importing `convert_sd3_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_sd3_transformer_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
def is_t5_in_single_file(checkpoint):
- if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
- return True
+ from .single_file.single_file_utils import is_t5_in_single_file
- return False
+ deprecation_message = "Importing `is_t5_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_t5_in_single_file` instead."
+ deprecate("diffusers.loaders.single_file_utils.is_t5_in_single_file", "0.36", deprecation_message)
+ return is_t5_in_single_file(checkpoint)
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
- keys = list(checkpoint.keys())
- text_model_dict = {}
-
- remove_prefixes = ["text_encoders.t5xxl.transformer."]
+ from .single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers
- for key in keys:
- for prefix in remove_prefixes:
- if key.startswith(prefix):
- diffusers_key = key.replace(prefix, "")
- text_model_dict[diffusers_key] = checkpoint.get(key)
-
- return text_model_dict
+ deprecation_message = "Importing `convert_sd3_t5_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_sd3_t5_checkpoint_to_diffusers", "0.36", deprecation_message
+ )
+ return convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
def create_diffusers_t5_model_from_checkpoint(
@@ -2078,1218 +399,134 @@ def create_diffusers_t5_model_from_checkpoint(
torch_dtype=None,
local_files_only=None,
):
- if config:
- config = {"pretrained_model_name_or_path": config}
- else:
- config = fetch_diffusers_config(checkpoint)
-
- model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
- with ctx():
- model = cls(model_config)
+ from .single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint
- diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
-
- if is_accelerate_available():
- load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
- else:
- model.load_state_dict(diffusers_format_checkpoint)
-
- use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
- if use_keep_in_fp32_modules:
- keep_in_fp32_modules = model._keep_in_fp32_modules
- else:
- keep_in_fp32_modules = []
-
- if keep_in_fp32_modules is not None:
- for name, param in model.named_parameters():
- if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
- # param = param.to(torch.float32) does not work here as only in the local scope.
- param.data = param.data.to(torch.float32)
-
- return model
+ deprecation_message = "Importing `create_diffusers_t5_model_from_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.create_diffusers_t5_model_from_checkpoint", "0.36", deprecation_message
+ )
+ return create_diffusers_t5_model_from_checkpoint(cls, checkpoint, subfolder, config, torch_dtype, local_files_only)
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
- for k, v in checkpoint.items():
- if "pos_encoder" in k:
- continue
-
- else:
- converted_state_dict[
- k.replace(".norms.0", ".norm1")
- .replace(".norms.1", ".norm2")
- .replace(".ff_norm", ".norm3")
- .replace(".attention_blocks.0", ".attn1")
- .replace(".attention_blocks.1", ".attn2")
- .replace(".temporal_transformer", "")
- ] = v
+ from .single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers
- return converted_state_dict
+ deprecation_message = "Importing `convert_animatediff_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_animatediff_checkpoint_to_diffusers", "0.36", deprecation_message
+ )
+ return convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
- keys = list(checkpoint.keys())
-
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
- num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
- mlp_ratio = 4.0
- inner_dim = 3072
-
- # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
- # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
- def swap_scale_shift(weight):
- shift, scale = weight.chunk(2, dim=0)
- new_weight = torch.cat([scale, shift], dim=0)
- return new_weight
-
- ## time_text_embed.timestep_embedder <- time_in
- converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
- "time_in.in_layer.weight"
- )
- converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
- converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
- "time_in.out_layer.weight"
- )
- converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
+ from .single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
- ## time_text_embed.text_embedder <- vector_in
- converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
- converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
- converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
- "vector_in.out_layer.weight"
- )
- converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
-
- # guidance
- has_guidance = any("guidance" in k for k in checkpoint)
- if has_guidance:
- converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
- "guidance_in.in_layer.weight"
- )
- converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
- "guidance_in.in_layer.bias"
- )
- converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
- "guidance_in.out_layer.weight"
- )
- converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
- "guidance_in.out_layer.bias"
- )
-
- # context_embedder
- converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
- converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
-
- # x_embedder
- converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
- converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
-
- # double transformer blocks
- for i in range(num_layers):
- block_prefix = f"transformer_blocks.{i}."
- # norms.
- ## norm1
- converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
- f"double_blocks.{i}.img_mod.lin.weight"
- )
- converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
- f"double_blocks.{i}.img_mod.lin.bias"
- )
- ## norm1_context
- converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mod.lin.weight"
- )
- converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mod.lin.bias"
- )
- # Q, K, V
- sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
- context_q, context_k, context_v = torch.chunk(
- checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
- )
- sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
- checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
- )
- context_q_bias, context_k_bias, context_v_bias = torch.chunk(
- checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
- )
- converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
- converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
- converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
- converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
- converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
- converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
- converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
- converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
- converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
- converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
- converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
- converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
- # qk_norm
- converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
- f"double_blocks.{i}.img_attn.norm.query_norm.scale"
- )
- converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
- f"double_blocks.{i}.img_attn.norm.key_norm.scale"
- )
- converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
- )
- converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
- )
- # ff img_mlp
- converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
- f"double_blocks.{i}.img_mlp.0.weight"
- )
- converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
- converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
- converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
- converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mlp.0.weight"
- )
- converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mlp.0.bias"
- )
- converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mlp.2.weight"
- )
- converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
- f"double_blocks.{i}.txt_mlp.2.bias"
- )
- # output projections.
- converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
- f"double_blocks.{i}.img_attn.proj.weight"
- )
- converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
- f"double_blocks.{i}.img_attn.proj.bias"
- )
- converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
- f"double_blocks.{i}.txt_attn.proj.weight"
- )
- converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
- f"double_blocks.{i}.txt_attn.proj.bias"
- )
-
- # single transfomer blocks
- for i in range(num_single_layers):
- block_prefix = f"single_transformer_blocks.{i}."
- # norm.linear <- single_blocks.0.modulation.lin
- converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
- f"single_blocks.{i}.modulation.lin.weight"
- )
- converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
- f"single_blocks.{i}.modulation.lin.bias"
- )
- # Q, K, V, mlp
- mlp_hidden_dim = int(inner_dim * mlp_ratio)
- split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
- q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
- q_bias, k_bias, v_bias, mlp_bias = torch.split(
- checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
- )
- converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
- converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
- converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
- converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
- converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
- converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
- converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
- converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
- # qk norm
- converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
- f"single_blocks.{i}.norm.query_norm.scale"
- )
- converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
- f"single_blocks.{i}.norm.key_norm.scale"
- )
- # output projections.
- converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
- converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
-
- converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
- converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
- converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
- checkpoint.pop("final_layer.adaLN_modulation.1.weight")
+ deprecation_message = "Importing `convert_flux_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_flux_transformer_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
)
- converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
- checkpoint.pop("final_layer.adaLN_modulation.1.bias")
- )
-
- return converted_state_dict
+ return convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
-
- TRANSFORMER_KEYS_RENAME_DICT = {
- "model.diffusion_model.": "",
- "patchify_proj": "proj_in",
- "adaln_single": "time_embed",
- "q_norm": "norm_q",
- "k_norm": "norm_k",
- }
+ from .single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers
- TRANSFORMER_SPECIAL_KEYS_REMAP = {}
-
- for key in list(converted_state_dict.keys()):
- new_key = key
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- converted_state_dict[new_key] = converted_state_dict.pop(key)
-
- for key in list(converted_state_dict.keys()):
- for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, converted_state_dict)
-
- return converted_state_dict
+ deprecation_message = "Importing `convert_ltx_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_ltx_transformer_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
-
- def remove_keys_(key: str, state_dict):
- state_dict.pop(key)
-
- VAE_KEYS_RENAME_DICT = {
- # common
- "vae.": "",
- # decoder
- "up_blocks.0": "mid_block",
- "up_blocks.1": "up_blocks.0",
- "up_blocks.2": "up_blocks.1.upsamplers.0",
- "up_blocks.3": "up_blocks.1",
- "up_blocks.4": "up_blocks.2.conv_in",
- "up_blocks.5": "up_blocks.2.upsamplers.0",
- "up_blocks.6": "up_blocks.2",
- "up_blocks.7": "up_blocks.3.conv_in",
- "up_blocks.8": "up_blocks.3.upsamplers.0",
- "up_blocks.9": "up_blocks.3",
- # encoder
- "down_blocks.0": "down_blocks.0",
- "down_blocks.1": "down_blocks.0.downsamplers.0",
- "down_blocks.2": "down_blocks.0.conv_out",
- "down_blocks.3": "down_blocks.1",
- "down_blocks.4": "down_blocks.1.downsamplers.0",
- "down_blocks.5": "down_blocks.1.conv_out",
- "down_blocks.6": "down_blocks.2",
- "down_blocks.7": "down_blocks.2.downsamplers.0",
- "down_blocks.8": "down_blocks.3",
- "down_blocks.9": "mid_block",
- # common
- "conv_shortcut": "conv_shortcut.conv",
- "res_blocks": "resnets",
- "norm3.norm": "norm3",
- "per_channel_statistics.mean-of-means": "latents_mean",
- "per_channel_statistics.std-of-means": "latents_std",
- }
-
- VAE_091_RENAME_DICT = {
- # decoder
- "up_blocks.0": "mid_block",
- "up_blocks.1": "up_blocks.0.upsamplers.0",
- "up_blocks.2": "up_blocks.0",
- "up_blocks.3": "up_blocks.1.upsamplers.0",
- "up_blocks.4": "up_blocks.1",
- "up_blocks.5": "up_blocks.2.upsamplers.0",
- "up_blocks.6": "up_blocks.2",
- "up_blocks.7": "up_blocks.3.upsamplers.0",
- "up_blocks.8": "up_blocks.3",
- # common
- "last_time_embedder": "time_embedder",
- "last_scale_shift_table": "scale_shift_table",
- }
-
- VAE_095_RENAME_DICT = {
- # decoder
- "up_blocks.0": "mid_block",
- "up_blocks.1": "up_blocks.0.upsamplers.0",
- "up_blocks.2": "up_blocks.0",
- "up_blocks.3": "up_blocks.1.upsamplers.0",
- "up_blocks.4": "up_blocks.1",
- "up_blocks.5": "up_blocks.2.upsamplers.0",
- "up_blocks.6": "up_blocks.2",
- "up_blocks.7": "up_blocks.3.upsamplers.0",
- "up_blocks.8": "up_blocks.3",
- # encoder
- "down_blocks.0": "down_blocks.0",
- "down_blocks.1": "down_blocks.0.downsamplers.0",
- "down_blocks.2": "down_blocks.1",
- "down_blocks.3": "down_blocks.1.downsamplers.0",
- "down_blocks.4": "down_blocks.2",
- "down_blocks.5": "down_blocks.2.downsamplers.0",
- "down_blocks.6": "down_blocks.3",
- "down_blocks.7": "down_blocks.3.downsamplers.0",
- "down_blocks.8": "mid_block",
- # common
- "last_time_embedder": "time_embedder",
- "last_scale_shift_table": "scale_shift_table",
- }
-
- VAE_SPECIAL_KEYS_REMAP = {
- "per_channel_statistics.channel": remove_keys_,
- "per_channel_statistics.mean-of-means": remove_keys_,
- "per_channel_statistics.mean-of-stds": remove_keys_,
- }
-
- if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
- VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
- elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
- VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
-
- for key in list(converted_state_dict.keys()):
- new_key = key
- for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- converted_state_dict[new_key] = converted_state_dict.pop(key)
-
- for key in list(converted_state_dict.keys()):
- for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, converted_state_dict)
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers
+
+ deprecation_message = "Importing `convert_ltx_vae_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_ltx_vae_checkpoint_to_diffusers", "0.36", deprecation_message
+ )
+ return convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
-
- def remap_qkv_(key: str, state_dict):
- qkv = state_dict.pop(key)
- q, k, v = torch.chunk(qkv, 3, dim=0)
- parent_module, _, _ = key.rpartition(".qkv.conv.weight")
- state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
- state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
- state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
-
- def remap_proj_conv_(key: str, state_dict):
- parent_module, _, _ = key.rpartition(".proj.conv.weight")
- state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
-
- AE_KEYS_RENAME_DICT = {
- # common
- "main.": "",
- "op_list.": "",
- "context_module": "attn",
- "local_module": "conv_out",
- # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
- # If there were more scales, there would be more layers, so a loop would be better to handle this
- "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
- "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
- "depth_conv.conv": "conv_depth",
- "inverted_conv.conv": "conv_inverted",
- "point_conv.conv": "conv_point",
- "point_conv.norm": "norm",
- "conv.conv.": "conv.",
- "conv1.conv": "conv1",
- "conv2.conv": "conv2",
- "conv2.norm": "norm",
- "proj.norm": "norm_out",
- # encoder
- "encoder.project_in.conv": "encoder.conv_in",
- "encoder.project_out.0.conv": "encoder.conv_out",
- "encoder.stages": "encoder.down_blocks",
- # decoder
- "decoder.project_in.conv": "decoder.conv_in",
- "decoder.project_out.0": "decoder.norm_out",
- "decoder.project_out.2.conv": "decoder.conv_out",
- "decoder.stages": "decoder.up_blocks",
- }
-
- AE_F32C32_F64C128_F128C512_KEYS = {
- "encoder.project_in.conv": "encoder.conv_in.conv",
- "decoder.project_out.2.conv": "decoder.conv_out.conv",
- }
-
- AE_SPECIAL_KEYS_REMAP = {
- "qkv.conv.weight": remap_qkv_,
- "proj.conv.weight": remap_proj_conv_,
- }
- if "encoder.project_in.conv.bias" not in converted_state_dict:
- AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
-
- for key in list(converted_state_dict.keys()):
- new_key = key[:]
- for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- converted_state_dict[new_key] = converted_state_dict.pop(key)
-
- for key in list(converted_state_dict.keys()):
- for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, converted_state_dict)
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers
+
+ deprecation_message = "Importing `convert_autoencoder_dc_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_autoencoder_dc_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
-
- # Comfy checkpoints add this prefix
- keys = list(checkpoint.keys())
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- # Convert patch_embed
- converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
- converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
-
- # Convert time_embed
- converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
- converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
- converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
- converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
- converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
- converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
- converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
- converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
- converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
- converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
- converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
- converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
-
- # Convert transformer blocks
- num_layers = 48
- for i in range(num_layers):
- block_prefix = f"transformer_blocks.{i}."
- old_prefix = f"blocks.{i}."
-
- # norm1
- converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
- converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
- if i < num_layers - 1:
- converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
- old_prefix + "mod_y.weight"
- )
- converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
- old_prefix + "mod_y.bias"
- )
- else:
- converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
- old_prefix + "mod_y.weight"
- )
- converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
- old_prefix + "mod_y.bias"
- )
-
- # Visual attention
- qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
- q, k, v = qkv_weight.chunk(3, dim=0)
-
- converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
- converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
- converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
- converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
- old_prefix + "attn.q_norm_x.weight"
- )
- converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
- old_prefix + "attn.k_norm_x.weight"
- )
- converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
- old_prefix + "attn.proj_x.weight"
- )
- converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
-
- # Context attention
- qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
- q, k, v = qkv_weight.chunk(3, dim=0)
-
- converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
- converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
- converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
- converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
- old_prefix + "attn.q_norm_y.weight"
- )
- converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
- old_prefix + "attn.k_norm_y.weight"
- )
- if i < num_layers - 1:
- converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
- old_prefix + "attn.proj_y.weight"
- )
- converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
- old_prefix + "attn.proj_y.bias"
- )
-
- # MLP
- converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
- checkpoint.pop(old_prefix + "mlp_x.w1.weight")
- )
- converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
- if i < num_layers - 1:
- converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
- checkpoint.pop(old_prefix + "mlp_y.w1.weight")
- )
- converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
- old_prefix + "mlp_y.w2.weight"
- )
-
- # Output layers
- converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
- converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
- converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
- converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
-
- converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers
+
+ deprecation_message = "Importing `convert_mochi_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_mochi_transformer_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
- def remap_norm_scale_shift_(key, state_dict):
- weight = state_dict.pop(key)
- shift, scale = weight.chunk(2, dim=0)
- new_weight = torch.cat([scale, shift], dim=0)
- state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
-
- def remap_txt_in_(key, state_dict):
- def rename_key(key):
- new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
- new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
- new_key = new_key.replace("txt_in", "context_embedder")
- new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
- new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
- new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
- new_key = new_key.replace("mlp", "ff")
- return new_key
-
- if "self_attn_qkv" in key:
- weight = state_dict.pop(key)
- to_q, to_k, to_v = weight.chunk(3, dim=0)
- state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
- state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
- state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
- else:
- state_dict[rename_key(key)] = state_dict.pop(key)
-
- def remap_img_attn_qkv_(key, state_dict):
- weight = state_dict.pop(key)
- to_q, to_k, to_v = weight.chunk(3, dim=0)
- state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
- state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
- state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
-
- def remap_txt_attn_qkv_(key, state_dict):
- weight = state_dict.pop(key)
- to_q, to_k, to_v = weight.chunk(3, dim=0)
- state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
- state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
- state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
-
- def remap_single_transformer_blocks_(key, state_dict):
- hidden_size = 3072
-
- if "linear1.weight" in key:
- linear1_weight = state_dict.pop(key)
- split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
- q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
- new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
- state_dict[f"{new_key}.attn.to_q.weight"] = q
- state_dict[f"{new_key}.attn.to_k.weight"] = k
- state_dict[f"{new_key}.attn.to_v.weight"] = v
- state_dict[f"{new_key}.proj_mlp.weight"] = mlp
-
- elif "linear1.bias" in key:
- linear1_bias = state_dict.pop(key)
- split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
- q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
- new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
- state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
- state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
- state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
- state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
-
- else:
- new_key = key.replace("single_blocks", "single_transformer_blocks")
- new_key = new_key.replace("linear2", "proj_out")
- new_key = new_key.replace("q_norm", "attn.norm_q")
- new_key = new_key.replace("k_norm", "attn.norm_k")
- state_dict[new_key] = state_dict.pop(key)
-
- TRANSFORMER_KEYS_RENAME_DICT = {
- "img_in": "x_embedder",
- "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
- "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
- "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
- "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
- "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
- "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
- "double_blocks": "transformer_blocks",
- "img_attn_q_norm": "attn.norm_q",
- "img_attn_k_norm": "attn.norm_k",
- "img_attn_proj": "attn.to_out.0",
- "txt_attn_q_norm": "attn.norm_added_q",
- "txt_attn_k_norm": "attn.norm_added_k",
- "txt_attn_proj": "attn.to_add_out",
- "img_mod.linear": "norm1.linear",
- "img_norm1": "norm1.norm",
- "img_norm2": "norm2",
- "img_mlp": "ff",
- "txt_mod.linear": "norm1_context.linear",
- "txt_norm1": "norm1.norm",
- "txt_norm2": "norm2_context",
- "txt_mlp": "ff_context",
- "self_attn_proj": "attn.to_out.0",
- "modulation.linear": "norm.linear",
- "pre_norm": "norm.norm",
- "final_layer.norm_final": "norm_out.norm",
- "final_layer.linear": "proj_out",
- "fc1": "net.0.proj",
- "fc2": "net.2",
- "input_embedder": "proj_in",
- }
-
- TRANSFORMER_SPECIAL_KEYS_REMAP = {
- "txt_in": remap_txt_in_,
- "img_attn_qkv": remap_img_attn_qkv_,
- "txt_attn_qkv": remap_txt_attn_qkv_,
- "single_blocks": remap_single_transformer_blocks_,
- "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
- }
-
- def update_state_dict_(state_dict, old_key, new_key):
- state_dict[new_key] = state_dict.pop(old_key)
-
- for key in list(checkpoint.keys()):
- new_key = key[:]
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
- update_state_dict_(checkpoint, key, new_key)
-
- for key in list(checkpoint.keys()):
- for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
- if special_key not in key:
- continue
- handler_fn_inplace(key, checkpoint)
-
- return checkpoint
+ from .single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers
+
+ deprecation_message = "Importing `convert_hunyuan_video_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_hunyuan_video_transformer_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs)
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
- state_dict_keys = list(checkpoint.keys())
-
- # Handle register tokens and positional embeddings
- converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
-
- # Handle time step projection
- converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
- converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
- converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
- converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
-
- # Handle context embedder
- converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
-
- # Calculate the number of layers
- def calculate_layers(keys, key_prefix):
- layers = set()
- for k in keys:
- if key_prefix in k:
- layer_num = int(k.split(".")[1]) # get the layer number
- layers.add(layer_num)
- return len(layers)
-
- mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
- single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
-
- # MMDiT blocks
- for i in range(mmdit_layers):
- # Feed-forward
- path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
- weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
- for orig_k, diffuser_k in path_mapping.items():
- for k, v in weight_mapping.items():
- converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
- f"double_layers.{i}.{orig_k}.{k}.weight", None
- )
-
- # Norms
- path_mapping = {"modX": "norm1", "modC": "norm1_context"}
- for orig_k, diffuser_k in path_mapping.items():
- converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
- f"double_layers.{i}.{orig_k}.1.weight", None
- )
-
- # Attentions
- x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
- context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
- for attn_mapping in [x_attn_mapping, context_attn_mapping]:
- for k, v in attn_mapping.items():
- converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
- f"double_layers.{i}.attn.{k}.weight", None
- )
-
- # Single-DiT blocks
- for i in range(single_dit_layers):
- # Feed-forward
- mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
- for k, v in mapping.items():
- converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
- f"single_layers.{i}.mlp.{k}.weight", None
- )
-
- # Norms
- converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
- f"single_layers.{i}.modCX.1.weight", None
- )
-
- # Attentions
- x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
- for k, v in x_attn_mapping.items():
- converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
- f"single_layers.{i}.attn.{k}.weight", None
- )
- # Final blocks
- converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
-
- # Handle the final norm layer
- norm_weight = checkpoint.pop("modF.1.weight", None)
- if norm_weight is not None:
- converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
- else:
- converted_state_dict["norm_out.linear.weight"] = None
-
- converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
- converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
- converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers
+
+ deprecation_message = "Importing `convert_auraflow_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers` instead."
+ deprecate(
+ "diffusers.loaders.single_file_utils.convert_auraflow_transformer_checkpoint_to_diffusers",
+ "0.36",
+ deprecation_message,
+ )
+ return convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs)
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
-
- # Original Lumina-Image-2 has an extra norm paramter that is unused
- # We just remove it here
- checkpoint.pop("norm_final.weight", None)
-
- # Comfy checkpoints add this prefix
- keys = list(checkpoint.keys())
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- LUMINA_KEY_MAP = {
- "cap_embedder": "time_caption_embed.caption_embedder",
- "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
- "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
- "attention": "attn",
- ".out.": ".to_out.0.",
- "k_norm": "norm_k",
- "q_norm": "norm_q",
- "w1": "linear_1",
- "w2": "linear_2",
- "w3": "linear_3",
- "adaLN_modulation.1": "norm1.linear",
- }
- ATTENTION_NORM_MAP = {
- "attention_norm1": "norm1.norm",
- "attention_norm2": "norm2",
- }
- CONTEXT_REFINER_MAP = {
- "context_refiner.0.attention_norm1": "context_refiner.0.norm1",
- "context_refiner.0.attention_norm2": "context_refiner.0.norm2",
- "context_refiner.1.attention_norm1": "context_refiner.1.norm1",
- "context_refiner.1.attention_norm2": "context_refiner.1.norm2",
- }
- FINAL_LAYER_MAP = {
- "final_layer.adaLN_modulation.1": "norm_out.linear_1",
- "final_layer.linear": "norm_out.linear_2",
- }
-
- def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
- q_dim = 2304
- k_dim = v_dim = 768
-
- to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
-
- return {
- diffusers_key.replace("qkv", "to_q"): to_q,
- diffusers_key.replace("qkv", "to_k"): to_k,
- diffusers_key.replace("qkv", "to_v"): to_v,
- }
-
- for key in keys:
- diffusers_key = key
- for k, v in CONTEXT_REFINER_MAP.items():
- diffusers_key = diffusers_key.replace(k, v)
- for k, v in FINAL_LAYER_MAP.items():
- diffusers_key = diffusers_key.replace(k, v)
- for k, v in ATTENTION_NORM_MAP.items():
- diffusers_key = diffusers_key.replace(k, v)
- for k, v in LUMINA_KEY_MAP.items():
- diffusers_key = diffusers_key.replace(k, v)
-
- if "qkv" in diffusers_key:
- converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
- else:
- converted_state_dict[diffusers_key] = checkpoint.pop(key)
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_lumina2_to_diffusers
+
+ deprecation_message = "Importing `convert_lumina2_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_lumina2_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_lumina2_to_diffusers", "0.36", deprecation_message)
+ return convert_lumina2_to_diffusers(checkpoint, **kwargs)
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
- keys = list(checkpoint.keys())
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
-
- # Positional and patch embeddings.
- checkpoint.pop("pos_embed")
- converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
- converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
-
- # Timestep embeddings.
- converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
- "t_embedder.mlp.0.weight"
- )
- converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
- converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
- "t_embedder.mlp.2.weight"
- )
- converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
- converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
- converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
-
- # Caption Projection.
- checkpoint.pop("y_embedder.y_embedding")
- converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
- converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
- converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
- converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
- converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
-
- for i in range(num_layers):
- converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
- f"blocks.{i}.scale_shift_table"
- )
-
- # Self-Attention
- sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
- converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
- converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
- converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
-
- # Output Projections
- converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
- f"blocks.{i}.attn.proj.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
- f"blocks.{i}.attn.proj.bias"
- )
-
- # Cross-Attention
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
- f"blocks.{i}.cross_attn.q_linear.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
- f"blocks.{i}.cross_attn.q_linear.bias"
- )
-
- linear_sample_k, linear_sample_v = torch.chunk(
- checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
- )
- linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
- checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
-
- # Output Projections
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
- f"blocks.{i}.cross_attn.proj.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
- f"blocks.{i}.cross_attn.proj.bias"
- )
-
- # MLP
- converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
- f"blocks.{i}.mlp.inverted_conv.conv.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
- f"blocks.{i}.mlp.inverted_conv.conv.bias"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
- f"blocks.{i}.mlp.depth_conv.conv.weight"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
- f"blocks.{i}.mlp.depth_conv.conv.bias"
- )
- converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
- f"blocks.{i}.mlp.point_conv.conv.weight"
- )
-
- # Final layer
- converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
- converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
- converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_sana_transformer_to_diffusers
+
+ deprecation_message = "Importing `convert_sana_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sana_transformer_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_sana_transformer_to_diffusers", "0.36", deprecation_message)
+ return convert_sana_transformer_to_diffusers(checkpoint, **kwargs)
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
-
- keys = list(checkpoint.keys())
- for k in keys:
- if "model.diffusion_model." in k:
- checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
-
- TRANSFORMER_KEYS_RENAME_DICT = {
- "time_embedding.0": "condition_embedder.time_embedder.linear_1",
- "time_embedding.2": "condition_embedder.time_embedder.linear_2",
- "text_embedding.0": "condition_embedder.text_embedder.linear_1",
- "text_embedding.2": "condition_embedder.text_embedder.linear_2",
- "time_projection.1": "condition_embedder.time_proj",
- "cross_attn": "attn2",
- "self_attn": "attn1",
- ".o.": ".to_out.0.",
- ".q.": ".to_q.",
- ".k.": ".to_k.",
- ".v.": ".to_v.",
- ".k_img.": ".add_k_proj.",
- ".v_img.": ".add_v_proj.",
- ".norm_k_img.": ".norm_added_k.",
- "head.modulation": "scale_shift_table",
- "head.head": "proj_out",
- "modulation": "scale_shift_table",
- "ffn.0": "ffn.net.0.proj",
- "ffn.2": "ffn.net.2",
- # Hack to swap the layer names
- # The original model calls the norms in following order: norm1, norm3, norm2
- # We convert it to: norm1, norm2, norm3
- "norm2": "norm__placeholder",
- "norm3": "norm2",
- "norm__placeholder": "norm3",
- # For the I2V model
- "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
- "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
- "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
- "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
- }
-
- for key in list(checkpoint.keys()):
- new_key = key[:]
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
- new_key = new_key.replace(replace_key, rename_key)
-
- converted_state_dict[new_key] = checkpoint.pop(key)
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_wan_transformer_to_diffusers
+
+ deprecation_message = "Importing `convert_wan_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_transformer_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_wan_transformer_to_diffusers", "0.36", deprecation_message)
+ return convert_wan_transformer_to_diffusers(checkpoint, **kwargs)
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
- converted_state_dict = {}
-
- # Create mappings for specific components
- middle_key_mapping = {
- # Encoder middle block
- "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
- "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
- "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
- "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
- "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
- "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
- "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
- "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
- "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
- "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
- "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
- "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
- # Decoder middle block
- "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
- "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
- "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
- "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
- "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
- "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
- "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
- "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
- "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
- "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
- "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
- "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
- }
-
- # Create a mapping for attention blocks
- attention_mapping = {
- # Encoder middle attention
- "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
- "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
- "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
- "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
- "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
- # Decoder middle attention
- "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
- "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
- "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
- "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
- "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
- }
-
- # Create a mapping for the head components
- head_mapping = {
- # Encoder head
- "encoder.head.0.gamma": "encoder.norm_out.gamma",
- "encoder.head.2.bias": "encoder.conv_out.bias",
- "encoder.head.2.weight": "encoder.conv_out.weight",
- # Decoder head
- "decoder.head.0.gamma": "decoder.norm_out.gamma",
- "decoder.head.2.bias": "decoder.conv_out.bias",
- "decoder.head.2.weight": "decoder.conv_out.weight",
- }
-
- # Create a mapping for the quant components
- quant_mapping = {
- "conv1.weight": "quant_conv.weight",
- "conv1.bias": "quant_conv.bias",
- "conv2.weight": "post_quant_conv.weight",
- "conv2.bias": "post_quant_conv.bias",
- }
-
- # Process each key in the state dict
- for key, value in checkpoint.items():
- # Handle middle block keys using the mapping
- if key in middle_key_mapping:
- new_key = middle_key_mapping[key]
- converted_state_dict[new_key] = value
- # Handle attention blocks using the mapping
- elif key in attention_mapping:
- new_key = attention_mapping[key]
- converted_state_dict[new_key] = value
- # Handle head keys using the mapping
- elif key in head_mapping:
- new_key = head_mapping[key]
- converted_state_dict[new_key] = value
- # Handle quant keys using the mapping
- elif key in quant_mapping:
- new_key = quant_mapping[key]
- converted_state_dict[new_key] = value
- # Handle encoder conv1
- elif key == "encoder.conv1.weight":
- converted_state_dict["encoder.conv_in.weight"] = value
- elif key == "encoder.conv1.bias":
- converted_state_dict["encoder.conv_in.bias"] = value
- # Handle decoder conv1
- elif key == "decoder.conv1.weight":
- converted_state_dict["decoder.conv_in.weight"] = value
- elif key == "decoder.conv1.bias":
- converted_state_dict["decoder.conv_in.bias"] = value
- # Handle encoder downsamples
- elif key.startswith("encoder.downsamples."):
- # Convert to down_blocks
- new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
-
- # Convert residual block naming but keep the original structure
- if ".residual.0.gamma" in new_key:
- new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
- elif ".residual.2.bias" in new_key:
- new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
- elif ".residual.2.weight" in new_key:
- new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
- elif ".residual.3.gamma" in new_key:
- new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
- elif ".residual.6.bias" in new_key:
- new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
- elif ".residual.6.weight" in new_key:
- new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
- elif ".shortcut.bias" in new_key:
- new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
- elif ".shortcut.weight" in new_key:
- new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
-
- converted_state_dict[new_key] = value
-
- # Handle decoder upsamples
- elif key.startswith("decoder.upsamples."):
- # Convert to up_blocks
- parts = key.split(".")
- block_idx = int(parts[2])
-
- # Group residual blocks
- if "residual" in key:
- if block_idx in [0, 1, 2]:
- new_block_idx = 0
- resnet_idx = block_idx
- elif block_idx in [4, 5, 6]:
- new_block_idx = 1
- resnet_idx = block_idx - 4
- elif block_idx in [8, 9, 10]:
- new_block_idx = 2
- resnet_idx = block_idx - 8
- elif block_idx in [12, 13, 14]:
- new_block_idx = 3
- resnet_idx = block_idx - 12
- else:
- # Keep as is for other blocks
- converted_state_dict[key] = value
- continue
-
- # Convert residual block naming
- if ".residual.0.gamma" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
- elif ".residual.2.bias" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
- elif ".residual.2.weight" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
- elif ".residual.3.gamma" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
- elif ".residual.6.bias" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
- elif ".residual.6.weight" in key:
- new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
- else:
- new_key = key
-
- converted_state_dict[new_key] = value
-
- # Handle shortcut connections
- elif ".shortcut." in key:
- if block_idx == 4:
- new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
- new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
- else:
- new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
- new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
-
- converted_state_dict[new_key] = value
-
- # Handle upsamplers
- elif ".resample." in key or ".time_conv." in key:
- if block_idx == 3:
- new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
- elif block_idx == 7:
- new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
- elif block_idx == 11:
- new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
- else:
- new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
-
- converted_state_dict[new_key] = value
- else:
- new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
- converted_state_dict[new_key] = value
- else:
- # Keep other keys unchanged
- converted_state_dict[key] = value
-
- return converted_state_dict
+ from .single_file.single_file_utils import convert_wan_vae_to_diffusers
+
+ deprecation_message = "Importing `convert_wan_vae_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_vae_to_diffusers` instead."
+ deprecate("diffusers.loaders.single_file_utils.convert_wan_vae_to_diffusers", "0.36", deprecation_message)
+ return convert_wan_vae_to_diffusers(checkpoint, **kwargs)
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
index 38a8a7ebe266..8487b0652feb 100644
--- a/src/diffusers/loaders/transformer_flux.py
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -11,170 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import nullcontext
-from ..models.embeddings import (
- ImageProjection,
- MultiIPAdapterImageProjection,
-)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
-from ..utils import (
- is_accelerate_available,
- is_torch_version,
- logging,
-)
+from ..utils import deprecate
+from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin
-if is_accelerate_available():
- pass
-
-logger = logging.get_logger(__name__)
-
-
-class FluxTransformer2DLoadersMixin:
- """
- Load layers into a [`FluxTransformer2DModel`].
- """
-
- def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- if low_cpu_mem_usage:
- if is_accelerate_available():
- from accelerate import init_empty_weights
-
- else:
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- updated_state_dict = {}
- image_projection = None
- init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
-
- if "proj.weight" in state_dict:
- # IP-Adapter
- num_image_text_embeds = 4
- if state_dict["proj.weight"].shape[0] == 65536:
- num_image_text_embeds = 16
- clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
- cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
-
- with init_context():
- image_projection = ImageProjection(
- cross_attention_dim=cross_attention_dim,
- image_embed_dim=clip_embeddings_dim,
- num_image_text_embeds=num_image_text_embeds,
- )
-
- for key, value in state_dict.items():
- diffusers_name = key.replace("proj", "image_embeds")
- updated_state_dict[diffusers_name] = value
-
- if not low_cpu_mem_usage:
- image_projection.load_state_dict(updated_state_dict, strict=True)
- else:
- device_map = {"": self.device}
- load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
-
- return image_projection
-
- def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- from ..models.attention_processor import (
- FluxIPAdapterJointAttnProcessor2_0,
- )
-
- if low_cpu_mem_usage:
- if is_accelerate_available():
- from accelerate import init_empty_weights
-
- else:
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- # set ip-adapter cross-attention processors & load state_dict
- attn_procs = {}
- key_id = 0
- init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
- for name in self.attn_processors.keys():
- if name.startswith("single_transformer_blocks"):
- attn_processor_class = self.attn_processors[name].__class__
- attn_procs[name] = attn_processor_class()
- else:
- cross_attention_dim = self.config.joint_attention_dim
- hidden_size = self.inner_dim
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
- num_image_text_embeds = []
- for state_dict in state_dicts:
- if "proj.weight" in state_dict["image_proj"]:
- num_image_text_embed = 4
- if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
- num_image_text_embed = 16
- # IP-Adapter
- num_image_text_embeds += [num_image_text_embed]
-
- with init_context():
- attn_procs[name] = attn_processor_class(
- hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim,
- scale=1.0,
- num_tokens=num_image_text_embeds,
- dtype=self.dtype,
- device=self.device,
- )
-
- value_dict = {}
- for i, state_dict in enumerate(state_dicts):
- value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
- value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
- value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
- value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
-
- if not low_cpu_mem_usage:
- attn_procs[name].load_state_dict(value_dict)
- else:
- device_map = {"": self.device}
- dtype = self.dtype
- load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
-
- key_id += 1
-
- return attn_procs
-
- def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- if not isinstance(state_dicts, list):
- state_dicts = [state_dicts]
-
- self.encoder_hid_proj = None
-
- attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
- self.set_attn_processor(attn_procs)
-
- image_projection_layers = []
- for state_dict in state_dicts:
- image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
- state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
- )
- image_projection_layers.append(image_projection_layer)
-
- self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
- self.config.encoder_hid_dim_type = "ip_image_proj"
+class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead."
+ deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
index ece17e6728fa..e2e4c9fb67a3 100644
--- a/src/diffusers/loaders/transformer_sd3.py
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -11,160 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from contextlib import nullcontext
-from typing import Dict
+from ..utils import deprecate
+from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin
-from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
-from ..models.embeddings import IPAdapterTimeImageProjection
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
-from ..utils import is_accelerate_available, is_torch_version, logging
-
-logger = logging.get_logger(__name__)
-
-
-class SD3Transformer2DLoadersMixin:
- """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
-
- def _convert_ip_adapter_attn_to_diffusers(
- self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
- ) -> Dict:
- if low_cpu_mem_usage:
- if is_accelerate_available():
- from accelerate import init_empty_weights
-
- else:
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- # IP-Adapter cross attention parameters
- hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
- ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
- timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
-
- # Dict where key is transformer layer index, value is attention processor's state dict
- # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
- layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
- for key, weights in state_dict.items():
- idx, name = key.split(".", maxsplit=1)
- layer_state_dict[int(idx)][name] = weights
-
- # Create IP-Adapter attention processor & load state_dict
- attn_procs = {}
- init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
- for idx, name in enumerate(self.attn_processors.keys()):
- with init_context():
- attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
- hidden_size=hidden_size,
- ip_hidden_states_dim=ip_hidden_states_dim,
- head_dim=self.config.attention_head_dim,
- timesteps_emb_dim=timesteps_emb_dim,
- )
-
- if not low_cpu_mem_usage:
- attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
- else:
- device_map = {"": self.device}
- load_model_dict_into_meta(
- attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
- )
-
- return attn_procs
-
- def _convert_ip_adapter_image_proj_to_diffusers(
- self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
- ) -> IPAdapterTimeImageProjection:
- if low_cpu_mem_usage:
- if is_accelerate_available():
- from accelerate import init_empty_weights
-
- else:
- low_cpu_mem_usage = False
- logger.warning(
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
- " install accelerate\n```\n."
- )
-
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
- raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
- )
-
- init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
-
- # Convert to diffusers
- updated_state_dict = {}
- for key, value in state_dict.items():
- # InstantX/SD3.5-Large-IP-Adapter
- if key.startswith("layers."):
- idx = key.split(".")[1]
- key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
- key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
- key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
- key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
- key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
- key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
- key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
- key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
- key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
- updated_state_dict[key] = value
-
- # Image projetion parameters
- embed_dim = updated_state_dict["proj_in.weight"].shape[1]
- output_dim = updated_state_dict["proj_out.weight"].shape[0]
- hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
- heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
- num_queries = updated_state_dict["latents"].shape[1]
- timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
-
- # Image projection
- with init_context():
- image_proj = IPAdapterTimeImageProjection(
- embed_dim=embed_dim,
- output_dim=output_dim,
- hidden_dim=hidden_dim,
- heads=heads,
- num_queries=num_queries,
- timestep_in_dim=timestep_in_dim,
- )
-
- if not low_cpu_mem_usage:
- image_proj.load_state_dict(updated_state_dict, strict=True)
- else:
- device_map = {"": self.device}
- load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
-
- return image_proj
-
- def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
- """Sets IP-Adapter attention processors, image projection, and loads state_dict.
-
- Args:
- state_dict (`Dict`):
- State dict with keys "ip_adapter", which contains parameters for attention processors, and
- "image_proj", which contains parameters for image projection net.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
- argument to `True` will raise an error.
- """
-
- attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
- self.set_attn_processor(attn_procs)
-
- self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
+class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead."
+ deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/loaders/unet/__init__.py b/src/diffusers/loaders/unet/__init__.py
new file mode 100644
index 000000000000..700ecbe53642
--- /dev/null
+++ b/src/diffusers/loaders/unet/__init__.py
@@ -0,0 +1,5 @@
+from ...utils import is_torch_available
+
+
+if is_torch_available():
+ from .unet import UNet2DConditionLoadersMixin
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet/unet.py
similarity index 98%
rename from src/diffusers/loaders/unet.py
rename to src/diffusers/loaders/unet/unet.py
index 1d8aba900c85..a632ceac210f 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet/unet.py
@@ -22,7 +22,7 @@
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
-from ..models.embeddings import (
+from ...models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
IPAdapterFaceIDPlusImageProjection,
@@ -30,8 +30,8 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
-from ..utils import (
+from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
+from ...utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
@@ -43,9 +43,9 @@
is_torch_version,
logging,
)
-from .lora_base import _func_optionally_disable_offloading
-from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
-from .utils import AttnProcsLayers
+from ..lora.lora_base import _func_optionally_disable_offloading
+from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
+from ..utils import AttnProcsLayers
logger = logging.get_logger(__name__)
@@ -247,7 +247,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
- from ..models.attention_processor import CustomDiffusionAttnProcessor
+ from ...models.attention_processor import CustomDiffusionAttnProcessor
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -395,7 +395,7 @@ def _process_lora(
return is_model_cpu_offload, is_sequential_cpu_offload
@classmethod
- # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
+ # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
@@ -451,7 +451,7 @@ def save_attn_procs(
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
```
"""
- from ..models.attention_processor import (
+ from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -513,7 +513,7 @@ def save_function(weights, filename):
logger.info(f"Model weights saved in {save_path}")
def _get_custom_diffusion_state_dict(self):
- from ..models.attention_processor import (
+ from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -759,7 +759,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- from ..models.attention_processor import (
+ from ...models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet/unet_loader_utils.py
similarity index 98%
rename from src/diffusers/loaders/unet_loader_utils.py
rename to src/diffusers/loaders/unet/unet_loader_utils.py
index 8f202ed4d44b..bfa935d13be8 100644
--- a/src/diffusers/loaders/unet_loader_utils.py
+++ b/src/diffusers/loaders/unet/unet_loader_utils.py
@@ -14,12 +14,12 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
-from ..utils import logging
+from ...utils import logging
if TYPE_CHECKING:
# import here to avoid circular imports
- from ..models import UNet2DConditionModel
+ from ...models import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 357df0c31087..119409d4c02c 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -17,8 +17,7 @@
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import PeftAdapterMixin
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index e2b26396899f..0f4be88c92d3 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -21,7 +21,7 @@
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py
index 7a6ca886caed..d1bf033459f0 100644
--- a/src/diffusers/models/controlnets/controlnet.py
+++ b/src/diffusers/models/controlnets/controlnet.py
@@ -19,7 +19,7 @@
from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py
index 26cb86718a21..0716934a7d8e 100644
--- a/src/diffusers/models/controlnets/controlnet_union.py
+++ b/src/diffusers/models/controlnets/controlnet_union.py
@@ -17,7 +17,7 @@
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py
index a873a6ec9444..25420394bdcd 100644
--- a/src/diffusers/models/transformers/transformer_lumina2.py
+++ b/src/diffusers/models/transformers/transformer_lumina2.py
@@ -20,8 +20,7 @@
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import PeftAdapterMixin
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py
index e6532f080d72..964c44ae4db4 100644
--- a/src/diffusers/models/transformers/transformer_mochi.py
+++ b/src/diffusers/models/transformers/transformer_mochi.py
@@ -19,8 +19,7 @@
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import PeftAdapterMixin
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 5674d8ba26ec..a35aa1a5b32c 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -19,8 +19,7 @@
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
-from ...loaders.single_file_model import FromOriginalModelMixin
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 860aa6511689..7124fcb41200 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -352,7 +352,7 @@ def test_with_norm_in_state_dict(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -403,7 +403,7 @@ def test_lora_parameter_expanded_shapes(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
@@ -486,7 +486,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -541,7 +541,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -590,7 +590,7 @@ def test_fuse_expanded_lora_with_regular_lora(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -653,7 +653,7 @@ def test_load_regular_lora(self):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
@@ -668,7 +668,7 @@ def test_load_regular_lora(self):
def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
@@ -734,7 +734,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 87a8fddfa583..fe132944a835 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -1017,7 +1017,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
)
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
- logger = logging.get_logger("diffusers.loaders.lora_base")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
@@ -1824,7 +1824,7 @@ def test_logs_info_when_no_lora_keys_found(self):
elif lora_module == "text_encoder_2":
prefix = "text_encoder_2"
- logger = logging.get_logger("diffusers.loaders.lora_base")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_base")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
@@ -1925,7 +1925,7 @@ def test_lora_B_bias(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e1713c9ceb1..10c229ba5e64 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -5,7 +5,7 @@
import torch
from huggingface_hub import hf_hub_download, snapshot_download
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index bba1726ae380..350a433cf68b 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -18,9 +18,7 @@
import torch
-from diffusers import (
- AutoencoderKL,
-)
+from diffusers import AutoencoderKL
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
index 7f0e1c1a4b0b..83acbdc56504 100644
--- a/tests/single_file/test_model_wan_autoencoder_single_file.py
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -16,9 +16,7 @@
import gc
import unittest
-from diffusers import (
- AutoencoderKLWan,
-)
+from diffusers import AutoencoderKLWan
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
index 36f0919cacb5..8a20135a8aab 100644
--- a/tests/single_file/test_model_wan_transformer3d_single_file.py
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -18,9 +18,7 @@
import torch
-from diffusers import (
- WanTransformer3DModel,
-)
+from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
index 802ca37abfc3..061a5698d33a 100644
--- a/tests/single_file/test_sana_transformer.py
+++ b/tests/single_file/test_sana_transformer.py
@@ -3,9 +3,7 @@
import torch
-from diffusers import (
- SanaTransformer2DModel,
-)
+from diffusers import SanaTransformer2DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index 7589b48028c2..29972241151d 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -5,7 +5,7 @@
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index 1555831db6db..166365a9a064 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -5,7 +5,7 @@
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index 2c1e414e5e36..35ce5daac250 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -5,7 +5,7 @@
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index 78baeb94929c..338842e8d08e 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -5,7 +5,7 @@
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index fb5f8725b86e..06a4f0bf5b3a 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -8,7 +8,7 @@
StableDiffusionXLAdapterPipeline,
T2IAdapter,
)
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index 6d8c4369e1e1..9ea2fcc29710 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -5,7 +5,7 @@
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
-from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
diff --git a/utils/check_support_list.py b/utils/check_support_list.py
index 5d06f0cb92a3..6f548796893c 100644
--- a/utils/check_support_list.py
+++ b/utils/check_support_list.py
@@ -98,7 +98,7 @@ def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_condit
},
"LoRA Mixins": {
"doc_path": "docs/source/en/api/loaders/lora.md",
- "src_path": "src/diffusers/loaders/lora_pipeline.py",
+ "src_path": "src/diffusers/loaders/lora/lora_pipeline.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
},