Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading a Lora on quantized model ? TorchaoLoraLinear.__init__() missing 1 required keyword-only argument: 'get_apply_tensor_subclass' #10621

Open
christopher5106 opened this issue Jan 21, 2025 · 11 comments

Comments

@christopher5106
Copy link

import time
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=torch.bfloat16,
).to("cuda")

quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()))

pipe.load_lora_weights('Octree/flux-schnell-lora', weight_name='flux-schnell-lora.safetensors')

gives the following error:

  File "venv/lib/python3.11/site-packages/peft/tuners/lora/torchao.py", line 147, in dispatch_torchao
    new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
@bghira
Copy link
Contributor

bghira commented Jan 23, 2025

i think you have to use a quantization config during pipeline init but then #10578

@christopher5106
Copy link
Author

that works, thanks!

@christopher5106
Copy link
Author

But what is the equivalent of float8dq or float8_dynamic_activation_float8_weight in bitsandbytes quantization config for diffusers ?

@sayakpaul

@christopher5106
Copy link
Author

and per row ?

@a-r-r-o-w
Copy link
Member

@christopher5106 TorchAO supports additional quantization options and algorithms than BnB. They have separate implementations of kernels for what they do, so there is not really any equivalent way of comparing them, or finding features of one in the other.

@christopher5106
Copy link
Author

So I have to use TorchAO and my initial issue was about loading lora on a quantized model with Torchao. So the issue remains open

@bghira
Copy link
Contributor

bghira commented Jan 23, 2025

well it's already being tracked but i couldn't find it 🤷 for torchAO you'll have to quantise after loading, and i think the issue there is possibly in PEFT..

@christopher5106
Copy link
Author

that means that unloading/loading no more possible after quantization, right ?

couldnt we quantize the lora separately before to merge it ?

@bghira
Copy link
Contributor

bghira commented Jan 23, 2025

oh, you could try that using get_peft_model style workarounds

@christopher5106
Copy link
Author

in the PEFT doc, it says that torch.compile with quantization (bitsandbytes) is not supported. What about other quantizations, such as Torchao ?

@christopher5106
Copy link
Author

christopher5106 commented Jan 27, 2025

@bghira I get the same error with PeftConfig than with load_lora_weights():

import torch 
from diffusers import FluxTransformer2DModel
from peft import LoraConfig


torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark_limit = 20

flux_transformer = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", subfolder="transformer"
).to("cuda", torch.float16)

from torchao.quantization import quantize_
from torchao.quantization import float8_dynamic_activation_float8_weight
quantize_(flux_transformer, float8_dynamic_activation_float8_weight())

target_modules = [
    "x_embedder",
    "attn.to_k",
    "attn.to_q",
    "attn.to_v",
    "attn.to_out.0",
    "attn.add_k_proj",
    "attn.add_q_proj",
    "attn.add_v_proj",
    "attn.to_add_out",
    "ff.net.0.proj",
    "ff.net.2",
    "ff_context.net.0.proj",
    "ff_context.net.2",
]

transformer_lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    init_lora_weights=True,
    target_modules=target_modules,
    lora_bias=False,
)
flux_transformer.add_adapter(transformer_lora_config)

File "venv/lib/python3.11/site-packages/peft/tuners/lora/torchao.py", line 147, in dispatch_torchao
    new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: TorchaoLoraLinear.__init__() missing 1 required keyword-only argument: 'get_apply_tensor_subclass'

Same error if I write it model = get_peft_model(flux_transformer, transformer_lora_config)

@christopher5106 christopher5106 changed the title Loading a Lora on quantized model ? Loading a Lora on quantized model ? TorchaoLoraLinear.__init__() missing 1 required keyword-only argument: 'get_apply_tensor_subclass' Jan 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants