diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 52262be172..7644ff0521 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -192,6 +192,25 @@ def __init__(self, config: TrainEngineConfig): self._initialized = False self.is_offload = False + # LoRA Configuration (extract from config if enabled) + self.lora_config = None + if hasattr(config, "use_lora") and config.use_lora: + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + enabled: bool + rank: int + alpha: float + target_modules: list[str] + + self.lora_config = LoRAConfig( + enabled=True, + rank=config.lora_rank, + alpha=float(config.lora_alpha), + target_modules=config.target_modules if config.target_modules else [], + ) + def create_process_group( self, parallel_strategy: ParallelStrategy | None = None, @@ -661,7 +680,19 @@ def update_weights(self, meta: WeightUpdateMeta): ) def save(self, meta: SaveLoadMeta): - """Save model in HuggingFace or DCP format.""" + """Save model in HuggingFace or DCP format. + + When LoRA is enabled, only the adapter weights are saved in PEFT format. + When LoRA is disabled, the full model is saved. + """ + if self.lora_config is not None: + from areal.experimental.engine.archon_lora_checkpoint import ( + save_lora_adapter, + ) + + save_lora_adapter(self, meta.path, meta.base_model_path) + return + if meta.weight_format == "hf": save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor) elif meta.weight_format == "dcp": @@ -673,7 +704,20 @@ def save(self, meta: SaveLoadMeta): save_optimizer_state(self, meta.path) def load(self, meta: SaveLoadMeta): - """Load model from HuggingFace or DCP format.""" + """Load model from HuggingFace or DCP format. + + When LoRA is enabled and the checkpoint is a PEFT adapter, + only adapter weights are loaded. + """ + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + load_lora_adapter, + ) + + if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): + load_lora_adapter(self, meta.path) + return + if meta.weight_format == "hf": load_model_from_hf(self, meta.path) elif meta.weight_format == "dcp": diff --git a/areal/experimental/engine/archon_lora_checkpoint.py b/areal/experimental/engine/archon_lora_checkpoint.py new file mode 100644 index 0000000000..888c42985c --- /dev/null +++ b/areal/experimental/engine/archon_lora_checkpoint.py @@ -0,0 +1,241 @@ +"""LoRA adapter checkpoint I/O in PEFT format. + +This module provides functions to save and load LoRA adapters in PEFT-compatible +format for HuggingFace ecosystem interoperability. + +PEFT checkpoint structure: + adapter_checkpoint/ + ├── adapter_model.safetensors # LoRA weights only + └── adapter_config.json # PEFT configuration + +Reference: peft/src/peft/utils/save_and_load.py +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist +from safetensors.torch import load_file, save_file + +from areal.experimental.models.archon.lora.adapter import get_adapter_params +from areal.utils import logging + +if TYPE_CHECKING: + from areal.experimental.engine.archon_engine import ArchonEngine + +logger = logging.getLogger("LoRACheckpoint") + + +def save_lora_adapter( + engine: "ArchonEngine", + path: str, + base_model_path: str | None = None, +) -> None: + """Save LoRA adapter in PEFT format. + + Creates two files: + - adapter_model.safetensors: LoRA weights (lora_a, lora_b) + - adapter_config.json: PEFT configuration + + Args: + engine: ArchonEngine instance with LoRA-enabled model + path: Directory path to save adapter checkpoint + base_model_path: Optional path to base model (for config reference) + + Raises: + RuntimeError: If LoRA is not enabled on engine + """ + if engine.lora_config is None: + raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine") + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + if rank == 0: + os.makedirs(path, exist_ok=True) + logger.info(f"Saving LoRA adapter to {path}") + + # Extract adapter parameters from model + adapter_params = get_adapter_params(engine.model) + + if not adapter_params: + logger.warning("No adapter parameters found in model") + if rank == 0: + logger.warning("Creating empty adapter checkpoint") + + # Convert to HF format using state dict adapter + archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()} + hf_state = engine.state_dict_adapter.to_hf(archon_state) + + # Add PEFT prefix: base_model.model.{key} + peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()} + + # Save weights (only rank 0) + if rank == 0: + weights_path = os.path.join(path, "adapter_model.safetensors") + save_file(peft_state, weights_path) + logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}") + + # Determine target modules from actual adapter parameters + target_modules = set() + for key in adapter_params: + parts = key.split(".") + for i, part in enumerate(parts): + if part in ("lora_a", "lora_b") and i > 0: + module_name = parts[i - 1] + target_modules.add(module_name) + break + + # Create config copy with actual target modules + from dataclasses import replace + + lora_config_for_save = replace( + engine.lora_config, target_modules=sorted(target_modules) + ) + + # Generate adapter config using model-specific state dict adapter + adapter_config = engine.state_dict_adapter.create_peft_adapter_config( + lora_config=lora_config_for_save, + base_model_path=base_model_path, + ) + + config_path = os.path.join(path, "adapter_config.json") + with open(config_path, "w") as f: + json.dump(adapter_config, f, indent=2) + logger.info(f"Saved adapter config to {config_path}") + + # Synchronize all ranks + if dist.is_initialized(): + dist.barrier() + + +def load_lora_adapter( + engine: "ArchonEngine", + path: str, + strict: bool = True, +) -> None: + """Load LoRA adapter from PEFT format checkpoint. + + Args: + engine: ArchonEngine instance with LoRA-enabled model + path: Directory path containing adapter checkpoint + strict: If True, raise error on missing/unexpected keys + + Raises: + RuntimeError: If LoRA is not enabled on engine + FileNotFoundError: If adapter checkpoint files not found + ValueError: If strict=True and keys don't match + """ + if engine.lora_config is None: + raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine") + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + if rank == 0: + logger.info(f"Loading LoRA adapter from {path}") + + # Load adapter weights + weights_path = os.path.join(path, "adapter_model.safetensors") + if not os.path.exists(weights_path): + # Fallback to .bin format + weights_path = os.path.join(path, "adapter_model.bin") + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"Adapter weights not found at {path}. " + "Expected adapter_model.safetensors or adapter_model.bin" + ) + peft_state = torch.load(weights_path, map_location="cpu", weights_only=True) + else: + peft_state = load_file(weights_path) + + if rank == 0: + logger.info(f"Loaded {len(peft_state)} adapter tensors from {weights_path}") + + # Strip PEFT prefix: base_model.model.{key} -> {key} + hf_state = {} + for key, value in peft_state.items(): + if key.startswith("base_model.model."): + hf_key = key.replace("base_model.model.", "", 1) + hf_state[hf_key] = value + else: + hf_state[key] = value + + # Convert from HF format to Archon format + archon_state = engine.state_dict_adapter.from_hf(hf_state) + + # Get expected adapter keys from model + expected_adapter_params = get_adapter_params(engine.model) + expected_keys = set(expected_adapter_params.keys()) + loaded_keys = set(archon_state.keys()) + + missing_keys = expected_keys - loaded_keys + unexpected_keys = loaded_keys - expected_keys + + if missing_keys or unexpected_keys: + if strict: + error_msg = [] + if missing_keys: + error_msg.append(f"Missing keys: {sorted(missing_keys)[:5]}...") + if unexpected_keys: + error_msg.append(f"Unexpected keys: {sorted(unexpected_keys)[:5]}...") + raise ValueError( + "Adapter checkpoint keys don't match model. " + " ".join(error_msg) + ) + else: + if missing_keys and rank == 0: + logger.warning( + f"Missing {len(missing_keys)} adapter keys: " + f"{sorted(missing_keys)[:5]}..." + ) + if unexpected_keys and rank == 0: + logger.warning( + f"Unexpected {len(unexpected_keys)} adapter keys: " + f"{sorted(unexpected_keys)[:5]}..." + ) + + # Load adapter weights into model + loaded_count = 0 + for key, value in archon_state.items(): + if key in expected_adapter_params: + param = expected_adapter_params[key] + value = value.to(device=param.device, dtype=param.dtype) + param.data.copy_(value) + loaded_count += 1 + + if rank == 0: + logger.info(f"Loaded {loaded_count} adapter parameters into model") + + if dist.is_initialized(): + dist.barrier() + + +def is_lora_adapter_checkpoint(path: str) -> bool: + """Check if path contains a PEFT LoRA adapter checkpoint. + + Args: + path: Directory path to check + + Returns: + True if path contains adapter_config.json with peft_type="LORA" + """ + config_path = Path(path) / "adapter_config.json" + + if not config_path.exists(): + return False + + try: + with open(config_path) as f: + config = json.load(f) + return config.get("peft_type") == "LORA" + except (OSError, json.JSONDecodeError): + return False diff --git a/areal/experimental/models/archon/base.py b/areal/experimental/models/archon/base.py index fd8978e203..d05f6eab5a 100644 --- a/areal/experimental/models/archon/base.py +++ b/areal/experimental/models/archon/base.py @@ -58,6 +58,9 @@ def __init__( self.to_hf_map: dict[str, str] = {} self.hf_assets_path = hf_assets_path self.fqn_to_index_mapping = None + # Model-specific mapping from Archon module names to PEFT module names + # Subclasses should define this mapping for LoRA adapter config generation + self.to_peft_module_map: dict[str, str] = {} if hf_assets_path: self._load_safetensors_index(hf_assets_path) @@ -115,6 +118,43 @@ def convert_single_to_hf( self, name: str, tensor: torch.Tensor ) -> list[tuple[str, torch.Tensor]]: ... + def create_peft_adapter_config( + self, lora_config: Any, base_model_path: str | None = None + ) -> dict: + """Generate adapter_config.json for PEFT format checkpoint. + + Args: + lora_config: LoRA configuration object with rank, alpha, target_modules + base_model_path: Optional path to base model + + Returns: + Dictionary containing PEFT adapter configuration + """ + # Convert Archon module names to PEFT names using model-specific mapping + peft_target_modules = [ + self.to_peft_module_map.get(name, name) + for name in lora_config.target_modules + ] + + return { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": base_model_path or "", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": lora_config.rank, + "lora_alpha": int(lora_config.alpha), + "lora_dropout": 0.0, + "target_modules": peft_target_modules, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + } + class BaseArchonModel(nn.Module, ABC): """Base class for Archon models.""" diff --git a/areal/experimental/models/archon/lora/__init__.py b/areal/experimental/models/archon/lora/__init__.py new file mode 100644 index 0000000000..36b1ec13ad --- /dev/null +++ b/areal/experimental/models/archon/lora/__init__.py @@ -0,0 +1,26 @@ +"""LoRA (Low-Rank Adaptation) infrastructure for Archon engine. + +Following torchtune's design patterns for FSDP2 compatibility. +This module provides custom LoRALinear implementation and utilities +for parameter-efficient fine-tuning of large language models. +""" + +from areal.experimental.models.archon.lora.adapter import ( + AdapterModule, + disable_adapter, + enable_adapter, + get_adapter_params, + get_adapter_state_dict, + set_trainable_params, +) +from areal.experimental.models.archon.lora.lora_linear import LoRALinear + +__all__ = [ + "LoRALinear", + "AdapterModule", + "get_adapter_params", + "get_adapter_state_dict", + "set_trainable_params", + "disable_adapter", + "enable_adapter", +] diff --git a/areal/experimental/models/archon/lora/adapter.py b/areal/experimental/models/archon/lora/adapter.py new file mode 100644 index 0000000000..f6325c66f4 --- /dev/null +++ b/areal/experimental/models/archon/lora/adapter.py @@ -0,0 +1,111 @@ +"""AdapterModule protocol and utilities for LoRA parameter management. + +Reference: torchtune/torchtune/modules/peft/_utils.py +Provides utilities for extracting, filtering, and managing adapter parameters. +""" + +from typing import Protocol, runtime_checkable + +import torch.nn as nn + + +@runtime_checkable +class AdapterModule(Protocol): + """Protocol for modules that contain adapter parameters. + + Any module implementing this protocol should provide an adapter_params() + method that returns a list of parameter names (relative to the module) + that should be treated as trainable adapters. + """ + + def adapter_params(self) -> list[str]: + """Return list of adapter parameter names relative to this module. + + Returns: + List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"]) + """ + ... + + +def get_adapter_params(model: nn.Module) -> dict[str, nn.Parameter]: + """Extract all adapter parameters from model using AdapterModule protocol. + + Walks through all modules in the model and collects parameters from modules + that implement the AdapterModule protocol. + + Args: + model: Model to extract adapter parameters from + + Returns: + Dictionary mapping fully-qualified parameter names to Parameter objects + """ + adapter_params = {} + + for module_name, module in model.named_modules(): + if isinstance(module, AdapterModule): + current_adapter_params = module.adapter_params() + + for param_name, param in module.named_parameters(recurse=True): + if param_name in current_adapter_params: + full_key = ( + f"{module_name}.{param_name}" if module_name else param_name + ) + adapter_params[full_key] = param + + return adapter_params + + +def set_trainable_params(model: nn.Module, adapter_param_names: set[str]) -> None: + """Freeze all parameters except those in adapter_param_names. + + Args: + model: Model to configure + adapter_param_names: Set of fully-qualified parameter names to keep trainable + """ + for name, param in model.named_parameters(): + param.requires_grad_(name in adapter_param_names) + + +def get_adapter_state_dict(state_dict: dict, device: str = "cpu") -> dict: + """Filter state dict to only adapter parameters. + + Args: + state_dict: Full model state dict + device: Device to move parameters to (default: "cpu") + + Returns: + Filtered state dict containing only adapter parameters + """ + + def is_adapter_key(k: str) -> bool: + return "lora_a" in k or "lora_b" in k + + return {k: v.to(device) for k, v in state_dict.items() if is_adapter_key(k)} + + +def disable_adapter(model: nn.Module) -> None: + """Disable LoRA adapters in all LoRALinear modules. + + Sets the ``disabled`` flag to True, causing forward passes to only use + the base weights. Useful for reference models in DPO/PPO. + + Args: + model: Model containing LoRALinear modules + """ + for module in model.modules(): + if isinstance(module, AdapterModule) and hasattr(module, "disabled"): + module.disabled = True + + +def enable_adapter(model: nn.Module) -> None: + """Enable LoRA adapters in all LoRALinear modules. + + Sets the ``disabled`` flag to False, enabling LoRA contributions + during forward passes. + + Args: + model: Model containing LoRALinear modules + """ + for module in model.modules(): + if isinstance(module, AdapterModule) and hasattr(module, "disabled"): + module.disabled = False diff --git a/areal/experimental/models/archon/lora/lora_linear.py b/areal/experimental/models/archon/lora/lora_linear.py new file mode 100644 index 0000000000..e66da99ac1 --- /dev/null +++ b/areal/experimental/models/archon/lora/lora_linear.py @@ -0,0 +1,162 @@ +"""LoRALinear module implementation following torchtune patterns. + +Reference: torchtune/torchtune/modules/peft/lora.py +This implementation provides FSDP2-compatible LoRA layers that work naturally +with Archon's meta device initialization flow. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoRALinear(nn.Module): + """Linear layer with Low-Rank Adaptation (LoRA). + + LoRA decomposes weight updates into low-rank matrices A and B: + W' = W + (alpha/rank) * B @ A + + During forward pass: + output = x @ W^T + (alpha/rank) * x @ A^T @ B^T + + Args: + in_dim: Input dimension + out_dim: Output dimension + rank: LoRA rank (r parameter) + alpha: LoRA scaling factor + dropout: Dropout probability for LoRA path (default: 0.0) + use_bias: Whether to include bias term (default: False) + + Attributes: + weight: Base weight parameter (frozen during LoRA training) + bias: Optional bias parameter (frozen during LoRA training) + lora_a: Low-rank matrix A (trainable) + lora_b: Low-rank matrix B (trainable) + dropout: Dropout layer for LoRA path + scaling: Computed scaling factor (alpha/rank) + disabled: Flag to disable LoRA during forward pass (for reference models) + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.scaling = alpha / rank + self.disabled = False + + # Base weight (frozen during LoRA training) + self.weight = nn.Parameter(torch.empty(out_dim, in_dim)) + if use_bias: + self.bias = nn.Parameter(torch.empty(out_dim)) + else: + self.register_parameter("bias", None) + + # LoRA adapters (trainable) + # Note: naming lora_a, lora_b (lowercase) matches PEFT convention + self.lora_a = nn.Linear(in_dim, rank, bias=False) + self.lora_b = nn.Linear(rank, out_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + self._initialize_weights() + + def _initialize_weights(self): + """Initialize weights following torchtune/PEFT conventions. + + Base weight: Kaiming uniform (will be overwritten by checkpoint loading) + Bias: Zeros + lora_a: Kaiming uniform (random initialization) + lora_b: Zeros (so initial LoRA contribution is 0) + """ + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + nn.init.zeros_(self.bias) + + # LoRA init: lora_a=kaiming, lora_b=zeros + # This ensures initial output matches base model output + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with LoRA. + + Args: + x: Input tensor of shape [..., in_dim] + + Returns: + Output tensor of shape [..., out_dim] + """ + # Base forward pass + base_out = F.linear(x, self.weight, self.bias) + + # If LoRA is disabled (e.g., for reference model), return base output + if self.disabled: + return base_out + + # LoRA forward pass: dropout -> A -> B -> scale + lora_out = self.lora_b(self.lora_a(self.dropout(x))) + return base_out + self.scaling * lora_out + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + rank: int, + alpha: float, + dropout: float = 0.0, + ) -> "LoRALinear": + """Convert an existing nn.Linear to LoRALinear. + + Args: + linear: Existing linear layer + rank: LoRA rank + alpha: LoRA scaling factor + dropout: Dropout probability (default: 0.0) + + Returns: + LoRALinear with base weights copied from input linear layer + """ + lora_linear = cls( + in_dim=linear.in_features, + out_dim=linear.out_features, + rank=rank, + alpha=alpha, + dropout=dropout, + use_bias=linear.bias is not None, + ) + + # Copy base weights from original linear layer + lora_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + lora_linear.bias.data.copy_(linear.bias.data) + + return lora_linear + + def adapter_params(self) -> list[str]: + """Return list of adapter parameter names. + + Implements AdapterModule protocol for parameter extraction. + + Returns: + List of parameter names relative to this module + """ + return ["lora_a.weight", "lora_b.weight"] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"in_dim={self.in_dim}, out_dim={self.out_dim}, " + f"rank={self.rank}, alpha={self.alpha}, " + f"dropout={self.dropout.p}, bias={self.bias is not None})" + ) diff --git a/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py b/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py index 7021282e2b..6ad2695f39 100644 --- a/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py +++ b/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py @@ -39,6 +39,25 @@ def __init__( "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", + # LoRA adapter key mappings (Attention) + "model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq.lora_a.weight", + "model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq.lora_b.weight", + "model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk.lora_a.weight", + "model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk.lora_b.weight", + "model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv.lora_a.weight", + "model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv.lora_b.weight", + "model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo.lora_a.weight", + "model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo.lora_b.weight", + # LoRA adapter key mappings (MLP) + "model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1.lora_a.weight", + "model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1.lora_b.weight", + "model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3.lora_a.weight", + "model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3.lora_b.weight", + "model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight", + "model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2.lora_b.weight", + # LoRA adapter key mappings (LM Head) + "lm_head.lora_A.weight": "output.lora_a.weight", + "lm_head.lora_B.weight": "output.lora_b.weight", } # Build reverse mapping @@ -49,6 +68,19 @@ def __init__( self.enable_weight_tying = getattr(model_config, "tie_word_embeddings", False) + # Archon module names to HF PEFT module names for LoRA adapters + # Used when generating adapter_config.json + self.to_peft_module_map = { + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "output": "lm_head", + } + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict = {} diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 6a0884dc76..5cac4d1b60 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -88,6 +88,7 @@ "Saver": "blue", "AsyncCheckpoint": "blue", "ArchonCheckpoint": "blue", + "LoRACheckpoint": "blue", # Platforms - cyan "Platform": "light_cyan", "PlatformInit": "light_cyan", diff --git a/tests/experimental/archon/test_archon_lora_checkpoint.py b/tests/experimental/archon/test_archon_lora_checkpoint.py new file mode 100644 index 0000000000..3d76e77ce7 --- /dev/null +++ b/tests/experimental/archon/test_archon_lora_checkpoint.py @@ -0,0 +1,259 @@ +"""Tests for LoRA adapter checkpoint I/O and PEFT format conversion.""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +import torch +from torch import nn + +from areal.experimental.models.archon.lora.adapter import get_adapter_params +from areal.experimental.models.archon.lora.lora_linear import LoRALinear +from areal.experimental.models.archon.qwen2.model.state_dict_adapter import ( + Qwen2StateDictAdapter, +) + +# Try to import PEFT for compatibility tests +try: + import peft # noqa: F401 + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +class TestStateDictAdapterLoRAKeys: + """Test LoRA key conversion in Qwen2StateDictAdapter.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_qwen2_lora_key_conversion_attention(self): + """Test attention LoRA key conversion (wq, wk, wv, wo).""" + hf_key = "model.layers.0.self_attn.q_proj.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "layers.0.attention.wq.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_key_conversion_mlp(self): + """Test MLP LoRA key conversion (w1, w2, w3).""" + hf_key = "model.layers.5.mlp.gate_proj.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "layers.5.feed_forward.w1.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_key_conversion_lm_head(self): + """Test LM head LoRA key conversion.""" + hf_key = "lm_head.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "output.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_case_conversion(self): + """Test lora_A/lora_B (HF) <-> lora_a/lora_b (Archon) case handling.""" + # lora_A -> lora_a + hf_key_a = "model.layers.0.self_attn.v_proj.lora_A.weight" + archon_key_a = self.adapter._convert_key_from_hf(hf_key_a) + assert "lora_a" in archon_key_a + + # lora_B -> lora_b + hf_key_b = "model.layers.0.self_attn.v_proj.lora_B.weight" + archon_key_b = self.adapter._convert_key_from_hf(hf_key_b) + assert "lora_b" in archon_key_b + + def test_qwen2_all_16_lora_mappings(self): + """Verify all 16 LoRA key patterns convert correctly.""" + lora_mappings = [ + ("model.layers.{}.self_attn.q_proj.lora_A.weight", "layers.{}.attention.wq.lora_a.weight"), + ("model.layers.{}.self_attn.q_proj.lora_B.weight", "layers.{}.attention.wq.lora_b.weight"), + ("model.layers.{}.self_attn.k_proj.lora_A.weight", "layers.{}.attention.wk.lora_a.weight"), + ("model.layers.{}.self_attn.k_proj.lora_B.weight", "layers.{}.attention.wk.lora_b.weight"), + ("model.layers.{}.self_attn.v_proj.lora_A.weight", "layers.{}.attention.wv.lora_a.weight"), + ("model.layers.{}.self_attn.v_proj.lora_B.weight", "layers.{}.attention.wv.lora_b.weight"), + ("model.layers.{}.self_attn.o_proj.lora_A.weight", "layers.{}.attention.wo.lora_a.weight"), + ("model.layers.{}.self_attn.o_proj.lora_B.weight", "layers.{}.attention.wo.lora_b.weight"), + ("model.layers.{}.mlp.gate_proj.lora_A.weight", "layers.{}.feed_forward.w1.lora_a.weight"), + ("model.layers.{}.mlp.gate_proj.lora_B.weight", "layers.{}.feed_forward.w1.lora_b.weight"), + ("model.layers.{}.mlp.up_proj.lora_A.weight", "layers.{}.feed_forward.w3.lora_a.weight"), + ("model.layers.{}.mlp.up_proj.lora_B.weight", "layers.{}.feed_forward.w3.lora_b.weight"), + ("model.layers.{}.mlp.down_proj.lora_A.weight", "layers.{}.feed_forward.w2.lora_a.weight"), + ("model.layers.{}.mlp.down_proj.lora_B.weight", "layers.{}.feed_forward.w2.lora_b.weight"), + ("lm_head.lora_A.weight", "output.lora_a.weight"), + ("lm_head.lora_B.weight", "output.lora_b.weight"), + ] + + for hf_pattern, archon_pattern in lora_mappings: + # Substitute layer index + hf_key = hf_pattern.replace("{}", "3") + archon_key = archon_pattern.replace("{}", "3") + + converted = self.adapter._convert_key_from_hf(hf_key) + assert converted == archon_key, f"from_hf failed: {hf_key} -> {converted}" + + back = self.adapter._convert_key_to_hf(archon_key) + assert back == hf_key, f"to_hf failed: {archon_key} -> {back}" + + def test_to_peft_module_map(self): + """Test Archon -> PEFT module name mapping.""" + assert self.adapter.to_peft_module_map["wq"] == "q_proj" + assert self.adapter.to_peft_module_map["wk"] == "k_proj" + assert self.adapter.to_peft_module_map["wv"] == "v_proj" + assert self.adapter.to_peft_module_map["wo"] == "o_proj" + assert self.adapter.to_peft_module_map["w1"] == "gate_proj" + assert self.adapter.to_peft_module_map["w2"] == "down_proj" + assert self.adapter.to_peft_module_map["w3"] == "up_proj" + assert self.adapter.to_peft_module_map["output"] == "lm_head" + + +class TestPEFTAdapterConfig: + """Test PEFT adapter config generation.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_create_peft_adapter_config(self): + """Test adapter_config.json generation.""" + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + rank: int + alpha: float + target_modules: list[str] + + lora_cfg = LoRAConfig(rank=8, alpha=16.0, target_modules=["wq", "wv"]) + config = self.adapter.create_peft_adapter_config(lora_cfg) + + assert config["peft_type"] == "LORA" + assert config["r"] == 8 + assert config["lora_alpha"] == 16 + assert config["task_type"] == "CAUSAL_LM" + assert "q_proj" in config["target_modules"] + assert "v_proj" in config["target_modules"] + assert config["bias"] == "none" + + def test_create_peft_adapter_config_with_base_model(self): + """Test adapter config with base model path.""" + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + rank: int + alpha: float + target_modules: list[str] + + lora_cfg = LoRAConfig(rank=16, alpha=32.0, target_modules=["wq"]) + config = self.adapter.create_peft_adapter_config( + lora_cfg, base_model_path="Qwen/Qwen2-0.5B" + ) + + assert config["base_model_name_or_path"] == "Qwen/Qwen2-0.5B" + assert config["r"] == 16 + assert config["lora_alpha"] == 32 + + +class TestLoRAAdapterCheckpointDetection: + """Test is_lora_adapter_checkpoint function.""" + + def test_detects_valid_adapter(self, tmp_path): + """Test detection of valid PEFT adapter checkpoint.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config = {"peft_type": "LORA", "r": 8, "lora_alpha": 16} + config_path = tmp_path / "adapter_config.json" + with open(config_path, "w") as f: + json.dump(config, f) + + assert is_lora_adapter_checkpoint(str(tmp_path)) + + def test_rejects_missing_config(self, tmp_path): + """Test rejection when no adapter_config.json exists.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + def test_rejects_non_lora_config(self, tmp_path): + """Test rejection of non-LoRA adapter type.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config = {"peft_type": "PREFIX_TUNING"} + config_path = tmp_path / "adapter_config.json" + with open(config_path, "w") as f: + json.dump(config, f) + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + def test_handles_invalid_json(self, tmp_path): + """Test handling of invalid JSON config file.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config_path = tmp_path / "adapter_config.json" + config_path.write_text("not valid json {{{") + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + +class TestStateDictRoundTrip: + """Test state dict round-trip conversion with LoRA keys.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_lora_state_dict_roundtrip(self): + """Test that LoRA keys survive HF -> Archon -> HF round-trip.""" + hf_state = { + "model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + "model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(64, 8), + } + + # HF -> Archon + archon_state = self.adapter.from_hf(hf_state) + assert "layers.0.attention.wq.lora_a.weight" in archon_state + assert "layers.0.attention.wv.lora_b.weight" in archon_state + + # Archon -> HF + hf_state_back = self.adapter.to_hf(archon_state) + + assert set(hf_state.keys()) == set(hf_state_back.keys()) + for key in hf_state: + assert torch.allclose(hf_state[key], hf_state_back[key]) + + def test_mixed_base_and_lora_roundtrip(self): + """Test round-trip with both base and LoRA keys.""" + hf_state = { + "model.layers.0.self_attn.q_proj.weight": torch.randn(64, 64), + "model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + "model.norm.weight": torch.randn(64), + } + + archon_state = self.adapter.from_hf(hf_state) + hf_state_back = self.adapter.to_hf(archon_state) + + assert set(hf_state.keys()) == set(hf_state_back.keys()) diff --git a/tests/experimental/archon/test_lora_linear.py b/tests/experimental/archon/test_lora_linear.py new file mode 100644 index 0000000000..0cf7988144 --- /dev/null +++ b/tests/experimental/archon/test_lora_linear.py @@ -0,0 +1,528 @@ +"""Unit tests for LoRALinear module and adapter utilities.""" + +import pytest +import torch +import torch.nn as nn + +from areal.experimental.models.archon.lora.adapter import ( + AdapterModule, + disable_adapter, + enable_adapter, + get_adapter_params, + get_adapter_state_dict, + set_trainable_params, +) +from areal.experimental.models.archon.lora.lora_linear import LoRALinear + +# Try to import PEFT's LoRA Linear module for comparison tests +try: + from peft.tuners.lora import Linear as PEFTLoRALinear + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +class TestLoRALinear: + """Test LoRALinear module functionality.""" + + def test_initialization(self): + """Test that LoRALinear is properly initialized with zero LoRA contribution.""" + torch.manual_seed(42) + + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + # lora_b initialized to zeros + assert torch.allclose( + lora_linear.lora_b.weight, torch.zeros_like(lora_linear.lora_b.weight) + ), "lora_b should be initialized to zeros" + + # Initial forward matches base-only output + x = torch.randn(2, 10, 64) + with torch.no_grad(): + out_with_lora = lora_linear(x) + base_out = torch.nn.functional.linear( + x, lora_linear.weight, lora_linear.bias + ) + assert torch.allclose(out_with_lora, base_out, atol=1e-6), ( + "Initial output should match base output (zero LoRA contribution)" + ) + + def test_forward_pass(self): + """Test that forward pass correctly implements LoRA computation.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 64, 32, 8 + alpha = 16.0 + lora_linear = LoRALinear( + in_dim=in_dim, out_dim=out_dim, rank=rank, alpha=alpha, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.weight.fill_(0.1) + lora_linear.lora_a.weight.fill_(0.2) + lora_linear.lora_b.weight.fill_(0.3) + + x = torch.randn(2, 10, in_dim) + output = lora_linear(x) + + # Manual calculation + base_out = torch.nn.functional.linear(x, lora_linear.weight, lora_linear.bias) + lora_a_out = torch.nn.functional.linear(x, lora_linear.lora_a.weight) + lora_b_out = torch.nn.functional.linear(lora_a_out, lora_linear.lora_b.weight) + scaling = alpha / rank + expected_output = base_out + scaling * lora_b_out + + assert torch.allclose(output, expected_output, atol=1e-5), ( + "Forward pass computation mismatch" + ) + + def test_gradient_flow(self): + """Test that only LoRA parameters receive gradients.""" + torch.manual_seed(42) + + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + # Freeze base weights + lora_linear.weight.requires_grad_(False) + if lora_linear.bias is not None: + lora_linear.bias.requires_grad_(False) + + # Set non-zero lora_b so gradients flow to lora_a + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + + x = torch.randn(2, 10, 64, requires_grad=True) + output = lora_linear(x) + loss = output.sum() + loss.backward() + + assert lora_linear.weight.grad is None, "Base weight should not have gradients" + if lora_linear.bias is not None: + assert lora_linear.bias.grad is None, "Bias should not have gradients" + + assert lora_linear.lora_a.weight.grad is not None, ( + "lora_a should have gradients" + ) + assert lora_linear.lora_b.weight.grad is not None, ( + "lora_b should have gradients" + ) + assert not torch.allclose( + lora_linear.lora_a.weight.grad, + torch.zeros_like(lora_linear.lora_a.weight.grad), + ), "lora_a gradients should be non-zero" + + def test_from_linear(self): + """Test conversion from nn.Linear to LoRALinear.""" + torch.manual_seed(42) + + linear = nn.Linear(64, 32, bias=True) + with torch.no_grad(): + linear.weight.fill_(0.5) + linear.bias.fill_(0.1) + + lora_linear = LoRALinear.from_linear(linear, rank=8, alpha=16.0) + + assert torch.allclose(lora_linear.weight, linear.weight), ( + "Base weight should match original linear weight" + ) + assert torch.allclose(lora_linear.bias, linear.bias), ( + "Bias should match original bias" + ) + assert lora_linear.in_dim == linear.in_features + assert lora_linear.out_dim == linear.out_features + + # Output with disabled LoRA matches original + x = torch.randn(2, 10, 64) + with torch.no_grad(): + lora_linear.disabled = True + lora_out = lora_linear(x) + linear_out = linear(x) + + assert torch.allclose(lora_out, linear_out, atol=1e-6), ( + "Output with disabled LoRA should match original linear" + ) + + def test_adapter_params_protocol(self): + """Test AdapterModule protocol implementation.""" + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + assert isinstance(lora_linear, AdapterModule), ( + "LoRALinear should implement AdapterModule" + ) + + adapter_param_names = lora_linear.adapter_params() + assert adapter_param_names == [ + "lora_a.weight", + "lora_b.weight", + ], "adapter_params() should return LoRA parameter names" + + def test_disabled_flag(self): + """Test that disabled flag correctly disables LoRA contribution.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + lora_linear.disabled = False + out_enabled = lora_linear(x) + + lora_linear.disabled = True + out_disabled = lora_linear(x) + + base_out = torch.nn.functional.linear( + x, lora_linear.weight, lora_linear.bias + ) + + assert torch.allclose(out_disabled, base_out, atol=1e-6), ( + "Disabled LoRA should match base output" + ) + assert not torch.allclose(out_enabled, base_out, atol=1e-5), ( + "Enabled LoRA should differ from base output" + ) + + def test_repr(self): + """Test __repr__ output.""" + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.1, use_bias=True + ) + repr_str = repr(lora_linear) + assert "LoRALinear" in repr_str + assert "in_dim=64" in repr_str + assert "out_dim=32" in repr_str + assert "rank=8" in repr_str + + +class TestAdapterUtilities: + """Test adapter utility functions.""" + + def test_get_adapter_params(self): + """Test extraction of adapter parameters from model.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora1 = LoRALinear(64, 32, rank=8, alpha=16.0) + self.lora2 = LoRALinear(32, 16, rank=8, alpha=16.0) + self.linear = nn.Linear(16, 8) + + def forward(self, x): + return self.linear(self.lora2(self.lora1(x))) + + model = SimpleModel() + adapter_params = get_adapter_params(model) + + expected_keys = { + "lora1.lora_a.weight", + "lora1.lora_b.weight", + "lora2.lora_a.weight", + "lora2.lora_b.weight", + } + assert set(adapter_params.keys()) == expected_keys, ( + "Should extract only LoRA parameters" + ) + + for param in adapter_params.values(): + assert isinstance(param, nn.Parameter), "Should return Parameter objects" + + def test_set_trainable_params(self): + """Test freezing/unfreezing parameters.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora = LoRALinear(64, 32, rank=8, alpha=16.0) + self.linear = nn.Linear(32, 16) + + model = SimpleModel() + adapter_param_names = set(get_adapter_params(model).keys()) + set_trainable_params(model, adapter_param_names) + + for name, param in model.named_parameters(): + if name in adapter_param_names: + assert param.requires_grad, f"{name} should be trainable" + else: + assert not param.requires_grad, f"{name} should be frozen" + + def test_get_adapter_state_dict(self): + """Test filtering state dict to adapter parameters only.""" + state_dict = { + "model.weight": torch.randn(10, 10), + "model.bias": torch.randn(10), + "model.lora_a.weight": torch.randn(8, 10), + "model.lora_b.weight": torch.randn(10, 8), + "other.lora_a.weight": torch.randn(8, 10), + "other.lora_b.weight": torch.randn(10, 8), + } + + adapter_state_dict = get_adapter_state_dict(state_dict) + + expected_keys = { + "model.lora_a.weight", + "model.lora_b.weight", + "other.lora_a.weight", + "other.lora_b.weight", + } + assert set(adapter_state_dict.keys()) == expected_keys, ( + "Should filter to LoRA params only" + ) + + def test_disable_enable_adapter(self): + """Test disabling/enabling adapters in model.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora1 = LoRALinear(64, 32, rank=8, alpha=16.0) + self.lora2 = LoRALinear(32, 16, rank=8, alpha=16.0) + + model = SimpleModel() + + assert not model.lora1.disabled + assert not model.lora2.disabled + + disable_adapter(model) + assert model.lora1.disabled + assert model.lora2.disabled + + enable_adapter(model) + assert not model.lora1.disabled + assert not model.lora2.disabled + + +class TestLoRALinearWithBias: + """Test LoRALinear with bias enabled.""" + + def test_bias_initialization(self): + """Test bias is properly initialized.""" + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, use_bias=True + ) + + assert lora_linear.bias is not None, "Bias should exist" + assert torch.allclose(lora_linear.bias, torch.zeros_like(lora_linear.bias)), ( + "Bias should be initialized to zeros" + ) + + def test_bias_forward(self): + """Test forward pass with bias.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, use_bias=True, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.weight.fill_(0.1) + lora_linear.bias.fill_(0.5) + lora_linear.lora_a.weight.fill_(0.2) + lora_linear.lora_b.weight.fill_(0.3) + + x = torch.randn(2, 10, 64) + output = lora_linear(x) + + base_out = torch.nn.functional.linear(x, lora_linear.weight, lora_linear.bias) + lora_a_out = torch.nn.functional.linear(x, lora_linear.lora_a.weight) + lora_b_out = torch.nn.functional.linear(lora_a_out, lora_linear.lora_b.weight) + scaling = lora_linear.alpha / lora_linear.rank + expected_output = base_out + scaling * lora_b_out + + assert torch.allclose(output, expected_output, atol=1e-5), ( + "Forward with bias should match manual calc" + ) + + +class TestLoRALinearDropout: + """Test LoRALinear with dropout.""" + + def test_dropout_training_mode(self): + """Test that dropout is active in training mode.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.5 + ) + lora_linear.train() + + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + lora_linear.lora_a.weight.fill_(0.1) + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + out1 = lora_linear(x) + out2 = lora_linear(x) + + assert not torch.allclose(out1, out2, atol=1e-5), ( + "Dropout should cause different outputs" + ) + + def test_dropout_eval_mode(self): + """Test that dropout is disabled in eval mode.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.5 + ) + lora_linear.eval() + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + out1 = lora_linear(x) + out2 = lora_linear(x) + + assert torch.allclose(out1, out2, atol=1e-6), ( + "Eval mode should have deterministic output" + ) + + +@pytest.mark.skipif(not PEFT_AVAILABLE, reason="PEFT not installed") +class TestPEFTCompatibility: + """Test compatibility with PEFT library implementation.""" + + def test_forward_pass_vs_peft(self): + """Compare forward pass output with PEFT's LoRA Linear module.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 128, 128, 16 + alpha = 32.0 + adapter_name = "default" + + base_linear = nn.Linear(in_dim, out_dim, bias=False) + + our_lora = LoRALinear.from_linear( + base_linear, rank=rank, alpha=alpha, dropout=0.0 + ) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(in_dim, out_dim, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + peft_lora.base_layer.weight.data.copy_(base_linear.weight.data) + + peft_lora_a = peft_lora.lora_A[adapter_name].weight + peft_lora_b = peft_lora.lora_B[adapter_name].weight + + with torch.no_grad(): + our_lora.lora_a.weight.copy_(peft_lora_a) + our_lora.lora_b.weight.copy_(peft_lora_b) + + x = torch.randn(2, 10, in_dim) + + with torch.no_grad(): + our_output = our_lora(x) + peft_output = peft_lora(x) + + max_diff = (our_output - peft_output).abs().max().item() + assert torch.allclose(our_output, peft_output, atol=1e-5), ( + f"Output mismatch vs PEFT: max diff = {max_diff}" + ) + + def test_gradient_flow_vs_peft(self): + """Compare gradient flow with PEFT's implementation.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 64, 64, 8 + alpha = 16.0 + adapter_name = "default" + + base_linear = nn.Linear(in_dim, out_dim, bias=False) + + our_lora = LoRALinear.from_linear( + base_linear, rank=rank, alpha=alpha, dropout=0.0 + ) + our_lora.weight.requires_grad_(False) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(in_dim, out_dim, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + peft_lora.base_layer.weight.data.copy_(base_linear.weight.data) + peft_lora.base_layer.weight.requires_grad_(False) + + peft_lora_a = peft_lora.lora_A[adapter_name].weight + peft_lora_b = peft_lora.lora_B[adapter_name].weight + + with torch.no_grad(): + our_lora.lora_a.weight.copy_(peft_lora_a) + our_lora.lora_b.weight.copy_(peft_lora_b) + our_lora.lora_b.weight.fill_(0.1) + peft_lora_b.fill_(0.1) + + x = torch.randn(2, 10, in_dim) + + our_output = our_lora(x) + our_output.sum().backward() + + peft_output = peft_lora(x) + peft_output.sum().backward() + + assert torch.allclose( + our_lora.lora_a.weight.grad, peft_lora_a.grad, atol=1e-4 + ), "lora_a gradient mismatch vs PEFT" + assert torch.allclose( + our_lora.lora_b.weight.grad, peft_lora_b.grad, atol=1e-4 + ), "lora_b gradient mismatch vs PEFT" + + def test_scaling_factor_vs_peft(self): + """Verify scaling factor matches PEFT's implementation.""" + rank = 16 + alpha = 32.0 + adapter_name = "default" + + our_lora = LoRALinear(in_dim=64, out_dim=64, rank=rank, alpha=alpha) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(64, 64, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + + expected_scaling = alpha / rank + peft_scaling = peft_lora.scaling[adapter_name] + + assert our_lora.scaling == expected_scaling + assert our_lora.scaling == peft_scaling + + def test_initialization_vs_peft(self): + """Compare initialization strategy with PEFT.""" + torch.manual_seed(42) + + our_lora = LoRALinear(in_dim=64, out_dim=64, rank=16, alpha=32.0) + + # lora_b zeros, lora_a non-zero + assert torch.allclose( + our_lora.lora_b.weight, torch.zeros_like(our_lora.lora_b.weight) + ), "lora_b should be zeros (PEFT convention)" + + assert not torch.allclose( + our_lora.lora_a.weight, torch.zeros_like(our_lora.lora_a.weight) + ), "lora_a should be non-zero (kaiming_uniform)" + + x = torch.randn(2, 10, 64) + with torch.no_grad(): + output = our_lora(x) + base_output = torch.nn.functional.linear(x, our_lora.weight, our_lora.bias) + assert torch.allclose(output, base_output, atol=1e-6), ( + "Initial output should match base (PEFT convention)" + )