Skip to content

Commit 9f5ad1d

Browse files
authored
[LoRA] fix peft state dict parsing (#10532)
* fix peft state dict parsing * updates
1 parent 464374f commit 9f5ad1d

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519519
remaining_keys = list(sds_sd.keys())
520520
te_state_dict = {}
521521
if remaining_keys:
522-
if not all(k.startswith("lora_te1") for k in remaining_keys):
522+
if not all(k.startswith("lora_te") for k in remaining_keys):
523523
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524524
for key in remaining_keys:
525525
if not key.endswith("lora_down.weight"):
@@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
558558
new_state_dict = {**ait_sd, **te_state_dict}
559559
return new_state_dict
560560

561+
def _convert_mixture_state_dict_to_diffusers(state_dict):
562+
new_state_dict = {}
563+
564+
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
565+
down_key = f"{original_key}.lora_down.weight"
566+
down_weight = state_dict.pop(down_key)
567+
lora_rank = down_weight.shape[0]
568+
569+
up_weight_key = f"{original_key}.lora_up.weight"
570+
up_weight = state_dict.pop(up_weight_key)
571+
572+
alpha_key = f"{original_key}.alpha"
573+
alpha = state_dict.pop(alpha_key)
574+
575+
# scale weight by alpha and dim
576+
scale = alpha / lora_rank
577+
# calculate scale_down and scale_up
578+
scale_down = scale
579+
scale_up = 1.0
580+
while scale_down * 2 < scale_up:
581+
scale_down *= 2
582+
scale_up /= 2
583+
down_weight = down_weight * scale_down
584+
up_weight = up_weight * scale_up
585+
586+
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
587+
new_state_dict[diffusers_down_key] = down_weight
588+
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
589+
590+
all_unique_keys = {
591+
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
592+
}
593+
all_unique_keys = sorted(all_unique_keys)
594+
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
595+
596+
for k in all_unique_keys:
597+
if k.startswith("lora_transformer_single_transformer_blocks_"):
598+
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
599+
diffusers_key = f"single_transformer_blocks.{i}"
600+
elif k.startswith("lora_transformer_transformer_blocks_"):
601+
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
602+
diffusers_key = f"transformer_blocks.{i}"
603+
else:
604+
raise NotImplementedError
605+
606+
if "attn_" in k:
607+
if "_to_out_0" in k:
608+
diffusers_key += ".attn.to_out.0"
609+
elif "_to_add_out" in k:
610+
diffusers_key += ".attn.to_add_out"
611+
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
612+
remaining = k.split("attn_")[-1]
613+
diffusers_key += f".attn.{remaining}"
614+
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
615+
remaining = k.split("attn_")[-1]
616+
diffusers_key += f".attn.{remaining}"
617+
618+
if diffusers_key == f"transformer_blocks.{i}":
619+
print(k, diffusers_key)
620+
_convert(k, diffusers_key, state_dict, new_state_dict)
621+
622+
if len(state_dict) > 0:
623+
raise ValueError(
624+
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
625+
)
626+
627+
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
628+
return new_state_dict
629+
630+
# This is weird.
631+
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
632+
# has both `peft` and non-peft state dict.
633+
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
634+
if has_peft_state_dict:
635+
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
636+
return state_dict
637+
# Another weird one.
638+
has_mixture = any(
639+
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
640+
)
641+
if has_mixture:
642+
return _convert_mixture_state_dict_to_diffusers(state_dict)
561643
return _convert_sd_scripts_to_ai_toolkit(state_dict)
562644

563645

0 commit comments

Comments
 (0)