Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] fix peft state dict parsing #10532

Merged
merged 7 commits into from
Feb 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
if not all(k.startswith("lora_te") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
Expand Down Expand Up @@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

def _convert_mixture_state_dict_to_diffusers(state_dict):
new_state_dict = {}

def _convert(original_key, diffusers_key, state_dict, new_state_dict):
down_key = f"{original_key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]

up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)

alpha_key = f"{original_key}.alpha"
alpha = state_dict.pop(alpha_key)

# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight

all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
}
all_unique_keys = sorted(all_unique_keys)
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"

for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"single_transformer_blocks.{i}"
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
else:
raise NotImplementedError

if "attn_" in k:
if "_to_out_0" in k:
diffusers_key += ".attn.to_out.0"
elif "_to_add_out" in k:
diffusers_key += ".attn.to_add_out"
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"

if diffusers_key == f"transformer_blocks.{i}":
print(k, diffusers_key)
_convert(k, diffusers_key, state_dict, new_state_dict)

if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)

new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
return new_state_dict

# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)


Expand Down
Loading