From a4a532a2ca1f9f6e0cf9fce3fceac2da1854cb9f Mon Sep 17 00:00:00 2001 From: MikaStars39 Date: Tue, 17 Mar 2026 01:28:59 +0800 Subject: [PATCH] feat(archon): add LoRA infrastructure with FSDP2-DTensor deadlock fix --- areal/engine/fsdp_utils/grad.py | 21 +- .../experimental/engine/archon_checkpoint.py | 16 +- areal/experimental/engine/archon_engine.py | 165 +++++- .../engine/archon_lora_checkpoint.py | 257 +++++++++ areal/experimental/engine/archon_runner.py | 4 +- .../experimental/engine/archon_weight_sync.py | 23 +- areal/experimental/models/archon/base.py | 40 ++ .../models/archon/lora/__init__.py | 27 + .../models/archon/lora/adapter.py | 116 ++++ .../models/archon/lora/lora_linear.py | 327 +++++++++++ .../models/archon/qwen2/infra/parallelize.py | 6 + .../archon/qwen2/model/state_dict_adapter.py | 32 ++ .../models/archon/qwen3/infra/parallelize.py | 6 + areal/infra/remote_inf_engine.py | 35 +- areal/trainer/ppo/actor.py | 3 +- areal/trainer/rl_trainer.py | 1 - areal/utils/logging.py | 1 + .../archon/test_archon_engine_lora.py | 226 ++++++++ .../archon/test_archon_lora_checkpoint.py | 259 +++++++++ tests/experimental/archon/test_lora_linear.py | 528 ++++++++++++++++++ 20 files changed, 2079 insertions(+), 14 deletions(-) create mode 100644 areal/experimental/engine/archon_lora_checkpoint.py create mode 100644 areal/experimental/models/archon/lora/__init__.py create mode 100644 areal/experimental/models/archon/lora/adapter.py create mode 100644 areal/experimental/models/archon/lora/lora_linear.py create mode 100644 tests/experimental/archon/test_archon_engine_lora.py create mode 100644 tests/experimental/archon/test_archon_lora_checkpoint.py create mode 100644 tests/experimental/archon/test_lora_linear.py diff --git a/areal/engine/fsdp_utils/grad.py b/areal/engine/fsdp_utils/grad.py index 48fd5c6e7c..9c5b0eb37c 100644 --- a/areal/engine/fsdp_utils/grad.py +++ b/areal/engine/fsdp_utils/grad.py @@ -99,11 +99,26 @@ def get_grad_norm_fp32( norm_type = float(norm_type) total_norm = 0.0 - if not grads_for_norm: - return 0.0 - device = current_platform.current_device() + if not grads_for_norm: + # Still participate in all_reduce with zero contribution so that + # ranks with grads don't hang waiting for this rank. + total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device) + reduce_op = ( + dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM + ) + if data_parallel_group: + dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group) + if model_parallel_group is not None: + dist.all_reduce( + total_norm_cuda, op=reduce_op, group=model_parallel_group + ) + total_norm = float(total_norm_cuda.item()) + if norm_type != torch.inf and total_norm > 0: + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + if norm_type == torch.inf: norms = [grad.abs().max() for grad in grads_for_norm] total_norm = torch.max(torch.stack(norms)) if norms else 0.0 diff --git a/areal/experimental/engine/archon_checkpoint.py b/areal/experimental/engine/archon_checkpoint.py index d55e83b93d..9975f22df6 100644 --- a/areal/experimental/engine/archon_checkpoint.py +++ b/areal/experimental/engine/archon_checkpoint.py @@ -337,6 +337,19 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None: # Convert to HF format to match checkpoint keys hf_state_dict = engine.state_dict_adapter.to_hf(state_dict) + # LoRA adapter parameters don't exist in the base HF checkpoint. + # Strip them before calling dcp.load() so it won't raise on missing keys. + lora_archon_keys: set[str] = set() + if engine.lora_config is not None: + lora_hf_keys = { + k for k in hf_state_dict if ".lora_A." in k or ".lora_B." in k + } + lora_archon_keys = { + k for k in state_dict if ".lora_a." in k or ".lora_b." in k + } + for k in lora_hf_keys: + del hf_state_dict[k] + # PP mode + weight tying fix: last stage needs embed_tokens weight for output layer # When tie_word_embeddings=True, HF checkpoint only stores embed_tokens.weight, # not lm_head.weight. In PP mode, last stage has output.weight but no tok_embeddings, @@ -377,10 +390,11 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None: # Filter known expected missing keys expected_missing = set() for key in list(missing_keys): - # rotary_emb is computed at runtime, not stored in checkpoint if "rotary_emb" in key: expected_missing.add(key) missing_keys -= expected_missing + # LoRA adapter keys are initialised separately, not loaded from base ckpt + missing_keys -= lora_archon_keys if dist.get_rank() == 0: if missing_keys: diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index d7df9879e1..34c76b2db9 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -190,6 +190,25 @@ def __init__(self, config: TrainEngineConfig): self._initialized = False self.is_offload = False + # LoRA Configuration (extract from config if enabled) + self.lora_config = None + if hasattr(config, "use_lora") and config.use_lora: + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + enabled: bool + rank: int + alpha: float + target_modules: list[str] + + self.lora_config = LoRAConfig( + enabled=True, + rank=config.lora_rank, + alpha=float(config.lora_alpha), + target_modules=config.target_modules if config.target_modules else [], + ) + def create_process_group( self, parallel_strategy: ParallelStrategy | None = None, @@ -319,6 +338,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): ) self._materialize_and_load_weights() + if self.lora_config is not None: + self._freeze_non_lora_params() self._create_optimizer(ft_spec) self.runner = create_runner( @@ -486,6 +507,16 @@ def process_output( self.forward_backward_batch(mb_list, process_output, forward_only=False) + if self.lora_config is not None: + from areal.experimental.models.archon.lora.lora_linear import ( + sync_lora_grads, + ) + sync_lora_grads( + self.model, + tp_group=self._tp_group, + dp_group=self.data_parallel_group, + ) + return self.optimizer_step() @torch.no_grad() @@ -659,7 +690,19 @@ def update_weights(self, meta: WeightUpdateMeta): ) def save(self, meta: SaveLoadMeta): - """Save model in HuggingFace or DCP format.""" + """Save model in HuggingFace or DCP format. + + When LoRA is enabled, only the adapter weights are saved in PEFT format. + When LoRA is disabled, the full model is saved. + """ + if self.lora_config is not None: + from areal.experimental.engine.archon_lora_checkpoint import ( + save_lora_adapter, + ) + + save_lora_adapter(self, meta.path, meta.base_model_path) + return + if meta.weight_format == "hf": save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor) elif meta.weight_format == "dcp": @@ -671,7 +714,20 @@ def save(self, meta: SaveLoadMeta): save_optimizer_state(self, meta.path) def load(self, meta: SaveLoadMeta): - """Load model from HuggingFace or DCP format.""" + """Load model from HuggingFace or DCP format. + + When LoRA is enabled and the checkpoint is a PEFT adapter, + only adapter weights are loaded. + """ + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + load_lora_adapter, + ) + + if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): + load_lora_adapter(self, meta.path) + return + if meta.weight_format == "hf": load_model_from_hf(self, meta.path) elif meta.weight_format == "dcp": @@ -790,6 +846,7 @@ def _apply_pipeline_parallelism( reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy, ac_config=ac_config, enable_compile=enable_compile, + apply_lora_fn=self._apply_lora if self.lora_config is not None else None, ) # Delete original model to free memory @@ -828,6 +885,7 @@ def _apply_parallelism( reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy, ac_config=ac_config, enable_compile=enable_compile, + apply_lora_fn=self._apply_lora if self.lora_config is not None else None, ) self.model_parts = [self.model] @@ -929,8 +987,109 @@ def _create_state_dict_adapter(self) -> BaseStateDictAdapter | None: self.model_config, hf_assets_path=self.config.path ) + def _apply_lora(self, module: nn.Module | None = None) -> None: + from areal.experimental.models.archon.lora import ( + LoRALinear, + get_adapter_params, + ) + + assert self.lora_config is not None + module = self.model if module is None else module + + target_modules = set(self.lora_config.target_modules) + apply_to_all_linears = "all-linear" in target_modules + peft_name_map = ( + self.state_dict_adapter.to_peft_module_map + if self.state_dict_adapter is not None + else {} + ) + replaced_modules: list[str] = [] + + def replace_linear_modules(parent_module: nn.Module, prefix: str = "") -> None: + for child_name, child in list(parent_module.named_children()): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + if isinstance(child, nn.Linear): + peft_name = peft_name_map.get(child_name) + if ( + not apply_to_all_linears + and child_name not in target_modules + and peft_name not in target_modules + ): + continue + + lora_mod = LoRALinear.from_linear( + child, + rank=self.lora_config.rank, + alpha=self.lora_config.alpha, + ) + lora_mod._debug_name = child_prefix + setattr(parent_module, child_name, lora_mod) + replaced_modules.append(child_prefix) + continue + + replace_linear_modules(child, child_prefix) + + replace_linear_modules(module) + + adapter_params = get_adapter_params(module) + + if replaced_modules: + self.logger.info( + f"Applied LoRA to {len(replaced_modules)} linear modules and created " + f"{len(adapter_params)} adapter parameters" + ) + + def _freeze_non_lora_params(self) -> None: + from areal.experimental.models.archon.lora import ( + LoRALinear, + get_adapter_params, + set_trainable_params, + ) + + adapter_param_count = 0 + for model in self.model_parts: + # LoRA weights are plain tensors created on meta device during + # model structure creation. FSDP2 only materialises + # nn.Parameters, so we must move LoRA tensors ourselves. + for module in model.modules(): + if isinstance(module, LoRALinear): + module.materialize_lora(self.device) + + adapter_params = get_adapter_params(model) + if not adapter_params: + continue + + # Re-initialize lora_a (kaiming) and lora_b (zeros) so the + # initial LoRA contribution is exactly zero. + with torch.no_grad(): + for name, tensor in adapter_params.items(): + if "lora_b" in name: + nn.init.zeros_(tensor) + elif "lora_a" in name: + nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) + + adapter_param_count += len(adapter_params) + set_trainable_params(model, set(adapter_params.keys())) + + if adapter_param_count == 0: + raise RuntimeError( + "LoRA is enabled but no adapter parameters were found after weight loading." + ) + + self.logger.info( + f"Froze base weights and kept {adapter_param_count} adapter parameters trainable" + ) + def _get_all_parameters(self) -> list[nn.Parameter]: - return [p for m in self.model_parts for p in m.parameters()] + params = [p for m in self.model_parts for p in m.parameters()] + if self.lora_config is not None: + from areal.experimental.models.archon.lora import LoRALinear + + for m in self.model_parts: + for module in m.modules(): + if isinstance(module, LoRALinear): + params.extend(module.lora_parameters()) + return params def _get_model_name_parameters(self) -> Iterator[tuple[str, nn.Parameter]]: for m in self.model_parts: diff --git a/areal/experimental/engine/archon_lora_checkpoint.py b/areal/experimental/engine/archon_lora_checkpoint.py new file mode 100644 index 0000000000..414541f68c --- /dev/null +++ b/areal/experimental/engine/archon_lora_checkpoint.py @@ -0,0 +1,257 @@ +"""LoRA adapter checkpoint I/O in PEFT format. + +This module provides functions to save and load LoRA adapters in PEFT-compatible +format for HuggingFace ecosystem interoperability. + +PEFT checkpoint structure: + adapter_checkpoint/ + ├── adapter_model.safetensors # LoRA weights only + └── adapter_config.json # PEFT configuration + +Reference: peft/src/peft/utils/save_and_load.py +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist +from safetensors.torch import load_file, save_file + +from areal.experimental.models.archon.lora.adapter import get_adapter_params +from areal.utils import logging + +if TYPE_CHECKING: + from areal.experimental.engine.archon_engine import ArchonEngine + +logger = logging.getLogger("LoRACheckpoint") + + +def save_lora_adapter( + engine: "ArchonEngine", + path: str, + base_model_path: str | None = None, +) -> None: + """Save LoRA adapter in PEFT format. + + Creates two files: + - adapter_model.safetensors: LoRA weights (lora_a, lora_b) + - adapter_config.json: PEFT configuration + + Args: + engine: ArchonEngine instance with LoRA-enabled model + path: Directory path to save adapter checkpoint + base_model_path: Optional path to base model (for config reference) + + Raises: + RuntimeError: If LoRA is not enabled on engine + """ + if engine.lora_config is None: + raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine") + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + if rank == 0: + os.makedirs(path, exist_ok=True) + logger.info(f"Saving LoRA adapter to {path}") + + # Extract adapter parameters from model + adapter_params = get_adapter_params(engine.model) + + if not adapter_params: + logger.warning("No adapter parameters found in model") + if rank == 0: + logger.warning("Creating empty adapter checkpoint") + + # After FSDP2, adapter params are DTensors (sharded). Gather them + # into plain CPU tensors so safetensors can serialise them. + from torch.distributed.tensor import DTensor + + archon_state = {} + for k, v in adapter_params.items(): + if isinstance(v, DTensor): + v = v.full_tensor() + # AC and torch.compile insert wrapper prefixes into the FQN + # (e.g. "._checkpoint_wrapped_module", "._orig_mod"). + # Strip them so the key converter can recognise the names. + k = k.replace("._checkpoint_wrapped_module", "") + k = k.replace("._orig_mod", "") + archon_state[k] = v.detach().cpu().clone() + hf_state = engine.state_dict_adapter.to_hf(archon_state) + + # Add PEFT prefix: base_model.model.{key} + peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()} + + # Save weights (only rank 0) + if rank == 0: + weights_path = os.path.join(path, "adapter_model.safetensors") + save_file(peft_state, weights_path) + logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}") + + # Determine target modules from actual adapter parameters + target_modules = set() + for key in adapter_params: + parts = key.split(".") + for i, part in enumerate(parts): + is_lora = part in ("lora_a", "lora_b") or part.startswith("_lora_") + if is_lora and i > 0: + module_name = parts[i - 1] + target_modules.add(module_name) + break + + # Create config copy with actual target modules + from dataclasses import replace + + lora_config_for_save = replace( + engine.lora_config, target_modules=sorted(target_modules) + ) + + # Generate adapter config using model-specific state dict adapter + adapter_config = engine.state_dict_adapter.create_peft_adapter_config( + lora_config=lora_config_for_save, + base_model_path=base_model_path, + ) + + config_path = os.path.join(path, "adapter_config.json") + with open(config_path, "w") as f: + json.dump(adapter_config, f, indent=2) + logger.info(f"Saved adapter config to {config_path}") + + # Synchronize all ranks (use gloo-based cpu_group, matching save_model_to_hf + # and update_weights_from_disk; avoids potential NCCL barrier issues) + if dist.is_initialized(): + dist.barrier(group=engine.cpu_group) + + +def load_lora_adapter( + engine: "ArchonEngine", + path: str, + strict: bool = True, +) -> None: + """Load LoRA adapter from PEFT format checkpoint. + + Args: + engine: ArchonEngine instance with LoRA-enabled model + path: Directory path containing adapter checkpoint + strict: If True, raise error on missing/unexpected keys + + Raises: + RuntimeError: If LoRA is not enabled on engine + FileNotFoundError: If adapter checkpoint files not found + ValueError: If strict=True and keys don't match + """ + if engine.lora_config is None: + raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine") + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + if rank == 0: + logger.info(f"Loading LoRA adapter from {path}") + + # Load adapter weights + weights_path = os.path.join(path, "adapter_model.safetensors") + if not os.path.exists(weights_path): + # Fallback to .bin format + weights_path = os.path.join(path, "adapter_model.bin") + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"Adapter weights not found at {path}. " + "Expected adapter_model.safetensors or adapter_model.bin" + ) + peft_state = torch.load(weights_path, map_location="cpu", weights_only=True) + else: + peft_state = load_file(weights_path) + + if rank == 0: + logger.info(f"Loaded {len(peft_state)} adapter tensors from {weights_path}") + + # Strip PEFT prefix: base_model.model.{key} -> {key} + hf_state = {} + for key, value in peft_state.items(): + if key.startswith("base_model.model."): + hf_key = key.replace("base_model.model.", "", 1) + hf_state[hf_key] = value + else: + hf_state[key] = value + + # Convert from HF format to Archon format + archon_state = engine.state_dict_adapter.from_hf(hf_state) + + # Get expected adapter keys from model + expected_adapter_params = get_adapter_params(engine.model) + expected_keys = set(expected_adapter_params.keys()) + loaded_keys = set(archon_state.keys()) + + missing_keys = expected_keys - loaded_keys + unexpected_keys = loaded_keys - expected_keys + + if missing_keys or unexpected_keys: + if strict: + error_msg = [] + if missing_keys: + error_msg.append(f"Missing keys: {sorted(missing_keys)[:5]}...") + if unexpected_keys: + error_msg.append(f"Unexpected keys: {sorted(unexpected_keys)[:5]}...") + raise ValueError( + "Adapter checkpoint keys don't match model. " + " ".join(error_msg) + ) + else: + if missing_keys and rank == 0: + logger.warning( + f"Missing {len(missing_keys)} adapter keys: " + f"{sorted(missing_keys)[:5]}..." + ) + if unexpected_keys and rank == 0: + logger.warning( + f"Unexpected {len(unexpected_keys)} adapter keys: " + f"{sorted(unexpected_keys)[:5]}..." + ) + + # Load adapter weights into model + loaded_count = 0 + for key, value in archon_state.items(): + if key in expected_adapter_params: + param = expected_adapter_params[key] + value = value.to(device=param.device, dtype=param.dtype) + param.data.copy_(value) + loaded_count += 1 + + if rank == 0: + logger.info(f"Loaded {loaded_count} adapter parameters into model") + + if dist.is_initialized() and hasattr(engine, "cpu_group"): + dist.barrier(group=engine.cpu_group) + elif dist.is_initialized(): + dist.barrier() + + +def is_lora_adapter_checkpoint(path: str) -> bool: + """Check if path contains a PEFT LoRA adapter checkpoint. + + Args: + path: Directory path to check + + Returns: + True if path contains adapter_config.json with peft_type="LORA" + """ + config_path = Path(path) / "adapter_config.json" + + if not config_path.exists(): + return False + + try: + with open(config_path) as f: + config = json.load(f) + return config.get("peft_type") == "LORA" + except (OSError, json.JSONDecodeError): + return False diff --git a/areal/experimental/engine/archon_runner.py b/areal/experimental/engine/archon_runner.py index f50ce0a264..d25d9f3d36 100644 --- a/areal/experimental/engine/archon_runner.py +++ b/areal/experimental/engine/archon_runner.py @@ -72,8 +72,9 @@ def run( forward_only: bool, ) -> list[torch.Tensor | dict[int, torch.Tensor]]: results: list[torch.Tensor | dict[int, torch.Tensor]] = [] + total_mbs = len(mb_list) - for mb_item in mb_list: + for mb_idx, mb_item in enumerate(mb_list): inputs, ctx = self.prepare_inputs_fn(mb_item) tree_attn_meta = None @@ -108,7 +109,6 @@ def run( if result is not None: if forward_only: - # Result can be a tensor or dict (for tree training) if isinstance(result, dict): results.append({k: v.detach() for k, v in result.items()}) else: diff --git a/areal/experimental/engine/archon_weight_sync.py b/areal/experimental/engine/archon_weight_sync.py index 7470445756..e70c0704a0 100644 --- a/areal/experimental/engine/archon_weight_sync.py +++ b/areal/experimental/engine/archon_weight_sync.py @@ -25,6 +25,9 @@ from areal.experimental.engine.archon_engine import ArchonEngine +WEIGHT_UPDATE_READY_FILE = ".areal_weight_update_ready" + + class WeightSyncState: """State container for weight synchronization. @@ -220,16 +223,32 @@ def update_weights_from_disk( fut = engine.rollout_engine.update_weights_from_disk(meta) assert meta.path is not None - save_model_to_hf(engine, meta.path, engine.tokenizer, None) + if engine.lora_config is not None: + from areal.experimental.engine.archon_lora_checkpoint import save_lora_adapter + + save_lora_adapter( + engine, + meta.path, + meta.base_model_name or engine.config.path, + ) + else: + save_model_to_hf(engine, meta.path, engine.tokenizer, None) if dist.get_rank() == 0: + ready_path = os.path.join(meta.path, WEIGHT_UPDATE_READY_FILE) + ready_tmp_path = ready_path + ".tmp" + ready_timestamp = str(datetime.now().timestamp()) + with open(ready_tmp_path, "w") as f: + f.write(ready_timestamp) + os.replace(ready_tmp_path, ready_path) + update_name = names.update_weights_from_disk( engine.config.experiment_name, engine.config.trial_name, engine.get_version(), ) name_resolve.add( - update_name, str(datetime.now().timestamp()), keepalive_ttl=120 + update_name, ready_timestamp, keepalive_ttl=600 ) assert fut is not None diff --git a/areal/experimental/models/archon/base.py b/areal/experimental/models/archon/base.py index fd8978e203..d05f6eab5a 100644 --- a/areal/experimental/models/archon/base.py +++ b/areal/experimental/models/archon/base.py @@ -58,6 +58,9 @@ def __init__( self.to_hf_map: dict[str, str] = {} self.hf_assets_path = hf_assets_path self.fqn_to_index_mapping = None + # Model-specific mapping from Archon module names to PEFT module names + # Subclasses should define this mapping for LoRA adapter config generation + self.to_peft_module_map: dict[str, str] = {} if hf_assets_path: self._load_safetensors_index(hf_assets_path) @@ -115,6 +118,43 @@ def convert_single_to_hf( self, name: str, tensor: torch.Tensor ) -> list[tuple[str, torch.Tensor]]: ... + def create_peft_adapter_config( + self, lora_config: Any, base_model_path: str | None = None + ) -> dict: + """Generate adapter_config.json for PEFT format checkpoint. + + Args: + lora_config: LoRA configuration object with rank, alpha, target_modules + base_model_path: Optional path to base model + + Returns: + Dictionary containing PEFT adapter configuration + """ + # Convert Archon module names to PEFT names using model-specific mapping + peft_target_modules = [ + self.to_peft_module_map.get(name, name) + for name in lora_config.target_modules + ] + + return { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": base_model_path or "", + "revision": None, + "task_type": "CAUSAL_LM", + "inference_mode": False, + "r": lora_config.rank, + "lora_alpha": int(lora_config.alpha), + "lora_dropout": 0.0, + "target_modules": peft_target_modules, + "fan_in_fan_out": False, + "bias": "none", + "modules_to_save": None, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + } + class BaseArchonModel(nn.Module, ABC): """Base class for Archon models.""" diff --git a/areal/experimental/models/archon/lora/__init__.py b/areal/experimental/models/archon/lora/__init__.py new file mode 100644 index 0000000000..66eca61f50 --- /dev/null +++ b/areal/experimental/models/archon/lora/__init__.py @@ -0,0 +1,27 @@ +"""LoRA (Low-Rank Adaptation) infrastructure for Archon engine. + +Following torchtune's design patterns for FSDP2 compatibility. +This module provides custom LoRALinear implementation and utilities +for parameter-efficient fine-tuning of large language models. +""" + +from areal.experimental.models.archon.lora.adapter import ( + AdapterModule, + disable_adapter, + enable_adapter, + get_adapter_params, + get_adapter_state_dict, + set_trainable_params, +) +from areal.experimental.models.archon.lora.lora_linear import LoRALinear, sync_lora_grads + +__all__ = [ + "LoRALinear", + "AdapterModule", + "get_adapter_params", + "get_adapter_state_dict", + "set_trainable_params", + "disable_adapter", + "enable_adapter", + "sync_lora_grads", +] diff --git a/areal/experimental/models/archon/lora/adapter.py b/areal/experimental/models/archon/lora/adapter.py new file mode 100644 index 0000000000..b2f5af4f20 --- /dev/null +++ b/areal/experimental/models/archon/lora/adapter.py @@ -0,0 +1,116 @@ +"""AdapterModule protocol and utilities for LoRA parameter management. + +Reference: torchtune/torchtune/modules/peft/_utils.py +Provides utilities for extracting, filtering, and managing adapter parameters. +""" + +from typing import Protocol, runtime_checkable + +import torch +import torch.nn as nn + + +@runtime_checkable +class AdapterModule(Protocol): + """Protocol for modules that contain adapter parameters. + + Any module implementing this protocol should provide an adapter_params() + method that returns a list of parameter names (relative to the module) + that should be treated as trainable adapters. + """ + + def adapter_params(self) -> list[str]: + """Return list of adapter parameter names relative to this module. + + Returns: + List of parameter names (e.g., ["lora_a.weight", "lora_b.weight"]) + """ + ... + + +def get_adapter_params(model: nn.Module) -> dict[str, torch.Tensor]: + """Extract all adapter parameters from model using AdapterModule protocol. + + Walks through all modules in the model and collects adapter tensors. + Supports both ``nn.Parameter`` attributes (found via ``named_parameters``) + and plain tensor attributes stored via ``object.__setattr__`` (which are + invisible to ``nn.Module`` tracking and therefore to FSDP2). + + Args: + model: Model to extract adapter parameters from + + Returns: + Dictionary mapping fully-qualified names to tensors + """ + adapter_params: dict[str, torch.Tensor] = {} + + for module_name, module in model.named_modules(): + if isinstance(module, AdapterModule): + for attr_name in module.adapter_params(): + tensor = getattr(module, attr_name, None) + if tensor is not None and isinstance(tensor, torch.Tensor): + full_key = ( + f"{module_name}.{attr_name}" if module_name else attr_name + ) + adapter_params[full_key] = tensor + + return adapter_params + + +def set_trainable_params(model: nn.Module, adapter_param_names: set[str]) -> None: + """Freeze all nn.Parameters except those in *adapter_param_names*. + + Plain-tensor LoRA weights (stored via ``object.__setattr__``) are not + affected by this function – they always keep ``requires_grad=True``. + + Args: + model: Model to configure + adapter_param_names: Set of fully-qualified parameter names to keep trainable + """ + for name, param in model.named_parameters(): + param.requires_grad_(name in adapter_param_names) + + +def get_adapter_state_dict(state_dict: dict, device: str = "cpu") -> dict: + """Filter state dict to only adapter parameters. + + Args: + state_dict: Full model state dict + device: Device to move parameters to (default: "cpu") + + Returns: + Filtered state dict containing only adapter parameters + """ + + def is_adapter_key(k: str) -> bool: + return "lora_a" in k or "lora_b" in k + + return {k: v.to(device) for k, v in state_dict.items() if is_adapter_key(k)} + + +def disable_adapter(model: nn.Module) -> None: + """Disable LoRA adapters in all LoRALinear modules. + + Sets the ``disabled`` flag to True, causing forward passes to only use + the base weights. Useful for reference models in DPO/PPO. + + Args: + model: Model containing LoRALinear modules + """ + for module in model.modules(): + if isinstance(module, AdapterModule) and hasattr(module, "disabled"): + module.disabled = True + + +def enable_adapter(model: nn.Module) -> None: + """Enable LoRA adapters in all LoRALinear modules. + + Sets the ``disabled`` flag to False, enabling LoRA contributions + during forward passes. + + Args: + model: Model containing LoRALinear modules + """ + for module in model.modules(): + if isinstance(module, AdapterModule) and hasattr(module, "disabled"): + module.disabled = False diff --git a/areal/experimental/models/archon/lora/lora_linear.py b/areal/experimental/models/archon/lora/lora_linear.py new file mode 100644 index 0000000000..201b7d41f7 --- /dev/null +++ b/areal/experimental/models/archon/lora/lora_linear.py @@ -0,0 +1,327 @@ +"""LoRALinear module implementation following torchtune patterns. + +Reference: torchtune/torchtune/modules/peft/lora.py + +LoRA weights are stored as **plain tensors** (not ``nn.Parameter``) so that +FSDP2 does not register ``post_accumulate_grad_hook`` on them. This avoids +FSDP DP reduce-scatter operations interleaving with DTensor TP operations +during backward, which would otherwise create a diamond deadlock across the +TP and DP communicators. + +After backward, ``sync_lora_grads`` must be called to all-reduce LoRA +gradients across both TP and DP groups before the optimizer step. +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoRALinear(nn.Module): + """Linear layer with Low-Rank Adaptation (LoRA). + + LoRA decomposes weight updates into low-rank matrices A and B: + W' = W + (alpha/rank) * B @ A + + During forward pass: + output = x @ W^T + (alpha/rank) * x @ A^T @ B^T + + LoRA weights (_lora_a_weight, _lora_b_weight) are plain tensors stored + via ``object.__setattr__`` to keep them invisible to ``nn.Module`` + parameter/buffer tracking and therefore to FSDP2. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.scaling = alpha / rank + self.disabled = False + self._dropout_p = dropout + + self.weight = nn.Parameter(torch.empty(out_dim, in_dim)) + if use_bias: + self.bias = nn.Parameter(torch.empty(out_dim)) + else: + self.register_parameter("bias", None) + + _a = torch.empty(rank, in_dim) + _b = torch.empty(out_dim, rank) + _a.requires_grad_(True) + _b.requires_grad_(True) + object.__setattr__(self, "_lora_a_weight", _a) + object.__setattr__(self, "_lora_b_weight", _b) + + self._tp_enabled = False + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + nn.init.zeros_(self.bias) + nn.init.kaiming_uniform_(self._lora_a_weight, a=math.sqrt(5)) + nn.init.zeros_(self._lora_b_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + base_out = F.linear(x, self.weight, self.bias) + + if self.disabled: + return base_out + + if self._tp_enabled: + result = self._tp_lora_forward(x, base_out) + if result.requires_grad and hasattr(self, "_debug_name"): + _name = self._debug_name + + result.register_hook(lambda grad: grad) + return result + + h = F.dropout(x, p=self._dropout_p, training=self.training) + h = F.linear(h, self._lora_a_weight) + lora_out = F.linear(h, self._lora_b_weight) + return base_out + self.scaling * lora_out + + def _tp_lora_forward( + self, x: torch.Tensor, base_out: torch.Tensor + ) -> torch.Tensor: + """LoRA forward compatible with TP + FSDP2. + + 1. Input is DETACHED for the LoRA path so the input gradient comes + entirely from the base ``F.linear`` path (which handles TP). + 2. The LoRA output is wrapped as a DTensor with the SAME placements + as ``base_out`` and added in DTensor space, keeping the autograd + connection to ``base_out`` intact. + 3. LoRA weights are plain tensors; their gradients are synced after + backward via ``sync_lora_grads``. + """ + from torch.distributed.tensor import DTensor + + local_x = x._local_tensor.detach() if isinstance(x, DTensor) else x.detach() + h = F.dropout(local_x, p=self._dropout_p, training=self.training) + + if self._tp_style == "rowwise": + s = self._tp_rank * self._tp_local_in + lora_a_w = self._lora_a_weight[:, s : s + self._tp_local_in] + h = F.linear(h, lora_a_w) + else: + h = F.linear(h, self._lora_a_weight) + + lora_out = F.linear(h, self._lora_b_weight) + lora_out = self.scaling * lora_out + + if self._tp_style == "colwise": + s = self._tp_rank * self._tp_local_out + lora_out = lora_out[..., s : s + self._tp_local_out] + + if isinstance(base_out, DTensor): + lora_dtensor = DTensor.from_local( + lora_out, + base_out.device_mesh, + list(base_out.placements), + run_check=False, + ) + return base_out + lora_dtensor + + return base_out + lora_out + + # ------------------------------------------------------------------ + # LoRA weight access helpers + # ------------------------------------------------------------------ + + def lora_parameters(self) -> list[torch.Tensor]: + """Return the raw LoRA weight tensors (for the optimizer).""" + return [self._lora_a_weight, self._lora_b_weight] + + def materialize_lora(self, device: torch.device) -> None: + """Move LoRA weights from meta device to *device* and re-init.""" + if self._lora_a_weight.device.type == "meta": + a = torch.empty( + self._lora_a_weight.shape, + dtype=self._lora_a_weight.dtype, + device=device, + ).requires_grad_(True) + object.__setattr__(self, "_lora_a_weight", a) + if self._lora_b_weight.device.type == "meta": + b = torch.empty( + self._lora_b_weight.shape, + dtype=self._lora_b_weight.dtype, + device=device, + ).requires_grad_(True) + object.__setattr__(self, "_lora_b_weight", b) + + # ------------------------------------------------------------------ + # State dict helpers (plain tensors are invisible to nn.Module) + # ------------------------------------------------------------------ + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + if self._lora_a_weight.device.type == "meta": + return + a = self._lora_a_weight if keep_vars else self._lora_a_weight.detach() + b = self._lora_b_weight if keep_vars else self._lora_b_weight.detach() + destination[prefix + "_lora_a_weight"] = a + destination[prefix + "_lora_b_weight"] = b + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs, + ): + a_key = prefix + "_lora_a_weight" + b_key = prefix + "_lora_b_weight" + if a_key in state_dict: + self._lora_a_weight.data.copy_(state_dict.pop(a_key)) + elif strict: + missing_keys.append(a_key) + if b_key in state_dict: + self._lora_b_weight.data.copy_(state_dict.pop(b_key)) + elif strict: + missing_keys.append(b_key) + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, + ) + + # ------------------------------------------------------------------ + # Factory & protocol + # ------------------------------------------------------------------ + + @classmethod + def from_linear( + cls, + linear: nn.Linear, + rank: int, + alpha: float, + dropout: float = 0.0, + ) -> "LoRALinear": + """Convert an existing nn.Linear to LoRALinear. + + After TP, ``linear.weight`` is a DTensor. LoRA weights are kept as + plain tensors (NOT nn.Parameter) so FSDP2 ignores them entirely, + preventing DP reduce-scatter from interleaving with TP operations + during backward. + """ + lora_linear = cls.__new__(cls) + nn.Module.__init__(lora_linear) + + lora_linear.in_dim = linear.in_features + lora_linear.out_dim = linear.out_features + lora_linear.rank = rank + lora_linear.alpha = alpha + lora_linear.scaling = alpha / rank + lora_linear.disabled = False + lora_linear._dropout_p = dropout + + lora_linear.weight = linear.weight + if linear.bias is not None: + lora_linear.bias = linear.bias + else: + lora_linear.register_parameter("bias", None) + + local_w = getattr(linear.weight, "_local_tensor", linear.weight) + _a = torch.empty(rank, linear.in_features, device=local_w.device, dtype=local_w.dtype) + _b = torch.empty(linear.out_features, rank, device=local_w.device, dtype=local_w.dtype) + _a.requires_grad_(True) + _b.requires_grad_(True) + object.__setattr__(lora_linear, "_lora_a_weight", _a) + object.__setattr__(lora_linear, "_lora_b_weight", _b) + + from torch.distributed.tensor import DTensor + + lora_linear._tp_enabled = False + if isinstance(linear.weight, DTensor): + from torch.distributed.tensor import Shard + + tp_mesh = linear.weight.device_mesh + placement = linear.weight.placements[0] + local_shape = linear.weight._local_tensor.shape + + lora_linear._tp_enabled = True + lora_linear._tp_rank = tp_mesh.get_local_rank(0) + lora_linear._tp_size = tp_mesh.size(0) + + if isinstance(placement, Shard) and placement.dim == 0: + lora_linear._tp_style = "colwise" + lora_linear._tp_local_out = local_shape[0] + elif isinstance(placement, Shard) and placement.dim == 1: + lora_linear._tp_style = "rowwise" + lora_linear._tp_local_in = local_shape[1] + else: + lora_linear._tp_style = "replicate" + + nn.init.kaiming_uniform_(lora_linear._lora_a_weight, a=math.sqrt(5)) + nn.init.zeros_(lora_linear._lora_b_weight) + + # Preserve TP forward hooks registered by parallelize_module. + lora_linear._forward_pre_hooks = linear._forward_pre_hooks.copy() + lora_linear._forward_hooks = linear._forward_hooks.copy() + lora_linear._forward_hooks_with_kwargs = ( + linear._forward_hooks_with_kwargs.copy() + ) + lora_linear._forward_hooks_always_called = ( + linear._forward_hooks_always_called.copy() + ) + lora_linear._forward_pre_hooks_with_kwargs = ( + linear._forward_pre_hooks_with_kwargs.copy() + ) + + return lora_linear + + def adapter_params(self) -> list[str]: + return ["_lora_a_weight", "_lora_b_weight"] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"in_dim={self.in_dim}, out_dim={self.out_dim}, " + f"rank={self.rank}, alpha={self.alpha}, " + f"dropout={self._dropout_p}, bias={self.bias is not None})" + ) + + +def sync_lora_grads( + model: nn.Module, + tp_group, + dp_group=None, +) -> None: + """All-reduce LoRA weight gradients across TP and DP groups. + + Because LoRA weights are plain tensors (not nn.Parameter), FSDP2 does + not handle their gradient synchronisation. This function must be + called between backward and optimizer_step. + + Args: + model: The model containing LoRALinear modules. + tp_group: Process group for tensor parallelism (required). + dp_group: Process group for data parallelism (optional but + recommended; without it gradients are only TP-synced). + """ + import torch.distributed as dist + + if tp_group is None and dp_group is None: + return + + for module in model.modules(): + if isinstance(module, LoRALinear) and module._tp_enabled: + for _pname, tensor in [ + ("a", module._lora_a_weight), + ("b", module._lora_b_weight), + ]: + if tensor.grad is not None: + grad = tensor.grad + if tp_group is not None: + dist.all_reduce(grad, group=tp_group) + if dp_group is not None: + dist.all_reduce(grad, group=dp_group) diff --git a/areal/experimental/models/archon/qwen2/infra/parallelize.py b/areal/experimental/models/archon/qwen2/infra/parallelize.py index 0ae6872401..323548209b 100644 --- a/areal/experimental/models/archon/qwen2/infra/parallelize.py +++ b/areal/experimental/models/archon/qwen2/infra/parallelize.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from collections.abc import Callable from typing import TYPE_CHECKING, Any import torch @@ -76,6 +77,7 @@ def parallelize_qwen2( reshard_after_forward_policy: str = "default", ac_config: ActivationCheckpointConfig | None = None, enable_compile: bool = True, + apply_lora_fn: Callable[[nn.Module], None] | None = None, ) -> nn.Module: """Apply parallelization to Qwen2 model. @@ -120,6 +122,10 @@ def parallelize_qwen2( cp_group = parallel_dims.get_group("cp") apply_cp(model, cp_group, tp_size=parallel_dims.tp) + # Inject LoRA after TP/CP so tensor-parallel planning still sees nn.Linear. + if apply_lora_fn is not None: + apply_lora_fn(model) + # AC must be after TP/CP if ac_config is not None and ac_config.mode != "none": apply_ac( diff --git a/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py b/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py index 7021282e2b..57a2f15f7d 100644 --- a/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py +++ b/areal/experimental/models/archon/qwen2/model/state_dict_adapter.py @@ -39,6 +39,25 @@ def __init__( "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", + # LoRA adapter key mappings (Attention) + "model.layers.{}.self_attn.q_proj.lora_A.weight": "layers.{}.attention.wq._lora_a_weight", + "model.layers.{}.self_attn.q_proj.lora_B.weight": "layers.{}.attention.wq._lora_b_weight", + "model.layers.{}.self_attn.k_proj.lora_A.weight": "layers.{}.attention.wk._lora_a_weight", + "model.layers.{}.self_attn.k_proj.lora_B.weight": "layers.{}.attention.wk._lora_b_weight", + "model.layers.{}.self_attn.v_proj.lora_A.weight": "layers.{}.attention.wv._lora_a_weight", + "model.layers.{}.self_attn.v_proj.lora_B.weight": "layers.{}.attention.wv._lora_b_weight", + "model.layers.{}.self_attn.o_proj.lora_A.weight": "layers.{}.attention.wo._lora_a_weight", + "model.layers.{}.self_attn.o_proj.lora_B.weight": "layers.{}.attention.wo._lora_b_weight", + # LoRA adapter key mappings (MLP) + "model.layers.{}.mlp.gate_proj.lora_A.weight": "layers.{}.feed_forward.w1._lora_a_weight", + "model.layers.{}.mlp.gate_proj.lora_B.weight": "layers.{}.feed_forward.w1._lora_b_weight", + "model.layers.{}.mlp.up_proj.lora_A.weight": "layers.{}.feed_forward.w3._lora_a_weight", + "model.layers.{}.mlp.up_proj.lora_B.weight": "layers.{}.feed_forward.w3._lora_b_weight", + "model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2._lora_a_weight", + "model.layers.{}.mlp.down_proj.lora_B.weight": "layers.{}.feed_forward.w2._lora_b_weight", + # LoRA adapter key mappings (LM Head) + "lm_head.lora_A.weight": "output._lora_a_weight", + "lm_head.lora_B.weight": "output._lora_b_weight", } # Build reverse mapping @@ -49,6 +68,19 @@ def __init__( self.enable_weight_tying = getattr(model_config, "tie_word_embeddings", False) + # Archon module names to HF PEFT module names for LoRA adapters + # Used when generating adapter_config.json + self.to_peft_module_map = { + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "output": "lm_head", + } + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict = {} diff --git a/areal/experimental/models/archon/qwen3/infra/parallelize.py b/areal/experimental/models/archon/qwen3/infra/parallelize.py index 58a8690a54..aff43d31dd 100644 --- a/areal/experimental/models/archon/qwen3/infra/parallelize.py +++ b/areal/experimental/models/archon/qwen3/infra/parallelize.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from collections.abc import Callable from typing import TYPE_CHECKING, Any import torch @@ -91,6 +92,7 @@ def parallelize_qwen3( reshard_after_forward_policy: str = "default", ac_config: ActivationCheckpointConfig | None = None, enable_compile: bool = True, + apply_lora_fn: Callable[[nn.Module], None] | None = None, ) -> nn.Module: """Apply parallelization to Qwen3 model. @@ -152,6 +154,10 @@ def parallelize_qwen3( cp_group = parallel_dims.get_group("cp") apply_cp(model, cp_group, tp_size=parallel_dims.tp) + # Inject LoRA after TP/EP/CP so TP planning still operates on nn.Linear. + if apply_lora_fn is not None: + apply_lora_fn(model) + # AC must be after TP/CP if ac_config is not None and ac_config.mode != "none": apply_ac( diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index c465c2c8d3..891bccd67d 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -58,6 +58,37 @@ RID_CACHE_SIZE = 128 logger = logging.getLogger("RemoteInfEngine") +WEIGHT_UPDATE_READY_FILE = ".areal_weight_update_ready" + + +def _wait_for_disk_weight_update_ready( + meta: WeightUpdateMeta, update_name: str, timeout: float +) -> float: + """Wait until the checkpoint directory is ready for remote loading. + + Prefer a ready file in the checkpoint directory, which lives on the same + shared storage as the actual weights. Fall back to the legacy name_resolve + key so older training workers still work. + """ + ready_path = None if meta.path is None else os.path.join(meta.path, WEIGHT_UPDATE_READY_FILE) + deadline = time.monotonic() + timeout + + while True: + if ready_path is not None and os.path.isfile(ready_path): + with open(ready_path) as f: + return float(f.read().strip()) + + try: + return float(name_resolve.get(update_name)) + except Exception: + pass + + if time.monotonic() > deadline: + raise TimeoutError( + f"Timeout waiting for checkpoint ready signal at " + f"'{ready_path}' or key '{update_name}'" + ) + time.sleep(1.0) class GroupedRolloutWorkflow(RolloutWorkflow): @@ -1295,7 +1326,9 @@ async def _fn(): update_name = names.update_weights_from_disk( experiment_name, trial_name, model_version ) - save_timestamp = float(name_resolve.wait(update_name, timeout=120)) + save_timestamp = _wait_for_disk_weight_update_ready( + meta, update_name, timeout=600 + ) load_timestamp = datetime.now().timestamp() # Get requests from backend with version for LoRA name diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 3f9fc7b2cd..e653cdf3d4 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -319,8 +319,9 @@ def ppo_update(self, data: dict[str, Any]) -> None: with stats_tracker.scope("update"): # Get current version for proximal approximation metrics current_version = self.engine.get_version() + _n_mbs = len(mb_inputs.mbs) - for mb in mb_inputs.mbs: + for _mb_idx, mb in enumerate(mb_inputs.mbs): train_stat = self.engine.train_batch( mb, loss_fn=functools.partial( diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 5c48649f03..160beea832 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -422,7 +422,6 @@ def train( adv_batch = self.actor.compute_advantages(rollout_batch) self.actor.get_device_stats().log("compute advantages") - # Wait for async checkpoint staging to complete before modifying parameters self.saver.maybe_wait_for_staging() with ( diff --git a/areal/utils/logging.py b/areal/utils/logging.py index 6a0884dc76..5cac4d1b60 100644 --- a/areal/utils/logging.py +++ b/areal/utils/logging.py @@ -88,6 +88,7 @@ "Saver": "blue", "AsyncCheckpoint": "blue", "ArchonCheckpoint": "blue", + "LoRACheckpoint": "blue", # Platforms - cyan "Platform": "light_cyan", "PlatformInit": "light_cyan", diff --git a/tests/experimental/archon/test_archon_engine_lora.py b/tests/experimental/archon/test_archon_engine_lora.py new file mode 100644 index 0000000000..0f4a91f44c --- /dev/null +++ b/tests/experimental/archon/test_archon_engine_lora.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock + +import torch.nn as nn + +from areal.api.io_struct import WeightUpdateMeta +from areal.experimental.engine import archon_weight_sync +from areal.experimental.engine.archon_engine import ArchonEngine +from areal.experimental.models.archon.lora import LoRALinear, get_adapter_params +from areal.experimental.models.archon.qwen2.infra import parallelize as qwen2_parallelize + + +class _ToyBlock(nn.Module): + def __init__(self): + super().__init__() + self.wq = nn.Linear(8, 8) + self.other = nn.Linear(8, 8) + self.inner = nn.Module() + self.inner.wv = nn.Linear(8, 8) + + +def _make_engine(model: nn.Module, target_modules: list[str]) -> ArchonEngine: + engine = ArchonEngine.__new__(ArchonEngine) + engine.model = model + engine.model_parts = [model] + engine.logger = Mock() + engine.lora_config = SimpleNamespace( + rank=4, + alpha=8.0, + target_modules=target_modules, + ) + engine.state_dict_adapter = SimpleNamespace( + to_peft_module_map={ + "wq": "q_proj", + "wv": "v_proj", + } + ) + return engine + + +def test_apply_lora_replaces_target_linear_modules(): + model = _ToyBlock() + engine = _make_engine(model, ["wq", "v_proj"]) + + engine._apply_lora() + + assert isinstance(model.wq, LoRALinear) + assert isinstance(model.inner.wv, LoRALinear) + assert isinstance(model.other, nn.Linear) + assert get_adapter_params(model) + + +def test_from_linear_preserves_weight_identity(): + """from_linear must transfer the original parameter, not copy it. + + After TP the weight is a DTensor; copying would fail with + ``got mixed torch.Tensor and DTensor``. + """ + import torch + + linear = nn.Linear(8, 8) + original_weight = linear.weight + lora = LoRALinear.from_linear(linear, rank=4, alpha=8.0) + assert lora.weight is original_weight + + +def test_freeze_non_lora_params_keeps_only_adapter_trainable(): + model = _ToyBlock() + engine = _make_engine(model, ["wq"]) + + engine._apply_lora() + engine._freeze_non_lora_params() + + assert model.wq.weight.requires_grad is False + assert model.wq.lora_a.weight.requires_grad is True + assert model.wq.lora_b.weight.requires_grad is True + assert model.other.weight.requires_grad is False + assert model.inner.wv.weight.requires_grad is False + + +class _ImmediateFuture: + def result(self): + return None + + +def test_update_weights_from_disk_uses_lora_adapter(monkeypatch, tmp_path): + meta = WeightUpdateMeta.from_disk( + experiment_name="exp", + trial_name="trial", + file_root=str(tmp_path), + use_lora=True, + lora_name="lora", + base_model_name="base-model", + ) + + calls: list[tuple[str, str, str]] = [] + rollout_engine = SimpleNamespace( + update_weights_from_disk=lambda _: _ImmediateFuture() + ) + engine = SimpleNamespace( + rollout_engine=rollout_engine, + lora_config=SimpleNamespace(rank=4), + config=SimpleNamespace( + experiment_name="exp", + trial_name="trial", + path="engine-model", + ), + cpu_group=None, + get_version=lambda: 0, + ) + + monkeypatch.setattr(archon_weight_sync.dist, "get_rank", lambda: 0) + monkeypatch.setattr(archon_weight_sync.dist, "barrier", lambda group=None: None) + monkeypatch.setattr(archon_weight_sync.current_platform, "synchronize", lambda: None) + monkeypatch.setattr(archon_weight_sync.name_resolve, "add", lambda *args, **kwargs: None) + monkeypatch.setattr( + archon_weight_sync.names, + "update_weights_from_disk", + lambda *args: "update-name", + ) + monkeypatch.setattr( + "areal.experimental.engine.archon_lora_checkpoint.save_lora_adapter", + lambda engine_arg, path_arg, base_model_path: calls.append( + ("lora", path_arg, base_model_path) + ), + ) + monkeypatch.setattr( + archon_weight_sync, + "save_model_to_hf", + lambda *args, **kwargs: calls.append(("full", "", "")), + ) + + archon_weight_sync.update_weights_from_disk(meta, engine) + + assert calls == [("lora", meta.path, "base-model")] + + +def test_update_weights_from_disk_falls_back_to_full_model(monkeypatch, tmp_path): + meta = WeightUpdateMeta.from_disk( + experiment_name="exp", + trial_name="trial", + file_root=str(tmp_path), + ) + + calls: list[str] = [] + rollout_engine = SimpleNamespace( + update_weights_from_disk=lambda _: _ImmediateFuture() + ) + engine = SimpleNamespace( + rollout_engine=rollout_engine, + lora_config=None, + tokenizer=None, + config=SimpleNamespace( + experiment_name="exp", + trial_name="trial", + path="engine-model", + ), + cpu_group=None, + get_version=lambda: 0, + ) + + monkeypatch.setattr(archon_weight_sync.dist, "get_rank", lambda: 0) + monkeypatch.setattr(archon_weight_sync.dist, "barrier", lambda group=None: None) + monkeypatch.setattr(archon_weight_sync.current_platform, "synchronize", lambda: None) + monkeypatch.setattr(archon_weight_sync.name_resolve, "add", lambda *args, **kwargs: None) + monkeypatch.setattr( + archon_weight_sync.names, + "update_weights_from_disk", + lambda *args: "update-name", + ) + monkeypatch.setattr( + archon_weight_sync, + "save_model_to_hf", + lambda *args, **kwargs: calls.append("full"), + ) + + archon_weight_sync.update_weights_from_disk(meta, engine) + + assert calls == ["full"] + + +def test_qwen2_parallelize_applies_lora_after_tp_and_cp(monkeypatch): + order: list[str] = [] + model = SimpleNamespace( + model_args=SimpleNamespace(enable_weight_tying=False), + ) + parallel_dims = SimpleNamespace( + tp_enabled=True, + cp_enabled=True, + pp_enabled=False, + tp=2, + get_mesh=lambda name: object(), + get_group=lambda name: object(), + ) + + monkeypatch.setattr( + qwen2_parallelize, + "apply_tp", + lambda *args, **kwargs: order.append("tp"), + ) + monkeypatch.setattr( + qwen2_parallelize, + "apply_cp", + lambda *args, **kwargs: order.append("cp"), + ) + monkeypatch.setattr( + qwen2_parallelize, + "apply_compile", + lambda *args, **kwargs: order.append("compile"), + ) + monkeypatch.setattr( + qwen2_parallelize, + "apply_fsdp", + lambda *args, **kwargs: order.append("fsdp"), + ) + + qwen2_parallelize.parallelize_qwen2( + model, + parallel_dims, + enable_compile=True, + apply_lora_fn=lambda module: order.append("lora"), + ) + + assert order == ["tp", "cp", "lora", "compile", "fsdp"] diff --git a/tests/experimental/archon/test_archon_lora_checkpoint.py b/tests/experimental/archon/test_archon_lora_checkpoint.py new file mode 100644 index 0000000000..3d76e77ce7 --- /dev/null +++ b/tests/experimental/archon/test_archon_lora_checkpoint.py @@ -0,0 +1,259 @@ +"""Tests for LoRA adapter checkpoint I/O and PEFT format conversion.""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +import torch +from torch import nn + +from areal.experimental.models.archon.lora.adapter import get_adapter_params +from areal.experimental.models.archon.lora.lora_linear import LoRALinear +from areal.experimental.models.archon.qwen2.model.state_dict_adapter import ( + Qwen2StateDictAdapter, +) + +# Try to import PEFT for compatibility tests +try: + import peft # noqa: F401 + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +class TestStateDictAdapterLoRAKeys: + """Test LoRA key conversion in Qwen2StateDictAdapter.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_qwen2_lora_key_conversion_attention(self): + """Test attention LoRA key conversion (wq, wk, wv, wo).""" + hf_key = "model.layers.0.self_attn.q_proj.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "layers.0.attention.wq.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_key_conversion_mlp(self): + """Test MLP LoRA key conversion (w1, w2, w3).""" + hf_key = "model.layers.5.mlp.gate_proj.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "layers.5.feed_forward.w1.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_key_conversion_lm_head(self): + """Test LM head LoRA key conversion.""" + hf_key = "lm_head.lora_A.weight" + archon_key = self.adapter._convert_key_from_hf(hf_key) + assert archon_key == "output.lora_a.weight" + + hf_key_back = self.adapter._convert_key_to_hf(archon_key) + assert hf_key_back == hf_key + + def test_qwen2_lora_case_conversion(self): + """Test lora_A/lora_B (HF) <-> lora_a/lora_b (Archon) case handling.""" + # lora_A -> lora_a + hf_key_a = "model.layers.0.self_attn.v_proj.lora_A.weight" + archon_key_a = self.adapter._convert_key_from_hf(hf_key_a) + assert "lora_a" in archon_key_a + + # lora_B -> lora_b + hf_key_b = "model.layers.0.self_attn.v_proj.lora_B.weight" + archon_key_b = self.adapter._convert_key_from_hf(hf_key_b) + assert "lora_b" in archon_key_b + + def test_qwen2_all_16_lora_mappings(self): + """Verify all 16 LoRA key patterns convert correctly.""" + lora_mappings = [ + ("model.layers.{}.self_attn.q_proj.lora_A.weight", "layers.{}.attention.wq.lora_a.weight"), + ("model.layers.{}.self_attn.q_proj.lora_B.weight", "layers.{}.attention.wq.lora_b.weight"), + ("model.layers.{}.self_attn.k_proj.lora_A.weight", "layers.{}.attention.wk.lora_a.weight"), + ("model.layers.{}.self_attn.k_proj.lora_B.weight", "layers.{}.attention.wk.lora_b.weight"), + ("model.layers.{}.self_attn.v_proj.lora_A.weight", "layers.{}.attention.wv.lora_a.weight"), + ("model.layers.{}.self_attn.v_proj.lora_B.weight", "layers.{}.attention.wv.lora_b.weight"), + ("model.layers.{}.self_attn.o_proj.lora_A.weight", "layers.{}.attention.wo.lora_a.weight"), + ("model.layers.{}.self_attn.o_proj.lora_B.weight", "layers.{}.attention.wo.lora_b.weight"), + ("model.layers.{}.mlp.gate_proj.lora_A.weight", "layers.{}.feed_forward.w1.lora_a.weight"), + ("model.layers.{}.mlp.gate_proj.lora_B.weight", "layers.{}.feed_forward.w1.lora_b.weight"), + ("model.layers.{}.mlp.up_proj.lora_A.weight", "layers.{}.feed_forward.w3.lora_a.weight"), + ("model.layers.{}.mlp.up_proj.lora_B.weight", "layers.{}.feed_forward.w3.lora_b.weight"), + ("model.layers.{}.mlp.down_proj.lora_A.weight", "layers.{}.feed_forward.w2.lora_a.weight"), + ("model.layers.{}.mlp.down_proj.lora_B.weight", "layers.{}.feed_forward.w2.lora_b.weight"), + ("lm_head.lora_A.weight", "output.lora_a.weight"), + ("lm_head.lora_B.weight", "output.lora_b.weight"), + ] + + for hf_pattern, archon_pattern in lora_mappings: + # Substitute layer index + hf_key = hf_pattern.replace("{}", "3") + archon_key = archon_pattern.replace("{}", "3") + + converted = self.adapter._convert_key_from_hf(hf_key) + assert converted == archon_key, f"from_hf failed: {hf_key} -> {converted}" + + back = self.adapter._convert_key_to_hf(archon_key) + assert back == hf_key, f"to_hf failed: {archon_key} -> {back}" + + def test_to_peft_module_map(self): + """Test Archon -> PEFT module name mapping.""" + assert self.adapter.to_peft_module_map["wq"] == "q_proj" + assert self.adapter.to_peft_module_map["wk"] == "k_proj" + assert self.adapter.to_peft_module_map["wv"] == "v_proj" + assert self.adapter.to_peft_module_map["wo"] == "o_proj" + assert self.adapter.to_peft_module_map["w1"] == "gate_proj" + assert self.adapter.to_peft_module_map["w2"] == "down_proj" + assert self.adapter.to_peft_module_map["w3"] == "up_proj" + assert self.adapter.to_peft_module_map["output"] == "lm_head" + + +class TestPEFTAdapterConfig: + """Test PEFT adapter config generation.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_create_peft_adapter_config(self): + """Test adapter_config.json generation.""" + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + rank: int + alpha: float + target_modules: list[str] + + lora_cfg = LoRAConfig(rank=8, alpha=16.0, target_modules=["wq", "wv"]) + config = self.adapter.create_peft_adapter_config(lora_cfg) + + assert config["peft_type"] == "LORA" + assert config["r"] == 8 + assert config["lora_alpha"] == 16 + assert config["task_type"] == "CAUSAL_LM" + assert "q_proj" in config["target_modules"] + assert "v_proj" in config["target_modules"] + assert config["bias"] == "none" + + def test_create_peft_adapter_config_with_base_model(self): + """Test adapter config with base model path.""" + from dataclasses import dataclass + + @dataclass + class LoRAConfig: + rank: int + alpha: float + target_modules: list[str] + + lora_cfg = LoRAConfig(rank=16, alpha=32.0, target_modules=["wq"]) + config = self.adapter.create_peft_adapter_config( + lora_cfg, base_model_path="Qwen/Qwen2-0.5B" + ) + + assert config["base_model_name_or_path"] == "Qwen/Qwen2-0.5B" + assert config["r"] == 16 + assert config["lora_alpha"] == 32 + + +class TestLoRAAdapterCheckpointDetection: + """Test is_lora_adapter_checkpoint function.""" + + def test_detects_valid_adapter(self, tmp_path): + """Test detection of valid PEFT adapter checkpoint.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config = {"peft_type": "LORA", "r": 8, "lora_alpha": 16} + config_path = tmp_path / "adapter_config.json" + with open(config_path, "w") as f: + json.dump(config, f) + + assert is_lora_adapter_checkpoint(str(tmp_path)) + + def test_rejects_missing_config(self, tmp_path): + """Test rejection when no adapter_config.json exists.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + def test_rejects_non_lora_config(self, tmp_path): + """Test rejection of non-LoRA adapter type.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config = {"peft_type": "PREFIX_TUNING"} + config_path = tmp_path / "adapter_config.json" + with open(config_path, "w") as f: + json.dump(config, f) + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + def test_handles_invalid_json(self, tmp_path): + """Test handling of invalid JSON config file.""" + from areal.experimental.engine.archon_lora_checkpoint import ( + is_lora_adapter_checkpoint, + ) + + config_path = tmp_path / "adapter_config.json" + config_path.write_text("not valid json {{{") + + assert not is_lora_adapter_checkpoint(str(tmp_path)) + + +class TestStateDictRoundTrip: + """Test state dict round-trip conversion with LoRA keys.""" + + def setup_method(self): + mock_config = Mock() + mock_config.tie_word_embeddings = False + self.adapter = Qwen2StateDictAdapter(mock_config) + + def test_lora_state_dict_roundtrip(self): + """Test that LoRA keys survive HF -> Archon -> HF round-trip.""" + hf_state = { + "model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + "model.layers.0.self_attn.v_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.v_proj.lora_B.weight": torch.randn(64, 8), + } + + # HF -> Archon + archon_state = self.adapter.from_hf(hf_state) + assert "layers.0.attention.wq.lora_a.weight" in archon_state + assert "layers.0.attention.wv.lora_b.weight" in archon_state + + # Archon -> HF + hf_state_back = self.adapter.to_hf(archon_state) + + assert set(hf_state.keys()) == set(hf_state_back.keys()) + for key in hf_state: + assert torch.allclose(hf_state[key], hf_state_back[key]) + + def test_mixed_base_and_lora_roundtrip(self): + """Test round-trip with both base and LoRA keys.""" + hf_state = { + "model.layers.0.self_attn.q_proj.weight": torch.randn(64, 64), + "model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "model.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + "model.norm.weight": torch.randn(64), + } + + archon_state = self.adapter.from_hf(hf_state) + hf_state_back = self.adapter.to_hf(archon_state) + + assert set(hf_state.keys()) == set(hf_state_back.keys()) diff --git a/tests/experimental/archon/test_lora_linear.py b/tests/experimental/archon/test_lora_linear.py new file mode 100644 index 0000000000..0cf7988144 --- /dev/null +++ b/tests/experimental/archon/test_lora_linear.py @@ -0,0 +1,528 @@ +"""Unit tests for LoRALinear module and adapter utilities.""" + +import pytest +import torch +import torch.nn as nn + +from areal.experimental.models.archon.lora.adapter import ( + AdapterModule, + disable_adapter, + enable_adapter, + get_adapter_params, + get_adapter_state_dict, + set_trainable_params, +) +from areal.experimental.models.archon.lora.lora_linear import LoRALinear + +# Try to import PEFT's LoRA Linear module for comparison tests +try: + from peft.tuners.lora import Linear as PEFTLoRALinear + + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +class TestLoRALinear: + """Test LoRALinear module functionality.""" + + def test_initialization(self): + """Test that LoRALinear is properly initialized with zero LoRA contribution.""" + torch.manual_seed(42) + + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + # lora_b initialized to zeros + assert torch.allclose( + lora_linear.lora_b.weight, torch.zeros_like(lora_linear.lora_b.weight) + ), "lora_b should be initialized to zeros" + + # Initial forward matches base-only output + x = torch.randn(2, 10, 64) + with torch.no_grad(): + out_with_lora = lora_linear(x) + base_out = torch.nn.functional.linear( + x, lora_linear.weight, lora_linear.bias + ) + assert torch.allclose(out_with_lora, base_out, atol=1e-6), ( + "Initial output should match base output (zero LoRA contribution)" + ) + + def test_forward_pass(self): + """Test that forward pass correctly implements LoRA computation.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 64, 32, 8 + alpha = 16.0 + lora_linear = LoRALinear( + in_dim=in_dim, out_dim=out_dim, rank=rank, alpha=alpha, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.weight.fill_(0.1) + lora_linear.lora_a.weight.fill_(0.2) + lora_linear.lora_b.weight.fill_(0.3) + + x = torch.randn(2, 10, in_dim) + output = lora_linear(x) + + # Manual calculation + base_out = torch.nn.functional.linear(x, lora_linear.weight, lora_linear.bias) + lora_a_out = torch.nn.functional.linear(x, lora_linear.lora_a.weight) + lora_b_out = torch.nn.functional.linear(lora_a_out, lora_linear.lora_b.weight) + scaling = alpha / rank + expected_output = base_out + scaling * lora_b_out + + assert torch.allclose(output, expected_output, atol=1e-5), ( + "Forward pass computation mismatch" + ) + + def test_gradient_flow(self): + """Test that only LoRA parameters receive gradients.""" + torch.manual_seed(42) + + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + # Freeze base weights + lora_linear.weight.requires_grad_(False) + if lora_linear.bias is not None: + lora_linear.bias.requires_grad_(False) + + # Set non-zero lora_b so gradients flow to lora_a + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + + x = torch.randn(2, 10, 64, requires_grad=True) + output = lora_linear(x) + loss = output.sum() + loss.backward() + + assert lora_linear.weight.grad is None, "Base weight should not have gradients" + if lora_linear.bias is not None: + assert lora_linear.bias.grad is None, "Bias should not have gradients" + + assert lora_linear.lora_a.weight.grad is not None, ( + "lora_a should have gradients" + ) + assert lora_linear.lora_b.weight.grad is not None, ( + "lora_b should have gradients" + ) + assert not torch.allclose( + lora_linear.lora_a.weight.grad, + torch.zeros_like(lora_linear.lora_a.weight.grad), + ), "lora_a gradients should be non-zero" + + def test_from_linear(self): + """Test conversion from nn.Linear to LoRALinear.""" + torch.manual_seed(42) + + linear = nn.Linear(64, 32, bias=True) + with torch.no_grad(): + linear.weight.fill_(0.5) + linear.bias.fill_(0.1) + + lora_linear = LoRALinear.from_linear(linear, rank=8, alpha=16.0) + + assert torch.allclose(lora_linear.weight, linear.weight), ( + "Base weight should match original linear weight" + ) + assert torch.allclose(lora_linear.bias, linear.bias), ( + "Bias should match original bias" + ) + assert lora_linear.in_dim == linear.in_features + assert lora_linear.out_dim == linear.out_features + + # Output with disabled LoRA matches original + x = torch.randn(2, 10, 64) + with torch.no_grad(): + lora_linear.disabled = True + lora_out = lora_linear(x) + linear_out = linear(x) + + assert torch.allclose(lora_out, linear_out, atol=1e-6), ( + "Output with disabled LoRA should match original linear" + ) + + def test_adapter_params_protocol(self): + """Test AdapterModule protocol implementation.""" + lora_linear = LoRALinear(in_dim=64, out_dim=32, rank=8, alpha=16.0) + + assert isinstance(lora_linear, AdapterModule), ( + "LoRALinear should implement AdapterModule" + ) + + adapter_param_names = lora_linear.adapter_params() + assert adapter_param_names == [ + "lora_a.weight", + "lora_b.weight", + ], "adapter_params() should return LoRA parameter names" + + def test_disabled_flag(self): + """Test that disabled flag correctly disables LoRA contribution.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + lora_linear.disabled = False + out_enabled = lora_linear(x) + + lora_linear.disabled = True + out_disabled = lora_linear(x) + + base_out = torch.nn.functional.linear( + x, lora_linear.weight, lora_linear.bias + ) + + assert torch.allclose(out_disabled, base_out, atol=1e-6), ( + "Disabled LoRA should match base output" + ) + assert not torch.allclose(out_enabled, base_out, atol=1e-5), ( + "Enabled LoRA should differ from base output" + ) + + def test_repr(self): + """Test __repr__ output.""" + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.1, use_bias=True + ) + repr_str = repr(lora_linear) + assert "LoRALinear" in repr_str + assert "in_dim=64" in repr_str + assert "out_dim=32" in repr_str + assert "rank=8" in repr_str + + +class TestAdapterUtilities: + """Test adapter utility functions.""" + + def test_get_adapter_params(self): + """Test extraction of adapter parameters from model.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora1 = LoRALinear(64, 32, rank=8, alpha=16.0) + self.lora2 = LoRALinear(32, 16, rank=8, alpha=16.0) + self.linear = nn.Linear(16, 8) + + def forward(self, x): + return self.linear(self.lora2(self.lora1(x))) + + model = SimpleModel() + adapter_params = get_adapter_params(model) + + expected_keys = { + "lora1.lora_a.weight", + "lora1.lora_b.weight", + "lora2.lora_a.weight", + "lora2.lora_b.weight", + } + assert set(adapter_params.keys()) == expected_keys, ( + "Should extract only LoRA parameters" + ) + + for param in adapter_params.values(): + assert isinstance(param, nn.Parameter), "Should return Parameter objects" + + def test_set_trainable_params(self): + """Test freezing/unfreezing parameters.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora = LoRALinear(64, 32, rank=8, alpha=16.0) + self.linear = nn.Linear(32, 16) + + model = SimpleModel() + adapter_param_names = set(get_adapter_params(model).keys()) + set_trainable_params(model, adapter_param_names) + + for name, param in model.named_parameters(): + if name in adapter_param_names: + assert param.requires_grad, f"{name} should be trainable" + else: + assert not param.requires_grad, f"{name} should be frozen" + + def test_get_adapter_state_dict(self): + """Test filtering state dict to adapter parameters only.""" + state_dict = { + "model.weight": torch.randn(10, 10), + "model.bias": torch.randn(10), + "model.lora_a.weight": torch.randn(8, 10), + "model.lora_b.weight": torch.randn(10, 8), + "other.lora_a.weight": torch.randn(8, 10), + "other.lora_b.weight": torch.randn(10, 8), + } + + adapter_state_dict = get_adapter_state_dict(state_dict) + + expected_keys = { + "model.lora_a.weight", + "model.lora_b.weight", + "other.lora_a.weight", + "other.lora_b.weight", + } + assert set(adapter_state_dict.keys()) == expected_keys, ( + "Should filter to LoRA params only" + ) + + def test_disable_enable_adapter(self): + """Test disabling/enabling adapters in model.""" + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.lora1 = LoRALinear(64, 32, rank=8, alpha=16.0) + self.lora2 = LoRALinear(32, 16, rank=8, alpha=16.0) + + model = SimpleModel() + + assert not model.lora1.disabled + assert not model.lora2.disabled + + disable_adapter(model) + assert model.lora1.disabled + assert model.lora2.disabled + + enable_adapter(model) + assert not model.lora1.disabled + assert not model.lora2.disabled + + +class TestLoRALinearWithBias: + """Test LoRALinear with bias enabled.""" + + def test_bias_initialization(self): + """Test bias is properly initialized.""" + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, use_bias=True + ) + + assert lora_linear.bias is not None, "Bias should exist" + assert torch.allclose(lora_linear.bias, torch.zeros_like(lora_linear.bias)), ( + "Bias should be initialized to zeros" + ) + + def test_bias_forward(self): + """Test forward pass with bias.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, use_bias=True, dropout=0.0 + ) + + with torch.no_grad(): + lora_linear.weight.fill_(0.1) + lora_linear.bias.fill_(0.5) + lora_linear.lora_a.weight.fill_(0.2) + lora_linear.lora_b.weight.fill_(0.3) + + x = torch.randn(2, 10, 64) + output = lora_linear(x) + + base_out = torch.nn.functional.linear(x, lora_linear.weight, lora_linear.bias) + lora_a_out = torch.nn.functional.linear(x, lora_linear.lora_a.weight) + lora_b_out = torch.nn.functional.linear(lora_a_out, lora_linear.lora_b.weight) + scaling = lora_linear.alpha / lora_linear.rank + expected_output = base_out + scaling * lora_b_out + + assert torch.allclose(output, expected_output, atol=1e-5), ( + "Forward with bias should match manual calc" + ) + + +class TestLoRALinearDropout: + """Test LoRALinear with dropout.""" + + def test_dropout_training_mode(self): + """Test that dropout is active in training mode.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.5 + ) + lora_linear.train() + + with torch.no_grad(): + lora_linear.lora_b.weight.fill_(0.1) + lora_linear.lora_a.weight.fill_(0.1) + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + out1 = lora_linear(x) + out2 = lora_linear(x) + + assert not torch.allclose(out1, out2, atol=1e-5), ( + "Dropout should cause different outputs" + ) + + def test_dropout_eval_mode(self): + """Test that dropout is disabled in eval mode.""" + torch.manual_seed(42) + + lora_linear = LoRALinear( + in_dim=64, out_dim=32, rank=8, alpha=16.0, dropout=0.5 + ) + lora_linear.eval() + + x = torch.randn(2, 10, 64) + + with torch.no_grad(): + out1 = lora_linear(x) + out2 = lora_linear(x) + + assert torch.allclose(out1, out2, atol=1e-6), ( + "Eval mode should have deterministic output" + ) + + +@pytest.mark.skipif(not PEFT_AVAILABLE, reason="PEFT not installed") +class TestPEFTCompatibility: + """Test compatibility with PEFT library implementation.""" + + def test_forward_pass_vs_peft(self): + """Compare forward pass output with PEFT's LoRA Linear module.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 128, 128, 16 + alpha = 32.0 + adapter_name = "default" + + base_linear = nn.Linear(in_dim, out_dim, bias=False) + + our_lora = LoRALinear.from_linear( + base_linear, rank=rank, alpha=alpha, dropout=0.0 + ) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(in_dim, out_dim, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + peft_lora.base_layer.weight.data.copy_(base_linear.weight.data) + + peft_lora_a = peft_lora.lora_A[adapter_name].weight + peft_lora_b = peft_lora.lora_B[adapter_name].weight + + with torch.no_grad(): + our_lora.lora_a.weight.copy_(peft_lora_a) + our_lora.lora_b.weight.copy_(peft_lora_b) + + x = torch.randn(2, 10, in_dim) + + with torch.no_grad(): + our_output = our_lora(x) + peft_output = peft_lora(x) + + max_diff = (our_output - peft_output).abs().max().item() + assert torch.allclose(our_output, peft_output, atol=1e-5), ( + f"Output mismatch vs PEFT: max diff = {max_diff}" + ) + + def test_gradient_flow_vs_peft(self): + """Compare gradient flow with PEFT's implementation.""" + torch.manual_seed(42) + + in_dim, out_dim, rank = 64, 64, 8 + alpha = 16.0 + adapter_name = "default" + + base_linear = nn.Linear(in_dim, out_dim, bias=False) + + our_lora = LoRALinear.from_linear( + base_linear, rank=rank, alpha=alpha, dropout=0.0 + ) + our_lora.weight.requires_grad_(False) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(in_dim, out_dim, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + peft_lora.base_layer.weight.data.copy_(base_linear.weight.data) + peft_lora.base_layer.weight.requires_grad_(False) + + peft_lora_a = peft_lora.lora_A[adapter_name].weight + peft_lora_b = peft_lora.lora_B[adapter_name].weight + + with torch.no_grad(): + our_lora.lora_a.weight.copy_(peft_lora_a) + our_lora.lora_b.weight.copy_(peft_lora_b) + our_lora.lora_b.weight.fill_(0.1) + peft_lora_b.fill_(0.1) + + x = torch.randn(2, 10, in_dim) + + our_output = our_lora(x) + our_output.sum().backward() + + peft_output = peft_lora(x) + peft_output.sum().backward() + + assert torch.allclose( + our_lora.lora_a.weight.grad, peft_lora_a.grad, atol=1e-4 + ), "lora_a gradient mismatch vs PEFT" + assert torch.allclose( + our_lora.lora_b.weight.grad, peft_lora_b.grad, atol=1e-4 + ), "lora_b gradient mismatch vs PEFT" + + def test_scaling_factor_vs_peft(self): + """Verify scaling factor matches PEFT's implementation.""" + rank = 16 + alpha = 32.0 + adapter_name = "default" + + our_lora = LoRALinear(in_dim=64, out_dim=64, rank=rank, alpha=alpha) + + peft_lora = PEFTLoRALinear( + base_layer=nn.Linear(64, 64, bias=False), + adapter_name=adapter_name, + r=rank, + lora_alpha=alpha, + lora_dropout=0.0, + init_lora_weights=True, + ) + + expected_scaling = alpha / rank + peft_scaling = peft_lora.scaling[adapter_name] + + assert our_lora.scaling == expected_scaling + assert our_lora.scaling == peft_scaling + + def test_initialization_vs_peft(self): + """Compare initialization strategy with PEFT.""" + torch.manual_seed(42) + + our_lora = LoRALinear(in_dim=64, out_dim=64, rank=16, alpha=32.0) + + # lora_b zeros, lora_a non-zero + assert torch.allclose( + our_lora.lora_b.weight, torch.zeros_like(our_lora.lora_b.weight) + ), "lora_b should be zeros (PEFT convention)" + + assert not torch.allclose( + our_lora.lora_a.weight, torch.zeros_like(our_lora.lora_a.weight) + ), "lora_a should be non-zero (kaiming_uniform)" + + x = torch.randn(2, 10, 64) + with torch.no_grad(): + output = our_lora(x) + base_output = torch.nn.functional.linear(x, our_lora.weight, our_lora.bias) + assert torch.allclose(output, base_output, atol=1e-6), ( + "Initial output should match base (PEFT convention)" + )