@@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519
519
remaining_keys = list (sds_sd .keys ())
520
520
te_state_dict = {}
521
521
if remaining_keys :
522
- if not all (k .startswith ("lora_te1 " ) for k in remaining_keys ):
522
+ if not all (k .startswith ("lora_te " ) for k in remaining_keys ):
523
523
raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524
524
for key in remaining_keys :
525
525
if not key .endswith ("lora_down.weight" ):
@@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
558
558
new_state_dict = {** ait_sd , ** te_state_dict }
559
559
return new_state_dict
560
560
561
+ def _convert_mixture_state_dict_to_diffusers (state_dict ):
562
+ new_state_dict = {}
563
+
564
+ def _convert (original_key , diffusers_key , state_dict , new_state_dict ):
565
+ down_key = f"{ original_key } .lora_down.weight"
566
+ down_weight = state_dict .pop (down_key )
567
+ lora_rank = down_weight .shape [0 ]
568
+
569
+ up_weight_key = f"{ original_key } .lora_up.weight"
570
+ up_weight = state_dict .pop (up_weight_key )
571
+
572
+ alpha_key = f"{ original_key } .alpha"
573
+ alpha = state_dict .pop (alpha_key )
574
+
575
+ # scale weight by alpha and dim
576
+ scale = alpha / lora_rank
577
+ # calculate scale_down and scale_up
578
+ scale_down = scale
579
+ scale_up = 1.0
580
+ while scale_down * 2 < scale_up :
581
+ scale_down *= 2
582
+ scale_up /= 2
583
+ down_weight = down_weight * scale_down
584
+ up_weight = up_weight * scale_up
585
+
586
+ diffusers_down_key = f"{ diffusers_key } .lora_A.weight"
587
+ new_state_dict [diffusers_down_key ] = down_weight
588
+ new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] = up_weight
589
+
590
+ all_unique_keys = {
591
+ k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for k in state_dict
592
+ }
593
+ all_unique_keys = sorted (all_unique_keys )
594
+ assert all ("lora_transformer_" in k for k in all_unique_keys ), f"{ all_unique_keys = } "
595
+
596
+ for k in all_unique_keys :
597
+ if k .startswith ("lora_transformer_single_transformer_blocks_" ):
598
+ i = int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599
+ diffusers_key = f"single_transformer_blocks.{ i } "
600
+ elif k .startswith ("lora_transformer_transformer_blocks_" ):
601
+ i = int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602
+ diffusers_key = f"transformer_blocks.{ i } "
603
+ else :
604
+ raise NotImplementedError
605
+
606
+ if "attn_" in k :
607
+ if "_to_out_0" in k :
608
+ diffusers_key += ".attn.to_out.0"
609
+ elif "_to_add_out" in k :
610
+ diffusers_key += ".attn.to_add_out"
611
+ elif any (qkv in k for qkv in ["to_q" , "to_k" , "to_v" ]):
612
+ remaining = k .split ("attn_" )[- 1 ]
613
+ diffusers_key += f".attn.{ remaining } "
614
+ elif any (add_qkv in k for add_qkv in ["add_q_proj" , "add_k_proj" , "add_v_proj" ]):
615
+ remaining = k .split ("attn_" )[- 1 ]
616
+ diffusers_key += f".attn.{ remaining } "
617
+
618
+ if diffusers_key == f"transformer_blocks.{ i } " :
619
+ print (k , diffusers_key )
620
+ _convert (k , diffusers_key , state_dict , new_state_dict )
621
+
622
+ if len (state_dict ) > 0 :
623
+ raise ValueError (
624
+ f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())} ."
625
+ )
626
+
627
+ new_state_dict = {f"transformer.{ k } " : v for k , v in new_state_dict .items ()}
628
+ return new_state_dict
629
+
630
+ # This is weird.
631
+ # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
632
+ # has both `peft` and non-peft state dict.
633
+ has_peft_state_dict = any (k .startswith ("transformer." ) for k in state_dict )
634
+ if has_peft_state_dict :
635
+ state_dict = {k : v for k , v in state_dict .items () if k .startswith ("transformer." )}
636
+ return state_dict
637
+ # Another weird one.
638
+ has_mixture = any (
639
+ k .startswith ("lora_transformer_" ) and ("lora_down" in k or "lora_up" in k or "alpha" in k ) for k in state_dict
640
+ )
641
+ if has_mixture :
642
+ return _convert_mixture_state_dict_to_diffusers (state_dict )
561
643
return _convert_sd_scripts_to_ai_toolkit (state_dict )
562
644
563
645
0 commit comments