Skip to content

[LoRA] parse metadata from LoRA and save metadata #11324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Apr 15, 2025

What does this PR do?

I know we have revisited this over and over again but this is becoming increasingly important. So, we should consider this on priority.

@a-r-r-o-w brought this issue to me while we were debugging something in Wan LoRA training. So, I started by just modifying the Wan LoRA loader (eventually, the changes will be propagated to other loaders too). Aryan, could you check if this change fixes the issue we were facing?

Admittedly, the PR can be cleaned and modularized a bit but I wanted to get something up quickly to get feedback on the direction.

TODOs

  • Docs
  • Modularize where possible
  • Propagate to other pipelines

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there were debates in the past, but I don't remember what the arguments for and against were. Could you quickly sketch why this is becoming increasingly important?

Implementation-wise, it generally LGTM. I was a bit confused about the lora_adapter_config key vs the lora_metadata key, what is the difference?

@sayakpaul
Copy link
Member Author

I remember there were debates in the past, but I don't remember what the arguments for and against were. Could you quickly sketch why this is becoming increasingly important?

Folks are using different ranks and alphas and finding it to be very effective in practice.

I was a bit confused about the lora_adapter_config key vs the lora_metadata key, what is the difference?

I will rename it to lora_adapter_metadata_key.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviving this conversation again @sayakpaul!

I see some TODOs marked so will take a better look once you let us know when the PR is ready, but changes look good to me and seem sensible from user end. Maybe it might make sense to allow user to also specify the metadata key, or us automatically detect some common names (are there any that you know of?)

@BenjaminBossan To provide some more context, my discussion with Sayak involved training "Control" LoRAs.

Expand for mine and Sayak's conversation

Hello curious 🐱 In short:

  • Expand the input projection layer's (patch embedding) input channels. Say we want to train with Canny conditioning on a 64-input channel model (Flux, for example), we'd expand the input projection linear to 128 channels with zeroed weights.
  • For the input latent to match the expanded linear proj, you concatenate the latents with the latents of the control condition.
  • In the "normal" latent stream (base layer), since we expanded with zeroed out weights, the control condition is effectively not added.
  • In the "lora" latent stream, the noisy latents and control latents are projected to the expected inner dim of the model. The lora latent stream outputs are added to the normal latent stream, effectively adding some amount of conditioning information from the new Canny condition
  • Now, in Flux, the linear projection goes from 64 channels -> 4096 channels. With the new conditioning related expansion, it is 128 channels -> 4096 channels. This is the "normal" latent stream.
  • In the "lora" latent stream, you have 128 channels -> "rank" channels -> 4096 channels. If the rank for the input projection layer is less than 128 channels, you might lose some information from the control conditioning (since the lora layer then acts as a 1-layer autoencoder).
  • I've noticed from a limited set of experiments that training with a low input projection rank (does not matter what the rank for rest of the network is) almost always produces worse results and contains random noise in the generations.
  • Due to this, in finetrainers, I made it so that the input projection layer has the same rank as inner_dim. So, effectively, in the "lora" latent stream, you have 128 channels -> 4096 channels -> 4096 channels. This results in much faster convergence (conclusion made from a very limited set of experiments and it intuitively made sense to me if you compare this low-rank version against a full-rank Control model).

Here are the relevant pieces of code for a quick look:

  • Expanding input channels: this
  • Creating the LoRA adapter with custom rank for input projection layer vs rest of the network: this
  • Channel-wise input concatenation with control latents: this
Convert to impasto style painting
Canny to Image

These results come with just 2500 training steps on a rank=64 lora (except rank=4096 for the input projection).

For this control lora training settings, every lora_alpha is the same as rank. So, the lora_alpha for the input projection is 4096, while for remaining layers is 64. The size of LoRA weights is between 300-1000 MB in most cases, which is much less compared to a full control conditioning model.

