diff --git a/docs/source/package_reference/hira.md b/docs/source/package_reference/hira.md new file mode 100644 index 0000000000..e29a43bd1a --- /dev/null +++ b/docs/source/package_reference/hira.md @@ -0,0 +1,90 @@ +# HiRA + +High-Rank Adaptation ([HiRA](https://openreview.net/pdf?id=TwJrTz9cRS)) is a PEFT method that extends the LoRA approach by applying an element-wise modulation on the original weight matrix. Instead of adding a low-rank update directly, HiRA computes: + +$$ +W' = W_0 + W_0 \odot (B A) +$$ + +where $W_0$ is the base weight, and $A, B$ are low-rank factors with rank $r \ll \min( \text{in_features}, \text{out_features})$. This formulation allows HiRA to adapt existing weights with a multiplicative, input-dependent modulation, often improving fine-tuning efficiency on downstream tasks. + +The abstract from the HiRA paper is: + +> *We propose Hadamard High-Rank Adaptation (HiRA), a parameter-efficient fine-tuning (PEFT) method that enhances the adaptability of Large Language Models (LLMs). While Low-rank Adaptation (LoRA) is widely used to reduce resource demands, its low-rank updates may limit its expressiveness for new tasks. HiRA addresses this by using a Hadamard product to retain high-rank update parameters, improving the model capacity. Empirically, HiRA outperforms LoRA and its variants on several tasks, with extensive ablation studies validating its effectiveness.* + + +## Examples + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import get_peft_model +from peft.tuners.hira import HiRAConfig + +# Example 1: HiRA on opt-125m for causal language modeling +model_id = "facebook/opt-125m" +base_model = AutoModelForCausalLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Define HiRA configuration: apply to the MLP dense layers in each transformer block +hira_config = HiRAConfig( + r=32, + target_modules=["k_proj", "q_proj","v_proj","fc1","fc2"], + hira_dropout=0.0, + init_hira_weights=True, +) +peft_model = get_peft_model(base_model, hira_config) + +peft_model.print_trainable_parameters() +# trainable params: 4,718,592 || all params: 129,957,888 || trainable%: 3.6309 +``` + +## HiRAConfig + +[[autodoc]] tuners.hira.config.HiRAConfig + +## Core Layers + +### HiRALayer + +[[autodoc]] tuners.hira.layer.HiRALayer + +### Linear Adapter + +[[autodoc]] tuners.hira.layer.Linear + +### Embedding Adapter + +[[autodoc]] tuners.hira.layer.Embedding + +### Convolutional Adapters + +[[autodoc]] tuners.hira.layer.Conv1d [[autodoc]] tuners.hira.layer.Conv2d [[autodoc]] tuners.hira.layer.ConvNd + +## BitsAndBytes Integration + +* **8-bit Quantized**: [[autodoc]] tuners.hira.bnb.Linear8bitLt +* **4-bit Quantized**: [[autodoc]] tuners.hira.bnb.Linear4bit +* **Dispatch Utilities**: + + * [[autodoc]] tuners.hira.bnb.dispatch_bnb_8bit + * [[autodoc]] tuners.hira.bnb.dispatch_bnb_4bit + +## Dispatch Handler + +Default layer replacement for HiRA adapters: + +[[autodoc]] tuners.hira.dispatch.dispatch_default + + +## Citation: +If you found HiRA is useful, please cite HiRA as: +``` +@inproceedings{ +huang2025hira, +title={Hi{RA}: Parameter-Efficient Hadamard High-Rank Adaptation for Large Language Models}, +author={Qiushi Huang and Tom Ko and Zhan Zhuang and Lilian Tang and Yu Zhang}, +booktitle={The Thirteenth International Conference on Learning Representations}, +year={2025}, +url={https://openreview.net/forum?id=TwJrTz9cRS} +} +``` diff --git a/method_comparison/sanitizer.py b/method_comparison/sanitizer.py index 7659d650c0..099133da44 100644 --- a/method_comparison/sanitizer.py +++ b/method_comparison/sanitizer.py @@ -26,14 +26,14 @@ def _evaluate_node(df, node): raise ValueError("Right side of comparison must be a literal (number, string, list).") operator_map = { - ast.Gt: lambda c, v: df[c] > v, - ast.GtE: lambda c, v: df[c] >= v, - ast.Lt: lambda c, v: df[c] < v, - ast.LtE: lambda c, v: df[c] <= v, - ast.Eq: lambda c, v: df[c] == v, + ast.Gt: lambda c, v: df[c] > v, + ast.GtE: lambda c, v: df[c] >= v, + ast.Lt: lambda c, v: df[c] < v, + ast.LtE: lambda c, v: df[c] <= v, + ast.Eq: lambda c, v: df[c] == v, ast.NotEq: lambda c, v: df[c] != v, - ast.In: lambda c, v: df[c].isin(v), - ast.NotIn: lambda c, v: ~df[c].isin(v) + ast.In: lambda c, v: df[c].isin(v), + ast.NotIn: lambda c, v: ~df[c].isin(v), } op_type = type(op_node) if op_type not in operator_map: @@ -90,7 +90,7 @@ def parse_and_filter(df, filter_str): try: # 'eval' mode ensures the source is a single expression. - tree = ast.parse(filter_str, mode='eval') + tree = ast.parse(filter_str, mode="eval") expression_node = tree.body except (SyntaxError, ValueError) as e: raise ValueError(f"Invalid filter syntax: {e}") diff --git a/method_comparison/test_sanitizer.py b/method_comparison/test_sanitizer.py index 59c0dd191e..6dc7bd63b2 100644 --- a/method_comparison/test_sanitizer.py +++ b/method_comparison/test_sanitizer.py @@ -7,32 +7,34 @@ @pytest.fixture def df_products(): data = { - 'product_id': [101, 102, 103, 104, 105, 106], - 'category': ['Electronics', 'Books', 'Electronics', 'Home Goods', 'Books', 'Electronics'], - 'price': [799.99, 19.99, 49.50, 120.00, 24.99, 150.00], - 'stock': [15, 300, 50, 25, 150, 0] + "product_id": [101, 102, 103, 104, 105, 106], + "category": ["Electronics", "Books", "Electronics", "Home Goods", "Books", "Electronics"], + "price": [799.99, 19.99, 49.50, 120.00, 24.99, 150.00], + "stock": [15, 300, 50, 25, 150, 0], } return pd.DataFrame(data) def test_exploit_fails(df_products): with pytest.raises(ValueError) as e: - mask1 = parse_and_filter(df_products, - """price < 50 and @os.system("/bin/echo password")""") - assert 'Invalid filter syntax' in str(e) + mask1 = parse_and_filter(df_products, """price < 50 and @os.system("/bin/echo password")""") + assert "Invalid filter syntax" in str(e) -@pytest.mark.parametrize('expression,ids', [ - ("price < 50", [102, 103, 105]), - ("product_id in [101, 102]", [101, 102]), - ("price < 50 and category == 'Electronics'", [103]), - ("stock < 100 or category == 'Home Goods'", [101, 103, 104, 106]), - ("(price > 100 and stock < 20) or category == 'Books'", [101, 102, 105, 106]), - ("not (price > 50 or stock > 100)", [103]), - ("not price > 50", [102, 103, 105]), - ("(price < 50) & (category == 'Electronics')", [103]), - ("(stock < 100) | (category == 'Home Goods')", [101, 103, 104, 106]), -]) +@pytest.mark.parametrize( + "expression,ids", + [ + ("price < 50", [102, 103, 105]), + ("product_id in [101, 102]", [101, 102]), + ("price < 50 and category == 'Electronics'", [103]), + ("stock < 100 or category == 'Home Goods'", [101, 103, 104, 106]), + ("(price > 100 and stock < 20) or category == 'Books'", [101, 102, 105, 106]), + ("not (price > 50 or stock > 100)", [103]), + ("not price > 50", [102, 103, 105]), + ("(price < 50) & (category == 'Electronics')", [103]), + ("(stock < 100) | (category == 'Home Goods')", [101, 103, 104, 106]), + ], +) def test_operations(df_products, expression, ids): mask1 = parse_and_filter(df_products, expression) assert sorted(df_products[mask1].product_id) == sorted(ids) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index b2fcbe901f..397cddc098 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -61,6 +61,9 @@ EvaConfig, FourierFTConfig, FourierFTModel, + HiRAConfig, + HiRAModel, + HiRARuntimeConfig, HRAConfig, HRAModel, IA3Config, @@ -149,6 +152,9 @@ "FourierFTModel", "HRAConfig", "HRAModel", + "HiRAConfig", + "HiRAModel", + "HiRARuntimeConfig", "IA3Config", "IA3Model", "LNTuningConfig", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f758499e12..051fedf503 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -19,6 +19,11 @@ from .c3a import C3AConfig, C3AModel from .cpt import CPTConfig, CPTEmbedding from .fourierft import FourierFTConfig, FourierFTModel +from .hira import ( + HiRAConfig, + HiRAModel, + HiRARuntimeConfig, +) from .hra import HRAConfig, HRAModel from .ia3 import IA3Config, IA3Model from .ln_tuning import LNTuningConfig, LNTuningModel @@ -66,6 +71,9 @@ "FourierFTModel", "HRAConfig", "HRAModel", + "HiRAConfig", + "HiRAModel", + "HiRARuntimeConfig", "IA3Config", "IA3Model", "LNTuningConfig", diff --git a/src/peft/tuners/hira/__init__.py b/src/peft/tuners/hira/__init__.py new file mode 100644 index 0000000000..aaf228afe6 --- /dev/null +++ b/src/peft/tuners/hira/__init__.py @@ -0,0 +1,55 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method + +from .config import HiRAConfig, HiRARuntimeConfig +from .layer import Conv2d, Conv3d, Embedding, HiRALayer, Linear +from .model import HiRAModel + + +__all__ = [ + "Conv2d", + "Conv3d", + "Embedding", + "HiRAConfig", + "HiRALayer", + "HiRAModel", + "HiRARuntimeConfig", + "Linear", +] + +register_peft_method(name="hira", config_cls=HiRAConfig, model_cls=HiRAModel, is_mixed_compatible=True) + + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + +# +# if (name == "EetqLoraLinear") and is_eetq_available(): +# from .eetq import EetqLoraLinear +# +# return EetqLoraLinear +# +# raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/hira/bnb.py b/src/peft/tuners/hira/bnb.py new file mode 100644 index 0000000000..cea9a1d535 --- /dev/null +++ b/src/peft/tuners/hira/bnb.py @@ -0,0 +1,484 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional + +import bitsandbytes as bnb +import torch +import torch.nn.functional as F + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight +from peft.utils.other import transpose + +from .layer import HiRALayer + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, HiRALayer): + # HiRA implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + hira_dropout: float = 0, + init_hira_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + HiRALayer.__init__(self, base_layer) + self.fan_in_fan_out = False + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + hira_dropout=hira_dropout, + init_hira_weights=init_hira_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter not in self.hira_A.keys(): + continue + + warnings.warn( + "Merge hira module to 8-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + + # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 + # dequantization directly + output = dequantize_bnb_weight(weight, state=state) + hira_data = self.get_delta_weight(active_adapter) + hira_data = output.to(hira_data.dtype).to(hira_data.device) * hira_data + w_data = output.to(hira_data.dtype).to(hira_data.device) + hira_data + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + + state.reset_grads() + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.hira_A.keys(): + continue + warnings.warn( + "Unmerge hira module to 8-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + output = dequantize_bnb_weight(weight, state=state) + + hira_data = self.get_delta_weight(active_adapter) + w_data = output.to(hira_data.dtype).to(hira_data.device) / (1 + hira_data) + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + + state.reset_grads() + + def get_delta_weight(self, adapter): + return transpose( + self.hira_B[adapter] @ self.hira_A[adapter], + False, + ) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + # dequantitize + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 + # dequantization directly + dequant_w = dequantize_bnb_weight(weight, state=state) + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.hira_A.keys(): + continue + + hira_A = self.hira_A[active_adapter] + hira_B = self.hira_B[active_adapter] + dropout = self.hira_dropout[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = self._cast_input_dtype(x, hira_A.dtype) + + # getting the sub-batch, passing it to HiRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + sub_batch = dropout(sub_batch) + prod_AB = torch.mm(hira_A.T, hira_B.T) + + eff_weight = transpose(dequant_w.to(prod_AB.dtype).to(prod_AB.device), self.fan_in_fan_out) * prod_AB.T + hira_out = F.linear(sub_batch, eff_weight) + if requires_conversion: + hira_out = hira_out.to(expected_dtype) + result[sub_batch_indices_list[i]] += hira_out + + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + # 2) now compute the HiRA “delta” in float32 and add it: + # dequantize once per forward + raw_weight = self.get_base_layer().weight + state = self.get_base_layer().state + dequant_w = dequantize_bnb_weight(raw_weight, state=state) # float32 tensor [out, in] + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.hira_A.keys(): + continue + hira_A = self.hira_A[active_adapter] + hira_B = self.hira_B[active_adapter] + dropout = self.hira_dropout[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = self._cast_input_dtype(x, hira_A.dtype) + dropout_sub = dropout(x) + # compute Δ = B @ A → shape (out, in) + prod_AB = hira_B @ hira_A + # effective weight = W₀ ⊙ Δ + eff_w = dequant_w * prod_AB + hira_out = F.linear(dropout_sub, eff_w) + if requires_conversion: + hira_out = hira_out.to(expected_dtype) + result = result + hira_out + if requires_conversion: + result = result.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hira." + rep + + def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_8bit = kwargs.get("loaded_in_8bit", False) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "threshold": target.state.threshold, + "index": target.index, + } + ) + new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + + return new_module + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, HiRALayer): + # HiRA implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + r: int = 0, + hira_dropout: float = 0.0, + init_hira_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + HiRALayer.__init__(self, base_layer) + self.fan_in_fan_out = False + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + hira_dropout=hira_dropout, + init_hira_weights=init_hira_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter not in self.hira_A.keys(): + continue + + warnings.warn( + "Merge hira module to 4-bit linear may get different generations due to rounding errors." + ) + # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + + output = dequantize_bnb_weight(weight, state=weight.quant_state) + hira_data = self.get_delta_weight(active_adapter) + w_data = output * (1 + hira_data) + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + kwargs["requires_grad"] = False + kwargs.pop("data", None) + # torch.compile can introduce attributes preceded by '_', remove them + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.hira_A.keys(): + continue + warnings.warn( + "Unmerge hira module to 4-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + output = dequantize_bnb_weight(weight, state=weight.quant_state) + + hira_data = self.get_delta_weight(active_adapter) + w_data = output / (1 + hira_data) + + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + kwargs["requires_grad"] = False + kwargs.pop("data", None) + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + + def get_delta_weight(self, adapter): + return transpose( + self.hira_B[adapter] @ self.hira_A[adapter], + False, + ) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + weight = self.get_base_layer().weight + dequant_w = dequantize_bnb_weight(weight, state=weight.quant_state) + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.hira_A.keys(): + continue + + hira_A = self.hira_A[active_adapter] + hira_B = self.hira_B[active_adapter] + dropout = self.hira_dropout[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = self._cast_input_dtype(x, hira_A.dtype) + + # getting the sub-batch, passing it to HiRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + prod_AB = torch.mm(hira_A.T, hira_B.T) + sub_batch = dropout(sub_batch) + eff_weight = transpose(dequant_w.to(prod_AB.dtype).to(prod_AB.device), self.fan_in_fan_out) * prod_AB.T + hira_out = F.linear(sub_batch, eff_weight) + if requires_conversion: + hira_out = hira_out.to(expected_dtype) + result[sub_batch_indices_list[i]] += hira_out + + return result + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + # As per Tim Dettmers, for 4bit, we need to defensively clone here. + # The reason is that in some cases, an error can occur that backprop + # does not work on a manipulated view. This issue may be solved with + # newer PyTorch versions but this would need extensive testing to be + # sure. + result = result.clone() + # dequanted w + weight = self.get_base_layer().weight + dequant_w = dequantize_bnb_weight(weight, state=weight.quant_state) + for active_adapter in self.active_adapters: + if active_adapter not in self.hira_A.keys(): + continue + hira_A = self.hira_A[active_adapter] + hira_B = self.hira_B[active_adapter] + dropout = self.hira_dropout[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = self._cast_input_dtype(x, hira_A.dtype) + + dropout_sub = dropout(x) + # compute Δ = B @ A → shape (out, in) + prod_AB = hira_B @ hira_A + # effective weight = W₀ ⊙ Δ + eff_w = dequant_w * prod_AB + hira_out = F.linear(dropout_sub, eff_w) + if requires_conversion: + hira_out = hira_out.to(expected_dtype) + result = result + hira_out + if requires_conversion: + result = result.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hira." + rep + + def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_4bit = kwargs.get("loaded_in_4bit", False) + if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + + return new_module diff --git a/src/peft/tuners/hira/config.py b/src/peft/tuners/hira/config.py new file mode 100644 index 0000000000..e5b9388e0e --- /dev/null +++ b/src/peft/tuners/hira/config.py @@ -0,0 +1,243 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal, Optional, Union + +from torch import nn + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class HiRARuntimeConfig: + """ + This is the sub-configuration class to store the runtime configurations for the model. + + Args: + ephemeral_gpu_offload (`bool`): + Whether to use ephemeral GPU offloading for models partially kept in CPU memory. + """ + + ephemeral_gpu_offload: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use ephemeral GPU offloading for models partially kept in CPU memory. Ephemeral GPU offloading result in " + "the data involved in intense operations being momentarily copied over to the GPU, and the results copied " + "back to CPU. There is a momentary VRAM overhead, but operations are generally orders of magnitude faster " + "compared to performing them on the CPU. This is useful when parts of the model and/or components (such " + "as adapters) are kept in CPU memory until they are needed. Rather than perform expensive operations on " + "small data, the data is transferred to the GPU on-demand, the operation(s) performed, and the results " + "moved back to CPU memory. Currently only affects DoRA initialization." + ) + }, + ) + + +@dataclass +class HiRAConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`HiRAModel`]. + + Args: + r (`int`): + HiRA r configuration (the "r"). + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen (if + the model is a PreTrainedModel, the output layer excluded). If this is not specified, modules will be + chosen according to the model architecture. If the architecture is not known, an error will be raised -- in + this case, you should specify the target modules manually. + exclude_modules (`Optional[Union[List[str], str]]`): + The names of the modules to not apply the adapter. When passing a string, a regex match will be performed. + When passing a list of strings, either an exact match will be performed or it is checked if the name of the + module ends with any of the passed strings. + hira_dropout (`float`): + The dropout probability for HiRA layers. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + init_hira_weights (`bool` | `Literal["gaussian"]`): + How to initialize the weights of the adapter layers. Passing True (default) results in the default + initialization from the reference implementation from Microsoft, with the HiRA B weight being set to 0. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`Optional[Union[List[str], str]]`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the + `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`. + r_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default r specified + by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`. + layer_replication (`List[Tuple[int, int]]`): + Build a new stack of layers by stacking the original model layers according to the ranges specified. This + allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will + all have separate HiRA adapters attached to them. + runtime_config (`HiRARuntimeConfig`): + Runtime configurations (which are not saved or restored). + """ + + r: int = field(default=32, metadata={"help": "HiRA intermediate r configuration"}) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with HiRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'." + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D " + "(if the model is a PreTrainedModel, the output layer excluded)." + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually." + ), + }, + ) + exclude_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={"help": "List of module names or regex expression of the module names to exclude from HiRA."}, + ) + hira_dropout: float = field(default=0.0, metadata={"help": "HiRA dropout"}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from HiRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + init_hira_weights: bool | Literal["gaussian"] = field( + default=True, + metadata={ + "help": ( + "How to initialize the weights of the HiRA layers. " + "Passing True (default) results in the default initialization" + ", with the HiRA B weight being set to 0. This means that without further training, the HiRA " + "adapter will be a no-op. " + "Setting the initialization to False leads to random initialization of HiRA A and B, meaning that HiRA " + "is not a no-op before training; this setting is intended for debugging purposes. " + "Passing `'gaussian'` results in Gaussian initialization scaled by the HiRA rank for linear and layers. " + ), + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index. " + "This only works when target_modules is a list of str." + }, + ) + layers_pattern: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + "This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the " + "model, which is often called `'layers'` or `'h'`." + }, + ) + r_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to r which are different from the default rank specified by `r`. " + "For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`." + ) + }, + ) + + # Enables replicating layers in a model to expand it to a larger model. + layer_replication: Optional[list[tuple[int, int]]] = field( + default=None, + metadata={ + "help": ( + "This enables using HiRA to effectively expand a transformer model to a larger size by repeating some layers. " + "The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with " + "a module list in the model which it modifies to expand the number of modules. " + "Base weights are shared so the memory usage is close to the original model. The intended use is these base weights " + "remain fixed during finetuning but each layer has a separate HiRA adapter so the layers can be specialed via " + "the adapter layers fit during fine tuning." + "The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n" + " Original model has 5 layers labelled by their position in the model: `[0, 1, 2, 3, 4]`\n" + " layer_replication: `[[0, 4], [2, 5]]`\n" + " Final model will have this arrangement of original layers: `[0, 1, 2, 3, 2, 3, 4]`\n" + "This format is based on what is used for pass-through merges in mergekit. It makes it simple to select sequential " + "ranges of a model and stack them while reusing layers at either end of each sequence." + ) + }, + ) + runtime_config: HiRARuntimeConfig = field( + default_factory=HiRARuntimeConfig, metadata={"help": "Runtime configurations"} + ) + + def to_dict(self): + """ + Returns the configuration for your adapter model as a dictionary. Removes runtime configurations. + """ + rv = super().to_dict() + rv.pop("runtime_config") + return rv + + def __post_init__(self): + super().__post_init__() + self.peft_type = PeftType.HIRA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + self.exclude_modules = ( + set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules + ) + + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") + + self._custom_modules: Optional[dict[type[nn.Module], type[nn.Module]]] = None + + def _register_custom_module(self, mapping: dict[type[nn.Module], type[nn.Module]]) -> None: + """ + Experimental API to support providing custom HiRA layers. + + This API is subject to change, you should carefully read the docs before deciding to use it: + + https://huggingface.co/docs/peft/developer_guides/custom_models + + To register custom HiRA module types, call this method with a `mapping` argument that is a dict that maps from + the target layer type to the custom HiRA layer type. The dict can contain multiple items if you wish to target + multiple layer types. The target layer type can be any nn.Module that we currently don't support in PEFT, + whether that is an official PyTorch layer type or a custom layer type. The custom HiRA module class has to be + implemented by the user and follow the PEFT conventions for HiRA layers. + + """ + if self._custom_modules is None: + self._custom_modules = {} + self._custom_modules.update(mapping) diff --git a/src/peft/tuners/hira/layer.py b/src/peft/tuners/hira/layer.py new file mode 100644 index 0000000000..3ce51a46d6 --- /dev/null +++ b/src/peft/tuners/hira/layer.py @@ -0,0 +1,953 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .config import HiRAConfig + + +class HiRALayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names: tuple[str, ...] = ("hira_A", "hira_B", "hira_embedding_A", "hira_embedding_B") + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str, ...] = ("r", "hira_dropout") + + def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: + self.base_layer = base_layer + self.r = {} + self.hira_dropout = nn.ModuleDict({}) + self.hira_A = nn.ParameterDict({}) + self.hira_B = nn.ParameterDict({}) + # For Embedding layer + self.hira_embedding_A = nn.ParameterDict({}) + self.hira_embedding_B = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self._caches: dict[str, Any] = {} + self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled: bool = True + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv1d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Conv2d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Conv3d): + in_features, out_features = base_layer.in_channels, base_layer.out_channels + elif isinstance(base_layer, nn.Embedding): + in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + elif isinstance(base_layer, nn.MultiheadAttention): + if not base_layer._qkv_same_embed_dim: + raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.") + in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim + elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): + # QuantLinear + in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": + # AQLM QuantLinear + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = base_layer.in_features, base_layer.out_features + elif base_layer.__class__.__name__ == "EetqLinear": + # Eetq layers + in_features, out_features = base_layer.in_features, base_layer.out_features + elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear": + # HQQ layers + in_features, out_features = base_layer.in_features, base_layer.out_features + elif base_layer.__class__.__name__ == "PatchedLinear": + # INC layers + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + # possibly support user provided custom layer types using dynamic dispatch + if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + in_features, out_features = None, None + warnings.warn( + f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + ) + + self.in_features = in_features + self.out_features = out_features + + def update_layer( + self, + adapter_name, + r, + hira_dropout, + init_hira_weights, + ): + # collect the kwargs + kwargs = locals().copy() + del kwargs["self"] + + # This code works for linear layers, override for other layer types + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + if hira_dropout > 0.0: + hira_dropout_layer = nn.Dropout(p=hira_dropout) + else: + hira_dropout_layer = nn.Identity() + + self.hira_dropout.update(nn.ModuleDict({adapter_name: hira_dropout_layer})) + # Actual trainable parameters + self.hira_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, self.in_features))})) + self.hira_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(self.out_features, r))})) + + # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed + if init_hira_weights: + self.reset_hira_parameters(adapter_name, init_hira_weights) + self._move_adapter_to_device_of_base_layer(adapter_name) + + self.set_adapter(self.active_adapters) + + def reset_hira_parameters(self, adapter_name, init_hira_weights): + if init_hira_weights is False: + return + + if adapter_name in self.hira_A.keys(): + if init_hira_weights is True: + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.hira_A[adapter_name], a=math.sqrt(5)) + elif init_hira_weights.lower() == "gaussian": + nn.init.normal_(self.hira_A[adapter_name], std=1 / self.r[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_hira_weights=}") + nn.init.zeros_(self.hira_B[adapter_name]) + if adapter_name in self.hira_embedding_A.keys(): + # Initialize A to zeros and B the same way as the default for nn.Embedding, see: + nn.init.zeros_(self.hira_embedding_A[adapter_name]) + nn.init.normal_(self.hira_embedding_B[adapter_name]) + + def _cache_store(self, key: str, value: Any) -> None: + self._caches[key] = value + + def _cache_pop(self, key: str) -> Any: + value = self._caches.pop(key) + return value + + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + if self.merged: + # It is unclear what would be the right thing to do if users pass adapter_names and there are merged + # adapters. Therefore, it is better to raise an error in this case. + msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." + raise ValueError(msg) + + def _mixed_batch_forward( + self, + x: torch.Tensor, + *args: Any, + adapter_names: list[str], + **kwargs: Any, + ) -> torch.Tensor: + """ + Forward pass that allows *different* adapters to be used for different examples in the same batch + (``adapter_names`` must have length == len(x)). + + The base projection is computed once; the HiRA update is then added separately for each sub-batch that shares + the same adapter. + """ + # 0. run the expensive base layer once for the whole batch + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + # 1. collect indices for each adapter in the batch + adapter_to_indices: dict[str, list[int]] = {} + for idx, name in enumerate(adapter_names): + adapter_to_indices.setdefault(name, []).append(idx) + + # 2. apply the HiRA update for each (adapter, sub-batch) + for active_adapter, idxs in adapter_to_indices.items(): + if active_adapter == "__base__": + continue # base only → nothing to add + if active_adapter not in self.hira_A: + continue # adapter not initialised + + # --- grab HiRA params & dropout --- + hira_A = self.hira_A[active_adapter] # (r, in_dim) + hira_B = self.hira_B[active_adapter] # (out_dim, r) + dropout = self.hira_dropout[active_adapter] + + # --- slice out sub-batch for this adapter --- + sub_batch = x[idxs] + sub_batch = self._cast_input_dtype(sub_batch, hira_A.dtype) + sub_batch = dropout(sub_batch) + + # --- compute element-wise modulated weight: W0 ⊙ (B @ A) --- + prod_AB = torch.mm(hira_A.T, hira_B.T) # (in_dim, out_dim) + assert prod_AB.T.shape == self.get_base_layer().weight.shape + eff_weight = transpose(self.get_base_layer().weight, self.fan_in_fan_out) * prod_AB.T + + hira_out = F.linear(sub_batch, eff_weight) + result[idxs] += hira_out.to(torch_result_dtype) + + return result + + +class Linear(nn.Module, HiRALayer): + # HiRA implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + hira_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_hira_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + HiRALayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + hira_dropout=hira_dropout, + init_hira_weights=init_hira_weights, + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.hira_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + orig_dtype = orig_weight.dtype + delta_weight = self.get_delta_weight(active_adapter) + orig_weight *= 1 + delta_weight.to(orig_dtype) + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weight + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data *= 1 + delta_weight + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.hira_A.keys(): + weight = self.get_base_layer().weight + orig_dtype = weight.dtype + delta_weight = self.get_delta_weight(active_adapter) + weight.data /= 1 + delta_weight.to(orig_dtype) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.hira_B[adapter].device + dtype = self.hira_B[adapter].dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.hira_A[adapter] + weight_B = self.hira_B[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + output_tensor = transpose((weight_B @ weight_A), self.fan_in_fan_out) + assert self.get_base_layer().weight.shape == output_tensor.shape + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.hira_A[adapter].data = weight_A.to(dtype) + self.hira_B[adapter].data = weight_B.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + hira_A_keys = self.hira_A.keys() + for active_adapter in self.active_adapters: + if active_adapter not in hira_A_keys: + continue + hira_A = self.hira_A[active_adapter] + hira_B = self.hira_B[active_adapter] + _prod_AB = torch.mm(hira_A.T, hira_B.T) + x = self._cast_input_dtype(x, hira_A.dtype) + dropout = self.hira_dropout[active_adapter] + dropout_sub = dropout(x) + hira_result = F.linear( + dropout_sub, transpose(self.get_base_layer().weight, self.fan_in_fan_out) * _prod_AB.T + ) + result = result + hira_result + + result = result.to(torch_result_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hira." + rep + + +class Embedding(nn.Module, HiRALayer): + # HiRA implemented in a Embedding layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + hira_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_hira_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + HiRALayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + hira_dropout=hira_dropout, + init_hira_weights=init_hira_weights, + ) + + def update_layer(self, adapter_name, r, hira_dropout, init_hira_weights): + # collect the kwargs + kwargs = locals().copy() + del kwargs["self"] + + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + if hira_dropout > 0.0: + hira_dropout_layer = nn.Dropout(p=hira_dropout) + else: + hira_dropout_layer = nn.Identity() + + self.hira_dropout[adapter_name] = hira_dropout_layer + # Actual trainable parameters + weight_A = torch.randn((r, self.in_features)) + weight_B = torch.randn((self.out_features, r)) + self.hira_embedding_A[adapter_name] = nn.Parameter(weight_A) + self.hira_embedding_B[adapter_name] = nn.Parameter(weight_B) + self.reset_hira_parameters(adapter_name, init_hira_weights) + + # call this before init of the hira variants + self._move_adapter_to_device_of_base_layer(adapter_name) + + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.hira_embedding_A.keys(): + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype) + orig_weight *= 1 + delta_weight + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.weight.data = orig_weight + else: + delta_weight = self.get_delta_weight(active_adapter).to(orig_dtype) + base_layer.weight.data *= 1 + delta_weight + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + orig_dtype = self.get_base_layer().weight.dtype + if active_adapter in self.hira_embedding_A.keys(): + weight = self.get_base_layer().weight + weight.data /= 1 + self.get_delta_weight(active_adapter).to(orig_dtype) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.hira_embedding_B[adapter].device + dtype = self.hira_embedding_A[adapter].dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.hira_embedding_A[adapter] + weight_B = self.hira_embedding_B[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + output_tensor = transpose(weight_B @ weight_A, True) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.hira_embedding_A[adapter] = weight_A.to(dtype) + self.hira_embedding_B[adapter] = weight_B.to(dtype) + + return output_tensor + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # Compute base embedding once for efficiency + result = self.base_layer(x, *args, **kwargs) + + # Group batch indices by adapter + adapter_to_indices = {} + for idx, adapter_name in enumerate(adapter_names): + adapter_to_indices.setdefault(adapter_name, []).append(idx) + + # Apply HiRA embedding updates + for adapter_name, indices in adapter_to_indices.items(): + if adapter_name == "__base__": + continue + if adapter_name not in self.hira_embedding_A: + continue + + embedding_A = self.hira_embedding_A[adapter_name] # shape: (r, num_embeddings) + embedding_B = self.hira_embedding_B[adapter_name] # shape: (embedding_dim, r) + + sub_batch = x[indices] # shape: (sub_batch_size, sequence_length) + + # Compute the low-rank update: (B @ A)[:, sub_batch].T + low_rank_update = F.embedding(sub_batch, (embedding_B @ embedding_A).T) + + # Element-wise modulation with base embedding + base_sub_embedding = result[indices] + hira_update = base_sub_embedding * low_rank_update + + # Update the result tensor + result[indices] += hira_update + + return result + + def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + base_layer = self.get_base_layer() + return F.embedding( + input, + weight, + padding_idx=base_layer.padding_idx, + max_norm=base_layer.max_norm, + norm_type=base_layer.norm_type, + scale_grad_by_freq=base_layer.scale_grad_by_freq, + sparse=base_layer.sparse, + ) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + HiRA forward for Embedding layer. Supports mixed adapters per batch or single adapter. + """ + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + # Adapter disabled or after merge: use base embedding + if self.disable_adapters: + if self.merged: + self.unmerge() + return self.base_layer(x, *args, **kwargs) + if adapter_names is not None: + return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + if self.merged: + return self.base_layer(x, *args, **kwargs) + + # Single adapter active: compute base embedding + HiRA residual + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + base_weight = self.get_base_layer().weight # (num_embeddings, embedding_dim) + + for adapter in self.active_adapters: + if adapter not in self.hira_embedding_A: + continue + # HiRA factors + hira_A = self.hira_embedding_A[adapter] # (r, num_embeddings) + hira_B = self.hira_embedding_B[adapter] # (embedding_dim, r) + + # Compute modulation matrix: B @ A -> (embedding_dim, num_embeddings) + mod_matrix = torch.mm(hira_B, hira_A) + # Element-wise modulated embedding weight + eff_weight = base_weight * mod_matrix.T # (num_embeddings, embedding_dim) + + # Compute HiRA residual via embedding lookup + hira_out = F.embedding(x, eff_weight) + result = result + hira_out + + return result.to(torch_result_dtype) + + def __repr__(self) -> str: + rep = super().__repr__() + return "hira." + rep + + +class _ConvNd(nn.Module, HiRALayer): + # HiRA implemented in a conv(2,3)d layer + def __init__( + self, + base_layer: nn.Module, + adapter_name: str, + r: int = 0, + hira_dropout: float = 0.0, + init_hira_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + HiRALayer.__init__(self, base_layer) + + if base_layer.groups > 1: + warnings.warn("HiRA adapter added to ConvNd layer with groups > 1. Merging is not supported.") + + self._active_adapter = adapter_name + self._kernel_dim = base_layer.weight.dim() + + self.update_layer( + adapter_name, + r, + hira_dropout=hira_dropout, + init_hira_weights=init_hira_weights, + ) + + def update_layer(self, adapter_name, r, hira_dropout, init_hira_weights): + # collect the kwargs + kwargs = locals().copy() + del kwargs["self"] + + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + if hira_dropout > 0.0: + hira_dropout_layer = nn.Dropout(p=hira_dropout) + else: + hira_dropout_layer = nn.Identity() + + self.hira_dropout[adapter_name] = hira_dropout_layer + # Determine base conv parameters + base = self.get_base_layer() + conv_cls = type(base) + in_channels = base.in_channels + out_channels = base.out_channels + kernel_size = base.kernel_size + stride = base.stride + padding = base.padding + dilation = getattr(base, "dilation", (1,) * (base.weight.dim() - 2)) + groups = getattr(base, "groups", 1) + # Spatial dims for B: 1 in each spatial dimension + spatial_ones = (1,) * (base.weight.dim() - 2) + + # HiRA factor A: conv from in_channels to r + self.hira_A[adapter_name] = conv_cls( + in_channels, + r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + # HiRA factor B: conv from r back to out_channels + self.hira_B[adapter_name] = conv_cls( + r, + out_channels, + kernel_size=spatial_ones, + stride=(1,) * len(spatial_ones), + padding=(0,) * len(spatial_ones), + dilation=(1,) * len(spatial_ones), + groups=groups, + bias=False, + ) + + # Initialize HiRA parameters (A with Kaiming, B zeros) + if init_hira_weights: + # A: same init as base conv weight + nn.init.kaiming_uniform_(self.hira_A[adapter_name].weight, a=math.sqrt(5)) + # B: initialize to zero + nn.init.zeros_(self.hira_B[adapter_name].weight) + + # Place adapters on correct device and register + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights inside the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.hira_A.keys(): + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + + if base_layer.groups > 1: + # https://github.com/huggingface/peft/pull/2403 + raise NotImplementedError("Merging is not supported for _ConvNd layers with groups > 1!") + + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + orig_weight *= 1 + delta_weight.to(orig_dtype) + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weight + + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data *= 1 + delta_weight.to(orig_dtype) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.hira_A.keys(): + weight = self.get_base_layer().weight + orig_dtype = weight.dtype + delta_weight = self.get_delta_weight(active_adapter) + weight.data /= 1 + delta_weight.to(orig_dtype) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.hira_B[adapter].weight.device + dtype = self.hira_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.hira_A[adapter].weight + weight_B = self.hira_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + output_tensor = self.conv_fn(weight_A.transpose(0, 1), weight_B) + + if self.get_base_layer().groups > 1: + output_tensor = output_tensor + else: + output_tensor = output_tensor.transpose(0, 1) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.hira_A[adapter].weight.data = weight_A.to(dtype) + self.hira_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + base = self.get_base_layer() + # Determine which convolution function to use + if isinstance(base, nn.Conv1d): + conv_fn = F.conv1d + elif isinstance(base, nn.Conv2d): + conv_fn = F.conv2d + elif isinstance(base, nn.Conv3d): + conv_fn = F.conv3d + else: + raise TypeError(f"Unsupported conv layer {type(base)} for HiRA.") + for active_adapter in self.active_adapters: + if active_adapter not in self.hira_A.keys(): + continue + hira_A = self.hira_A[active_adapter] + dropout = self.hira_dropout[active_adapter] + x = self._cast_input_dtype(x, hira_A.weight.dtype) + x_in = self._cast_input_dtype(x, hira_A.weight.dtype) + x_drop = dropout(x_in) + # low-rank factor B@A + bia = self.get_delta_weight(active_adapter) # now returns (B@A) tensor + # element-wise modulate base weight: W0 ⊙ (B@A) + base_weight = base.weight + eff_weight = base_weight * bia + hira_out = conv_fn( + x_drop, + eff_weight, + bias=None, + stride=base.stride, + padding=base.padding, + dilation=getattr(base, "dilation", 1), + groups=base.groups, + ) + result = result + hira_out + + result = result.to(torch_result_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hira." + rep + + +class Conv2d(_ConvNd): + # HiRA implemented in a conv2d layer + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self._kernel_dim == 4: + raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") + self.conv_fn = F.conv2d + + +class Conv1d(_ConvNd): + # HiRA implemented in a conv1d layer + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self._kernel_dim == 3: + raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}") + self.conv_fn = F.conv1d + + +class Conv3d(_ConvNd): + # HiRA implemented in a conv3d layer + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self._kernel_dim == 5: + raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}") + self.conv_fn = F.conv3d + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + hira_config: HiRAConfig, + **kwargs, +) -> Optional[torch.nn.Module]: + """ + Dispatch function for HiRA adapters that avoids LoFT-Q related config. + """ + # Determine the base module (unwrap BaseTunerLayer) + if isinstance(target, BaseTunerLayer): + base = target.get_base_layer() + else: + base = target + + # Common HiRA init kwargs + module_kwargs = { + "r": hira_config.r, + "hira_dropout": hira_config.hira_dropout, + "init_hira_weights": hira_config.init_hira_weights, + } + # If fan_in_fan_out present in config, include it + if hasattr(hira_config, "fan_in_fan_out"): + module_kwargs["fan_in_fan_out"] = hira_config.fan_in_fan_out + + new_module = None + # Embedding + if isinstance(base, nn.Embedding): + new_module = Embedding(target, adapter_name, **module_kwargs) + # Conv layers + elif isinstance(base, nn.Conv2d): + new_module = Conv2d(target, adapter_name, **module_kwargs) + elif isinstance(base, nn.Conv3d): + new_module = Conv3d(target, adapter_name, **module_kwargs) + elif isinstance(base, nn.Conv1d): + new_module = Conv1d(target, adapter_name, **module_kwargs) + # Linear layers + elif isinstance(base, nn.Linear): + # Linear always uses fan_in_fan_out=False + if module_kwargs.get("fan_in_fan_out", False): + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + module_kwargs["fan_in_fan_out"] = False + new_module = Linear(target, adapter_name, **module_kwargs) + # HuggingFace Conv1D + elif isinstance(base, Conv1D): + # Conv1D expects fan_in_fan_out=True + if not module_kwargs.get("fan_in_fan_out", False): + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." + ) + module_kwargs["fan_in_fan_out"] = True + new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **module_kwargs) + + return new_module diff --git a/src/peft/tuners/hira/model.py b/src/peft/tuners/hira/model.py new file mode 100644 index 0000000000..8b81b41f05 --- /dev/null +++ b/src/peft/tuners/hira/model.py @@ -0,0 +1,743 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import operator +import warnings +from contextlib import contextmanager +from dataclasses import asdict, replace +from enum import Enum +from functools import partial, reduce +from typing import Literal, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + onload_layer, + replicate_layers, +) +from peft.utils import ( + AuxiliaryTrainingWrapper, + ModulesToSaveWrapper, + _freeze_adapter, + _get_submodules, + get_quantization_config, +) +from peft.utils.other import get_pattern_key + +from ...utils.constants import TRANSFORMERS_MODELS_TO_HIRA_TARGET_MODULES_MAPPING +from .config import HiRAConfig +from .layer import HiRALayer, dispatch_default + + +def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["adapter_names"] = adapter_names + return args, kwargs + + +class HiRAModel(BaseTuner): + """ + Creates HiRA Adapter model from a pretrained transformers model. + + The method is described in detail in https://openreview.net/pdf?id=TwJrTz9cRS. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The Lora model. + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM + >>> from peft import HiRAModel, HiRAConfig + + >>> config = HiRAConfig( + ... task_type="SEQ_2_SEQ_LM", + ... r=32, + ... target_modules=["q", "v"], + ... hira_dropout=0.01, + ... ) + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> hira_model = HiRAModel(model, config, "default") + ``` + + ```py + >>> import torch + >>> import transformers + >>> from peft import HiRAConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + >>> rank = ... + >>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] + >>> config = HiRAConfig(r=32, target_modules=target_modules, hira_dropout=0.1, task_type="CAUSAL_LM") + >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) + + >>> tokenizer = transformers.AutoTokenizer.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... bos_token="[BOS]", + ... eos_token="[EOS]", + ... unk_token="[UNK]", + ... pad_token="[PAD]", + ... mask_token="[MASK]", + ... ) + >>> model = transformers.GPTJForCausalLM.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... pad_token_id=tokenizer.eos_token_id, + ... use_cache=False, + ... device_map={"": rank}, + ... torch_dtype=torch.float16, + ... quantization_config=quantization_config, + ... ) + >>> model = prepare_model_for_kbit_training(model) + >>> hira_model = get_peft_model(model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. + """ + + prefix: str = "hira_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def _check_new_adapter_config(self, config: HiRAConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + return True + + @staticmethod + def _check_target_module_exists(hira_config, key): + return check_target_module_exists(hira_config, key) + + def _prepare_model(self, peft_config: HiRAConfig, model: nn.Module): + r""" + A private method to modify the model structure before adapter is applied. + + Args: + peft_config (`PeftConfig`): + The prepared adapter config. + model (`nn.Module`): + The model that is going to be adapted. + """ + if peft_config.layer_replication: + replicate_layers(model, peft_config.layer_replication) + + def _create_and_replace( + self, + hira_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + r_key = get_pattern_key(hira_config.r_pattern.keys(), current_key) + r = hira_config.r_pattern.get(r_key, hira_config.r) + + kwargs = { + "r": r, + "hira_dropout": hira_config.hira_dropout, + "fan_in_fan_out": hira_config.fan_in_fan_out, + "init_hira_weights": hira_config.init_hira_weights, + "ephemeral_gpu_offload": hira_config.runtime_config.ephemeral_gpu_offload, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + # for torchao merging, we need the get_apply_tensor_subclass from the quantization config + try: + kwargs["get_apply_tensor_subclass"] = operator.attrgetter( + "hf_quantizer.quantization_config.get_apply_tensor_subclass" + )(self.model) + except AttributeError: + pass + + quant_methods = ["gptq", "aqlm", "awq"] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f"{quant_method}_quantization_config"] = quantization_config + + if isinstance(target, HiRALayer): + target.update_layer( + adapter_name, + r, + hira_dropout=hira_config.hira_dropout, + init_hira_weights=hira_config.init_hira_weights, + ) + else: + device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None + new_module = self._create_new_module(hira_config, adapter_name, target, device_map=device_map, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ("ranknum" in name): + if hasattr(child, "qweight"): + weight = child.qweight + elif hasattr(child, "W_q"): + weight = child.W_q + elif hasattr(child, "weight"): + weight = child.weight + elif getattr(child, "in_proj_weight", None) is not None: # MHA + weight = child.in_proj_weight + else: + weight = next(child.parameters()) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + # No bias for HiRA, skipping + + @staticmethod + def _create_new_module(hira_config: HiRAConfig, adapter_name, target, **kwargs): + # Collect dispatcher functions to decide what backend to use for the replaced HiRA layer. The order matters, + # because the first match is always used. Therefore, the default layers should be checked last. + dispatchers = [] + + if hira_config._custom_modules: + # Experimental custom HiRA module support. Allows users to pass a custom mapping for unsupported layer + # types by impelementing their own LoRA layers. + def dynamic_dispatch_func(target, adapter_name, hira_config, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + for key, custom_cls in hira_config._custom_modules.items(): + if isinstance(target_base_layer, key): + new_module = custom_cls(target, adapter_name, **kwargs) + break + + return new_module + + dispatchers.append(dynamic_dispatch_func) + + # avoid eager bnb import + if is_bnb_available(): + from .bnb import dispatch_bnb_8bit + + dispatchers.append(dispatch_bnb_8bit) + + if is_bnb_4bit_available(): + from .bnb import dispatch_bnb_4bit + + dispatchers.append(dispatch_bnb_4bit) + # TODO: Needs check here + dispatchers.extend( + [ + dispatch_default, + ] + ) + + new_module = None + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, hira_config=hira_config, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv1d`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, " + "`transformers.pytorch_utils.Conv1D`." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled: bool = True) -> None: + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, HiRALayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If adapter_names is passed as an argument, we inject it into the forward arguments. + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is None: + # nothing to do + yield + return + + if self.training: + raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + + # Check that users only passed actually existing adapters. + # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want + # to check that there is at least one layer with the given name, or else something like typos can easily slip. + expected_adapters = set() + for layer in self.modules(): + if isinstance(layer, HiRALayer): + expected_adapters |= layer.hira_A.keys() + expected_adapters |= layer.hira_embedding_A.keys() + unique_adapters = {name for name in adapter_names if name != "__base__"} + unexpected_adapters = unique_adapters - expected_adapters + if unexpected_adapters: + raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + + # deal with beam search + num_beams = kwargs.get("num_beams", None) + uses_beam_search = isinstance(num_beams, int) and (num_beams > 1) + original_adapter_names = adapter_names[:] + if uses_beam_search: + if not isinstance(adapter_names, (list, tuple)): + raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.") + # When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and + # then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the + # encoder part. Further below, the original argument is thus restored for the encoder. + adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), []) + + hook_handles = [] + for module in self.modules(): + if isinstance(module, HiRALayer) or isinstance(module, AuxiliaryTrainingWrapper): + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + if uses_beam_search and hasattr(self.model, "get_encoder"): + # For encoder-decoder models, even when applying beam search, the encoder part of the model should not use + # the extended adapter_names. This is because the encoder still uses the original, non-extended samples. + for module in self.model.get_encoder().modules(): + if isinstance(module, HiRALayer) or isinstance(module, AuxiliaryTrainingWrapper): + # Add another hook to overwrite the kwargs with the original adapter names -- this is easier than + # trying to exclude the encoder. + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + yield + + for handle in hook_handles: + handle.remove() + + def _check_merge_allowed(self): + """Verify that the configuration supports merging. + + Currently gptq quantization and replicated layers do not support merging. + """ + super()._check_merge_allowed() + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge HiRA layers when the model is gptq quantized") + if self.peft_config.get("layer_replication"): + raise ValueError("Cannot merge HiRA layers when base model layers are replicated") + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_HIRA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_HIRA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + self._check_merge_allowed() + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + with onload_layer(target): + if hasattr(target, "unload_and_optionally_merge_module"): + # if layers have special unloading method, like MultiheadAttention, use that + unloaded_module = target.unload_and_optionally_merge_module( + merge=merge, safe_merge=safe_merge, adapter_names=adapter_names + ) + self._replace_module(parent, target_name, unloaded_module, target) + elif hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + + return self.model + + def _check_add_weighted_adapter( + self, adapters: list[str], combination_type: str, svd_rank: int | None + ) -> tuple[str, int, str]: + """ + Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying + model. + """ + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these + # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they + # have modules for the adapters to be merged. + modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)] + problematic_wrappers = [ + wrapper + for wrapper in modules_to_save_wrappers + if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 + ] + if problematic_wrappers: + raise ValueError( + "Cannot add weighted adapters if they target the same module with modules_to_save, but found " + f"{len(problematic_wrappers)} such instance(s)." + ) + + # if there is only one adapter, we can only use linear merging + combination_type = "linear" if len(adapters) == 1 else combination_type + + adapters_rs: list[int] = [ + # When allocating tensors for the new adapter, we need the maximum possible r to not overflow + config.r if not config.r_pattern else max(config.r, *config.r_pattern.values()) + for config in (self.peft_config[adapter] for adapter in adapters) + ] + + if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"): + # all adapters ranks should be same, new rank is just this value + if len(set(adapters_rs)) != 1: + raise ValueError( + "All adapters must have the same r value when using combination_type linear, ties, dare_ties or " + "dare_linear." + ) + new_r = adapters_rs[0] + elif combination_type == "cat": + # adapters ranks may be different, new rank is sum of all ranks + # be careful, because output adapter rank may be really big if mixing a lot of adapters + new_r = sum(adapters_rs) + elif combination_type.endswith("svd"): + # new rank is the max of all ranks of the adapters if not provided + new_r = svd_rank or max(adapters_rs) + else: + raise ValueError(f"Invalid combination_type: {combination_type}") + + target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] + if not target_module_types: + raise ValueError(f"Found no adapter matching the names in {adapters}") + if len(set(target_module_types)) > 1: + raise ValueError( + "all adapter configs should follow the same target modules type. " + "Combining adapters with `target_modules` type being a mix of list/set and string is not supported." + ) + + if target_module_types[0] is str: + new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) + elif target_module_types[0] is set: + new_target_modules = reduce( + operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) + ) + else: + raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") + + return combination_type, new_r, new_target_modules + + def add_weighted_adapter( + self, + adapters: list[str], + weights: list[float], + adapter_name: str, + combination_type: str = "svd", + svd_rank: int | None = None, + svd_clamp: int | None = None, + svd_full_matrices: bool = True, + svd_driver: str | None = None, + density: float | None = None, + majority_sign_method: Literal["total", "frequency"] = "total", + ) -> None: + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + svd_driver (`str`, *optional*): + Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be + one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd` + documentation. Defaults to None. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + + if adapter_name in list(self.peft_config.keys()): + return + + combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter( + adapters=adapters, + combination_type=combination_type, + svd_rank=svd_rank, + ) + + self.peft_config[adapter_name] = replace( + self.peft_config[adapters[0]], + r=new_rank, + target_modules=new_target_modules, + alpha_pattern={}, + r_pattern={}, + ) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? + _freeze_adapter(self.model, adapter_name) + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, HiRALayer): + if adapter_name in target.hira_A: + target_hira_A = target.hira_A[adapter_name] + target_hira_B = target.hira_B[adapter_name] + elif adapter_name in target.hira_embedding_A: + target_hira_A = target.hira_embedding_A[adapter_name] + target_hira_B = target.hira_embedding_B[adapter_name] + else: + continue + + target_hira_A.data = target_hira_A.data * 0.0 + target_hira_B.data = target_hira_B.data * 0.0 + if combination_type == "cat": + hiras_A, hiras_B = [], [] + for adapter, weight in zip(adapters, weights): + if adapter in target.hira_A: + current_adapter_hira_A = target.hira_A[adapter] + current_adapter_hira_B = target.hira_B[adapter] + elif adapter in target.hira_embedding_A: + current_adapter_hira_A = target.hira_embedding_A[adapter] + current_adapter_hira_B = target.hira_embedding_B[adapter] + else: + continue + hiras_A.append(current_adapter_hira_A.data * weight) + hiras_B.append(current_adapter_hira_B.data) + + if len(hiras_A) == 0: + raise ValueError("No matching HiRAs found. Please raise an issue on GitHub.") + hiras_A = torch.cat(hiras_A, dim=0) + hiras_B = torch.cat(hiras_B, dim=1) + target_hira_A.data[: hiras_A.shape[0], :] = hiras_A + target_hira_B.data[:, : hiras_B.shape[1]] = hiras_B + elif combination_type in [ + "svd", + "ties_svd", + "dare_linear_svd", + "dare_ties_svd", + "magnitude_prune_svd", + ]: + target_hira_A.data, target_hira_B.data = self._svd_generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + new_rank, + target, + target_hira_A, + target_hira_B, + density, + majority_sign_method, + svd_clamp, + full_matrices=svd_full_matrices, + driver=svd_driver, + ) + elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]: + target_hira_A.data, target_hira_B.data = self._generalized_task_arithmetic_weighted_adapter( + combination_type, adapters, weights, target, density, majority_sign_method + ) + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, HiRALayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the HiRA layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-hira-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the hira modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 02453b283e..3fdefd0ca8 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -136,6 +136,8 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen3": ["q_proj", "v_proj"], } +TRANSFORMERS_MODELS_TO_HIRA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = { "t5": ["k", "v", "wo"], "mt5": ["k", "v", "wi_1"], diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 6e4aeae248..35886f839b 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -41,6 +41,7 @@ class PeftType(str, enum.Enum): - HRA - BONE - RANDLORA + - HIRA - SHIRA - C3A """ @@ -68,6 +69,7 @@ class PeftType(str, enum.Enum): BONE = "BONE" RANDLORA = "RANDLORA" TRAINABLE_TOKENS = "TRAINABLE_TOKENS" + HIRA = "HIRA" SHIRA = "SHIRA" C3A = "C3A" diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 4ca9be8898..558016a58d 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -58,6 +58,7 @@ from peft import ( AdaLoraConfig, EvaConfig, + HiRAConfig, LoftQConfig, LoraConfig, PeftModel, @@ -1791,6 +1792,118 @@ def tokenize(samples): # sanity check: assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests + def test_causal_lm_training_hira(self): + r""" + Test the CausalLM training on a single GPU device. This test is a converted version of + https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train + `opt-6.7b` on `english_quotes` dataset in few steps. The test would simply fail if the adapters are not set + correctly. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = HiRAConfig( + r=16, + target_modules=["q_proj", "v_proj"], + hira_dropout=0.05, + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset_english_quotes() + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_4bit_hira(self): + r""" + Test the CausalLM training on a single GPU device. This test is a converted version of + https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train + `opt-6.7b` on `english_quotes` dataset in few steps using 4bit base model. The test would simply fail if the + adapters are not set correctly. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = HiRAConfig( + r=16, + target_modules=["q_proj", "v_proj"], + hira_dropout=0.05, + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset_english_quotes() + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + @require_torch_gpu @require_auto_gptq diff --git a/tests/test_hira.py b/tests/test_hira.py new file mode 100644 index 0000000000..e1261dd6b0 --- /dev/null +++ b/tests/test_hira.py @@ -0,0 +1,164 @@ +import pytest +import torch +import torch.nn as nn + +from peft.tuners.hira import Linear # Assuming your HiRA implementation is under peft.tuners.hira +from peft.tuners.hira.layer import Conv1d as HiraConv1d +from peft.tuners.hira.layer import Conv2d as HiraConv2d + + +def test_hira_linear_merge_unmerge_basic(): + """ + Basic test for HiRA Linear layer: ensures merge and unmerge preserve outputs. + """ + # Setup + input_dim, output_dim, rank = 10, 5, 2 + adapter_name = "test_adapter" + batch_size = 4 + + # Base layer + base_linear = nn.Linear(input_dim, output_dim) + + # Wrap base layer with HiRA + hira_linear = Linear( + base_layer=base_linear, + adapter_name=adapter_name, + r=rank, + hira_alpha=rank, + hira_dropout=0.0, + init_hira_weights=True, + ) + + # Dummy input + x = torch.randn(batch_size, input_dim) + + # Forward pass without merging + output_before_merge = hira_linear(x) + + # Merge adapter weights + hira_linear.merge() + output_after_merge = hira_linear(x) + + # Assert merge preserves output + assert torch.allclose(output_before_merge, output_after_merge, atol=1e-5), ( + "Merged HiRA Linear output doesn't match original" + ) + + # Unmerge adapter weights + hira_linear.unmerge() + output_after_unmerge = hira_linear(x) + + # Assert unmerge restores original output + assert torch.allclose(output_before_merge, output_after_unmerge, atol=1e-5), ( + "Unmerged HiRA Linear output doesn't match original" + ) + + +@pytest.mark.parametrize( + "batch_size,in_ch,out_ch,length,rank", + [ + (2, 4, 6, 10, 3), + (3, 2, 5, 8, 2), + ], +) +def test_hira_conv1d_merge_unmerge(batch_size, in_ch, out_ch, length, rank): + base_conv = nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1) + hira_conv = HiraConv1d( + base_layer=base_conv, + adapter_name="test_adapter", + r=rank, + hira_dropout=0.0, + init_hira_weights=True, + ) + x = torch.randn(batch_size, in_ch, length) + + # Before merge + y0 = hira_conv(x) + + # Merge into W and test + hira_conv.merge() + y1 = hira_conv(x) + assert torch.allclose(y0, y1, atol=1e-5), "Merged Conv1d HiRA output doesn't match original" + + # Unmerge and test + hira_conv.unmerge() + y2 = hira_conv(x) + assert torch.allclose(y0, y2, atol=1e-5), "Unmerged Conv1d HiRA output doesn't match original" + + +@pytest.mark.parametrize( + "batch_size,in_ch,out_ch,H,W,rank", + [ + (2, 3, 5, 8, 8, 2), + (1, 1, 4, 10, 10, 1), + ], +) +def test_hira_conv2d_merge_unmerge(batch_size, in_ch, out_ch, H, W, rank): + base_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) + hira_conv = HiraConv2d( + base_layer=base_conv, + adapter_name="test_adapter", + r=rank, + hira_dropout=0.0, + init_hira_weights=True, + ) + x = torch.randn(batch_size, in_ch, H, W) + + # Before merge + y0 = hira_conv(x) + + # Merge into W and test + hira_conv.merge() + y1 = hira_conv(x) + assert torch.allclose(y0, y1, atol=1e-5), "Merged Conv2d HiRA output doesn't match original" + + # Unmerge and test + hira_conv.unmerge() + y2 = hira_conv(x) + assert torch.allclose(y0, y2, atol=1e-5), "Unmerged Conv2d HiRA output doesn't match original" + + +def test_manual_hira_linear_equivalence(): + import torch.nn.functional as F + + torch.manual_seed(42) + batch_size, input_dim, output_dim, rank = 3, 8, 6, 2 + adapter_name = "manual_test" + + # create base linear and HiRA wrapper + base = nn.Linear(input_dim, output_dim, bias=False) + # init W0 to something deterministic + nn.init.uniform_(base.weight, -0.5, 0.5) + + hira = Linear( + base_layer=base, + adapter_name=adapter_name, + r=rank, + hira_dropout=0.0, + init_hira_weights=True, + ) + # force A, B to known values + with torch.no_grad(): + hira.hira_A[adapter_name].copy_(torch.randn(rank, input_dim)) + hira.hira_B[adapter_name].copy_(torch.randn(output_dim, rank)) + + x = torch.randn(batch_size, input_dim) + + # HiRA forward (without merging) + y_hira = hira(x) + + # manual forward + W0 = base.weight.data # (out, in) + A = hira.hira_A[adapter_name] # (r, in) + B = hira.hira_B[adapter_name] # (out, r) + BA = B @ A # (out, in) + effW = W0 * BA # element-wise + # base output + y0 = F.linear(x, W0) # (batch, out) + # delta output + y_delta = F.linear(x, effW) + y_manual = y0 + y_delta + + assert torch.allclose(y_hira, y_manual, atol=1e-6), ( + f"HiRA forward mismatch: max diff = {(y_hira - y_manual).abs().max()}" + )