Skip to content

[LoRA] support Kohya Flux LoRAs that have text encoders as well #9542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 30, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

https://huggingface.co/cocktailpeanut has a bunch of very nice Flux LoRAs that were trained using Kohya but has text encoder components too. This PR adds support for fully loading those LoRAs.

Test code (has a slow test in this PR too):

from diffusers import FluxPipeline
import torch 

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")

pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
  
prompts = [
    "optimus is cleaning the house with broomstick",
    "optimus is a DJ performing at a hip nightclub",
    "optimus is competing in a bboy break dancing competition",
    "optimus is playing tennis in a tennis court"
]
images = pipeline(
    prompts, 
    num_inference_steps=50,
    guidance_scale=3.5,
    max_sequence_length=512,
    generator=torch.manual_seed(0)
).images
for i, image in enumerate(images):
    image.save(f"{i}.png")

Favorite sample:
image

optimus is a DJ performing at a hip nightclub

@sayakpaul sayakpaul requested review from apolinario and yiyixuxu and removed request for apolinario and yiyixuxu September 27, 2024 12:27
@sayakpaul
Copy link
Member Author

Sorry for the review request messup.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

@asomoza can you give this a review too?

@sayakpaul
Copy link
Member Author

sayakpaul commented Sep 28, 2024

Just for the sake of comparison, if we run the example code provided in the PR description with diffusers:main, we get:

with text encoder without text encoder
Image 1 Image 1
Image 2 Image 2
Image 3 Image 3
Image 4 Image 4

@asomoza
Copy link
Member

asomoza commented Sep 28, 2024

Thanks, LGTM in respect to the changes.

Since we now have text encoders training with a transformer model, I did some tests with the blockwise loras that I often use with SDXL:

no lora just transformer with TEs
flux-optimum flux-lora-scales-no-te flux-lora-te

This is not related to this PR but so that we just know.

if I do this:

scales = {"text_encoder": 0.0, "text_encoder_2": 0.0, "transformer": 0.0}
pipe.set_adapters("optimus", adapter_weights=scales)

It works as intended, but I copy & pasted the same code I use for SDXL:

scales = {"text_encoder": 0.0, "text_encoder_2": 0.0, "unet": 0.0}
pipe.set_adapters("optimus", adapter_weights=scales)

This doesn't work, the transformer keeps the lora scale at 1.0 but it doesn't show an error or warning that I'm setting the "unet" instead of the "transformer".

@sayakpaul
Copy link
Member Author

sayakpaul commented Sep 28, 2024

@asomoza

scales = {"text_encoder": 0.0, "text_encoder_2": 0.0, "transformer": 0.0}

I think it shouldn't apply to this LoRA because it doesn't have the text_encoder_2 component in the first place. I can look into catching this and erroring/warning as needed. Will look into the "unet" thingy as well. Thanks much for flagging!

Copy link
Collaborator

@apolinario apolinario left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great! Thanks for adding it

@yiyixuxu yiyixuxu merged commit f9fd511 into main Sep 30, 2024
18 checks passed
@yiyixuxu yiyixuxu deleted the kohya-flux-lora-te branch September 30, 2024 17:59
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
sayakpaul added a commit that referenced this pull request Dec 23, 2024
@RIOFornium
Copy link

RIOFornium commented Jan 8, 2025

Hello!
Seems:
if not all(k.startswith("lora_te1") for k in remaining_keys): raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
It is unnecessary or logic should be changed.
If this exists many Kohya models without text encoders are crashed.
For example (crashed):
https://civitai.com/models/200251?modelVersionId=1081295
And many others are crashed. It is a serious bug, please fix it.
Thank you!

@sayakpaul
Copy link
Member Author

Can you open a new issue thread with a fully reproducible code snippet?

@RIOFornium
Copy link

Hello, the reproducing way is very simple, you can try to open any Kohya model without a text encoder.

@sayakpaul
Copy link
Member Author

Not sure about it really. I just ran pytest tests/lora/ -k "test_flux_kohya" and everything passed.

@RIOFornium
Copy link

Hello, but in real life is not working :-)
Tests also can have bugs...
Check it manually, you can test by:
https://civitai.com/models/200251?modelVersionId=1081295

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 11, 2025

@RIOFornium #10532.

Also, I would suggest avoid being anecdotal here.

Hello, the reproducing way is very simple, you can try to open any Kohya model without a text encoder.
Hello, but in real life is not working :-)
Tests also can have bugs...

  1. I ran two tests where one LoRA has a text encoder and another one doesn't. Both of them passed. So, your first statement is utterly incorrect.
  2. What is real-life? We cannot predict every possibility in the community checkpoints beforehand and we fix them on a case-by-case basis. Hence, the testing setup. You cited one LoRA that had the bug and that maybe useful for your application. That may not be the case for others, and for those, our test cases might very well be sufficiently representative. So, there's no clear definition of "real-world" here, IMO.

I kept asking for a simple reproducer and you just referred me to the Civit AI model link. Please try empathizing with the maintainer's time. To get it to work:

  1. I had to download the CivitAI model.
  2. Upload to the Hub. I need this because I don't have a local GPU and have to rely on remote ones.
  3. Figure out a reproducible code snippet.

Expecting a simple reproducer isn't too much of an ask and I would suggest respecting that going forward. So far I have engaged in good faith. If you keep providing incomplete information, I will just stop engaging.

This applies to you and everyone else having similar attitude.

@RIOFornium
Copy link

Hello!
Sorry, I did not want to hurt you.
It is my first collaboration here.
Next time I will be more constructive.
Thank you very much for the fix and your beneficial work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants