Skip to content

Commit 5d49b3e

Browse files
hlkysayakpaulgithub-actions[bot]DN6
authored
Flux quantized with lora (#10990)
* Flux quantized with lora * fix * changes * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * Apply style fixes * enable model cpu offload() * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky <[email protected]> * update * Apply suggestions from code review * update * add peft as an additional dependency for gguf --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 71f34fc commit 5d49b3e

File tree

5 files changed

+152
-8
lines changed

5 files changed

+152
-8
lines changed

.github/workflows/nightly_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ jobs:
417417
additional_deps: ["peft"]
418418
- backend: "gguf"
419419
test_location: "gguf"
420-
additional_deps: []
420+
additional_deps: ["peft"]
421421
- backend: "torchao"
422422
test_location: "torchao"
423423
additional_deps: []

src/diffusers/loaders/lora_pipeline.py

+58-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
USE_PEFT_BACKEND,
2323
deprecate,
2424
get_submodule_by_name,
25+
is_bitsandbytes_available,
26+
is_gguf_available,
2527
is_peft_available,
2628
is_peft_version,
2729
is_torch_version,
@@ -68,6 +70,49 @@
6870
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
6971

7072

73+
def _maybe_dequantize_weight_for_expanded_lora(model, module):
74+
if is_bitsandbytes_available():
75+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
76+
77+
if is_gguf_available():
78+
from ..quantizers.gguf.utils import dequantize_gguf_tensor
79+
80+
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
81+
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82+
83+
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84+
raise ValueError(
85+
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86+
)
87+
if is_gguf_quantized and not is_gguf_available():
88+
raise ValueError(
89+
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90+
)
91+
92+
weight_on_cpu = False
93+
if not module.weight.is_cuda:
94+
weight_on_cpu = True
95+
96+
if is_bnb_4bit_quantized:
97+
module_weight = dequantize_bnb_weight(
98+
module.weight.cuda() if weight_on_cpu else module.weight,
99+
state=module.weight.quant_state,
100+
dtype=model.dtype,
101+
).data
102+
elif is_gguf_quantized:
103+
module_weight = dequantize_gguf_tensor(
104+
module.weight.cuda() if weight_on_cpu else module.weight,
105+
)
106+
module_weight = module_weight.to(model.dtype)
107+
else:
108+
module_weight = module.weight.data
109+
110+
if weight_on_cpu:
111+
module_weight = module_weight.cpu()
112+
113+
return module_weight
114+
115+
71116
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
72117
r"""
73118
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22672312
overwritten_params = {}
22682313

22692314
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2315+
is_quantized = hasattr(transformer, "hf_quantizer")
22702316
for name, module in transformer.named_modules():
22712317
if isinstance(module, torch.nn.Linear):
22722318
module_weight = module.weight.data
@@ -2291,9 +2337,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22912337
if tuple(module_weight_shape) == (out_features, in_features):
22922338
continue
22932339

2294-
# TODO (sayakpaul): We still need to consider if the module we're expanding is
2295-
# quantized and handle it accordingly if that is the case.
2296-
module_out_features, module_in_features = module_weight.shape
2340+
module_out_features, module_in_features = module_weight_shape
22972341
debug_message = ""
22982342
if in_features > module_in_features:
22992343
debug_message += (
@@ -2316,6 +2360,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
23162360
parent_module_name, _, current_module_name = name.rpartition(".")
23172361
parent_module = transformer.get_submodule(parent_module_name)
23182362

2363+
if is_quantized:
2364+
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
2365+
2366+
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
23192367
with torch.device("meta"):
23202368
expanded_module = torch.nn.Linear(
23212369
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2327,7 +2375,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23272375
new_weight = torch.zeros_like(
23282376
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
23292377
)
2330-
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2378+
slices = tuple(slice(0, dim) for dim in module_weight_shape)
23312379
new_weight[slices] = module_weight
23322380
tmp_state_dict = {"weight": new_weight}
23332381
if module_bias is not None:
@@ -2416,7 +2464,12 @@ def _calculate_module_shape(
24162464
base_weight_param_name: str = None,
24172465
) -> "torch.Size":
24182466
def _get_weight_shape(weight: torch.Tensor):
2419-
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2467+
if weight.__class__.__name__ == "Params4bit":
2468+
return weight.quant_state.shape
2469+
elif weight.__class__.__name__ == "GGUFParameter":
2470+
return weight.quant_shape
2471+
else:
2472+
return weight.shape
24202473

24212474
if base_module is not None:
24222475
return _get_weight_shape(base_module.weight)

src/diffusers/quantizers/gguf/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
400400
data = data if data is not None else torch.empty(0)
401401
self = torch.Tensor._make_subclass(cls, data, requires_grad)
402402
self.quant_type = quant_type
403+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
404+
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
403405

404406
return self
405407

tests/quantization/bnb/test_4bit.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@
2121
import pytest
2222
import safetensors.torch
2323
from huggingface_hub import hf_hub_download
24-
25-
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
24+
from PIL import Image
25+
26+
from diffusers import (
27+
BitsAndBytesConfig,
28+
DiffusionPipeline,
29+
FluxControlPipeline,
30+
FluxTransformer2DModel,
31+
SD3Transformer2DModel,
32+
)
2633
from diffusers.utils import is_accelerate_version, logging
2734
from diffusers.utils.testing_utils import (
2835
CaptureLogger,
@@ -696,6 +703,42 @@ def test_lora_loading(self):
696703
self.assertTrue(max_diff < 1e-3)
697704

698705

706+
@require_transformers_version_greater("4.44.0")
707+
@require_peft_backend
708+
class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
709+
def setUp(self) -> None:
710+
gc.collect()
711+
torch.cuda.empty_cache()
712+
713+
self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
714+
self.pipeline_4bit.enable_model_cpu_offload()
715+
716+
def tearDown(self):
717+
del self.pipeline_4bit
718+
719+
gc.collect()
720+
torch.cuda.empty_cache()
721+
722+
def test_lora_loading(self):
723+
self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
724+
725+
output = self.pipeline_4bit(
726+
prompt=self.prompt,
727+
control_image=Image.new(mode="RGB", size=(256, 256)),
728+
height=256,
729+
width=256,
730+
max_sequence_length=64,
731+
output_type="np",
732+
num_inference_steps=8,
733+
generator=torch.Generator().manual_seed(42),
734+
).images
735+
out_slice = output[0, -3:, -3:, -1].flatten()
736+
expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
737+
738+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
739+
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
740+
741+
699742
@slow
700743
class BaseBnb4BitSerializationTests(Base4bitTests):
701744
def tearDown(self):

tests/quantization/gguf/test_gguf.py

+46
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,22 @@
88
from diffusers import (
99
AuraFlowPipeline,
1010
AuraFlowTransformer2DModel,
11+
FluxControlPipeline,
1112
FluxPipeline,
1213
FluxTransformer2DModel,
1314
GGUFQuantizationConfig,
1415
SD3Transformer2DModel,
1516
StableDiffusion3Pipeline,
1617
)
18+
from diffusers.utils import load_image
1719
from diffusers.utils.testing_utils import (
1820
is_gguf_available,
1921
nightly,
2022
numpy_cosine_similarity_distance,
2123
require_accelerate,
2224
require_big_gpu_with_torch_cuda,
2325
require_gguf_version_greater_or_equal,
26+
require_peft_backend,
2427
torch_device,
2528
)
2629

@@ -456,3 +459,46 @@ def test_pipeline_inference(self):
456459
)
457460
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
458461
assert max_diff < 1e-4
462+
463+
464+
@require_peft_backend
465+
@nightly
466+
@require_big_gpu_with_torch_cuda
467+
@require_accelerate
468+
@require_gguf_version_greater_or_equal("0.10.0")
469+
class FluxControlLoRAGGUFTests(unittest.TestCase):
470+
def test_lora_loading(self):
471+
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
472+
transformer = FluxTransformer2DModel.from_single_file(
473+
ckpt_path,
474+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
475+
torch_dtype=torch.bfloat16,
476+
)
477+
pipe = FluxControlPipeline.from_pretrained(
478+
"black-forest-labs/FLUX.1-dev",
479+
transformer=transformer,
480+
torch_dtype=torch.bfloat16,
481+
).to("cuda")
482+
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
483+
484+
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
485+
control_image = load_image(
486+
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
487+
)
488+
489+
output = pipe(
490+
prompt=prompt,
491+
control_image=control_image,
492+
height=256,
493+
width=256,
494+
num_inference_steps=10,
495+
guidance_scale=30.0,
496+
output_type="np",
497+
generator=torch.manual_seed(0),
498+
).images
499+
500+
out_slice = output[0, -3:, -3:, -1].flatten()
501+
expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
502+
503+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
504+
self.assertTrue(max_diff < 1e-3)

0 commit comments

Comments
 (0)