Skip to content

[LoRA] parse metadata from LoRA and save metadata #11324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
5139de1
feat: parse metadata from lora state dicts.
sayakpaul Apr 15, 2025
d8a305e
tests
sayakpaul Apr 15, 2025
ba546bc
fix tests
sayakpaul Apr 15, 2025
25f826e
Merge branch 'main' into metadata-lora
sayakpaul Apr 15, 2025
61d3708
key renaming
sayakpaul Apr 15, 2025
e98fb84
fix
sayakpaul Apr 15, 2025
2f1c326
Merge branch 'main' into metadata-lora
sayakpaul Apr 15, 2025
d390d4d
Merge branch 'main' into metadata-lora
sayakpaul Apr 16, 2025
201bd7b
resolve conflicts.
sayakpaul Apr 21, 2025
a771982
Merge branch 'main' into metadata-lora
sayakpaul May 2, 2025
42bb6bc
smol update
sayakpaul May 2, 2025
7ec4ef4
smol updates
sayakpaul May 2, 2025
7f59ca0
load metadata.
sayakpaul May 2, 2025
ded2fd6
automatically save metadata in save_lora_adapter.
sayakpaul May 2, 2025
d5b3037
propagate changes.
sayakpaul May 2, 2025
bee9e00
changes
sayakpaul May 2, 2025
a9f5088
add test to models too.
sayakpaul May 2, 2025
7716303
tigher tests.
sayakpaul May 2, 2025
0ac1a39
updates
sayakpaul May 2, 2025
4b51bbf
fixes
sayakpaul May 2, 2025
e2ca95a
rename tests.
sayakpaul May 2, 2025
7a2ba69
Merge branch 'main' into metadata-lora
sayakpaul May 3, 2025
e0449c2
sorted.
sayakpaul May 3, 2025
918aef1
Update src/diffusers/loaders/lora_base.py
sayakpaul May 3, 2025
4bd325c
review suggestions.
sayakpaul May 3, 2025
e8bec86
removeprefix.
sayakpaul May 5, 2025
aa5cb3c
Merge branch 'main' into metadata-lora
sayakpaul May 5, 2025
7bb6c9f
propagate changes.
sayakpaul May 8, 2025
116306e
fix-copies
sayakpaul May 8, 2025
ae0580a
sd
sayakpaul May 8, 2025
f6fde6f
docs.
sayakpaul May 8, 2025
cbb4071
resolve conflicts.
sayakpaul May 8, 2025
87417b2
fixes
sayakpaul May 8, 2025
55a41bf
Merge branch 'main' into metadata-lora
sayakpaul May 9, 2025
16dba2d
get review ready.
sayakpaul May 9, 2025
023c0fe
Merge branch 'main' into metadata-lora
sayakpaul May 9, 2025
67bceda
one more test to catch error.
sayakpaul May 9, 2025
83a8995
merge conflicts.
sayakpaul May 9, 2025
d336486
Merge branch 'main' into metadata-lora
sayakpaul May 11, 2025
4f2d90c
Merge branch 'main' into metadata-lora
sayakpaul May 12, 2025
42a0d1c
Merge branch 'main' into metadata-lora
sayakpaul May 15, 2025
9c32dc2
Merge branch 'main' into metadata-lora
linoytsaban May 18, 2025
5d578c9
Merge branch 'main' into metadata-lora
sayakpaul May 19, 2025
1c37845
Merge branch 'main' into metadata-lora
linoytsaban May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -45,6 +46,7 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata


if is_transformers_available():
Expand Down Expand Up @@ -206,6 +208,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
load_with_metadata=False,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
Expand All @@ -223,6 +226,9 @@ def _fetch_state_dict(
file_extension=".safetensors",
local_files_only=local_files_only,
)
if load_with_metadata and not weight_name.endswith(".safetensors"):
raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.")

model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
Expand All @@ -236,6 +242,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
if load_with_metadata:
state_dict = _maybe_populate_state_dict_with_metadata(
state_dict, model_file, metadata_key="lora_adapter_metadata"
)

except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
Expand Down Expand Up @@ -882,16 +893,31 @@ def write_lora_layers(
weight_name: str,
save_function: Callable,
safe_serialization: bool,
lora_adapter_metadata: dict = None,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if lora_adapter_metadata is not None and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if not isinstance(lora_adapter_metadata, dict):
raise ValueError("`lora_adapter_metadata` must be of type `dict`.")

if save_function is None:
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# We need to be able to serialize the NoneTypes too, otherwise we run into
# 'NoneType' object cannot be converted to 'PyString'
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)

return safetensors.torch.save_file(weights, filename, metadata=metadata)
Comment on lines +920 to +922
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In light of previous comment, I see that the metadata is being added to the state dict as well as the safetensors metadata. A little confused as to what the intention is with adding the metadata to state dict


