Skip to content

[LoRA] parse metadata from LoRA and save metadata #11324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -45,6 +46,7 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata


if is_transformers_available():
Expand Down Expand Up @@ -206,6 +208,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
load_with_metadata=False,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
Expand All @@ -223,6 +226,9 @@ def _fetch_state_dict(
file_extension=".safetensors",
local_files_only=local_files_only,
)
if load_with_metadata and not weight_name.endswith(".safetensors"):
raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.")

model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
Expand All @@ -236,6 +242,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
if load_with_metadata:
state_dict = _maybe_populate_state_dict_with_metadata(
state_dict, model_file, metadata_key="lora_adapter_metadata"
)

except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
Expand Down Expand Up @@ -882,16 +893,31 @@ def write_lora_layers(
weight_name: str,
save_function: Callable,
safe_serialization: bool,
lora_adapter_metadata: dict = None,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if lora_adapter_metadata is not None and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if not isinstance(lora_adapter_metadata, dict):
raise ValueError("`lora_adapter_metadata` must be of type `dict`.")

if save_function is None:
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# We need to be able to serialize the NoneTypes too, otherwise we run into
# 'NoneType' object cannot be converted to 'PyString'
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)

return safetensors.torch.save_file(weights, filename, metadata=metadata)

else:
save_function = torch.save
Expand Down
84 changes: 66 additions & 18 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4726,6 +4726,7 @@ def lora_state_dict(
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

load_with_metadata: TODO
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
Expand Down Expand Up @@ -4760,6 +4761,7 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
load_with_metadata = kwargs.pop("load_with_metadata", False)

allow_pickle = False
if use_safetensors is None:
Expand All @@ -4784,6 +4786,7 @@ def lora_state_dict(
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
load_with_metadata=load_with_metadata,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
Expand Down Expand Up @@ -4857,6 +4860,7 @@ def load_lora_weights(
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
load_with_metadata = kwargs.get("load_with_metadata", False)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
Expand All @@ -4883,32 +4887,68 @@ def load_lora_weights(
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
load_with_metadata=load_with_metadata,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
load_with_metadata: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.

Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`WanTransformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
This will load the LoRA layers specified in `state_dict` into `transformer`.

Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed
directly into the unet or prefixed with an additional `unet` which can be used to distinguish
between text encoder lora layers.
transformer (`WanTransformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the
random weights.
<<<<<<< HEAD
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded
adapter in-place. This means that, instead of loading an additional adapter, this will take the
existing adapter weights and replace them with the weights of the new adapter. This can be
faster and more memory efficient. However, the main advantage of hotswapping is that when the
model is compiled with torch.compile, loading the new adapter does not require recompilation of
the model. When using hotswapping, the passed `adapter_name` should be the name of an already
loaded adapter.

If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling),
you need to call an additional method before loading the adapter:

```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```

Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
load_with_metadata: TODO
=======
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
>>>>>>> main
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
Expand All @@ -4924,6 +4964,7 @@ def load_lora_into_transformer(
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
load_with_metadata=load_with_metadata,
)

@classmethod
Expand All @@ -4936,6 +4977,7 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Expand All @@ -4955,15 +4997,20 @@ def save_lora_weights(
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata: TODO
"""
state_dict = {}
lora_adapter_metadata = {}

if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")

if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))

if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))

# Save the model
cls.write_lora_layers(
state_dict=state_dict,
Expand All @@ -4972,6 +5019,7 @@ def save_lora_weights(
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
Expand Down
51 changes: 43 additions & 8 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -114,7 +115,12 @@ def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)

def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
self,
pretrained_model_name_or_path_or_dict,
prefix="transformer",
hotswap: bool = False,
load_with_metadata: bool = False,
**kwargs,
):
r"""
Loads a LoRA adapter into the underlying model.
Expand Down Expand Up @@ -182,6 +188,8 @@ def load_lora_adapter(
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap

load_with_metadata: TODO
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -224,13 +232,18 @@ def load_lora_adapter(
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
load_with_metadata=load_with_metadata,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

if prefix is not None:
metadata = state_dict.pop("lora_metadata", None)
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if metadata is not None:
state_dict["lora_metadata"] = metadata

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
Expand Down Expand Up @@ -262,7 +275,13 @@ def load_lora_adapter(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = get_peft_kwargs(
rank,
network_alpha_dict=network_alphas,
peft_state_dict=state_dict,
load_with_metadata=load_with_metadata,
prefix=prefix,
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
Expand All @@ -285,7 +304,11 @@ def load_lora_adapter(
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
logger.error(f"`LoraConfig` class could not be instantiated with the following trace: {e}.")

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
Expand Down Expand Up @@ -429,6 +452,7 @@ def save_lora_adapter(
upcast_before_saving: bool = False,
safe_serialization: bool = True,
weight_name: Optional[str] = None,
lora_adapter_metadata: Optional[dict] = None,
):
"""
Save the LoRA parameters corresponding to the underlying model.
Expand All @@ -440,18 +464,20 @@ def save_lora_adapter(
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
lora_adapter_metadata: TODO
"""
from peft.utils import get_peft_model_state_dict

from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE

if lora_adapter_metadata is not None and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if not isinstance(lora_adapter_metadata, dict):
raise ValueError("`lora_adapter_metadata` must be of type `dict`.")

if adapter_name is None:
adapter_name = get_adapter_name(self)

Expand All @@ -467,7 +493,16 @@ def save_lora_adapter(
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# We need to be able to serialize the NoneTypes too, otherwise we run into
# 'NoneType' object cannot be converted to 'PyString'
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)

return safetensors.torch.save_file(weights, filename, metadata=metadata)

else:
save_function = torch.save
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False
):
if load_with_metadata:
if "lora_metadata" not in peft_state_dict:
raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.")
metadata = peft_state_dict["lora_metadata"]
if prefix is not None:
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
return metadata

rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
Expand Down
Loading
Loading