diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e064aeba43b6..72daccfe5aec 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): remaining_keys = list(sds_sd.keys()) te_state_dict = {} if remaining_keys: - if not all(k.startswith("lora_te1") for k in remaining_keys): + if not all(k.startswith("lora_te") for k in remaining_keys): raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") for key in remaining_keys: if not key.endswith("lora_down.weight"): @@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): new_state_dict = {**ait_sd, **te_state_dict} return new_state_dict + def _convert_mixture_state_dict_to_diffusers(state_dict): + new_state_dict = {} + + def _convert(original_key, diffusers_key, state_dict, new_state_dict): + down_key = f"{original_key}.lora_down.weight" + down_weight = state_dict.pop(down_key) + lora_rank = down_weight.shape[0] + + up_weight_key = f"{original_key}.lora_up.weight" + up_weight = state_dict.pop(up_weight_key) + + alpha_key = f"{original_key}.alpha" + alpha = state_dict.pop(alpha_key) + + # scale weight by alpha and dim + scale = alpha / lora_rank + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + diffusers_down_key = f"{diffusers_key}.lora_A.weight" + new_state_dict[diffusers_down_key] = down_weight + new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight + + all_unique_keys = { + k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict + } + all_unique_keys = sorted(all_unique_keys) + assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}" + + for k in all_unique_keys: + if k.startswith("lora_transformer_single_transformer_blocks_"): + i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"single_transformer_blocks.{i}" + elif k.startswith("lora_transformer_transformer_blocks_"): + i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0]) + diffusers_key = f"transformer_blocks.{i}" + else: + raise NotImplementedError + + if "attn_" in k: + if "_to_out_0" in k: + diffusers_key += ".attn.to_out.0" + elif "_to_add_out" in k: + diffusers_key += ".attn.to_add_out" + elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]): + remaining = k.split("attn_")[-1] + diffusers_key += f".attn.{remaining}" + elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]): + remaining = k.split("attn_")[-1] + diffusers_key += f".attn.{remaining}" + + if diffusers_key == f"transformer_blocks.{i}": + print(k, diffusers_key) + _convert(k, diffusers_key, state_dict, new_state_dict) + + if len(state_dict) > 0: + raise ValueError( + f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}." + ) + + new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()} + return new_state_dict + + # This is weird. + # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors + # has both `peft` and non-peft state dict. + has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict) + if has_peft_state_dict: + state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")} + return state_dict + # Another weird one. + has_mixture = any( + k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict + ) + if has_mixture: + return _convert_mixture_state_dict_to_diffusers(state_dict) return _convert_sd_scripts_to_ai_toolkit(state_dict)