else:
save_function = torch.save
Expand Down
23 changes: 22 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5067,6 +5067,7 @@ def lora_state_dict(
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

load_with_metadata: TODO
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
Expand Down Expand Up @@ -5101,6 +5102,7 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
load_with_metadata = kwargs.pop("load_with_metadata", False)

allow_pickle = False
if use_safetensors is None:
Expand All @@ -5125,6 +5127,7 @@ def lora_state_dict(
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
load_with_metadata=load_with_metadata,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
Expand Down Expand Up @@ -5192,6 +5195,7 @@ def load_lora_weights(
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
load_with_metdata = kwargs.get("load_with_metdata", False)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
Expand All @@ -5218,12 +5222,20 @@ def load_lora_weights(
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
load_with_metadata=load_with_metdata,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
load_with_metadata: bool = False,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Expand Down Expand Up @@ -5264,6 +5276,7 @@ def load_lora_into_transformer(
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
load_with_metadata: TODO
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
Expand All @@ -5279,6 +5292,7 @@ def load_lora_into_transformer(
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
load_with_metadata=load_with_metadata,
)

@classmethod
Expand All @@ -5291,6 +5305,7 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Expand All @@ -5310,15 +5325,20 @@ def save_lora_weights(
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata: TODO
"""
state_dict = {}
lora_adapter_metadata = {}

if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")

if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))

if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))

# Save the model
cls.write_lora_layers(
state_dict=state_dict,
Expand All @@ -5327,6 +5347,7 @@ def save_lora_weights(
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
Expand Down
51 changes: 43 additions & 8 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -114,7 +115,12 @@ def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)

def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
self,
pretrained_model_name_or_path_or_dict,
prefix="transformer",
hotswap: bool = False,
load_with_metadata: bool = False,
**kwargs,
):
r"""
Loads a LoRA adapter into the underlying model.
Expand Down Expand Up @@ -182,6 +188,8 @@ def load_lora_adapter(
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap

load_with_metadata: TODO
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -224,13 +232,18 @@ def load_lora_adapter(
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
load_with_metadata=load_with_metadata,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

if prefix is not None:
metadata = state_dict.pop("lora_metadata", None)
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}

if metadata is not None:
state_dict["lora_metadata"] = metadata

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
Expand Down Expand Up @@ -262,7 +275,13 @@ def load_lora_adapter(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = get_peft_kwargs(
rank,
network_alpha_dict=network_alphas,
peft_state_dict=state_dict,
load_with_metadata=load_with_metadata,
prefix=prefix,
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
Expand All @@ -285,7 +304,11 @@ def load_lora_adapter(
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
logger.error(f"`LoraConfig` class could not be instantiated with the following trace: {e}.")

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
Expand Down Expand Up @@ -429,6 +452,7 @@ def save_lora_adapter(
upcast_before_saving: bool = False,
safe_serialization: bool = True,
weight_name: Optional[str] = None,
lora_adapter_metadata: Optional[dict] = None,
):
"""
Save the LoRA parameters corresponding to the underlying model.
Expand All @@ -440,18 +464,20 @@ def save_lora_adapter(
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
Comment on lines -447 to -450
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's not used.

safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
lora_adapter_metadata: TODO
"""
from peft.utils import get_peft_model_state_dict

from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE

if lora_adapter_metadata is not None and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if not isinstance(lora_adapter_metadata, dict):
raise ValueError("`lora_adapter_metadata` must be of type `dict`.")

if adapter_name is None:
adapter_name = get_adapter_name(self)

Expand All @@ -467,7 +493,16 @@ def save_lora_adapter(
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# We need to be able to serialize the NoneTypes too, otherwise we run into
# 'NoneType' object cannot be converted to 'PyString'
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)

return safetensors.torch.save_file(weights, filename, metadata=metadata)

else:
save_function = torch.save
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,17 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False
):
if load_with_metadata:
if "lora_metadata" not in peft_state_dict:
raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.")
metadata = peft_state_dict["lora_metadata"]
if prefix is not None:
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
return metadata

rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import enum
import json

from .import_utils import is_torch_available
from .logging import get_logger
Expand Down Expand Up @@ -347,3 +348,19 @@ def state_dict_all_zero(state_dict, filter_str=None):
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}

return all(torch.all(param == 0).item() for param in state_dict.values())


def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_key):
import safetensors.torch

with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
if hasattr(f, "metadata"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a safetensors feature, no? So why do we have to check with hasattr if the method exists?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Will remove.

metadata = f.metadata()
if metadata is not None:
metadata_keys = list(metadata.keys())
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
peft_metadata = {k: v for k, v in metadata.items() if k != "format"}
state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key])
else:
raise ValueError("Metadata couldn't be parsed from the safetensors file.")
return state_dict
Loading
Loading