From 5139de1165bc833253dd09099fa9b60c080aba81 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 11:44:31 +0530 Subject: [PATCH 1/5] feat: parse metadata from lora state dicts. --- src/diffusers/loaders/lora_base.py | 21 +++++++++++++- src/diffusers/loaders/lora_pipeline.py | 23 +++++++++++++++- src/diffusers/loaders/peft.py | 38 +++++++++++++++++++++++--- src/diffusers/utils/peft_utils.py | 7 ++++- tests/lora/test_lora_layers_wan.py | 33 ++++++++++++++++++---- 5 files changed, 110 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 280a9fa6e73f..ae590245f3c8 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -206,6 +206,7 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, + load_with_metadata=False, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -223,6 +224,9 @@ def _fetch_state_dict( file_extension=".safetensors", local_files_only=local_files_only, ) + if load_with_metadata and not weight_name.endswith(".safetensors"): + raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.") + model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, @@ -236,6 +240,12 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + if load_with_metadata: + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + if hasattr(f, "metadata") and f.metadata() is not None: + state_dict["_metadata"] = f.metadata() + else: + raise ValueError("Metadata couldn't be parsed from the safetensors file.") except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -882,16 +892,25 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + lora_adapter_metadata: dict = None, ): if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + if lora_adapter_metadata is not None and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if not isinstance(lora_adapter_metadata, dict): + raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + metadata.update(lora_adapter_metadata) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2e241bc9ffad..00596b1b0139 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4734,6 +4734,7 @@ def lora_state_dict( - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + load_with_metadata: TODO cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. @@ -4768,6 +4769,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + load_with_metadata = kwargs.pop("load_with_metadata", False) allow_pickle = False if use_safetensors is None: @@ -4792,6 +4794,7 @@ def lora_state_dict( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + load_with_metadata=load_with_metadata, ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) @@ -4859,6 +4862,7 @@ def load_lora_weights( raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + load_with_metdata = kwargs.get("load_with_metdata", False) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -4885,12 +4889,20 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + load_with_metdata=load_with_metdata, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + load_with_metadata: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4931,6 +4943,7 @@ def load_lora_into_transformer( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + load_with_metadata: TODO """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4946,6 +4959,7 @@ def load_lora_into_transformer( _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, + load_with_metadata=load_with_metadata, ) @classmethod @@ -4958,6 +4972,7 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -4977,8 +4992,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -4986,6 +5003,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -4994,6 +5014,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9165c46f3c78..208ee4b7a5fe 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -113,7 +113,12 @@ def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( - self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + self, + pretrained_model_name_or_path_or_dict, + prefix="transformer", + hotswap: bool = False, + load_with_metadata: bool = False, + **kwargs, ): r""" Loads a LoRA adapter into the underlying model. @@ -181,6 +186,8 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + + load_with_metadata: TODO """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -223,10 +230,14 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + load_with_metadata=load_with_metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + if load_with_metadata is not None and not use_safetensors: + raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") + if prefix is not None: state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} @@ -261,7 +272,12 @@ def load_lora_adapter( alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + lora_config_kwargs = get_peft_kwargs( + rank, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + load_with_metadata=load_with_metadata, + ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: @@ -284,7 +300,11 @@ def load_lora_adapter( if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") - lora_config = LoraConfig(**lora_config_kwargs) + try: + lora_config = LoraConfig(**lora_config_kwargs) + except TypeError as e: + logger.error(f"`LoraConfig` class could not be instantiated with the following trace: {e}.") + # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) @@ -428,6 +448,7 @@ def save_lora_adapter( upcast_before_saving: bool = False, safe_serialization: bool = True, weight_name: Optional[str] = None, + lora_adapter_metadata: Optional[dict] = None, ): """ Save the LoRA parameters corresponding to the underlying model. @@ -446,11 +467,17 @@ def save_lora_adapter( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. + lora_adapter_metadata: TODO """ from peft.utils import get_peft_model_state_dict from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + if lora_adapter_metadata is not None and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if not isinstance(lora_adapter_metadata, dict): + raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + if adapter_name is None: adapter_name = get_adapter_name(self) @@ -466,7 +493,10 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + metadata.update(lora_adapter_metadata) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index d1269fbc5f20..28c3ab29773f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,7 +147,12 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, load_with_metadata=False): + if load_with_metadata: + if "_metadata" not in peft_state_dict: + raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") + return peft_state_dict["_metadata"] + rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index c2498fa68c3d..8d13339e554d 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import tempfile import unittest import torch @@ -24,11 +25,7 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import ( - floats_tensor, - require_peft_backend, - skip_mps, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device sys.path.append(".") @@ -141,3 +138,29 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_save_load(self): pass + + def test_save_load_with_adapter_metadata(self): + # Will write the test in utils.py eventually. + scheduler_cls = self.scheduler_classes[0] + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + safe_serialization=False, + lora_adapter_metadata=denoiser_lora_config.to_dict(), + **lora_state_dicts, + ) From d8a305e0eec38c01158aa9acce784cacc6303bdc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 14:43:49 +0530 Subject: [PATCH 2/5] tests --- src/diffusers/loaders/lora_base.py | 19 ++++++--- src/diffusers/loaders/peft.py | 18 ++++++--- src/diffusers/utils/peft_utils.py | 9 ++++- src/diffusers/utils/state_dict_utils.py | 17 ++++++++ tests/lora/test_lora_layers_wan.py | 52 +++++++++++++++++++++---- tests/lora/utils.py | 27 +++++++++++++ 6 files changed, 121 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index ae590245f3c8..a845fab8d97e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -45,6 +46,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata if is_transformers_available(): @@ -241,11 +243,10 @@ def _fetch_state_dict( ) state_dict = safetensors.torch.load_file(model_file, device="cpu") if load_with_metadata: - with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - if hasattr(f, "metadata") and f.metadata() is not None: - state_dict["_metadata"] = f.metadata() - else: - raise ValueError("Metadata couldn't be parsed from the safetensors file.") + state_dict = _maybe_populate_state_dict_with_metadata( + state_dict, model_file, metadata_key="lora_adapter_config" + ) + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -907,9 +908,15 @@ def write_lora_layers( if safe_serialization: def save_function(weights, filename): + # We need to be able to serialize the NoneTypes too, otherwise we run into + # 'NoneType' object cannot be converted to 'PyString' metadata = {"format": "pt"} if lora_adapter_metadata is not None: - metadata.update(lora_adapter_metadata) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 208ee4b7a5fe..70425c96b153 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import os from functools import partial from pathlib import Path @@ -239,8 +240,12 @@ def load_lora_adapter( raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") if prefix is not None: + metadata = state_dict.pop("_metadata", None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + state_dict["_metadata"] = metadata + if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( @@ -277,6 +282,7 @@ def load_lora_adapter( network_alpha_dict=network_alphas, peft_state_dict=state_dict, load_with_metadata=load_with_metadata, + prefix=prefix, ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) @@ -460,10 +466,6 @@ def save_lora_adapter( underlying model has multiple adapters loaded. upcast_before_saving (`bool`, defaults to `False`): Whether to cast the underlying model to `torch.float32` before serialization. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. @@ -493,9 +495,15 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): + # We need to be able to serialize the NoneTypes too, otherwise we run into + # 'NoneType' object cannot be converted to 'PyString' metadata = {"format": "pt"} if lora_adapter_metadata is not None: - metadata.update(lora_adapter_metadata) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 28c3ab29773f..cd63616178bc 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,11 +147,16 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, load_with_metadata=False): +def get_peft_kwargs( + rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False +): if load_with_metadata: if "_metadata" not in peft_state_dict: raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") - return peft_state_dict["_metadata"] + metadata = peft_state_dict["_metadata"] + if prefix is not None: + metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} + return metadata rank_pattern = {} alpha_pattern = {} diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 3682c5bfacd6..45922ef162d2 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -16,6 +16,7 @@ """ import enum +import json from .import_utils import is_torch_available from .logging import get_logger @@ -347,3 +348,19 @@ def state_dict_all_zero(state_dict, filter_str=None): state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} return all(torch.all(param == 0).item() for param in state_dict.values()) + + +def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_key): + import safetensors.torch + + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + if hasattr(f, "metadata"): + metadata = f.metadata() + if metadata is not None: + metadata_keys = list(metadata.keys()) + if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): + peft_metadata = {k: v for k, v in metadata.items() if k != "format"} + state_dict["_metadata"] = json.loads(peft_metadata[metadata_key]) + else: + raise ValueError("Metadata couldn't be parsed from the safetensors file.") + return state_dict diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 8d13339e554d..0de2c5978516 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -16,6 +16,7 @@ import tempfile import unittest +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -30,7 +31,7 @@ sys.path.append(".") -from utils import PeftLoraLoaderMixinTests # noqa: E402 +from utils import PeftLoraLoaderMixinTests, check_if_dicts_are_equal # noqa: E402 @require_peft_backend @@ -139,13 +140,39 @@ def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_save_load(self): pass - def test_save_load_with_adapter_metadata(self): + def test_adapter_metadata_is_loaded_correctly(self): # Will write the test in utils.py eventually. scheduler_cls = self.scheduler_classes[0] components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + metadata = denoiser_lora_config.to_dict() + self.pipeline_class.save_lora_weights( + save_directory=tmpdir, + transformer_lora_adapter_metadata=metadata, + **lora_state_dicts, + ) + pipe.unload_lora_weights() + state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) + + self.assertTrue("_metadata" in state_dict) + + parsed_metadata = state_dict["_metadata"] + parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} + check_if_dicts_are_equal(parsed_metadata, metadata) + + def test_adapter_metadata_save_load_inference(self): + # Will write the test in utils.py eventually. + scheduler_cls = self.scheduler_classes[0] + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -154,13 +181,22 @@ def test_save_load_with_adapter_metadata(self): pipe, _ = self.check_if_adapters_added_correctly( pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: + with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + metadata = denoiser_lora_config.to_dict() self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - safe_serialization=False, - lora_adapter_metadata=denoiser_lora_config.to_dict(), + save_directory=tmpdir, + transformer_lora_adapter_metadata=metadata, **lora_state_dicts, ) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir, load_with_metadata=True) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 27fef495a484..9cd26f221850 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -61,6 +61,33 @@ def state_dicts_almost_equal(sd1, sd2): return models_are_equal +def check_if_dicts_are_equal(dict1, dict2): + for key, value in dict1.items(): + if isinstance(value, set): + dict1[key] = list(value) + for key, value in dict2.items(): + if isinstance(value, set): + dict2[key] = list(value) + + for key in dict1: + if key not in dict2: + raise ValueError( + f"Key '{key}' is missing in the second dictionary. Its value in the first dictionary is {dict1[key]}." + ) + if dict1[key] != dict2[key]: + raise ValueError( + f"Difference found at key '{key}': first dictionary has {dict1[key]}, second dictionary has {dict2[key]}." + ) + + for key in dict2: + if key not in dict1: + raise ValueError( + f"Key '{key}' is missing in the first dictionary. Its value in the second dictionary is {dict2[key]}." + ) + + return True + + def check_if_lora_correctly_set(model) -> bool: """ Checks if the LoRA layers are correctly set with peft From ba546bcbd84dbe1bfa33fe873ded354a8f26751e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 15:05:50 +0530 Subject: [PATCH 3/5] fix tests --- src/diffusers/loaders/lora_pipeline.py | 2 +- src/diffusers/loaders/peft.py | 7 ++----- src/diffusers/utils/peft_utils.py | 4 ++-- src/diffusers/utils/state_dict_utils.py | 2 +- tests/lora/test_lora_layers_wan.py | 4 ++-- 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 00596b1b0139..8456f04d9e84 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4889,7 +4889,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - load_with_metdata=load_with_metdata, + load_with_metadata=load_with_metdata, ) @classmethod diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 70425c96b153..6b61fe03724a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -236,15 +236,12 @@ def load_lora_adapter( if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - if load_with_metadata is not None and not use_safetensors: - raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") - if prefix is not None: - metadata = state_dict.pop("_metadata", None) + metadata = state_dict.pop("lora_metadata", None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if metadata is not None: - state_dict["_metadata"] = metadata + state_dict["lora_metadata"] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index cd63616178bc..6408de79c2a7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -151,9 +151,9 @@ def get_peft_kwargs( rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False ): if load_with_metadata: - if "_metadata" not in peft_state_dict: + if "lora_metadata" not in peft_state_dict: raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") - metadata = peft_state_dict["_metadata"] + metadata = peft_state_dict["lora_metadata"] if prefix is not None: metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} return metadata diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 45922ef162d2..2723ab822df1 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -360,7 +360,7 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke metadata_keys = list(metadata.keys()) if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): peft_metadata = {k: v for k, v in metadata.items() if k != "format"} - state_dict["_metadata"] = json.loads(peft_metadata[metadata_key]) + state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key]) else: raise ValueError("Metadata couldn't be parsed from the safetensors file.") return state_dict diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 0de2c5978516..4d8b06d748e1 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -162,9 +162,9 @@ def test_adapter_metadata_is_loaded_correctly(self): pipe.unload_lora_weights() state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) - self.assertTrue("_metadata" in state_dict) + self.assertTrue("lora_metadata" in state_dict) - parsed_metadata = state_dict["_metadata"] + parsed_metadata = state_dict["lora_metadata"] parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} check_if_dicts_are_equal(parsed_metadata, metadata) From 61d37086b5a04fc3e4ef513d37f66cf8ecc92fef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 17:26:38 +0530 Subject: [PATCH 4/5] key renaming --- src/diffusers/loaders/lora_base.py | 4 ++-- src/diffusers/loaders/peft.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a845fab8d97e..a2295edb7b2b 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -244,7 +244,7 @@ def _fetch_state_dict( state_dict = safetensors.torch.load_file(model_file, device="cpu") if load_with_metadata: state_dict = _maybe_populate_state_dict_with_metadata( - state_dict, model_file, metadata_key="lora_adapter_config" + state_dict, model_file, metadata_key="lora_adapter_metadata" ) except (IOError, safetensors.SafetensorError) as e: @@ -915,7 +915,7 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) return safetensors.torch.save_file(weights, filename, metadata=metadata) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 81f70ce587bb..8fd18c64df40 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -500,7 +500,7 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) return safetensors.torch.save_file(weights, filename, metadata=metadata) From e98fb846e48e709ce2ce5918a1badc85de487179 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 18:54:54 +0530 Subject: [PATCH 5/5] fix --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 35b19b13ae22..76ad07355d0f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5195,7 +5195,7 @@ def load_lora_weights( raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - load_with_metdata = kwargs.get("load_with_metdata", False) + load_with_metadata = kwargs.get("load_with_metadata", False) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -5222,7 +5222,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - load_with_metadata=load_with_metdata, + load_with_metadata=load_with_metadata, ) @classmethod