@@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519519 remaining_keys = list (sds_sd .keys ())
520520 te_state_dict = {}
521521 if remaining_keys :
522- if not all (k .startswith ("lora_te1 " ) for k in remaining_keys ):
522+ if not all (k .startswith ("lora_te " ) for k in remaining_keys ):
523523 raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524524 for key in remaining_keys :
525525 if not key .endswith ("lora_down.weight" ):
@@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
558558 new_state_dict = {** ait_sd , ** te_state_dict }
559559 return new_state_dict
560560
561+ def _convert_mixture_state_dict_to_diffusers (state_dict ):
562+ new_state_dict = {}
563+
564+ def _convert (original_key , diffusers_key , state_dict , new_state_dict ):
565+ down_key = f"{ original_key } .lora_down.weight"
566+ down_weight = state_dict .pop (down_key )
567+ lora_rank = down_weight .shape [0 ]
568+
569+ up_weight_key = f"{ original_key } .lora_up.weight"
570+ up_weight = state_dict .pop (up_weight_key )
571+
572+ alpha_key = f"{ original_key } .alpha"
573+ alpha = state_dict .pop (alpha_key )
574+
575+ # scale weight by alpha and dim
576+ scale = alpha / lora_rank
577+ # calculate scale_down and scale_up
578+ scale_down = scale
579+ scale_up = 1.0
580+ while scale_down * 2 < scale_up :
581+ scale_down *= 2
582+ scale_up /= 2
583+ down_weight = down_weight * scale_down
584+ up_weight = up_weight * scale_up
585+
586+ diffusers_down_key = f"{ diffusers_key } .lora_A.weight"
587+ new_state_dict [diffusers_down_key ] = down_weight
588+ new_state_dict [diffusers_down_key .replace (".lora_A." , ".lora_B." )] = up_weight
589+
590+ all_unique_keys = {
591+ k .replace (".lora_down.weight" , "" ).replace (".lora_up.weight" , "" ).replace (".alpha" , "" ) for k in state_dict
592+ }
593+ all_unique_keys = sorted (all_unique_keys )
594+ assert all ("lora_transformer_" in k for k in all_unique_keys ), f"{ all_unique_keys = } "
595+
596+ for k in all_unique_keys :
597+ if k .startswith ("lora_transformer_single_transformer_blocks_" ):
598+ i = int (k .split ("lora_transformer_single_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
599+ diffusers_key = f"single_transformer_blocks.{ i } "
600+ elif k .startswith ("lora_transformer_transformer_blocks_" ):
601+ i = int (k .split ("lora_transformer_transformer_blocks_" )[- 1 ].split ("_" )[0 ])
602+ diffusers_key = f"transformer_blocks.{ i } "
603+ else :
604+ raise NotImplementedError
605+
606+ if "attn_" in k :
607+ if "_to_out_0" in k :
608+ diffusers_key += ".attn.to_out.0"
609+ elif "_to_add_out" in k :
610+ diffusers_key += ".attn.to_add_out"
611+ elif any (qkv in k for qkv in ["to_q" , "to_k" , "to_v" ]):
612+ remaining = k .split ("attn_" )[- 1 ]
613+ diffusers_key += f".attn.{ remaining } "
614+ elif any (add_qkv in k for add_qkv in ["add_q_proj" , "add_k_proj" , "add_v_proj" ]):
615+ remaining = k .split ("attn_" )[- 1 ]
616+ diffusers_key += f".attn.{ remaining } "
617+
618+ if diffusers_key == f"transformer_blocks.{ i } " :
619+ print (k , diffusers_key )
620+ _convert (k , diffusers_key , state_dict , new_state_dict )
621+
622+ if len (state_dict ) > 0 :
623+ raise ValueError (
624+ f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: { list (state_dict .keys ())} ."
625+ )
626+
627+ new_state_dict = {f"transformer.{ k } " : v for k , v in new_state_dict .items ()}
628+ return new_state_dict
629+
630+ # This is weird.
631+ # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
632+ # has both `peft` and non-peft state dict.
633+ has_peft_state_dict = any (k .startswith ("transformer." ) for k in state_dict )
634+ if has_peft_state_dict :
635+ state_dict = {k : v for k , v in state_dict .items () if k .startswith ("transformer." )}
636+ return state_dict
637+ # Another weird one.
638+ has_mixture = any (
639+ k .startswith ("lora_transformer_" ) and ("lora_down" in k or "lora_up" in k or "alpha" in k ) for k in state_dict
640+ )
641+ if has_mixture :
642+ return _convert_mixture_state_dict_to_diffusers (state_dict )
561643 return _convert_sd_scripts_to_ai_toolkit (state_dict )
562644
563645
0 commit comments