diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index c440526b9b9..99f70646db2 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -13,6 +13,12 @@ def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: """An optimized implementation of the residual calculation for a sidecar linear LoRALayer.""" + # up matrix and down matrix have different ranks so we can't simply multiply them + if lora_layer.up.shape[1] != lora_layer.down.shape[0]: + x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias) + x *= lora_weight * lora_layer.scale() + return x + x = torch.nn.functional.linear(input, lora_layer.down) if lora_layer.mid is not None: x = torch.nn.functional.linear(x, lora_layer.mid) diff --git a/invokeai/backend/patches/layers/lora_layer.py b/invokeai/backend/patches/layers/lora_layer.py index c9210dce933..34183b005b8 100644 --- a/invokeai/backend/patches/layers/lora_layer.py +++ b/invokeai/backend/patches/layers/lora_layer.py @@ -19,6 +19,7 @@ def __init__( self.up = up self.mid = mid self.down = down + self.are_ranks_equal = up.shape[1] == down.shape[0] @classmethod def from_state_dict_values( @@ -58,12 +59,44 @@ def from_state_dict_values( def _rank(self) -> int: return self.down.shape[0] + def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor: + """ + Fuse the weights of the up and down matrices of a LoRA layer with different ranks. + + Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format + the LoRA weights have different ranks. This function handles the fusion of these differently sized + matrices. + """ + + fused_lora = torch.zeros( + (up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype + ) + rank_diff = down.shape[0]/up.shape[1] + + if rank_diff > 1: + rank_diff = down.shape[0]/up.shape[1] + w_down = down.chunk(int(rank_diff), dim=0) + for w_down_chunk in w_down: + fused_lora = fused_lora + (torch.mm(up, w_down_chunk)) + else: + rank_diff = up.shape[1]/down.shape[0] + w_up = up.chunk(int(rank_diff), dim=0) + for w_up_chunk in w_up: + fused_lora = fused_lora + (torch.mm(w_up_chunk, down)) + + return fused_lora + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: + # up matrix and down matrix have different ranks so we can't simply multiply them + if not self.are_ranks_equal: + weight = self.fuse_weights(self.up, self.down) + return weight + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) return weight diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 41e41dbb517..7b5f3468963 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -20,6 +20,14 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" ) + +# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format. +# Example keys: +# lora_unet_final_layer_linear.alpha +# lora_unet_final_layer_linear.lora_down.weight +# lora_unet_final_layer_linear.lora_up.weight +FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)" + # A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format. # Example keys: # lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha @@ -44,6 +52,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo """ return all( re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) + or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() @@ -65,6 +74,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} for layer_name, layer_state_dict in grouped_state_dict.items(): if layer_name.startswith("lora_unet"): + # Skip the final layer. This is incompatible with current model definition. + if layer_name.startswith("lora_unet_final_layer"): + continue transformer_grouped_sd[layer_name] = layer_state_dict elif layer_name.startswith("lora_te1"): clip_grouped_sd[layer_name] = layer_state_dict