-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[LoRA] parse metadata from LoRA and save metadata #11324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there were debates in the past, but I don't remember what the arguments for and against were. Could you quickly sketch why this is becoming increasingly important?
Implementation-wise, it generally LGTM. I was a bit confused about the lora_adapter_config
key vs the lora_metadata
key, what is the difference?
Folks are using different ranks and alphas and finding it to be very effective in practice.
I will rename it to |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviving this conversation again @sayakpaul!
I see some TODOs marked so will take a better look once you let us know when the PR is ready, but changes look good to me and seem sensible from user end. Maybe it might make sense to allow user to also specify the metadata key, or us automatically detect some common names (are there any that you know of?)
@BenjaminBossan To provide some more context, my discussion with Sayak involved training "Control" LoRAs.
Expand for mine and Sayak's conversation
Hello curious 🐱 In short:
- Expand the input projection layer's (patch embedding) input channels. Say we want to train with Canny conditioning on a 64-input channel model (Flux, for example), we'd expand the input projection linear to 128 channels with zeroed weights.
- For the input latent to match the expanded linear proj, you concatenate the latents with the latents of the control condition.
- In the "normal" latent stream (base layer), since we expanded with zeroed out weights, the control condition is effectively not added.
- In the "lora" latent stream, the noisy latents and control latents are projected to the expected inner dim of the model. The lora latent stream outputs are added to the normal latent stream, effectively adding some amount of conditioning information from the new Canny condition
- Now, in Flux, the linear projection goes from
64 channels -> 4096 channels
. With the new conditioning related expansion, it is128 channels -> 4096 channels
. This is the "normal" latent stream. - In the "lora" latent stream, you have
128 channels -> "rank" channels -> 4096 channels
. If the rank for the input projection layer is less than 128 channels, you might lose some information from the control conditioning (since the lora layer then acts as a 1-layer autoencoder). - I've noticed from a limited set of experiments that training with a low input projection rank (does not matter what the rank for rest of the network is) almost always produces worse results and contains random noise in the generations.
- Due to this, in finetrainers, I made it so that the input projection layer has the same rank as
inner_dim
. So, effectively, in the "lora" latent stream, you have128 channels -> 4096 channels -> 4096 channels
. This results in much faster convergence (conclusion made from a very limited set of experiments and it intuitively made sense to me if you compare this low-rank version against a full-rank Control model).
Here are the relevant pieces of code for a quick look:
- Expanding input channels: this
- Creating the LoRA adapter with custom rank for input projection layer vs rest of the network: this
- Channel-wise input concatenation with control latents: this
Convert to impasto style painting |
---|
![]() |
Canny to Image |
![]() |
These results come with just 2500 training steps on a rank=64
lora (except rank=4096
for the input projection).
For this control lora training settings, every lora_alpha
is the same as rank
. So, the lora_alpha
for the input projection is 4096
, while for remaining layers is 64
. The size of LoRA weights is between 300-1000 MB in most cases, which is much less compared to a full control conditioning model.
The problem comes when trying to load such a LoRA with diffusers. Since we only know the ranks, but not the lora_alpha, diffusers config uses the same lora_alpha
as the inferred rank
(which is 4096
since the input projection layer is the first lora layer in the state dict too). As you can imagine, setting an alpha to 4096 on all layers (even the ones that originally had rank = lora_alpha = 64
will result in random noise.
This is a more general problem because a lot of loras on CivitAI and from different trainers are trained with different alpha configurations from the rank. So making the assumption that lora_alpha=rank
(which is our current default behaviour) is incorrect, and having this metadata information will be really helpful.
In order to solve the problem of being able train and run validation in finetrainers
, we just directly serialize the lora adapter config directly into the safetensors metadata. See this and this, and we use a custom function to load the lora weights. Really big kudos to whoever implemented the save_function
allowing custom ones to be provided!
Well, I wanted to check if the changes help solve the issue of Regarding detecting metadata in non-diffusers checkpoints, we already either infer alphas or directly scale the LoRA weights. So, that should already be covered. |
Currently, finetrainers exports lora with a metadata attribute |
SGTM, no rush. Thanks for helping. |
Thanks for explaining this further @a-r-r-o-w
There could be better heuristics, like using the most common rank, but I imagine changing that now would be backwards incompatible.
Any idea who other libraries/apps deal with this lack of info? I thought there was a bit of a standard to multiply the scale into the weights before saving, so that the assumption that |
@BenjaminBossan many trainers just embed the alpha values in the state dict instead of maintaining separate configuration metadata. See here as an example:
|
If no objections, I will button up the PR and make it generally available to the other loaders. But LMK if you have anything to share or have any concerns over this direction. |
Personally, I don't have any issues, just wondering if anything needs to be discussed (again) with the safetensors folks. |
Well, previously there was no problem with them IIRC. It was about time and the complexity. But not we have enough evidence to ship it maintaining the single-file format LoRA checkpoints. |
@sayakpaul I've verified it to work, thanks for working on this. My understanding was that the lora initialization config would also be automatically saved into the exported safetensors as metadata, but it needs to be passed manually to
I think it might be good if we serialize this info automatically in save_lora_weights, if you think there's an easy way to do it. |
Thanks for trying! Well, we're serializing the the entire diffusers/src/diffusers/loaders/lora_pipeline.py Line 5308 in d390d4d
The test below also confirms that: diffusers/tests/lora/test_lora_layers_wan.py Line 143 in d390d4d
If you could provide a minimal reproducer for the issue I would be happy to look into it on priority. |
Yes, what you mentioned works as expected: we need to pass |
Ah I see. I might have an idea. Will update soon. Thanks again for trying it out quickly. |
why can't we add alpha to the state dict like kohya trainer does? |
The code path for |
so because "its harder" we will end up with a solution that requires other tools like comfyUI to implement special support to read these attributes and set them? i expected we would be adding compatibility points and not adding more tech debt. please reconsider just adding alpha to the state dict or leave the work to the community if it is too hard? this kinda thing really takes an eternity to be implemented by tools like Swarm and Forge and ComfyUI while the alpha attribute in dict key would Just Work even with AUTOMATIC1111. please tell me you see the value in this.. |
also it wasnt arrow who brought this up first, it was me, because simpletuner has the option to set rank differently to alpha, which peft has always supported but diffusers does not. setting alpha to 1 always would allow learning rates to stay the same across every rank. the initial request was just to write the alpha into the state dict key. this proposal is heavy handed and requires too much effort from everybody. |
We value other formats and hence we support most of them directly within the library. At the same time, I also think it makes sense to be lenient towards our libraries and think about long-term maintainability. With the current proposal, we can serialize additional useful things like I am not sure that for
Corrected my issue description but please note what my description starts with:
We set I would like to take a second opinion from @DN6, especially for maintainability. To summarize, the issue is that we just serialize the LoRA state dict in This PR proposes to serialize the |
We can still just serialize |
i agree the config is needed for other more intricate patterns and whatever Diffusers does there can even become the standard depending on how @kohya-ss and @comfyanonymous feel about it, i'd rather something "common" happen (cc @Nerogar ) but it is a hard path to go down to try and get everybody on the same page before things start happening/becoming committed to. |
I feel like there is some background information I'm missing, I haven't used the diffusers LoRA implementation yet, so don't know about the state dict and metadata format. That said, this sounds like you are trying to create yet another format where the alpha is stored in the model metadata instead of the state dict. I've never seen a lora where this was done. All files I've seen so far store the alpha in the state dict with the suffix Over the last few months I was talking mostly to kohya about a more standardized LoRA format. We've now agreed on something similar to the currently supported format. Just with a few small adjustments. I've already written conversion code for this format, and you can find example files here. If you are now changing the format for your LoRA output, I strongly suggest following that format instead of re-inventing something well established again, just because it's easier to implement. |
For the record, the structure (keys, for example) of the state dict won't change. It will just be populated with metadata. |
Thanks for working on this. As Nerogar wrote, standardization is progressing. I expect ComfyUI to support it too. In addition, I think de‑facto compatibility today is ComfyUI. Therefore, if diffusers adopts metadata and the keys are not compatible with the existing (or the new standard) format, changed or not, I recommend shipping bidirectional converters to/from the format which is supported by ComfyUI at the same time. That keeps the barrier for ordinary users close to zero. |
To provide more context, this is not a new format and no existing non-diffusers code for lora loading needs to be changed. All this PR does is serialize a simple string dict into the safetensors file format (docs). The state dict formats remain unaffected. This may actually help improve loading diffusers implementation state dicts more easily because everything needed for loading the peft-related parts becomes part of the file:
Knowing this info and using the native model implementations, the lora loading code becomes: with safe_open(model_file, framework="pt") as f:
metadata = json.loads(f.metadata()["lora_metadata"])
state_dict = ...
config = LoraConfig(**metadata)
adapter_name = "my-awesome-lora"
inject_adapter_in_model(config, model, adapter_name=adapter_name, low_cpu_mem_usage=True)
set_peft_model_state_dict(model, state_dict, adapter_name=adapter_name, ignore_mismatched_sizes=False, low_cpu_mem_usage=True) No changes are needed in any existing code that supports loading diffusers loras, since the state dict remains completely unaltered. However, if another line or two are added to read the metadata and create the peft-init config based on that, the exact ranks/alphas used in training will be used. |
the rank can be different from the parameter shape? |
rank_pattern/alpha_pattern might be useful for the diffusers library itself. But I don't know any other implementation that expects the alpha value in that format. If you want to be more compatible with other implementations the correct way of storing that information is a tensor in the state dict for every lora layer. If that key is missing, there are probably defaults that are used. But they are implementation specific. For example, OneTrainer always assumes alpha=1 as a default value.
That's also something I don't understand. The rank is already defined by the shape of the down/up tensors. There's no need to store that information separately. |
can't do that with Diffusers models. they are |
Not sure what you mean. What I mean is that different layers can have different ranks and alphas (see example below). We will now store this information as part of the safetensors metadata (which one may choose to use or not). Supporting existing compatibility with storing the alphas in the state dict is a separate issue and unrelated to what's happening here (we can look into exporting loras that way if there's an ask) Example model:
Now, let's assume we want a lora on all layers with The equivalent peft init config is:
The extra metadata that will be stored in the safetensors metadata (emphasizing again that it's not the state dict, so no existing code needs to be modified) is the serialized Does this make sense? |
What does this PR do?
I know we have revisited this over and over again but this is becoming increasingly important. So, we should consider this on priority.
@a-r-r-o-w brought this issue to me while we were debugging something in Wan LoRA training. So, I started by just modifying the Wan LoRA loader (eventually, the changes will be propagated to other loaders too). Aryan, could you check if this change fixes the issue we were facing?
Admittedly, the PR can be cleaned and modularized a bit but I wanted to get something up quickly to get feedback on the direction.
TODOs