From 8267677a241f90d3aa7f15c9f86ec1814c56ccca Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:02:06 +0530 Subject: [PATCH 01/26] start folderizing the loaders. --- src/diffusers/loaders/__init__.py | 17 ++++++------- src/diffusers/loaders/ip_adapter/__init__.py | 9 +++++++ .../loaders/{ => ip_adapter}/ip_adapter.py | 0 .../{ => ip_adapter}/transformer_flux.py | 20 ++++------------ .../{ => ip_adapter}/transformer_sd3.py | 8 +++---- src/diffusers/loaders/lora/__init__.py | 24 +++++++++++++++++++ src/diffusers/loaders/{ => lora}/lora_base.py | 0 .../{ => lora}/lora_conversion_utils.py | 0 .../loaders/{ => lora}/lora_pipeline.py | 0 src/diffusers/loaders/single_file/__init__.py | 8 +++++++ .../loaders/{ => single_file}/single_file.py | 2 +- .../{ => single_file}/single_file_model.py | 8 +++---- .../{ => single_file}/single_file_utils.py | 6 ++--- src/diffusers/loaders/unet/__init__.py | 5 ++++ src/diffusers/loaders/{ => unet}/unet.py | 20 ++++++++-------- .../loaders/{ => unet}/unet_loader_utils.py | 0 16 files changed, 79 insertions(+), 48 deletions(-) create mode 100644 src/diffusers/loaders/ip_adapter/__init__.py rename src/diffusers/loaders/{ => ip_adapter}/ip_adapter.py (100%) rename src/diffusers/loaders/{ => ip_adapter}/transformer_flux.py (94%) rename src/diffusers/loaders/{ => ip_adapter}/transformer_sd3.py (96%) create mode 100644 src/diffusers/loaders/lora/__init__.py rename src/diffusers/loaders/{ => lora}/lora_base.py (100%) rename src/diffusers/loaders/{ => lora}/lora_conversion_utils.py (100%) rename src/diffusers/loaders/{ => lora}/lora_pipeline.py (100%) create mode 100644 src/diffusers/loaders/single_file/__init__.py rename src/diffusers/loaders/{ => single_file}/single_file.py (99%) rename src/diffusers/loaders/{ => single_file}/single_file_model.py (98%) rename src/diffusers/loaders/{ => single_file}/single_file_utils.py (99%) create mode 100644 src/diffusers/loaders/unet/__init__.py rename src/diffusers/loaders/{ => unet}/unet.py (98%) rename src/diffusers/loaders/{ => unet}/unet_loader_utils.py (100%) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 7b440f6f4515..b4f3fb805496 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -54,7 +54,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure = {} if is_torch_available(): - _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"] _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "LoraBaseMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -90,25 +91,21 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from .single_file_model import FromOriginalModelMixin - from .transformer_flux import FluxTransformer2DLoadersMixin - from .transformer_sd3 import SD3Transformer2DLoadersMixin + from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin + from .single_file import FromOriginalModelMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import ( - FluxIPAdapterMixin, - IPAdapterMixin, - SD3IPAdapterMixin, - ) - from .lora_pipeline import ( + from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin + from .lora import ( AmusedLoraLoaderMixin, AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, CogView4LoraLoaderMixin, FluxLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + LoraBaseMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter/__init__.py b/src/diffusers/loaders/ip_adapter/__init__.py new file mode 100644 index 000000000000..42b1d6e1b509 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter/__init__.py @@ -0,0 +1,9 @@ +from ...utils.import_utils import is_torch_available, is_transformers_available + + +if is_torch_available(): + from .transformer_flux import FluxTransformer2DLoadersMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin + + if is_transformers_available(): + from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter/ip_adapter.py similarity index 100% rename from src/diffusers/loaders/ip_adapter.py rename to src/diffusers/loaders/ip_adapter/ip_adapter.py diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/ip_adapter/transformer_flux.py similarity index 94% rename from src/diffusers/loaders/transformer_flux.py rename to src/diffusers/loaders/ip_adapter/transformer_flux.py index 38a8a7ebe266..5d6e7c809b12 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/ip_adapter/transformer_flux.py @@ -13,21 +13,11 @@ # limitations under the License. from contextlib import nullcontext -from ..models.embeddings import ( - ImageProjection, - MultiIPAdapterImageProjection, -) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta -from ..utils import ( - is_accelerate_available, - is_torch_version, - logging, -) +from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ...utils import is_accelerate_available, is_torch_version, logging -if is_accelerate_available(): - pass - logger = logging.get_logger(__name__) @@ -88,9 +78,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( - FluxIPAdapterJointAttnProcessor2_0, - ) + from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 if low_cpu_mem_usage: if is_accelerate_available(): diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/ip_adapter/transformer_sd3.py similarity index 96% rename from src/diffusers/loaders/transformer_sd3.py rename to src/diffusers/loaders/ip_adapter/transformer_sd3.py index ece17e6728fa..5911ec5903d5 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/ip_adapter/transformer_sd3.py @@ -14,10 +14,10 @@ from contextlib import nullcontext from typing import Dict -from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 -from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta -from ..utils import is_accelerate_available, is_torch_version, logging +from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ...models.embeddings import IPAdapterTimeImageProjection +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ...utils import is_accelerate_available, is_torch_version, logging logger = logging.get_logger(__name__) diff --git a/src/diffusers/loaders/lora/__init__.py b/src/diffusers/loaders/lora/__init__.py new file mode 100644 index 000000000000..c7b1d9988e78 --- /dev/null +++ b/src/diffusers/loaders/lora/__init__.py @@ -0,0 +1,24 @@ +from ...utils import is_peft_available, is_torch_available, is_transformers_available + + +if is_torch_available(): + from .lora_base import LoraBaseMixin + + if is_transformers_available(): + from .lora_pipeline import ( + AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, + CogVideoXLoraLoaderMixin, + CogView4LoraLoaderMixin, + FluxLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, + LoraLoaderMixin, + LTXVideoLoraLoaderMixin, + Lumina2LoraLoaderMixin, + Mochi1LoraLoaderMixin, + SanaLoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + WanLoraLoaderMixin, + ) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora/lora_base.py similarity index 100% rename from src/diffusers/loaders/lora_base.py rename to src/diffusers/loaders/lora/lora_base.py diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora/lora_conversion_utils.py similarity index 100% rename from src/diffusers/loaders/lora_conversion_utils.py rename to src/diffusers/loaders/lora/lora_conversion_utils.py diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py similarity index 100% rename from src/diffusers/loaders/lora_pipeline.py rename to src/diffusers/loaders/lora/lora_pipeline.py diff --git a/src/diffusers/loaders/single_file/__init__.py b/src/diffusers/loaders/single_file/__init__.py new file mode 100644 index 000000000000..827685218160 --- /dev/null +++ b/src/diffusers/loaders/single_file/__init__.py @@ -0,0 +1,8 @@ +from ...utils import is_torch_available, is_transformers_available + + +if is_torch_available(): + from .single_file_model import FromOriginalModelMixin + + if is_transformers_available(): + from .single_file import FromSingleFileMixin diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file/single_file.py similarity index 99% rename from src/diffusers/loaders/single_file.py rename to src/diffusers/loaders/single_file/single_file.py index c2843fc7406a..26d94e7935ef 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file/single_file.py @@ -21,7 +21,7 @@ from packaging import version from typing_extensions import Self -from ..utils import deprecate, is_transformers_available, logging +from ...utils import deprecate, is_transformers_available, logging from .single_file_utils import ( SingleFileComponentError, _is_legacy_scheduler_kwargs, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file/single_file_model.py similarity index 98% rename from src/diffusers/loaders/single_file_model.py rename to src/diffusers/loaders/single_file/single_file_model.py index a2f27b765a1b..0b97d7fbfafd 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file/single_file_model.py @@ -21,9 +21,9 @@ from huggingface_hub.utils import validate_hf_hub_args from typing_extensions import Self -from .. import __version__ -from ..quantizers import DiffusersAutoQuantizer -from ..utils import deprecate, is_accelerate_available, logging +from ... import __version__ +from ...quantizers import DiffusersAutoQuantizer +from ...utils import deprecate, is_accelerate_available, logging from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -58,7 +58,7 @@ if is_accelerate_available(): from accelerate import dispatch_model, init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ...models.modeling_utils import load_model_dict_into_meta SINGLE_FILE_LOADABLE_CLASSES = { diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file/single_file_utils.py similarity index 99% rename from src/diffusers/loaders/single_file_utils.py rename to src/diffusers/loaders/single_file/single_file_utils.py index b55b1b55206e..0c13fadcfd43 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file/single_file_utils.py @@ -25,8 +25,8 @@ import torch import yaml -from ..models.modeling_utils import load_state_dict -from ..schedulers import ( +from ...models.modeling_utils import load_state_dict +from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EDMDPMSolverMultistepScheduler, @@ -54,7 +54,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ...models.modeling_utils import load_model_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/loaders/unet/__init__.py b/src/diffusers/loaders/unet/__init__.py new file mode 100644 index 000000000000..700ecbe53642 --- /dev/null +++ b/src/diffusers/loaders/unet/__init__.py @@ -0,0 +1,5 @@ +from ...utils import is_torch_available + + +if is_torch_available(): + from .unet import UNet2DConditionLoadersMixin diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet/unet.py similarity index 98% rename from src/diffusers/loaders/unet.py rename to src/diffusers/loaders/unet/unet.py index 1d8aba900c85..564316a2c51d 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet/unet.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args -from ..models.embeddings import ( +from ...models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterFaceIDPlusImageProjection, @@ -30,8 +30,8 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict -from ..utils import ( +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict +from ...utils import ( USE_PEFT_BACKEND, _get_model_file, convert_unet_state_dict_to_peft, @@ -43,9 +43,9 @@ is_torch_version, logging, ) -from .lora_base import _func_optionally_disable_offloading -from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME -from .utils import AttnProcsLayers +from ..lora import _func_optionally_disable_offloading +from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME +from ..utils import AttnProcsLayers logger = logging.get_logger(__name__) @@ -247,7 +247,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Unsafe code /> def _process_custom_diffusion(self, state_dict): - from ..models.attention_processor import CustomDiffusionAttnProcessor + from ...models.attention_processor import CustomDiffusionAttnProcessor attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict) @@ -451,7 +451,7 @@ def save_attn_procs( pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") ``` """ - from ..models.attention_processor import ( + from ...models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, @@ -513,7 +513,7 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {save_path}") def _get_custom_diffusion_state_dict(self): - from ..models.attention_processor import ( + from ...models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, @@ -759,7 +759,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( + from ...models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet/unet_loader_utils.py similarity index 100% rename from src/diffusers/loaders/unet_loader_utils.py rename to src/diffusers/loaders/unet/unet_loader_utils.py From eb47a67d5012d018be08213ee3a6dd4ae60114ca Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:18:30 +0530 Subject: [PATCH 02/26] fix --- src/diffusers/loaders/ip_adapter/ip_adapter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter/ip_adapter.py b/src/diffusers/loaders/ip_adapter/ip_adapter.py index 025f52521485..782391d84a88 100644 --- a/src/diffusers/loaders/ip_adapter/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter/ip_adapter.py @@ -20,8 +20,8 @@ from huggingface_hub.utils import validate_hf_hub_args from safetensors import safe_open -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict -from ..utils import ( +from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict +from ...utils import ( USE_PEFT_BACKEND, _get_detailed_type, _get_model_file, @@ -31,13 +31,13 @@ is_transformers_available, logging, ) -from .unet_loader_utils import _maybe_expand_lora_scales +from ..unet.unet_loader_utils import _maybe_expand_lora_scales if is_transformers_available(): from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel -from ..models.attention_processor import ( +from ...models.attention_processor import ( AttnProcessor, AttnProcessor2_0, FluxAttnProcessor2_0, From a71334b86144588eac4fe72729529225a8d547e4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:22:12 +0530 Subject: [PATCH 03/26] fixes --- src/diffusers/loaders/lora/lora_pipeline.py | 114 ++++++++++---------- src/diffusers/loaders/peft.py | 2 +- src/diffusers/loaders/unet/unet.py | 2 +- 3 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py index aa508cf87f40..7a46be45984b 100644 --- a/src/diffusers/loaders/lora/lora_pipeline.py +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -713,7 +713,7 @@ def load_lora_weights( @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -828,7 +828,7 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet def load_lora_into_unet( cls, state_dict, @@ -906,7 +906,7 @@ def load_lora_into_unet( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, @@ -1386,7 +1386,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, @@ -1461,7 +1461,7 @@ def load_lora_into_text_encoder( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1523,7 +1523,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], @@ -1571,7 +1571,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of @@ -1603,7 +1603,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1699,7 +1699,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1751,7 +1751,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -1812,7 +1812,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1859,7 +1859,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1907,7 +1907,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of @@ -2339,7 +2339,7 @@ def _load_norm_into_transformer( return overwritten_layers_state_dict @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, @@ -2414,7 +2414,7 @@ def load_lora_into_text_encoder( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -2827,7 +2827,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): text_encoder_name = TEXT_ENCODER_NAME @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel def load_lora_into_transformer( cls, state_dict, @@ -2899,7 +2899,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + # Copied from diffusers.loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, @@ -3038,7 +3038,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3185,7 +3185,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -3368,7 +3368,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3464,7 +3464,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -3516,7 +3516,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -3577,7 +3577,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -3624,7 +3624,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -3672,7 +3672,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -3701,7 +3701,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3797,7 +3797,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -3849,7 +3849,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -3910,7 +3910,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -3957,7 +3957,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -4005,7 +4005,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -4034,7 +4034,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -4130,7 +4130,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -4182,7 +4182,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -4243,7 +4243,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -4290,7 +4290,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -4338,7 +4338,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -4466,7 +4466,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -4518,7 +4518,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -4579,7 +4579,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -4626,7 +4626,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -4674,7 +4674,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -4803,7 +4803,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -4855,7 +4855,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -4916,7 +4916,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -4963,7 +4963,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -5011,7 +5011,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -5221,7 +5221,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -5282,7 +5282,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -5329,7 +5329,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -5377,7 +5377,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -5406,7 +5406,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -5502,7 +5502,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -5554,7 +5554,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -5615,7 +5615,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -5662,7 +5662,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -5710,7 +5710,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1d990e81458d..7d114ea311df 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -98,7 +98,7 @@ class PeftAdapterMixin: _prepare_lora_hotswap_kwargs: Optional[dict] = None @classmethod - # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. diff --git a/src/diffusers/loaders/unet/unet.py b/src/diffusers/loaders/unet/unet.py index 564316a2c51d..fd86d26be132 100644 --- a/src/diffusers/loaders/unet/unet.py +++ b/src/diffusers/loaders/unet/unet.py @@ -395,7 +395,7 @@ def _process_lora( return is_model_cpu_offload, is_sequential_cpu_offload @classmethod - # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + # Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. From 21b25669330d8b54e2ab26d4dbdc659056a210af Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:23:16 +0530 Subject: [PATCH 04/26] fixes --- src/diffusers/loaders/single_file/single_file_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/single_file/single_file_utils.py b/src/diffusers/loaders/single_file/single_file_utils.py index 0c13fadcfd43..a37a6c22d432 100644 --- a/src/diffusers/loaders/single_file/single_file_utils.py +++ b/src/diffusers/loaders/single_file/single_file_utils.py @@ -36,7 +36,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ..utils import ( +from ...utils import ( SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, deprecate, @@ -44,8 +44,8 @@ is_transformers_available, logging, ) -from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT -from ..utils.hub_utils import _get_model_file +from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT +from ...utils.hub_utils import _get_model_file if is_transformers_available(): From ea3ba4f431361ac3f2d3eebdb60f8000d313be25 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:26:30 +0530 Subject: [PATCH 05/26] fies --- utils/check_support_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/check_support_list.py b/utils/check_support_list.py index 89cfce62de0b..3248fae4ec14 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -98,7 +98,7 @@ def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_condit }, "LoRA Mixins": { "doc_path": "docs/source/en/api/loaders/lora.md", - "src_path": "src/diffusers/loaders/lora_pipeline.py", + "src_path": "src/diffusers/loaders/lora/lora_pipeline.py", "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)", "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):", }, From 2da3cb4a8ccd58f173a2bf0980c22e26cc407129 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:27:37 +0530 Subject: [PATCH 06/26] fixes --- src/diffusers/loaders/lora/lora_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_base.py b/src/diffusers/loaders/lora/lora_base.py index 280a9fa6e73f..80445106015b 100644 --- a/src/diffusers/loaders/lora/lora_base.py +++ b/src/diffusers/loaders/lora/lora_base.py @@ -24,8 +24,8 @@ from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE -from ..models.modeling_utils import ModelMixin, load_state_dict -from ..utils import ( +from ...models.modeling_utils import ModelMixin, load_state_dict +from ...utils import ( USE_PEFT_BACKEND, _get_model_file, convert_state_dict_to_diffusers, @@ -50,7 +50,7 @@ if is_transformers_available(): from transformers import PreTrainedModel - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer From 178b884673c319be4f929e0442a1e4d54041cd5c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:29:16 +0530 Subject: [PATCH 07/26] updates --- src/diffusers/loaders/peft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 7d114ea311df..85a52913bb2f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -35,8 +35,8 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading -from .unet_loader_utils import _maybe_expand_lora_scales +from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading +from .unet.unet_loader_utils import _maybe_expand_lora_scales logger = logging.get_logger(__name__) From d870e3c9a67d291b756e23ff902bb0f3d1123325 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:35:09 +0530 Subject: [PATCH 08/26] update --- src/diffusers/loaders/lora/lora_pipeline.py | 2 +- src/diffusers/loaders/unet/unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py index 7a46be45984b..fce66e43937e 100644 --- a/src/diffusers/loaders/lora/lora_pipeline.py +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -18,7 +18,7 @@ import torch from huggingface_hub.utils import validate_hf_hub_args -from ..utils import ( +from ...utils import ( USE_PEFT_BACKEND, deprecate, get_submodule_by_name, diff --git a/src/diffusers/loaders/unet/unet.py b/src/diffusers/loaders/unet/unet.py index fd86d26be132..a632ceac210f 100644 --- a/src/diffusers/loaders/unet/unet.py +++ b/src/diffusers/loaders/unet/unet.py @@ -43,7 +43,7 @@ is_torch_version, logging, ) -from ..lora import _func_optionally_disable_offloading +from ..lora.lora_base import _func_optionally_disable_offloading from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from ..utils import AttnProcsLayers From 4faac73219dfc5fde7a119b55dd65c6baac1d2e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:38:58 +0530 Subject: [PATCH 09/26] update --- src/diffusers/loaders/lora/lora_conversion_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora/lora_conversion_utils.py b/src/diffusers/loaders/lora/lora_conversion_utils.py index 7fec3299eeac..e15ddeb9682f 100644 --- a/src/diffusers/loaders/lora/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora/lora_conversion_utils.py @@ -17,7 +17,7 @@ import torch -from ..utils import is_peft_version, logging, state_dict_all_zero +from ...utils import is_peft_version, logging, state_dict_all_zero logger = logging.get_logger(__name__) From f2aa2f91dc54506071f7033f0aee442f522be65b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:46:04 +0530 Subject: [PATCH 10/26] fix --- src/diffusers/loaders/lora/lora_pipeline.py | 4 ++-- src/diffusers/loaders/unet/unet_loader_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py index fce66e43937e..07ee90891ffb 100644 --- a/src/diffusers/loaders/lora/lora_pipeline.py +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -73,10 +73,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): if is_bitsandbytes_available(): - from ..quantizers.bitsandbytes import dequantize_bnb_weight + from ...quantizers.bitsandbytes import dequantize_bnb_weight if is_gguf_available(): - from ..quantizers.gguf.utils import dequantize_gguf_tensor + from ...quantizers.gguf.utils import dequantize_gguf_tensor is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" diff --git a/src/diffusers/loaders/unet/unet_loader_utils.py b/src/diffusers/loaders/unet/unet_loader_utils.py index 8f202ed4d44b..bfa935d13be8 100644 --- a/src/diffusers/loaders/unet/unet_loader_utils.py +++ b/src/diffusers/loaders/unet/unet_loader_utils.py @@ -14,12 +14,12 @@ import copy from typing import TYPE_CHECKING, Dict, List, Union -from ..utils import logging +from ...utils import logging if TYPE_CHECKING: # import here to avoid circular imports - from ..models import UNet2DConditionModel + from ...models import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name From ea0ce4bfabfeb06b63857f027f5e2dc5030fb6df Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 12:50:09 +0530 Subject: [PATCH 11/26] fixes --- src/diffusers/models/unets/unet_2d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..a35aa1a5b32c 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -19,8 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..activations import get_activation from ..attention_processor import ( From 6138cc17202f7890883af199c2c30603356f9654 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 13:01:48 +0530 Subject: [PATCH 12/26] updates --- src/diffusers/loaders/__init__.py | 12 ++++++------ src/diffusers/models/autoencoders/autoencoder_kl.py | 3 +-- .../models/autoencoders/autoencoder_kl_cogvideox.py | 2 +- src/diffusers/models/controlnets/controlnet.py | 2 +- src/diffusers/models/controlnets/controlnet_union.py | 2 +- .../models/transformers/transformer_lumina2.py | 3 +-- .../models/transformers/transformer_mochi.py | 3 +-- 7 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b4f3fb805496..eac72c704816 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -54,14 +54,14 @@ def text_encoder_attn_modules(text_encoder): _import_structure = {} if is_torch_available(): + _import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"] + _import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"] - _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"] - _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] - _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] + _import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): - _import_structure["single_file"] = ["FromSingleFileMixin"] - _import_structure["lora_pipeline"] = [ + _import_structure["single_file.single_file"] = ["FromSingleFileMixin"] + _import_structure["lora.lora_pipeline"] = [ "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", @@ -80,7 +80,7 @@ def text_encoder_attn_modules(text_encoder): "LoraBaseMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = [ + _import_structure["ip_adapter.ip_adapter"] = [ "IPAdapterMixin", "FluxIPAdapterMixin", "SD3IPAdapterMixin", diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 357df0c31087..119409d4c02c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -17,8 +17,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import deprecate from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index e2b26396899f..0f4be88c92d3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 7a6ca886caed..d1bf033459f0 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 26cb86718a21..0716934a7d8e 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -17,7 +17,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin from ...utils import logging from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a873a6ec9444..25420394bdcd 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -20,8 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward from ..attention_processor import Attention diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6532f080d72..964c44ae4db4 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -19,8 +19,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...loaders.single_file_model import FromOriginalModelMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward From 1b7c286974d0731707dbd691cc8df7302c5c6cf4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 15:09:23 +0530 Subject: [PATCH 13/26] fix --- docs/source/en/api/loaders/lora.md | 34 +++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 093cc99972a3..f09e990e96bf 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -25,6 +25,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana). - [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video). - [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). +- [`WanLoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). +- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView3](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -36,50 +38,58 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## StableDiffusionLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin ## StableDiffusionXLLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin ## SD3LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin ## FluxLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin ## CogVideoXLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin ## Mochi1LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin ## AuraFlowLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin ## LTXVideoLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin ## SanaLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin ## HunyuanVideoLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin ## Lumina2LoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin + +## WanLoraLoaderMixin + +[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin + +## CogView4LoraLoaderMixin + +[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin ## AmusedLoraLoaderMixin -[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin +[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin ## LoraBaseMixin From f0ea9ff2e27b1a1e2314380068dce8a108d2eb72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 15:47:40 +0530 Subject: [PATCH 14/26] deprecate lora loader from loaders easily. --- src/diffusers/loaders/lora_base.py | 77 ++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 140 +++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 src/diffusers/loaders/lora_base.py create mode 100644 src/diffusers/loaders/lora_pipeline.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py new file mode 100644 index 000000000000..0970ef86f109 --- /dev/null +++ b/src/diffusers/loaders/lora_base.py @@ -0,0 +1,77 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate +from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401 + + +def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + from .lora.lora_base import fuse_text_encoder_lora + + deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead." + deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message) + + return fuse_text_encoder_lora( + text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + +def unfuse_text_encoder_lora(text_encoder): + from .lora.lora_base import unfuse_text_encoder_lora + + deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead." + deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message) + + return unfuse_text_encoder_lora(text_encoder) + + +def set_adapters_for_text_encoder( + adapter_names, + text_encoder=None, + text_encoder_weights=None, +): + from .lora.lora_base import set_adapters_for_text_encoder + + deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message) + + return set_adapters_for_text_encoder( + adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights + ) + + +def disable_lora_for_text_encoder(text_encoder=None): + from .lora.lora_base import disable_lora_for_text_encoder + + deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message) + + return disable_lora_for_text_encoder(text_encoder=text_encoder) + + +def enable_lora_for_text_encoder(text_encoder=None): + from .lora.lora_base import enable_lora_for_text_encoder + + deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead." + deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message) + + return enable_lora_for_text_encoder(text_encoder=text_encoder) + + +class LoraBaseMixin(LoraBaseMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead." + deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py new file mode 100644 index 000000000000..a1cd93feb0b1 --- /dev/null +++ b/src/diffusers/loaders/lora_pipeline.py @@ -0,0 +1,140 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate +from .lora.lora_pipeline import ( + TEXT_ENCODER_NAME, # noqa: F401 + TRANSFORMER_NAME, # noqa: F401 + UNET_NAME, # noqa: F401 + AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, + CogVideoXLoraLoaderMixin, + CogView4LoraLoaderMixin, + FluxLoraLoaderMixin, + HunyuanVideoLoraLoaderMixin, + LTXVideoLoraLoaderMixin, + Lumina2LoraLoaderMixin, + Mochi1LoraLoaderMixin, + SanaLoraLoaderMixin, + SD3LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + WanLoraLoaderMixin, +) + + +class StableDiffusionLoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `StableDiffusionLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class StableDiffusionXLLoraLoaderMixin(StableDiffusionXLLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `StableDiffusionXLLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import StableDiffusionXLLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class SD3LoraLoaderMixin(SD3LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SD3LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class AuraFlowLoraLoaderMixin(AuraFlowLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `AuraFlowLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AuraFlowLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.AuraFlowLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class FluxLoraLoaderMixin(FluxLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import FluxLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class AmusedLoraLoaderMixin(AmusedLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `AmusedLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import AmusedLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.AmusedLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class CogVideoXLoraLoaderMixin(CogVideoXLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CogVideoXLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogVideoXLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class Mochi1LoraLoaderMixin(Mochi1LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Mochi1LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Mochi1LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class LTXVideoLoraLoaderMixin(LTXVideoLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `LTXVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import LTXVideoLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.LTXVideoLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class SanaLoraLoaderMixin(SanaLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SanaLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import SanaLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class HunyuanVideoLoraLoaderMixin(HunyuanVideoLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `HunyuanVideoLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import HunyuanVideoLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class Lumina2LoraLoaderMixin(Lumina2LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Lumina2LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import Lumina2LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.Lumina2LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class WanLoraLoaderMixin(WanLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `WanLoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import WanLoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.WanLoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class CogView4LoraLoaderMixin(CogView4LoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CogView4LoraLoaderMixin` from diffusers.loaders.lora_pipeline has been deprecated. Please use `from diffusers.loaders.lora.lora_pipeline import CogView4LoraLoaderMixin` instead." + deprecate("diffusers.loaders.lora_pipeline.CogView4LoraLoaderMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." + deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) From ea3f0b8d6821c99c330984f84462faa1769744f6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 15:49:14 +0530 Subject: [PATCH 15/26] update --- docs/source/en/api/loaders/ip_adapter.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index 946a8b1af875..260762f36d50 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] ## IPAdapterMixin -[[autodoc]] loaders.ip_adapter.IPAdapterMixin +[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin ## SD3IPAdapterMixin -[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin +[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin - all - is_ip_adapter_active From 546446ae21f8a677f8f4d848e5e0295d1b27f501 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 15:52:57 +0530 Subject: [PATCH 16/26] ip_adapter. --- src/diffusers/loaders/ip_adapter.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/diffusers/loaders/ip_adapter.py diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py new file mode 100644 index 000000000000..07043114187e --- /dev/null +++ b/src/diffusers/loaders/ip_adapter.py @@ -0,0 +1,38 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate +from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin + + +class IPAdapterMixin(IPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import IPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.IPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class FluxIPAdapterMixin(FluxIPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxIPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import FluxIPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.FluxIPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +class SD3IPAdapterMixin(SD3IPAdapterMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import SD3IPAdapterMixin` instead." + deprecate("diffusers.loaders.ip_adapter.SD3IPAdapterMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) From 0e8d1d25eb2e0710b49a585e50352958f4b1c2e0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 15:59:01 +0530 Subject: [PATCH 17/26] ip_adapter --- docs/source/en/api/loaders/transformer_sd3.md | 4 ++-- src/diffusers/loaders/transformer_flux.py | 23 +++++++++++++++++++ src/diffusers/loaders/transformer_sd3.py | 22 ++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/loaders/transformer_flux.py create mode 100644 src/diffusers/loaders/transformer_sd3.py diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md index 4fc9603054b4..cb4f9bdf37b6 100644 --- a/docs/source/en/api/loaders/transformer_sd3.md +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # SD3Transformer2D -This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. @@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## SD3Transformer2DLoadersMixin -[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin +[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin - all - _load_ip_adapter_weights \ No newline at end of file diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py new file mode 100644 index 000000000000..8487b0652feb --- /dev/null +++ b/src/diffusers/loaders/transformer_flux.py @@ -0,0 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import deprecate +from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin + + +class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead." + deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py new file mode 100644 index 000000000000..e2e4c9fb67a3 --- /dev/null +++ b/src/diffusers/loaders/transformer_sd3.py @@ -0,0 +1,22 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin + + +class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead." + deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) From 1ddfe14220e8c94326ec990a8af4e3ab9952bb89 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 16:30:39 +0530 Subject: [PATCH 18/26] single_file --- src/diffusers/loaders/single_file.py | 59 ++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 src/diffusers/loaders/single_file.py diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py new file mode 100644 index 000000000000..e6acefb43976 --- /dev/null +++ b/src/diffusers/loaders/single_file.py @@ -0,0 +1,59 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .single_file.single_file import FromSingleFileMixin + + +def load_single_file_sub_model( + library_name, + class_name, + name, + checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config=None, + local_files_only=False, + torch_dtype=None, + is_legacy_loading=False, + disable_mmap=False, + **kwargs, +): + from .single_file.single_file import load_single_file_sub_model + + deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead." + deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message) + + return load_single_file_sub_model( + library_name, + class_name, + name, + checkpoint, + pipelines, + is_pipeline_module, + cached_model_config_path, + original_config, + local_files_only, + torch_dtype, + is_legacy_loading, + disable_mmap, + **kwargs, + ) + + +class FromSingleFileMixin(FromSingleFileMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead." + deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) From 27d2401e591bc2876b7045a796c7b45e543f6361 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 16:52:54 +0530 Subject: [PATCH 19/26] partially complete single_file_utils --- src/diffusers/loaders/single_file_utils.py | 1939 ++++++++++++++++++++ 1 file changed, 1939 insertions(+) create mode 100644 src/diffusers/loaders/single_file_utils.py diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py new file mode 100644 index 000000000000..4ae1850fa769 --- /dev/null +++ b/src/diffusers/loaders/single_file_utils.py @@ -0,0 +1,1939 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch + +from ..utils import deprecate +from .single_file.single_file_utils import ( + CHECKPOINT_KEY_NAMES, + DIFFUSERS_TO_LDM_MAPPING, + LDM_CLIP_PREFIX_TO_REMOVE, + LDM_OPEN_CLIP_TEXT_PROJECTION_DIM, + SD_2_TEXT_ENCODER_KEYS_TO_IGNORE, + SingleFileComponentError, +) + + +class SingleFileComponentError(SingleFileComponentError): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `SingleFileComponentError` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_files_utils import SingleFileComponentError` instead." + deprecate("diffusers.loaders.single_file_utils. ", "0.36", deprecation_message) + super().__init__(*args, **kwargs) + + +def is_valid_url(url): + from .single_file.single_file_utils import is_valid_url + + deprecation_message = "Importing `is_valid_url()` from diffusers.loaders.single_file_utils has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_valid_url` instead." + deprecate("diffusers.loaders.single_file_utils.is_valid_url", "0.36", deprecation_message) + + return is_valid_url(url) + + +def load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, + disable_mmap=False, + user_agent=None, +): + from .single_file.single_file_utils import load_single_file_checkpoint + + deprecation_message = "Importing `load_single_file_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import load_single_file_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.load_single_file_checkpoint", "0.36", deprecation_message) + + return load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download, + proxies, + token, + cache_dir, + local_files_only, + revision, + disable_mmap, + user_agent, + ) + + +def fetch_original_config(original_config_file, local_files_only=False): + from .single_file.single_file_utils import fetch_original_config + + deprecation_message = "Importing `fetch_original_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_original_config` instead." + deprecate("diffusers.loaders.single_file_utils.fetch_original_config", "0.36", deprecation_message) + + return fetch_original_config(original_config_file, local_files_only) + + +def is_clip_model(checkpoint): + from .single_file.single_file_utils import is_clip_model + + deprecation_message = "Importing `is_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_model", "0.36", deprecation_message) + + return is_clip_model(checkpoint) + + +def is_clip_sdxl_model(checkpoint): + from .single_file.single_file_utils import is_clip_sdxl_model + + deprecation_message = "Importing `is_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sdxl_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_sdxl_model", "0.36", deprecation_message) + + return is_clip_sdxl_model(checkpoint) + + +def is_clip_sd3_model(checkpoint): + from .single_file.single_file_utils import is_clip_sd3_model + + deprecation_message = "Importing `is_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_sd3_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_sd3_model", "0.36", deprecation_message) + + return is_clip_sd3_model(checkpoint) + + +def is_open_clip_model(checkpoint): + + deprecation_message = "Importing `is_open_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_model", "0.36", deprecation_message) + + return is_open_clip_model(checkpoint) + + +def is_open_clip_sdxl_model(checkpoint): + from .single_file.single_file_utils import is_open_clip_sdxl_model + + deprecation_message = "Importing `is_open_clip_sdxl_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_model", "0.36", deprecation_message) + + return is_open_clip_sdxl_model(checkpoint) + +def is_open_clip_sd3_model(checkpoint): + from .single_file.single_file_utils import is_open_clip_sd3_model + + deprecation_message = "Importing `is_open_clip_sd3_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sd3_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sd3_model", "0.36", deprecation_message) + + return is_open_clip_sd3_model(checkpoint) + + +def is_open_clip_sdxl_refiner_model(checkpoint): + from .single_file.single_file_utils import is_open_clip_sdxl_refiner_model + + deprecation_message = "Importing `is_open_clip_sdxl_refiner_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_sdxl_refiner_model` instead." + deprecate("diffusers.loaders.single_file_utils.is_open_clip_sdxl_refiner_model", "0.36", deprecation_message) + + return is_open_clip_sdxl_refiner_model(checkpoint) + + +def is_clip_model_in_single_file(class_obj, checkpoint): + from .single_file.single_file_utils import is_clip_model_in_single_file + + deprecation_message = "Importing `is_clip_model_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_clip_model_in_single_file` instead." + deprecate("diffusers.loaders.single_file_utils.is_clip_model_in_single_file", "0.36", deprecation_message) + + return is_clip_model_in_single_file(class_obj, checkpoint) + + +def infer_diffusers_model_type(checkpoint): + from .single_file.single_file_utils import infer_diffusers_model_type + + deprecation_message = "Importing `infer_diffusers_model_type()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import infer_diffusers_model_type` instead." + deprecate("diffusers.loaders.single_file_utils.infer_diffusers_model_type", "0.36", deprecation_message) + + return infer_diffusers_model_type(checkpoint) + + + +def fetch_diffusers_config(checkpoint): + from .single_file.single_file_utils import fetch_diffusers_config + + deprecation_message = "Importing `fetch_diffusers_config()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import fetch_diffusers_config` instead." + deprecate("diffusers.loaders.single_file_utils.fetch_diffusers_config", "0.36", deprecation_message) + + return fetch_diffusers_config(checkpoint) + + +def set_image_size(checkpoint, image_size=None): + from .single_file.single_file_utils import set_image_size + + deprecation_message = "Importing `set_image_size()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import set_image_size` instead." + deprecate("diffusers.loaders.single_file_utils.set_image_size", "0.36", deprecation_message) + + return set_image_size(checkpoint, image_size) + + +def conv_attn_to_linear(checkpoint): + from .single_file.single_file_utils import conv_attn_to_linear + + deprecation_message = "Importing `conv_attn_to_linear()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import conv_attn_to_linear` instead." + deprecate("diffusers.loaders.single_file_utils.conv_attn_to_linear", "0.36", deprecation_message) + + return conv_attn_to_linear(checkpoint) + + + +def create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None +): + from .single_file.single_file_utils import create_unet_diffusers_config_from_ldm + + deprecation_message = "Importing `create_unet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_unet_diffusers_config_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_unet_diffusers_config_from_ldm", "0.36", deprecation_message) + + return create_unet_diffusers_config_from_ldm(original_config, checkpoint, image_size, upcast_attention, num_in_channels) + + +def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): + from .single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm + + deprecation_message = "Importing `create_controlnet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_controlnet_diffusers_config_from_ldm", "0.36", deprecation_message) + return create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size, **kwargs) + +def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): + from .single_file.single_file_utils import create_vae_diffusers_config_from_ldm + + deprecation_message = "Importing `create_vae_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_vae_diffusers_config_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_vae_diffusers_config_from_ldm", "0.36", deprecation_message) + return create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size, scaling_factor) + +def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): + from .single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers + + deprecation_message = "Importing `update_unet_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_unet_resnet_ldm_to_diffusers", "0.36", deprecation_message) + + return update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping) + + +def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): + from .single_file.single_file_utils import update_unet_attention_ldm_to_diffusers + + deprecation_message = "Importing `update_unet_attention_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_attention_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_unet_attention_ldm_to_diffusers", "0.36", deprecation_message) + + return update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping) + + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + from .single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers + + deprecation_message = "Importing `update_vae_resnet_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_resnet_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_vae_resnet_ldm_to_diffusers", "0.36", deprecation_message) + + return update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + from .single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers + + deprecation_message = "Importing `update_vae_attentions_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.update_vae_attentions_ldm_to_diffusers", "0.36", deprecation_message) + + return update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping) + +def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): + from .single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers + deprecation_message = "Importing `convert_stable_cascade_unet_single_file_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_stable_cascade_unet_single_file_to_diffusers", "0.36", deprecation_message) + return convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs) + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): + from .single_file.single_file_utils import convert_ldm_unet_checkpoint + deprecation_message = "Importing `convert_ldm_unet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_unet_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_unet_checkpoint", "0.36", deprecation_message) + return convert_ldm_unet_checkpoint(checkpoint, config, extract_ema, **kwargs) + + +def convert_controlnet_checkpoint( + checkpoint, + config, + **kwargs, +): + from .single_file.single_file_utils import convert_controlnet_checkpoint + deprecation_message = "Importing `convert_controlnet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_controlnet_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_controlnet_checkpoint", "0.36", deprecation_message) + return convert_controlnet_checkpoint(checkpoint, config, **kwargs) + + + +def convert_ldm_vae_checkpoint(checkpoint, config): + from .single_file.single_file_utils import convert_ldm_vae_checkpoint + deprecation_message = "Importing `convert_ldm_vae_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_vae_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_vae_checkpoint", "0.36", deprecation_message) + return convert_ldm_vae_checkpoint(checkpoint, config, config) + + +def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = [] + remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) + if remove_prefix: + remove_prefixes.append(remove_prefix) + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def convert_open_clip_checkpoint( + text_model, + checkpoint, + prefix="cond_stage_model.model.", +): + text_model_dict = {} + text_proj_key = prefix + "text_projection" + + if text_proj_key in checkpoint: + text_proj_dim = int(checkpoint[text_proj_key].shape[0]) + elif hasattr(text_model.config, "hidden_size"): + text_proj_dim = text_model.config.hidden_size + else: + text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM + + keys = list(checkpoint.keys()) + keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE + + openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] + for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): + ldm_key = prefix + ldm_key + if ldm_key not in checkpoint: + continue + if ldm_key in keys_to_ignore: + continue + if ldm_key.endswith("text_projection"): + text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() + else: + text_model_dict[diffusers_key] = checkpoint[ldm_key] + + for key in keys: + if key in keys_to_ignore: + continue + + if not key.startswith(prefix + "transformer."): + continue + + diffusers_key = key.replace(prefix + "transformer.", "") + transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] + for new_key, old_key in transformer_diffusers_to_ldm_map.items(): + diffusers_key = ( + diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") + ) + + if key.endswith(".in_proj_weight"): + weight_value = checkpoint.get(key) + + text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() + text_model_dict[diffusers_key + ".k_proj.weight"] = ( + weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() + + elif key.endswith(".in_proj_bias"): + weight_value = checkpoint.get(key) + text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() + text_model_dict[diffusers_key + ".k_proj.bias"] = ( + weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() + ) + text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() + else: + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_clip_model_from_ldm( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, + is_legacy_loading=False, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + # For backwards compatibility + # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo + # in the cache_dir, rather than in a subfolder of the Diffusers model + if is_legacy_loading: + logger.warning( + ( + "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " + "the local cache directory with the necessary CLIP model config files. " + "Attempting to load CLIP model from legacy cache directory." + ) + ) + + if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): + clip_config = "openai/clip-vit-large-patch14" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + elif is_open_clip_model(checkpoint): + clip_config = "stabilityai/stable-diffusion-2" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "text_encoder" + + else: + clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config["pretrained_model_name_or_path"] = clip_config + subfolder = "" + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] + + if is_clip_model(checkpoint): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) + + elif ( + is_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") + diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) + + elif is_open_clip_model(checkpoint): + prefix = "cond_stage_model.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sdxl_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim + ): + prefix = "conditioner.embedders.1.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif is_open_clip_sdxl_refiner_model(checkpoint): + prefix = "conditioner.embedders.0.model." + diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) + + elif ( + is_open_clip_sd3_model(checkpoint) + and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim + ): + diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") + + else: + raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") + + if is_accelerate_available(): + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + else: + model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if torch_dtype is not None: + model.to(torch_dtype) + + model.eval() + + return model + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + + return tuple(sorted(set(attn2_layers))) + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + +def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 + dual_attention_layers = get_attn2_layers(checkpoint) + + caption_projection_dim = get_caption_projection_dim(checkpoint) + has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) + + # Positional and patch embeddings. + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + + # Context projections. + converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") + + # Pooled context projection. + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") + + # Transformer blocks 🎸. + for i in range(num_layers): + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 + ) + + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) + + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.proj.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.proj.bias" + ) + + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + + # norms. + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" + ) + else: + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), + dim=caption_projection_dim, + ) + converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( + checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), + dim=caption_projection_dim, + ) + + # ffs. + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.mlp.fc2.bias" + ) + if not (i == num_layers - 1): + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc1.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.mlp.fc2.bias" + ) + + # Final blocks. + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim + ) + + return converted_state_dict + + +def is_t5_in_single_file(checkpoint): + if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: + return True + + return False + + +def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = ["text_encoders.t5xxl.transformer."] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint.get(key) + + return text_model_dict + + +def create_diffusers_t5_model_from_checkpoint( + cls, + checkpoint, + subfolder="", + config=None, + torch_dtype=None, + local_files_only=None, +): + if config: + config = {"pretrained_model_name_or_path": config} + else: + config = fetch_diffusers_config(checkpoint) + + model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls(model_config) + + diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) + + if is_accelerate_available(): + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + else: + model.load_state_dict(diffusers_format_checkpoint) + + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + return model + + +def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + for k, v in checkpoint.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 + mlp_ratio = 4.0 + inner_dim = 3072 + + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") + + # guidance + has_guidance = any("guidance" in k for k in checkpoint) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( + "guidance_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") + + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict + + +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} + + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + + VAE_095_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + } + + if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remap_qkv_(key: str, state_dict): + qkv = state_dict.pop(key) + q, k, v = torch.chunk(qkv, 3, dim=0) + parent_module, _, _ = key.rpartition(".qkv.conv.weight") + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() + + def remap_proj_conv_(key: str, state_dict): + parent_module, _, _ = key.rpartition(".proj.conv.weight") + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() + + AE_KEYS_RENAME_DICT = { + # common + "main.": "", + "op_list.": "", + "context_module": "attn", + "local_module": "conv_out", + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 + # If there were more scales, there would be more layers, so a loop would be better to handle this + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", + "depth_conv.conv": "conv_depth", + "inverted_conv.conv": "conv_inverted", + "point_conv.conv": "conv_point", + "point_conv.norm": "norm", + "conv.conv.": "conv.", + "conv1.conv": "conv1", + "conv2.conv": "conv2", + "conv2.norm": "norm", + "proj.norm": "norm_out", + # encoder + "encoder.project_in.conv": "encoder.conv_in", + "encoder.project_out.0.conv": "encoder.conv_out", + "encoder.stages": "encoder.down_blocks", + # decoder + "decoder.project_in.conv": "decoder.conv_in", + "decoder.project_out.0": "decoder.norm_out", + "decoder.project_out.2.conv": "decoder.conv_out", + "decoder.stages": "decoder.up_blocks", + } + + AE_F32C32_F64C128_F128C512_KEYS = { + "encoder.project_in.conv": "encoder.conv_in.conv", + "decoder.project_out.2.conv": "decoder.conv_out.conv", + } + + AE_SPECIAL_KEYS_REMAP = { + "qkv.conv.weight": remap_qkv_, + "proj.conv.weight": remap_proj_conv_, + } + if "encoder.project_in.conv.bias" not in converted_state_dict: + AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) + + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + # Convert patch_embed + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Convert time_embed + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) + else: + converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) + + # Visual attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + converted_state_dict[block_prefix + "attn1.to_q.weight"] = q + converted_state_dict[block_prefix + "attn1.to_k.weight"] = k + converted_state_dict[block_prefix + "attn1.to_v.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop( + old_prefix + "attn.proj_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_y.weight" + ) + converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + old_prefix + "attn.proj_y.weight" + ) + converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_x.w1.weight") + ) + converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_y.w1.weight") + ) + converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + + return converted_state_dict + + +def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + def update_state_dict_(state_dict, old_key, new_key): + state_dict[new_key] = state_dict.pop(old_key) + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(checkpoint, key, new_key) + + for key in list(checkpoint.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, checkpoint) + + return checkpoint + + +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + state_dict_keys = list(checkpoint.keys()) + + # Handle register tokens and positional embeddings + converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) + + # Handle time step projection + converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) + converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) + converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) + converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) + + # Handle context embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) + + # Calculate the number of layers + def calculate_layers(keys, key_prefix): + layers = set() + for k in keys: + if key_prefix in k: + layer_num = int(k.split(".")[1]) # get the layer number + layers.add(layer_num) + return len(layers) + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks + for i in range(mmdit_layers): + # Feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.{k}.weight", None + ) + + # Norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.attn.{k}.weight", None + ) + + # Single-DiT blocks + for i in range(single_dit_layers): + # Feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.mlp.{k}.weight", None + ) + + # Norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"single_layers.{i}.modCX.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.attn.{k}.weight", None + ) + # Final blocks + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) + + # Handle the final norm layer + norm_weight = checkpoint.pop("modF.1.weight", None) + if norm_weight is not None: + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) + else: + converted_state_dict["norm_out.linear.weight"] = None + + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") + + return converted_state_dict + + +def convert_lumina2_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Original Lumina-Image-2 has an extra norm paramter that is unused + # We just remove it here + checkpoint.pop("norm_final.weight", None) + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + LUMINA_KEY_MAP = { + "cap_embedder": "time_caption_embed.caption_embedder", + "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", + "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", + "attention": "attn", + ".out.": ".to_out.0.", + "k_norm": "norm_k", + "q_norm": "norm_q", + "w1": "linear_1", + "w2": "linear_2", + "w3": "linear_3", + "adaLN_modulation.1": "norm1.linear", + } + ATTENTION_NORM_MAP = { + "attention_norm1": "norm1.norm", + "attention_norm2": "norm2", + } + CONTEXT_REFINER_MAP = { + "context_refiner.0.attention_norm1": "context_refiner.0.norm1", + "context_refiner.0.attention_norm2": "context_refiner.0.norm2", + "context_refiner.1.attention_norm1": "context_refiner.1.norm1", + "context_refiner.1.attention_norm2": "context_refiner.1.norm2", + } + FINAL_LAYER_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear_1", + "final_layer.linear": "norm_out.linear_2", + } + + def convert_lumina_attn_to_diffusers(tensor, diffusers_key): + q_dim = 2304 + k_dim = v_dim = 768 + + to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) + + return { + diffusers_key.replace("qkv", "to_q"): to_q, + diffusers_key.replace("qkv", "to_k"): to_k, + diffusers_key.replace("qkv", "to_v"): to_v, + } + + for key in keys: + diffusers_key = key + for k, v in CONTEXT_REFINER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in FINAL_LAYER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in ATTENTION_NORM_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in LUMINA_KEY_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + + if "qkv" in diffusers_key: + converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) + else: + converted_state_dict[diffusers_key] = checkpoint.pop(key) + + return converted_state_dict + + +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + + # Positional and patch embeddings. + checkpoint.pop("pos_embed") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") + + # Caption Projection. + checkpoint.pop("y_embedder.y_embedding") + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") + + for i in range(num_layers): + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( + f"blocks.{i}.scale_shift_table" + ) + + # Self-Attention + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.attn.proj.bias" + ) + + # Cross-Attention + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.bias" + ) + + linear_sample_k, linear_sample_v = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 + ) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.bias" + ) + + # MLP + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.point_conv.conv.weight" + ) + + # Final layer + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") + + return converted_state_dict + + +def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "cross_attn": "attn2", + "self_attn": "attn1", + ".o.": ".to_out.0.", + ".q.": ".to_q.", + ".k.": ".to_k.", + ".v.": ".to_v.", + ".k_img.": ".add_k_proj.", + ".v_img.": ".add_v_proj.", + ".norm_k_img.": ".norm_added_k.", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + } + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + converted_state_dict[new_key] = checkpoint.pop(key) + + return converted_state_dict + + +def convert_wan_vae_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in checkpoint.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + converted_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + converted_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + converted_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + converted_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + converted_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + converted_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + converted_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + converted_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + converted_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + converted_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + converted_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + converted_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + converted_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + converted_state_dict[new_key] = value + else: + # Keep other keys unchanged + converted_state_dict[key] = value + + return converted_state_dict From 6b8b225acad7c9fc18520de35b39e8340d367c3e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 17:26:30 +0530 Subject: [PATCH 20/26] single file utils. --- src/diffusers/loaders/single_file_utils.py | 1761 ++------------------ 1 file changed, 177 insertions(+), 1584 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 4ae1850fa769..13bb67bfaad0 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -13,19 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext - -import torch from ..utils import deprecate -from .single_file.single_file_utils import ( - CHECKPOINT_KEY_NAMES, - DIFFUSERS_TO_LDM_MAPPING, - LDM_CLIP_PREFIX_TO_REMOVE, - LDM_OPEN_CLIP_TEXT_PROJECTION_DIM, - SD_2_TEXT_ENCODER_KEYS_TO_IGNORE, - SingleFileComponentError, -) +from .single_file.single_file_utils import SingleFileComponentError class SingleFileComponentError(SingleFileComponentError): @@ -62,14 +52,14 @@ def load_single_file_checkpoint( return load_single_file_checkpoint( pretrained_model_link_or_path, - force_download, - proxies, - token, - cache_dir, - local_files_only, - revision, - disable_mmap, - user_agent, + force_download, + proxies, + token, + cache_dir, + local_files_only, + revision, + disable_mmap, + user_agent, ) @@ -110,7 +100,6 @@ def is_clip_sd3_model(checkpoint): def is_open_clip_model(checkpoint): - deprecation_message = "Importing `is_open_clip_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_open_clip_model` instead." deprecate("diffusers.loaders.single_file_utils.is_open_clip_model", "0.36", deprecation_message) @@ -125,6 +114,7 @@ def is_open_clip_sdxl_model(checkpoint): return is_open_clip_sdxl_model(checkpoint) + def is_open_clip_sd3_model(checkpoint): from .single_file.single_file_utils import is_open_clip_sd3_model @@ -161,7 +151,6 @@ def infer_diffusers_model_type(checkpoint): return infer_diffusers_model_type(checkpoint) - def fetch_diffusers_config(checkpoint): from .single_file.single_file_utils import fetch_diffusers_config @@ -189,7 +178,6 @@ def conv_attn_to_linear(checkpoint): return conv_attn_to_linear(checkpoint) - def create_unet_diffusers_config_from_ldm( original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None ): @@ -198,16 +186,21 @@ def create_unet_diffusers_config_from_ldm( deprecation_message = "Importing `create_unet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_unet_diffusers_config_from_ldm` instead." deprecate("diffusers.loaders.single_file_utils.create_unet_diffusers_config_from_ldm", "0.36", deprecation_message) - return create_unet_diffusers_config_from_ldm(original_config, checkpoint, image_size, upcast_attention, num_in_channels) + return create_unet_diffusers_config_from_ldm( + original_config, checkpoint, image_size, upcast_attention, num_in_channels + ) def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): from .single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm deprecation_message = "Importing `create_controlnet_diffusers_config_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_controlnet_diffusers_config_from_ldm` instead." - deprecate("diffusers.loaders.single_file_utils.create_controlnet_diffusers_config_from_ldm", "0.36", deprecation_message) + deprecate( + "diffusers.loaders.single_file_utils.create_controlnet_diffusers_config_from_ldm", "0.36", deprecation_message + ) return create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size, **kwargs) + def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): from .single_file.single_file_utils import create_vae_diffusers_config_from_ldm @@ -215,6 +208,7 @@ def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size deprecate("diffusers.loaders.single_file_utils.create_vae_diffusers_config_from_ldm", "0.36", deprecation_message) return create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size, scaling_factor) + def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): from .single_file.single_file_utils import update_unet_resnet_ldm_to_diffusers @@ -228,7 +222,9 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, from .single_file.single_file_utils import update_unet_attention_ldm_to_diffusers deprecation_message = "Importing `update_unet_attention_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_unet_attention_ldm_to_diffusers` instead." - deprecate("diffusers.loaders.single_file_utils.update_unet_attention_ldm_to_diffusers", "0.36", deprecation_message) + deprecate( + "diffusers.loaders.single_file_utils.update_unet_attention_ldm_to_diffusers", "0.36", deprecation_message + ) return update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping) @@ -246,19 +242,28 @@ def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, map from .single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers deprecation_message = "Importing `update_vae_attentions_ldm_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import update_vae_attentions_ldm_to_diffusers` instead." - deprecate("diffusers.loaders.single_file_utils.update_vae_attentions_ldm_to_diffusers", "0.36", deprecation_message) + deprecate( + "diffusers.loaders.single_file_utils.update_vae_attentions_ldm_to_diffusers", "0.36", deprecation_message + ) return update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping) + def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): from .single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers + deprecation_message = "Importing `convert_stable_cascade_unet_single_file_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers` instead." - deprecate("diffusers.loaders.single_file_utils.convert_stable_cascade_unet_single_file_to_diffusers", "0.36", deprecation_message) + deprecate( + "diffusers.loaders.single_file_utils.convert_stable_cascade_unet_single_file_to_diffusers", + "0.36", + deprecation_message, + ) return convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs) def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): from .single_file.single_file_utils import convert_ldm_unet_checkpoint + deprecation_message = "Importing `convert_ldm_unet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_unet_checkpoint` instead." deprecate("diffusers.loaders.single_file_utils.convert_ldm_unet_checkpoint", "0.36", deprecation_message) return convert_ldm_unet_checkpoint(checkpoint, config, extract_ema, **kwargs) @@ -270,35 +275,26 @@ def convert_controlnet_checkpoint( **kwargs, ): from .single_file.single_file_utils import convert_controlnet_checkpoint + deprecation_message = "Importing `convert_controlnet_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_controlnet_checkpoint` instead." deprecate("diffusers.loaders.single_file_utils.convert_controlnet_checkpoint", "0.36", deprecation_message) return convert_controlnet_checkpoint(checkpoint, config, **kwargs) - def convert_ldm_vae_checkpoint(checkpoint, config): from .single_file.single_file_utils import convert_ldm_vae_checkpoint + deprecation_message = "Importing `convert_ldm_vae_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_vae_checkpoint` instead." deprecate("diffusers.loaders.single_file_utils.convert_ldm_vae_checkpoint", "0.36", deprecation_message) return convert_ldm_vae_checkpoint(checkpoint, config, config) def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): - keys = list(checkpoint.keys()) - text_model_dict = {} - - remove_prefixes = [] - remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) - if remove_prefix: - remove_prefixes.append(remove_prefix) - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - diffusers_key = key.replace(prefix, "") - text_model_dict[diffusers_key] = checkpoint.get(key) + from .single_file.single_file_utils import convert_ldm_clip_checkpoint - return text_model_dict + deprecation_message = "Importing `convert_ldm_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ldm_clip_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_ldm_clip_checkpoint", "0.36", deprecation_message) + return convert_ldm_clip_checkpoint(checkpoint, remove_prefix) def convert_open_clip_checkpoint( @@ -306,65 +302,11 @@ def convert_open_clip_checkpoint( checkpoint, prefix="cond_stage_model.model.", ): - text_model_dict = {} - text_proj_key = prefix + "text_projection" - - if text_proj_key in checkpoint: - text_proj_dim = int(checkpoint[text_proj_key].shape[0]) - elif hasattr(text_model.config, "hidden_size"): - text_proj_dim = text_model.config.hidden_size - else: - text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM - - keys = list(checkpoint.keys()) - keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE - - openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] - for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): - ldm_key = prefix + ldm_key - if ldm_key not in checkpoint: - continue - if ldm_key in keys_to_ignore: - continue - if ldm_key.endswith("text_projection"): - text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() - else: - text_model_dict[diffusers_key] = checkpoint[ldm_key] - - for key in keys: - if key in keys_to_ignore: - continue - - if not key.startswith(prefix + "transformer."): - continue - - diffusers_key = key.replace(prefix + "transformer.", "") - transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] - for new_key, old_key in transformer_diffusers_to_ldm_map.items(): - diffusers_key = ( - diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") - ) - - if key.endswith(".in_proj_weight"): - weight_value = checkpoint.get(key) - - text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() - text_model_dict[diffusers_key + ".k_proj.weight"] = ( - weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() - ) - text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() - - elif key.endswith(".in_proj_bias"): - weight_value = checkpoint.get(key) - text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() - text_model_dict[diffusers_key + ".k_proj.bias"] = ( - weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() - ) - text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() - else: - text_model_dict[diffusers_key] = checkpoint.get(key) - - return text_model_dict + from .single_file.single_file_utils import convert_open_clip_checkpoint + + deprecation_message = "Importing `convert_open_clip_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_open_clip_checkpoint` instead." + deprecate("diffusers.loaders.single_file_utils.convert_open_clip_checkpoint", "0.36", deprecation_message) + return convert_open_clip_checkpoint(text_model, checkpoint, prefix) def create_diffusers_clip_model_from_ldm( @@ -376,342 +318,77 @@ def create_diffusers_clip_model_from_ldm( local_files_only=None, is_legacy_loading=False, ): - if config: - config = {"pretrained_model_name_or_path": config} - else: - config = fetch_diffusers_config(checkpoint) - - # For backwards compatibility - # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo - # in the cache_dir, rather than in a subfolder of the Diffusers model - if is_legacy_loading: - logger.warning( - ( - "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " - "the local cache directory with the necessary CLIP model config files. " - "Attempting to load CLIP model from legacy cache directory." - ) - ) - - if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): - clip_config = "openai/clip-vit-large-patch14" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - elif is_open_clip_model(checkpoint): - clip_config = "stabilityai/stable-diffusion-2" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "text_encoder" - - else: - clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - config["pretrained_model_name_or_path"] = clip_config - subfolder = "" - - model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls(model_config) - - position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] - - if is_clip_model(checkpoint): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - - elif ( - is_clip_sdxl_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) - - elif ( - is_clip_sd3_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") - diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) - - elif is_open_clip_model(checkpoint): - prefix = "cond_stage_model.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif ( - is_open_clip_sdxl_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim - ): - prefix = "conditioner.embedders.1.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif is_open_clip_sdxl_refiner_model(checkpoint): - prefix = "conditioner.embedders.0.model." - diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) - - elif ( - is_open_clip_sd3_model(checkpoint) - and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim - ): - diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") - - else: - raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") - - if is_accelerate_available(): - load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - else: - model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if torch_dtype is not None: - model.to(torch_dtype) - - model.eval() - - return model + from .single_file.single_file_utils import create_diffusers_clip_model_from_ldm + + deprecation_message = "Importing `create_diffusers_clip_model_from_ldm()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_clip_model_from_ldm` instead." + deprecate("diffusers.loaders.single_file_utils.create_diffusers_clip_model_from_ldm", "0.36", deprecation_message) + return create_diffusers_clip_model_from_ldm( + cls, checkpoint, subfolder, config, torch_dtype, local_files_only, is_legacy_loading + ) # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation def swap_scale_shift(weight, dim): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight + from .single_file.single_file_utils import swap_scale_shift + + deprecation_message = "Importing `swap_scale_shift()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_scale_shift` instead." + deprecate("diffusers.loaders.single_file_utils.swap_scale_shift", "0.36", deprecation_message) + return swap_scale_shift(weight, dim) def swap_proj_gate(weight): - proj, gate = weight.chunk(2, dim=0) - new_weight = torch.cat([gate, proj], dim=0) - return new_weight + from .single_file.single_file_utils import swap_proj_gate + + deprecation_message = "Importing `swap_proj_gate()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import swap_proj_gate` instead." + deprecate("diffusers.loaders.single_file_utils.swap_proj_gate", "0.36", deprecation_message) + return swap_proj_gate(weight) def get_attn2_layers(state_dict): - attn2_layers = [] - for key in state_dict.keys(): - if "attn2." in key: - # Extract the layer number from the key - layer_num = int(key.split(".")[1]) - attn2_layers.append(layer_num) + from .single_file.single_file_utils import get_attn2_layers - return tuple(sorted(set(attn2_layers))) + deprecation_message = "Importing `get_attn2_layers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_attn2_layers` instead." + deprecate("diffusers.loaders.single_file_utils.get_attn2_layers", "0.36", deprecation_message) + return get_attn2_layers(state_dict) def get_caption_projection_dim(state_dict): - caption_projection_dim = state_dict["context_embedder.weight"].shape[0] - return caption_projection_dim + from .single_file.single_file_utils import get_caption_projection_dim + + deprecation_message = "Importing `get_caption_projection_dim()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import get_caption_projection_dim` instead." + deprecate("diffusers.loaders.single_file_utils.get_caption_projection_dim", "0.36", deprecation_message) + return get_caption_projection_dim(state_dict) def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 - dual_attention_layers = get_attn2_layers(checkpoint) - - caption_projection_dim = get_caption_projection_dim(checkpoint) - has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) - - # Positional and patch embeddings. - converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") - converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Timestep embeddings. - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - - # Context projections. - converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") - - # Pooled context projection. - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") - - # Transformer blocks 🎸. - for i in range(num_layers): - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 - ) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 - ) - - converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) - - converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - - # qk norm - if has_qk_norm: - converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.ln_k.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.ln_k.weight" - ) - - # output projections. - converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn.proj.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.attn.proj.bias" - ) - - if i in dual_attention_layers: - # Q, K, V - sample_q2, sample_k2, sample_v2 = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 - ) - sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( - checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) - - # qk norm - if has_qk_norm: - converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.ln_q.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.ln_k.weight" - ) - - # output projections. - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.attn2.proj.bias" - ) - - # norms. - converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" - ) - else: - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( - checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), - dim=caption_projection_dim, - ) - converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( - checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), - dim=caption_projection_dim, - ) - - # ffs. - converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc1.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc2.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( - f"joint_blocks.{i}.x_block.mlp.fc2.bias" - ) - if not (i == num_layers - 1): - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc1.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc1.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc2.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( - f"joint_blocks.{i}.context_block.mlp.fc2.bias" - ) - - # Final blocks. - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim - ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim - ) + from .single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers - return converted_state_dict + deprecation_message = "Importing `convert_sd3_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_sd3_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def is_t5_in_single_file(checkpoint): - if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: - return True + from .single_file.single_file_utils import is_t5_in_single_file - return False + deprecation_message = "Importing `is_t5_in_single_file()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import is_t5_in_single_file` instead." + deprecate("diffusers.loaders.single_file_utils.is_t5_in_single_file", "0.36", deprecation_message) + return is_t5_in_single_file(checkpoint) def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - - remove_prefixes = ["text_encoders.t5xxl.transformer."] + from .single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - diffusers_key = key.replace(prefix, "") - text_model_dict[diffusers_key] = checkpoint.get(key) - - return text_model_dict + deprecation_message = "Importing `convert_sd3_t5_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sd3_t5_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_sd3_t5_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_sd3_t5_checkpoint_to_diffusers(checkpoint) def create_diffusers_t5_model_from_checkpoint( @@ -722,1218 +399,134 @@ def create_diffusers_t5_model_from_checkpoint( torch_dtype=None, local_files_only=None, ): - if config: - config = {"pretrained_model_name_or_path": config} - else: - config = fetch_diffusers_config(checkpoint) + from .single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint - model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls(model_config) - - diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) - - if is_accelerate_available(): - load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - else: - model.load_state_dict(diffusers_format_checkpoint) - - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = model._keep_in_fp32_modules - else: - keep_in_fp32_modules = [] - - if keep_in_fp32_modules is not None: - for name, param in model.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): - # param = param.to(torch.float32) does not work here as only in the local scope. - param.data = param.data.to(torch.float32) - - return model + deprecation_message = "Importing `create_diffusers_t5_model_from_checkpoint()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import create_diffusers_t5_model_from_checkpoint` instead." + deprecate( + "diffusers.loaders.single_file_utils.create_diffusers_t5_model_from_checkpoint", "0.36", deprecation_message + ) + return create_diffusers_t5_model_from_checkpoint(cls, checkpoint, subfolder, config, torch_dtype, local_files_only) def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - for k, v in checkpoint.items(): - if "pos_encoder" in k: - continue - - else: - converted_state_dict[ - k.replace(".norms.0", ".norm1") - .replace(".norms.1", ".norm2") - .replace(".ff_norm", ".norm3") - .replace(".attention_blocks.0", ".attn1") - .replace(".attention_blocks.1", ".attn2") - .replace(".temporal_transformer", "") - ] = v + from .single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers - return converted_state_dict + deprecation_message = "Importing `convert_animatediff_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_animatediff_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_animatediff_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 - num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 - mlp_ratio = 4.0 - inner_dim = 3072 - - # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; - # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation - def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - - ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") - converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + from .single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers - ## time_text_embed.text_embedder <- vector_in - converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") - converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") - converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( - "vector_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") - - # guidance - has_guidance = any("guidance" in k for k in checkpoint) - if has_guidance: - converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( - "guidance_in.in_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( - "guidance_in.in_layer.bias" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( - "guidance_in.out_layer.weight" - ) - converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( - "guidance_in.out_layer.bias" - ) - - # context_embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") - converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") - - # x_embedder - converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") - converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - # norms. - ## norm1 - converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_mod.lin.bias" - ) - ## norm1_context - converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mod.lin.bias" - ) - # Q, K, V - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) - context_q, context_k, context_v = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 - ) - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 - ) - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) - # qk_norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.norm.key_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.norm.key_norm.scale" - ) - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") - converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") - converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.0.bias" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.weight" - ) - converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_mlp.2.bias" - ) - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( - f"double_blocks.{i}.img_attn.proj.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( - f"double_blocks.{i}.txt_attn.proj.bias" - ) - - # single transfomer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.weight" - ) - converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( - f"single_blocks.{i}.modulation.lin.bias" - ) - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) - q_bias, k_bias, v_bias, mlp_bias = torch.split( - checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 - ) - converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) - converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) - # qk norm - converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.query_norm.scale" - ) - converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( - f"single_blocks.{i}.norm.key_norm.scale" - ) - # output projections. - converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") - converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.weight") + deprecation_message = "Importing `convert_flux_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_flux_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, ) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( - checkpoint.pop("final_layer.adaLN_modulation.1.bias") - ) - - return converted_state_dict + return convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key} - - TRANSFORMER_KEYS_RENAME_DICT = { - "model.diffusion_model.": "", - "patchify_proj": "proj_in", - "adaln_single": "time_embed", - "q_norm": "norm_q", - "k_norm": "norm_k", - } - - TRANSFORMER_SPECIAL_KEYS_REMAP = {} - - for key in list(converted_state_dict.keys()): - new_key = key - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) + from .single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers - return converted_state_dict + deprecation_message = "Importing `convert_ltx_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_ltx_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} - - def remove_keys_(key: str, state_dict): - state_dict.pop(key) - - VAE_KEYS_RENAME_DICT = { - # common - "vae.": "", - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0", - "up_blocks.2": "up_blocks.1.upsamplers.0", - "up_blocks.3": "up_blocks.1", - "up_blocks.4": "up_blocks.2.conv_in", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.conv_in", - "up_blocks.8": "up_blocks.3.upsamplers.0", - "up_blocks.9": "up_blocks.3", - # encoder - "down_blocks.0": "down_blocks.0", - "down_blocks.1": "down_blocks.0.downsamplers.0", - "down_blocks.2": "down_blocks.0.conv_out", - "down_blocks.3": "down_blocks.1", - "down_blocks.4": "down_blocks.1.downsamplers.0", - "down_blocks.5": "down_blocks.1.conv_out", - "down_blocks.6": "down_blocks.2", - "down_blocks.7": "down_blocks.2.downsamplers.0", - "down_blocks.8": "down_blocks.3", - "down_blocks.9": "mid_block", - # common - "conv_shortcut": "conv_shortcut.conv", - "res_blocks": "resnets", - "norm3.norm": "norm3", - "per_channel_statistics.mean-of-means": "latents_mean", - "per_channel_statistics.std-of-means": "latents_std", - } - - VAE_091_RENAME_DICT = { - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0.upsamplers.0", - "up_blocks.2": "up_blocks.0", - "up_blocks.3": "up_blocks.1.upsamplers.0", - "up_blocks.4": "up_blocks.1", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.upsamplers.0", - "up_blocks.8": "up_blocks.3", - # common - "last_time_embedder": "time_embedder", - "last_scale_shift_table": "scale_shift_table", - } - - VAE_095_RENAME_DICT = { - # decoder - "up_blocks.0": "mid_block", - "up_blocks.1": "up_blocks.0.upsamplers.0", - "up_blocks.2": "up_blocks.0", - "up_blocks.3": "up_blocks.1.upsamplers.0", - "up_blocks.4": "up_blocks.1", - "up_blocks.5": "up_blocks.2.upsamplers.0", - "up_blocks.6": "up_blocks.2", - "up_blocks.7": "up_blocks.3.upsamplers.0", - "up_blocks.8": "up_blocks.3", - # encoder - "down_blocks.0": "down_blocks.0", - "down_blocks.1": "down_blocks.0.downsamplers.0", - "down_blocks.2": "down_blocks.1", - "down_blocks.3": "down_blocks.1.downsamplers.0", - "down_blocks.4": "down_blocks.2", - "down_blocks.5": "down_blocks.2.downsamplers.0", - "down_blocks.6": "down_blocks.3", - "down_blocks.7": "down_blocks.3.downsamplers.0", - "down_blocks.8": "mid_block", - # common - "last_time_embedder": "time_embedder", - "last_scale_shift_table": "scale_shift_table", - } - - VAE_SPECIAL_KEYS_REMAP = { - "per_channel_statistics.channel": remove_keys_, - "per_channel_statistics.mean-of-means": remove_keys_, - "per_channel_statistics.mean-of-stds": remove_keys_, - } - - if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048: - VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) - elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: - VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - - for key in list(converted_state_dict.keys()): - new_key = key - for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - return converted_state_dict + from .single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_ltx_vae_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_ltx_vae_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_ltx_vae_checkpoint_to_diffusers", "0.36", deprecation_message + ) + return convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} - - def remap_qkv_(key: str, state_dict): - qkv = state_dict.pop(key) - q, k, v = torch.chunk(qkv, 3, dim=0) - parent_module, _, _ = key.rpartition(".qkv.conv.weight") - state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() - state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() - state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() - - def remap_proj_conv_(key: str, state_dict): - parent_module, _, _ = key.rpartition(".proj.conv.weight") - state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() - - AE_KEYS_RENAME_DICT = { - # common - "main.": "", - "op_list.": "", - "context_module": "attn", - "local_module": "conv_out", - # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 - # If there were more scales, there would be more layers, so a loop would be better to handle this - "aggreg.0.0": "to_qkv_multiscale.0.proj_in", - "aggreg.0.1": "to_qkv_multiscale.0.proj_out", - "depth_conv.conv": "conv_depth", - "inverted_conv.conv": "conv_inverted", - "point_conv.conv": "conv_point", - "point_conv.norm": "norm", - "conv.conv.": "conv.", - "conv1.conv": "conv1", - "conv2.conv": "conv2", - "conv2.norm": "norm", - "proj.norm": "norm_out", - # encoder - "encoder.project_in.conv": "encoder.conv_in", - "encoder.project_out.0.conv": "encoder.conv_out", - "encoder.stages": "encoder.down_blocks", - # decoder - "decoder.project_in.conv": "decoder.conv_in", - "decoder.project_out.0": "decoder.norm_out", - "decoder.project_out.2.conv": "decoder.conv_out", - "decoder.stages": "decoder.up_blocks", - } - - AE_F32C32_F64C128_F128C512_KEYS = { - "encoder.project_in.conv": "encoder.conv_in.conv", - "decoder.project_out.2.conv": "decoder.conv_out.conv", - } - - AE_SPECIAL_KEYS_REMAP = { - "qkv.conv.weight": remap_qkv_, - "proj.conv.weight": remap_proj_conv_, - } - if "encoder.project_in.conv.bias" not in converted_state_dict: - AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) - - for key in list(converted_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - return converted_state_dict + from .single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_autoencoder_dc_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_autoencoder_dc_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_autoencoder_dc_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Comfy checkpoints add this prefix - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - # Convert patch_embed - converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Convert time_embed - converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") - converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") - converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") - converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") - converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") - converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") - converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") - converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") - converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") - converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") - - # Convert transformer blocks - num_layers = 48 - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - old_prefix = f"blocks.{i}." - - # norm1 - converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") - converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") - if i < num_layers - 1: - converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop( - old_prefix + "mod_y.weight" - ) - converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop( - old_prefix + "mod_y.bias" - ) - else: - converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( - old_prefix + "mod_y.weight" - ) - converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop( - old_prefix + "mod_y.bias" - ) - - # Visual attention - qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") - q, k, v = qkv_weight.chunk(3, dim=0) - - converted_state_dict[block_prefix + "attn1.to_q.weight"] = q - converted_state_dict[block_prefix + "attn1.to_k.weight"] = k - converted_state_dict[block_prefix + "attn1.to_v.weight"] = v - converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop( - old_prefix + "attn.q_norm_x.weight" - ) - converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop( - old_prefix + "attn.k_norm_x.weight" - ) - converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop( - old_prefix + "attn.proj_x.weight" - ) - converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") - - # Context attention - qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") - q, k, v = qkv_weight.chunk(3, dim=0) - - converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q - converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k - converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v - converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( - old_prefix + "attn.q_norm_y.weight" - ) - converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( - old_prefix + "attn.k_norm_y.weight" - ) - if i < num_layers - 1: - converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( - old_prefix + "attn.proj_y.weight" - ) - converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop( - old_prefix + "attn.proj_y.bias" - ) - - # MLP - converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( - checkpoint.pop(old_prefix + "mlp_x.w1.weight") - ) - converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") - if i < num_layers - 1: - converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( - checkpoint.pop(old_prefix + "mlp_y.w1.weight") - ) - converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop( - old_prefix + "mlp_y.w2.weight" - ) - - # Output layers - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - - converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") - - return converted_state_dict + from .single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_mochi_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_mochi_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_mochi_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): - def remap_norm_scale_shift_(key, state_dict): - weight = state_dict.pop(key) - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight - - def remap_txt_in_(key, state_dict): - def rename_key(key): - new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") - new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") - new_key = new_key.replace("txt_in", "context_embedder") - new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") - new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") - new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") - new_key = new_key.replace("mlp", "ff") - return new_key - - if "self_attn_qkv" in key: - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v - else: - state_dict[rename_key(key)] = state_dict.pop(key) - - def remap_img_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v - - def remap_txt_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v - - def remap_single_transformer_blocks_(key, state_dict): - hidden_size = 3072 - - if "linear1.weight" in key: - linear1_weight = state_dict.pop(key) - split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) - q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") - state_dict[f"{new_key}.attn.to_q.weight"] = q - state_dict[f"{new_key}.attn.to_k.weight"] = k - state_dict[f"{new_key}.attn.to_v.weight"] = v - state_dict[f"{new_key}.proj_mlp.weight"] = mlp - - elif "linear1.bias" in key: - linear1_bias = state_dict.pop(key) - split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) - q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") - state_dict[f"{new_key}.attn.to_q.bias"] = q_bias - state_dict[f"{new_key}.attn.to_k.bias"] = k_bias - state_dict[f"{new_key}.attn.to_v.bias"] = v_bias - state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias - - else: - new_key = key.replace("single_blocks", "single_transformer_blocks") - new_key = new_key.replace("linear2", "proj_out") - new_key = new_key.replace("q_norm", "attn.norm_q") - new_key = new_key.replace("k_norm", "attn.norm_k") - state_dict[new_key] = state_dict.pop(key) - - TRANSFORMER_KEYS_RENAME_DICT = { - "img_in": "x_embedder", - "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", - "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", - "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", - "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", - "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "double_blocks": "transformer_blocks", - "img_attn_q_norm": "attn.norm_q", - "img_attn_k_norm": "attn.norm_k", - "img_attn_proj": "attn.to_out.0", - "txt_attn_q_norm": "attn.norm_added_q", - "txt_attn_k_norm": "attn.norm_added_k", - "txt_attn_proj": "attn.to_add_out", - "img_mod.linear": "norm1.linear", - "img_norm1": "norm1.norm", - "img_norm2": "norm2", - "img_mlp": "ff", - "txt_mod.linear": "norm1_context.linear", - "txt_norm1": "norm1.norm", - "txt_norm2": "norm2_context", - "txt_mlp": "ff_context", - "self_attn_proj": "attn.to_out.0", - "modulation.linear": "norm.linear", - "pre_norm": "norm.norm", - "final_layer.norm_final": "norm_out.norm", - "final_layer.linear": "proj_out", - "fc1": "net.0.proj", - "fc2": "net.2", - "input_embedder": "proj_in", - } - - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "txt_in": remap_txt_in_, - "img_attn_qkv": remap_img_attn_qkv_, - "txt_attn_qkv": remap_txt_attn_qkv_, - "single_blocks": remap_single_transformer_blocks_, - "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, - } - - def update_state_dict_(state_dict, old_key, new_key): - state_dict[new_key] = state_dict.pop(old_key) - - for key in list(checkpoint.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_(checkpoint, key, new_key) - - for key in list(checkpoint.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, checkpoint) - - return checkpoint + from .single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers + + deprecation_message = "Importing `convert_hunyuan_video_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_hunyuan_video_transformer_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_hunyuan_video_transformer_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs) def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - state_dict_keys = list(checkpoint.keys()) - - # Handle register tokens and positional embeddings - converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) - - # Handle time step projection - converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) - converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) - converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) - converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) - - # Handle context embedder - converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) - - # Calculate the number of layers - def calculate_layers(keys, key_prefix): - layers = set() - for k in keys: - if key_prefix in k: - layer_num = int(k.split(".")[1]) # get the layer number - layers.add(layer_num) - return len(layers) - - mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") - single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") - - # MMDiT blocks - for i in range(mmdit_layers): - # Feed-forward - path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} - weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} - for orig_k, diffuser_k in path_mapping.items(): - for k, v in weight_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( - f"double_layers.{i}.{orig_k}.{k}.weight", None - ) - - # Norms - path_mapping = {"modX": "norm1", "modC": "norm1_context"} - for orig_k, diffuser_k in path_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( - f"double_layers.{i}.{orig_k}.1.weight", None - ) - - # Attentions - x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} - context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} - for attn_mapping in [x_attn_mapping, context_attn_mapping]: - for k, v in attn_mapping.items(): - converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( - f"double_layers.{i}.attn.{k}.weight", None - ) - - # Single-DiT blocks - for i in range(single_dit_layers): - # Feed-forward - mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} - for k, v in mapping.items(): - converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( - f"single_layers.{i}.mlp.{k}.weight", None - ) - - # Norms - converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( - f"single_layers.{i}.modCX.1.weight", None - ) - - # Attentions - x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} - for k, v in x_attn_mapping.items(): - converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( - f"single_layers.{i}.attn.{k}.weight", None - ) - # Final blocks - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) - - # Handle the final norm layer - norm_weight = checkpoint.pop("modF.1.weight", None) - if norm_weight is not None: - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) - else: - converted_state_dict["norm_out.linear.weight"] = None - - converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") - converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") - converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") - - return converted_state_dict + from .single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers + + deprecation_message = "Importing `convert_auraflow_transformer_checkpoint_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_auraflow_transformer_checkpoint_to_diffusers` instead." + deprecate( + "diffusers.loaders.single_file_utils.convert_auraflow_transformer_checkpoint_to_diffusers", + "0.36", + deprecation_message, + ) + return convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs) def convert_lumina2_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Original Lumina-Image-2 has an extra norm paramter that is unused - # We just remove it here - checkpoint.pop("norm_final.weight", None) - - # Comfy checkpoints add this prefix - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - LUMINA_KEY_MAP = { - "cap_embedder": "time_caption_embed.caption_embedder", - "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", - "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", - "attention": "attn", - ".out.": ".to_out.0.", - "k_norm": "norm_k", - "q_norm": "norm_q", - "w1": "linear_1", - "w2": "linear_2", - "w3": "linear_3", - "adaLN_modulation.1": "norm1.linear", - } - ATTENTION_NORM_MAP = { - "attention_norm1": "norm1.norm", - "attention_norm2": "norm2", - } - CONTEXT_REFINER_MAP = { - "context_refiner.0.attention_norm1": "context_refiner.0.norm1", - "context_refiner.0.attention_norm2": "context_refiner.0.norm2", - "context_refiner.1.attention_norm1": "context_refiner.1.norm1", - "context_refiner.1.attention_norm2": "context_refiner.1.norm2", - } - FINAL_LAYER_MAP = { - "final_layer.adaLN_modulation.1": "norm_out.linear_1", - "final_layer.linear": "norm_out.linear_2", - } - - def convert_lumina_attn_to_diffusers(tensor, diffusers_key): - q_dim = 2304 - k_dim = v_dim = 768 - - to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) - - return { - diffusers_key.replace("qkv", "to_q"): to_q, - diffusers_key.replace("qkv", "to_k"): to_k, - diffusers_key.replace("qkv", "to_v"): to_v, - } - - for key in keys: - diffusers_key = key - for k, v in CONTEXT_REFINER_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in FINAL_LAYER_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in ATTENTION_NORM_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - for k, v in LUMINA_KEY_MAP.items(): - diffusers_key = diffusers_key.replace(k, v) - - if "qkv" in diffusers_key: - converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) - else: - converted_state_dict[diffusers_key] = checkpoint.pop(key) - - return converted_state_dict + from .single_file.single_file_utils import convert_lumina2_to_diffusers + + deprecation_message = "Importing `convert_lumina2_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_lumina2_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_lumina2_to_diffusers", "0.36", deprecation_message) + return convert_lumina2_to_diffusers(checkpoint, **kwargs) def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 - - # Positional and patch embeddings. - checkpoint.pop("pos_embed") - converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Timestep embeddings. - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") - converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") - - # Caption Projection. - checkpoint.pop("y_embedder.y_embedding") - converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") - converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") - converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") - converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") - converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") - - for i in range(num_layers): - converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( - f"blocks.{i}.scale_shift_table" - ) - - # Self-Attention - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.attn.proj.bias" - ) - - # Cross-Attention - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.bias" - ) - - linear_sample_k, linear_sample_v = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 - ) - linear_sample_k_bias, linear_sample_v_bias = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.bias" - ) - - # MLP - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.point_conv.conv.weight" - ) - - # Final layer - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") - - return converted_state_dict + from .single_file.single_file_utils import convert_sana_transformer_to_diffusers + + deprecation_message = "Importing `convert_sana_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_sana_transformer_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_sana_transformer_to_diffusers", "0.36", deprecation_message) + return convert_sana_transformer_to_diffusers(checkpoint, **kwargs) def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "cross_attn": "attn2", - "self_attn": "attn1", - ".o.": ".to_out.0.", - ".q.": ".to_q.", - ".k.": ".to_k.", - ".v.": ".to_v.", - ".k_img.": ".add_k_proj.", - ".v_img.": ".add_v_proj.", - ".norm_k_img.": ".norm_added_k.", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - } - - for key in list(checkpoint.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - - converted_state_dict[new_key] = checkpoint.pop(key) - - return converted_state_dict + from .single_file.single_file_utils import convert_wan_transformer_to_diffusers + + deprecation_message = "Importing `convert_wan_transformer_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_transformer_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_wan_transformer_to_diffusers", "0.36", deprecation_message) + return convert_wan_transformer_to_diffusers(checkpoint, **kwargs) def convert_wan_vae_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - - # Create mappings for specific components - middle_key_mapping = { - # Encoder middle block - "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", - "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", - "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", - "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", - "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", - "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", - "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", - "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", - "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", - "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", - "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", - "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", - # Decoder middle block - "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", - "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", - "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", - "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", - "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", - "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", - "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", - "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", - "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", - "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", - "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", - "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", - } - - # Create a mapping for attention blocks - attention_mapping = { - # Encoder middle attention - "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", - "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", - "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", - "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", - "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", - # Decoder middle attention - "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", - "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", - "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", - "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", - "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", - } - - # Create a mapping for the head components - head_mapping = { - # Encoder head - "encoder.head.0.gamma": "encoder.norm_out.gamma", - "encoder.head.2.bias": "encoder.conv_out.bias", - "encoder.head.2.weight": "encoder.conv_out.weight", - # Decoder head - "decoder.head.0.gamma": "decoder.norm_out.gamma", - "decoder.head.2.bias": "decoder.conv_out.bias", - "decoder.head.2.weight": "decoder.conv_out.weight", - } - - # Create a mapping for the quant components - quant_mapping = { - "conv1.weight": "quant_conv.weight", - "conv1.bias": "quant_conv.bias", - "conv2.weight": "post_quant_conv.weight", - "conv2.bias": "post_quant_conv.bias", - } - - # Process each key in the state dict - for key, value in checkpoint.items(): - # Handle middle block keys using the mapping - if key in middle_key_mapping: - new_key = middle_key_mapping[key] - converted_state_dict[new_key] = value - # Handle attention blocks using the mapping - elif key in attention_mapping: - new_key = attention_mapping[key] - converted_state_dict[new_key] = value - # Handle head keys using the mapping - elif key in head_mapping: - new_key = head_mapping[key] - converted_state_dict[new_key] = value - # Handle quant keys using the mapping - elif key in quant_mapping: - new_key = quant_mapping[key] - converted_state_dict[new_key] = value - # Handle encoder conv1 - elif key == "encoder.conv1.weight": - converted_state_dict["encoder.conv_in.weight"] = value - elif key == "encoder.conv1.bias": - converted_state_dict["encoder.conv_in.bias"] = value - # Handle decoder conv1 - elif key == "decoder.conv1.weight": - converted_state_dict["decoder.conv_in.weight"] = value - elif key == "decoder.conv1.bias": - converted_state_dict["decoder.conv_in.bias"] = value - # Handle encoder downsamples - elif key.startswith("encoder.downsamples."): - # Convert to down_blocks - new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - - # Convert residual block naming but keep the original structure - if ".residual.0.gamma" in new_key: - new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") - elif ".residual.2.bias" in new_key: - new_key = new_key.replace(".residual.2.bias", ".conv1.bias") - elif ".residual.2.weight" in new_key: - new_key = new_key.replace(".residual.2.weight", ".conv1.weight") - elif ".residual.3.gamma" in new_key: - new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") - elif ".residual.6.bias" in new_key: - new_key = new_key.replace(".residual.6.bias", ".conv2.bias") - elif ".residual.6.weight" in new_key: - new_key = new_key.replace(".residual.6.weight", ".conv2.weight") - elif ".shortcut.bias" in new_key: - new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - elif ".shortcut.weight" in new_key: - new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") - - converted_state_dict[new_key] = value - - # Handle decoder upsamples - elif key.startswith("decoder.upsamples."): - # Convert to up_blocks - parts = key.split(".") - block_idx = int(parts[2]) - - # Group residual blocks - if "residual" in key: - if block_idx in [0, 1, 2]: - new_block_idx = 0 - resnet_idx = block_idx - elif block_idx in [4, 5, 6]: - new_block_idx = 1 - resnet_idx = block_idx - 4 - elif block_idx in [8, 9, 10]: - new_block_idx = 2 - resnet_idx = block_idx - 8 - elif block_idx in [12, 13, 14]: - new_block_idx = 3 - resnet_idx = block_idx - 12 - else: - # Keep as is for other blocks - converted_state_dict[key] = value - continue - - # Convert residual block naming - if ".residual.0.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" - elif ".residual.2.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" - elif ".residual.2.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" - elif ".residual.3.gamma" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" - elif ".residual.6.bias" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" - elif ".residual.6.weight" in key: - new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" - else: - new_key = key - - converted_state_dict[new_key] = value - - # Handle shortcut connections - elif ".shortcut." in key: - if block_idx == 4: - new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") - new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - new_key = new_key.replace(".shortcut.", ".conv_shortcut.") - - converted_state_dict[new_key] = value - - # Handle upsamplers - elif ".resample." in key or ".time_conv." in key: - if block_idx == 3: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") - elif block_idx == 7: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") - elif block_idx == 11: - new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - - converted_state_dict[new_key] = value - else: - new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - converted_state_dict[new_key] = value - else: - # Keep other keys unchanged - converted_state_dict[key] = value - - return converted_state_dict + from .single_file.single_file_utils import convert_wan_vae_to_diffusers + + deprecation_message = "Importing `convert_wan_vae_to_diffusers()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file_utils import convert_wan_vae_to_diffusers` instead." + deprecate("diffusers.loaders.single_file_utils.convert_wan_vae_to_diffusers", "0.36", deprecation_message) + return convert_wan_vae_to_diffusers(checkpoint, **kwargs) From 11a23d11fe68155cbeed738079cfca7805ed2863 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 17:29:26 +0530 Subject: [PATCH 21/26] updates --- src/diffusers/loaders/single_file_model.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/diffusers/loaders/single_file_model.py diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py new file mode 100644 index 000000000000..4eeab1679e66 --- /dev/null +++ b/src/diffusers/loaders/single_file_model.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate +from .single_file.single_file_model import ( + SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401 + FromOriginalModelMixin, +) + + +class FromOriginalModelMixin(FromOriginalModelMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead." + deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message) + super().__init__(*args, **kwargs) From 1597ae6ac97470fefa054e23c8b4685ed5fd6ee8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 17:37:17 +0530 Subject: [PATCH 22/26] updates --- tests/single_file/single_file_testing_utils.py | 2 +- tests/single_file/test_model_vae_single_file.py | 4 +--- tests/single_file/test_model_wan_autoencoder_single_file.py | 4 +--- tests/single_file/test_model_wan_transformer3d_single_file.py | 4 +--- tests/single_file/test_sana_transformer.py | 4 +--- .../test_stable_diffusion_controlnet_img2img_single_file.py | 2 +- .../test_stable_diffusion_controlnet_inpaint_single_file.py | 2 +- .../test_stable_diffusion_controlnet_single_file.py | 2 +- tests/single_file/test_stable_diffusion_single_file.py | 2 +- .../test_stable_diffusion_xl_adapter_single_file.py | 2 +- .../test_stable_diffusion_xl_controlnet_single_file.py | 2 +- 11 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 4e1713c9ceb1..10c229ba5e64 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -5,7 +5,7 @@ import torch from huggingface_hub import hf_hub_download, snapshot_download -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py index bba1726ae380..350a433cf68b 100644 --- a/tests/single_file/test_model_vae_single_file.py +++ b/tests/single_file/test_model_vae_single_file.py @@ -18,9 +18,7 @@ import torch -from diffusers import ( - AutoencoderKL, -) +from diffusers import AutoencoderKL from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py index 7f0e1c1a4b0b..83acbdc56504 100644 --- a/tests/single_file/test_model_wan_autoencoder_single_file.py +++ b/tests/single_file/test_model_wan_autoencoder_single_file.py @@ -16,9 +16,7 @@ import gc import unittest -from diffusers import ( - AutoencoderKLWan, -) +from diffusers import AutoencoderKLWan from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py index 36f0919cacb5..8a20135a8aab 100644 --- a/tests/single_file/test_model_wan_transformer3d_single_file.py +++ b/tests/single_file/test_model_wan_transformer3d_single_file.py @@ -18,9 +18,7 @@ import torch -from diffusers import ( - WanTransformer3DModel, -) +from diffusers import WanTransformer3DModel from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index 802ca37abfc3..061a5698d33a 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -3,9 +3,7 @@ import torch -from diffusers import ( - SanaTransformer2DModel, -) +from diffusers import SanaTransformer2DModel from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 7589b48028c2..29972241151d 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 1555831db6db..166365a9a064 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index 2c1e414e5e36..35ce5daac250 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 78baeb94929c..338842e8d08e 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index fb5f8725b86e..06a4f0bf5b3a 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -8,7 +8,7 @@ StableDiffusionXLAdapterPipeline, T2IAdapter, ) -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index 6d8c4369e1e1..9ea2fcc29710 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -5,7 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name +from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( backend_empty_cache, From d6430c79a30897c66a13065a3fe2815dedbc02e9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 18:11:39 +0530 Subject: [PATCH 23/26] updates --- tests/lora/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 87a8fddfa583..fe132944a835 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1017,7 +1017,7 @@ def test_multiple_wrong_adapter_name_raises_error(self): ) scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} - logger = logging.get_logger("diffusers.loaders.lora_base") + logger = logging.get_logger("diffusers.loaders.lora.lora_base") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components) @@ -1824,7 +1824,7 @@ def test_logs_info_when_no_lora_keys_found(self): elif lora_module == "text_encoder_2": prefix = "text_encoder_2" - logger = logging.get_logger("diffusers.loaders.lora_base") + logger = logging.get_logger("diffusers.loaders.lora.lora_base") logger.setLevel(logging.WARNING) with CaptureLogger(logger) as cap_logger: @@ -1925,7 +1925,7 @@ def test_lora_B_bias(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] From e4dd7c533389fc8dcbfc721e0e5e2eb459d9e8ba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 16 Apr 2025 18:26:27 +0530 Subject: [PATCH 24/26] updates --- tests/lora/test_lora_layers_flux.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 860aa6511689..7124fcb41200 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -352,7 +352,7 @@ def test_with_norm_in_state_dict(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -403,7 +403,7 @@ def test_lora_parameter_expanded_shapes(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. @@ -486,7 +486,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -541,7 +541,7 @@ def test_normal_lora_with_expanded_lora_raises_error(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -590,7 +590,7 @@ def test_fuse_expanded_lora_with_regular_lora(self): pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) out_features, in_features = pipe.transformer.x_embedder.weight.shape @@ -653,7 +653,7 @@ def test_load_regular_lora(self): "transformer.x_embedder.lora_B.weight": normal_lora_B.weight, } - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.INFO) with CaptureLogger(logger) as cap_logger: pipe.load_lora_weights(lora_state_dict, "adapter-1") @@ -668,7 +668,7 @@ def test_load_regular_lora(self): def test_lora_unload_with_parameter_expanded_shapes(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. @@ -734,7 +734,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self): def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self): components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) - logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline") logger.setLevel(logging.DEBUG) # Change the transformer config to mimic a real use case. From 9bebdf225d92ef4abb92d7ef05c20831ec2d2e2d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 00:11:52 +0800 Subject: [PATCH 25/26] fix repo consistency. --- src/diffusers/loaders/lora/lora_pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py index 75613bfcb4fc..32e5fa1e6b71 100644 --- a/src/diffusers/loaders/lora/lora_pipeline.py +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -5739,7 +5739,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -5835,7 +5835,7 @@ def lora_state_dict( return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -5894,7 +5894,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel + # Copied from diffusers.loaders.lora.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): @@ -5934,7 +5934,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -5981,7 +5981,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -6029,7 +6029,7 @@ def fuse_lora( **kwargs, ) - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of From 0c3789544034b6579ec53bda011fbff60f6b70f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 00:26:28 +0800 Subject: [PATCH 26/26] consistency --- src/diffusers/loaders/lora/lora_pipeline.py | 619 ++++---------------- 1 file changed, 125 insertions(+), 494 deletions(-) diff --git a/src/diffusers/loaders/lora/lora_pipeline.py b/src/diffusers/loaders/lora/lora_pipeline.py index 32e5fa1e6b71..27705041408d 100644 --- a/src/diffusers/loaders/lora/lora_pipeline.py +++ b/src/diffusers/loaders/lora/lora_pipeline.py @@ -127,7 +127,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name=None, + adapter_name: Optional[str] = None, hotswap: bool = False, **kwargs, ): @@ -154,7 +154,7 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) + hotswap (`bool`, *optional*): Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means that, instead of loading an additional adapter, this will take the existing adapter weights and replace them with the weights of the new adapter. This can be faster and more @@ -368,29 +368,8 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -451,29 +430,8 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -625,6 +583,7 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, + hotswap: bool = False, **kwargs, ): """ @@ -651,6 +610,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -689,6 +650,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -699,6 +661,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -709,6 +672,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -859,29 +823,8 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -943,29 +886,8 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1248,29 +1170,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1345,29 +1246,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1423,29 +1303,8 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1701,7 +1560,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -1719,6 +1582,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1748,6 +1613,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1771,29 +1637,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -2076,7 +1921,7 @@ def lora_state_dict( def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name=None, + adapter_name: Optional[str] = None, hotswap: bool = False, **kwargs, ): @@ -2095,34 +1940,16 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new - adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an - additional method before loading the adapter: - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2244,29 +2071,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2376,29 +2182,8 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2858,29 +2643,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2936,29 +2700,8 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -3135,7 +2878,11 @@ def lora_state_dict( return state_dict def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3153,6 +2900,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3182,6 +2931,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3205,29 +2955,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3466,7 +3195,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3484,6 +3217,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3513,6 +3248,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3536,29 +3272,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3799,7 +3514,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3817,6 +3536,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3846,6 +3567,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3869,29 +3591,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4132,7 +3833,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4150,6 +3855,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4179,6 +3886,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4202,29 +3910,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4468,7 +4155,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4486,6 +4177,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4515,6 +4208,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4538,29 +4232,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4805,7 +4478,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4823,6 +4500,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4852,6 +4531,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4875,29 +4555,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5167,7 +4826,11 @@ def _maybe_expand_t2v_lora_for_i2v( return state_dict def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5185,6 +4848,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5218,6 +4883,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -5241,29 +4907,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5504,7 +5149,11 @@ def lora_state_dict( # Copied from diffusers.loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5522,6 +5171,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5551,6 +5202,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -5574,29 +5226,8 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError(