22
22
USE_PEFT_BACKEND ,
23
23
deprecate ,
24
24
get_submodule_by_name ,
25
+ is_bitsandbytes_available ,
26
+ is_gguf_available ,
25
27
is_peft_available ,
26
28
is_peft_version ,
27
29
is_torch_version ,
68
70
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder" : "in_channels" }
69
71
70
72
73
+ def _maybe_dequantize_weight_for_expanded_lora (model , module ):
74
+ if is_bitsandbytes_available ():
75
+ from ..quantizers .bitsandbytes import dequantize_bnb_weight
76
+
77
+ if is_gguf_available ():
78
+ from ..quantizers .gguf .utils import dequantize_gguf_tensor
79
+
80
+ is_bnb_4bit_quantized = module .weight .__class__ .__name__ == "Params4bit"
81
+ is_gguf_quantized = module .weight .__class__ .__name__ == "GGUFParameter"
82
+
83
+ if is_bnb_4bit_quantized and not is_bitsandbytes_available ():
84
+ raise ValueError (
85
+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86
+ )
87
+ if is_gguf_quantized and not is_gguf_available ():
88
+ raise ValueError (
89
+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90
+ )
91
+
92
+ weight_on_cpu = False
93
+ if not module .weight .is_cuda :
94
+ weight_on_cpu = True
95
+
96
+ if is_bnb_4bit_quantized :
97
+ module_weight = dequantize_bnb_weight (
98
+ module .weight .cuda () if weight_on_cpu else module .weight ,
99
+ state = module .weight .quant_state ,
100
+ dtype = model .dtype ,
101
+ ).data
102
+ elif is_gguf_quantized :
103
+ module_weight = dequantize_gguf_tensor (
104
+ module .weight .cuda () if weight_on_cpu else module .weight ,
105
+ )
106
+ module_weight = module_weight .to (model .dtype )
107
+ else :
108
+ module_weight = module .weight .data
109
+
110
+ if weight_on_cpu :
111
+ module_weight = module_weight .cpu ()
112
+
113
+ return module_weight
114
+
115
+
71
116
class StableDiffusionLoraLoaderMixin (LoraBaseMixin ):
72
117
r"""
73
118
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
2267
2312
overwritten_params = {}
2268
2313
2269
2314
is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
2315
+ is_quantized = hasattr (transformer , "hf_quantizer" )
2270
2316
for name , module in transformer .named_modules ():
2271
2317
if isinstance (module , torch .nn .Linear ):
2272
2318
module_weight = module .weight .data
@@ -2291,9 +2337,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
2291
2337
if tuple (module_weight_shape ) == (out_features , in_features ):
2292
2338
continue
2293
2339
2294
- # TODO (sayakpaul): We still need to consider if the module we're expanding is
2295
- # quantized and handle it accordingly if that is the case.
2296
- module_out_features , module_in_features = module_weight .shape
2340
+ module_out_features , module_in_features = module_weight_shape
2297
2341
debug_message = ""
2298
2342
if in_features > module_in_features :
2299
2343
debug_message += (
@@ -2316,6 +2360,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
2316
2360
parent_module_name , _ , current_module_name = name .rpartition ("." )
2317
2361
parent_module = transformer .get_submodule (parent_module_name )
2318
2362
2363
+ if is_quantized :
2364
+ module_weight = _maybe_dequantize_weight_for_expanded_lora (transformer , module )
2365
+
2366
+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
2319
2367
with torch .device ("meta" ):
2320
2368
expanded_module = torch .nn .Linear (
2321
2369
in_features , out_features , bias = bias , dtype = module_weight .dtype
@@ -2327,7 +2375,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
2327
2375
new_weight = torch .zeros_like (
2328
2376
expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2329
2377
)
2330
- slices = tuple (slice (0 , dim ) for dim in module_weight . shape )
2378
+ slices = tuple (slice (0 , dim ) for dim in module_weight_shape )
2331
2379
new_weight [slices ] = module_weight
2332
2380
tmp_state_dict = {"weight" : new_weight }
2333
2381
if module_bias is not None :
@@ -2416,7 +2464,12 @@ def _calculate_module_shape(
2416
2464
base_weight_param_name : str = None ,
2417
2465
) -> "torch.Size" :
2418
2466
def _get_weight_shape (weight : torch .Tensor ):
2419
- return weight .quant_state .shape if weight .__class__ .__name__ == "Params4bit" else weight .shape
2467
+ if weight .__class__ .__name__ == "Params4bit" :
2468
+ return weight .quant_state .shape
2469
+ elif weight .__class__ .__name__ == "GGUFParameter" :
2470
+ return weight .quant_shape
2471
+ else :
2472
+ return weight .shape
2420
2473
2421
2474
if base_module is not None :
2422
2475
return _get_weight_shape (base_module .weight )
0 commit comments