The problem comes when trying to load such a LoRA with diffusers. Since we only know the ranks, but not the lora_alpha, diffusers config uses the same lora_alpha as the inferred rank (which is 4096 since the input projection layer is the first lora layer in the state dict too). As you can imagine, setting an alpha to 4096 on all layers (even the ones that originally had rank = lora_alpha = 64 will result in random noise.

This is a more general problem because a lot of loras on CivitAI and from different trainers are trained with different alpha configurations from the rank. So making the assumption that lora_alpha=rank (which is our current default behaviour) is incorrect, and having this metadata information will be really helpful.

In order to solve the problem of being able train and run validation in finetrainers, we just directly serialize the lora adapter config directly into the safetensors metadata. See this and this, and we use a custom function to load the lora weights. Really big kudos to whoever implemented the save_function allowing custom ones to be provided!

@sayakpaul
Copy link
Member Author

I see some TODOs marked so will take a better look once you let us know when the PR is ready, but changes look good to me and seem sensible from user end. Maybe it might make sense to allow user to also specify the metadata key, or us automatically detect some common names (are there any that you know of?)

Well, I wanted to check if the changes help solve the issue of finetrainers. If possible could you let me know about that?

Regarding detecting metadata in non-diffusers checkpoints, we already either infer alphas or directly scale the LoRA weights. So, that should already be covered.

@a-r-r-o-w
Copy link
Member

Currently, finetrainers exports lora with a metadata attribute lora_config and not lora_adapter_metadata, so I don't have any checkpoint to test with at the moment. I will also need to do a full dummy run (training + validation) to check if all the changes here work as expected - from serialializing to loading - so please give me some time. I'm caught up with something else at the moment, so I'll run tests tomorrow and let you know.

@sayakpaul
Copy link
Member Author

SGTM, no rush. Thanks for helping.

@BenjaminBossan
Copy link
Member

Thanks for explaining this further @a-r-r-o-w

diffusers config uses the same lora_alpha as the inferred rank (which is 4096 since the input projection layer is the first lora layer in the state dict too)

There could be better heuristics, like using the most common rank, but I imagine changing that now would be backwards incompatible.

This is a more general problem because a lot of loras on CivitAI and from different trainers are trained with different alpha configurations from the rank. So making the assumption that lora_alpha=rank (which is our current default behaviour) is incorrect, and having this metadata information will be really helpful.

Any idea who other libraries/apps deal with this lack of info? I thought there was a bit of a standard to multiply the scale into the weights before saving, so that the assumption that scale==1 holds true at the end, but I'm not knowledgeable enough about common practices in this area.

@sayakpaul
Copy link
Member Author

Any idea who other libraries/apps deal with this lack of info? I thought there was a bit of a standard to multiply the scale into the weights before saving, so that the assumption that scale==1 holds true at the end, but I'm not knowledgeable enough about common practices in this area.

@BenjaminBossan many trainers just embed the alpha values in the state dict instead of maintaining separate configuration metadata. See here as an example:

scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")

@sayakpaul
Copy link
Member Author

If no objections, I will button up the PR and make it generally available to the other loaders. But LMK if you have anything to share or have any concerns over this direction.

@BenjaminBossan @a-r-r-o-w

@BenjaminBossan
Copy link
Member

Personally, I don't have any issues, just wondering if anything needs to be discussed (again) with the safetensors folks.

@sayakpaul
Copy link
Member Author

just wondering if anything needs to be discussed (again) with the safetensors folks.

Well, previously there was no problem with them IIRC. It was about time and the complexity. But not we have enough evidence to ship it maintaining the single-file format LoRA checkpoints.

@a-r-r-o-w
Copy link
Member

@sayakpaul I've verified it to work, thanks for working on this. My understanding was that the lora initialization config would also be automatically saved into the exported safetensors as metadata, but it needs to be passed manually to save_lora_weights. If not passed, it leads to the following error (this example uses a different rank/alpha for the patch layer as described above):

RuntimeError: Error(s) in loading state_dict for WanTransformer3DModel:
        size mismatch for patch_embedding.lora_A.default_0.weight: copying a param with shape torch.Size([1536, 32, 1, 2, 2]) from checkpoint, the shape in current model is torch.Size([128, 32, 1, 2, 2]).
        size mismatch for patch_embedding.lora_B.default_0.weight: copying a param with shape torch.Size([1536, 1536, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1536, 128, 1, 1, 1]).

I think it might be good if we serialize this info automatically in save_lora_weights, if you think there's an easy way to do it.

@sayakpaul
Copy link
Member Author

Thanks for trying! Well, we're serializing the the entire LoraConfig() when save_lora_weights() is called with a non-none transformer_lora_adapter_metadata:

transformer_lora_adapter_metadata: Optional[dict] = None,

The test below also confirms that:

def test_adapter_metadata_is_loaded_correctly(self):

If you could provide a minimal reproducer for the issue I would be happy to look into it on priority.

@a-r-r-o-w
Copy link
Member

Yes, what you mentioned works as expected: we need to pass transformer_lora_adapter_metadata ourselves with the lora config when calling save_lora_weights. This is already quite convenient. I was asking if it would be possible/make sense to do this behind-the-scenes without needing to pass the metadata explicitly (since the model already "knows" what lora layers it has and what it was initialized with). Or is it not really ideal?

@sayakpaul
Copy link
Member Author

I was asking if it would be possible/make sense to do this behind-the-scenes without needing to pass the metadata explicitly (since the model already "knows" what lora layers it has and what it was initialized with). Or is it not really ideal?

Ah I see. I might have an idea. Will update soon. Thanks again for trying it out quickly.

@bghira
Copy link
Contributor

bghira commented Apr 17, 2025

why can't we add alpha to the state dict like kohya trainer does?

@sayakpaul
Copy link
Member Author

why can't we add alpha to the state dict like kohya trainer does?

The code path for diffusers state dict would have to be modified a lot more + the current solution being proposed in this PR is easier than those changes.

@bghira
Copy link
Contributor

bghira commented Apr 18, 2025

so because "its harder" we will end up with a solution that requires other tools like comfyUI to implement special support to read these attributes and set them?

i expected we would be adding compatibility points and not adding more tech debt. please reconsider just adding alpha to the state dict or leave the work to the community if it is too hard?

this kinda thing really takes an eternity to be implemented by tools like Swarm and Forge and ComfyUI while the alpha attribute in dict key would Just Work even with AUTOMATIC1111. please tell me you see the value in this..

@bghira
Copy link
Contributor

bghira commented Apr 18, 2025

also it wasnt arrow who brought this up first, it was me, because simpletuner has the option to set rank differently to alpha, which peft has always supported but diffusers does not. setting alpha to 1 always would allow learning rates to stay the same across every rank. the initial request was just to write the alpha into the state dict key.

this proposal is heavy handed and requires too much effort from everybody.

@sayakpaul
Copy link
Member Author

so because "its harder" we will end up with a solution that requires other tools like comfyUI to implement special support to read these attributes and set them?

i expected we would be adding compatibility points and not adding more tech debt. please reconsider just adding alpha to the state dict or leave the work to the community if it is too hard?

this kinda thing really takes an eternity to be implemented by tools like Swarm and Forge and ComfyUI while the alpha attribute in dict key would Just Work even with AUTOMATIC1111. please tell me you see the value in this..

We value other formats and hence we support most of them directly within the library.

At the same time, I also think it makes sense to be lenient towards our libraries and think about long-term maintainability. With the current proposal, we can serialize additional useful things like rank_pattern, alpha_pattern, etc. (supported by LoraConfig) which are becoming important to tweak from what I understand.

I am not sure that for diffusers-format LoRAs, it's sufficient to just write alpha keys and the other community tools would just work correctly with them.

also it wasnt arrow who brought this up first, it was me, because simpletuner has the option to set rank differently to alpha, which peft has always supported but diffusers does not.

Corrected my issue description but please note what my description starts with:

I know we have revisited this over and over again

setting alpha to 1

We set alpha=rank in our official scripts currently and also for diffusers-format state dicts.


I would like to take a second opinion from @DN6, especially for maintainability. To summarize, the issue is that we just serialize the LoRA state dict in diffusers but no other info from LoraConfig is provided in the serialized state dict when training. So, if someone trains with say, an alpha that doesn't equal rank, we won't be inferring it from the state dict as there's no info in the state dict about it. This will make inference wrong.

This PR proposes to serialize the LoraConfig as a part of the metadata of the state dict as safetensors allows for that. This will allow us to parse the entire LoraConfig from the state dict easily. If it were only for alpha values, we could have simply serialized them as a part of the state dict (not metadata) but with the current PR proposal we have the flexibility to leverage full specturm of LoraConfig (additional things like alpha_pattern, rank_pattern, etc.).

@sayakpaul
Copy link
Member Author

We can still just serialize alpha keys in the diffusers-format LoRA state dicts (in retrospect, I think it will be easier to do actually), considering what's passed to the LoraConfig. Then make changes to the loader as needed. But before that I just want to be sure this is the design we want to have. So, will let @DN6 also comment here.

@bghira
Copy link
Contributor

bghira commented Apr 18, 2025

i agree the config is needed for other more intricate patterns and whatever Diffusers does there can even become the standard depending on how @kohya-ss and @comfyanonymous feel about it, i'd rather something "common" happen (cc @Nerogar ) but it is a hard path to go down to try and get everybody on the same page before things start happening/becoming committed to.

@Nerogar
Copy link
Contributor

Nerogar commented Apr 18, 2025

I feel like there is some background information I'm missing, I haven't used the diffusers LoRA implementation yet, so don't know about the state dict and metadata format.

That said, this sounds like you are trying to create yet another format where the alpha is stored in the model metadata instead of the state dict. I've never seen a lora where this was done. All files I've seen so far store the alpha in the state dict with the suffix .alpha.

Over the last few months I was talking mostly to kohya about a more standardized LoRA format. We've now agreed on something similar to the currently supported format. Just with a few small adjustments. I've already written conversion code for this format, and you can find example files here.

If you are now changing the format for your LoRA output, I strongly suggest following that format instead of re-inventing something well established again, just because it's easier to implement.

@sayakpaul
Copy link
Member Author

For the record, the structure (keys, for example) of the state dict won't change. It will just be populated with metadata.

@kohya-ss
Copy link

kohya-ss commented Apr 18, 2025

Thanks for working on this.

As Nerogar wrote, standardization is progressing. I expect ComfyUI to support it too.

In addition, I think de‑facto compatibility today is ComfyUI. Therefore, if diffusers adopts metadata and the keys are not compatible with the existing (or the new standard) format, changed or not, I recommend shipping bidirectional converters to/from the format which is supported by ComfyUI at the same time. That keeps the barrier for ordinary users close to zero.

@a-r-r-o-w
Copy link
Member

To provide more context, this is not a new format and no existing non-diffusers code for lora loading needs to be changed. All this PR does is serialize a simple string dict into the safetensors file format (docs). The state dict formats remain unaffected.

This may actually help improve loading diffusers implementation state dicts more easily because everything needed for loading the peft-related parts becomes part of the file:

  • rank, alpha
  • rank_pattern (for any layers that have a different rank than rank)
  • alpha_pattern (for any layers that have a different alpha than alpha)
  • info like weight init and other configs for bookkeeping (basically what's documented here)

Knowing this info and using the native model implementations, the lora loading code becomes:

with safe_open(model_file, framework="pt") as f:
   metadata = json.loads(f.metadata()["lora_metadata"])
   state_dict = ...
config = LoraConfig(**metadata)
adapter_name = "my-awesome-lora"
inject_adapter_in_model(config, model, adapter_name=adapter_name, low_cpu_mem_usage=True)
set_peft_model_state_dict(model, state_dict, adapter_name=adapter_name, ignore_mismatched_sizes=False, low_cpu_mem_usage=True)

No changes are needed in any existing code that supports loading diffusers loras, since the state dict remains completely unaltered. However, if another line or two are added to read the metadata and create the peft-init config based on that, the exact ranks/alphas used in training will be used.

@bghira
Copy link
Contributor

bghira commented Apr 18, 2025

the rank can be different from the parameter shape?

@Nerogar
Copy link
Contributor

Nerogar commented Apr 18, 2025

rank_pattern/alpha_pattern might be useful for the diffusers library itself. But I don't know any other implementation that expects the alpha value in that format. If you want to be more compatible with other implementations the correct way of storing that information is a tensor in the state dict for every lora layer. If that key is missing, there are probably defaults that are used. But they are implementation specific. For example, OneTrainer always assumes alpha=1 as a default value.

the rank can be different from the parameter shape?

That's also something I don't understand. The rank is already defined by the shape of the down/up tensors. There's no need to store that information separately.

@bghira
Copy link
Contributor

bghira commented Apr 18, 2025

OneTrainer always assumes alpha=1 as a default value.

can't do that with Diffusers models. they are alpha=rank.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Apr 18, 2025

the rank can be different from the parameter shape?

Not sure what you mean. What I mean is that different layers can have different ranks and alphas (see example below). We will now store this information as part of the safetensors metadata (which one may choose to use or not).

Supporting existing compatibility with storing the alphas in the state dict is a separate issue and unrelated to what's happening here (we can look into exporting loras that way if there's an ask)

Example model:

proj_in: in_features=4, out_features=64
layer_1: in_features=64, out_features=64
proj_out: in_features=64, out_features=4

Now, let's assume we want a lora on all layers with rank=X, alpha=K, rank_pattern={"layer_1": Y}, alpha_pattern={"layer_1": Z}

The equivalent peft init config is:

# You can have arbitrary values for K, X, Y, Z here
config = LoraConfig(
    target_modules=["proj_in", "layer_1", "proj_out"],
    rank=X,
    lora_alpha=K,
    rank_pattern={"layer_1": Y},
    alpha_pattern={"layer_1": Z},
)
proj_in: lora_A (in_features=4, out_features=X), alpha=K, lora_B (in_features=X, out_features=64)
layer_1: lora_A (in_features=64, out_features=Y), alpha=Z, lora_B(in_features=Y, out_features=64)
proj_out: lora_A (in_features=64, out_features=X), alpha=K, lora_B (in_features=X, out_features=4)

The extra metadata that will be stored in the safetensors metadata (emphasizing again that it's not the state dict, so no existing code needs to be modified) is the serialized config

Does this make sense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants