-
Notifications
You must be signed in to change notification settings - Fork 501
feat: add LoRA infrastructure for Archon engine (Phase 1 & 2) #1000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -192,6 +192,25 @@ def __init__(self, config: TrainEngineConfig): | |||||||||||||||||||||||||||||||||||
| self._initialized = False | ||||||||||||||||||||||||||||||||||||
| self.is_offload = False | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # LoRA Configuration (extract from config if enabled) | ||||||||||||||||||||||||||||||||||||
| self.lora_config = None | ||||||||||||||||||||||||||||||||||||
| if hasattr(config, "use_lora") and config.use_lora: | ||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||
| class LoRAConfig: | ||||||||||||||||||||||||||||||||||||
| enabled: bool | ||||||||||||||||||||||||||||||||||||
| rank: int | ||||||||||||||||||||||||||||||||||||
| alpha: float | ||||||||||||||||||||||||||||||||||||
| target_modules: list[str] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| self.lora_config = LoRAConfig( | ||||||||||||||||||||||||||||||||||||
| enabled=True, | ||||||||||||||||||||||||||||||||||||
| rank=config.lora_rank, | ||||||||||||||||||||||||||||||||||||
| alpha=float(config.lora_alpha), | ||||||||||||||||||||||||||||||||||||
| target_modules=config.target_modules if config.target_modules else [], | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+197
to
+212
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def create_process_group( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| parallel_strategy: ParallelStrategy | None = None, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -661,7 +680,19 @@ def update_weights(self, meta: WeightUpdateMeta): | |||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def save(self, meta: SaveLoadMeta): | ||||||||||||||||||||||||||||||||||||
| """Save model in HuggingFace or DCP format.""" | ||||||||||||||||||||||||||||||||||||
| """Save model in HuggingFace or DCP format. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| When LoRA is enabled, only the adapter weights are saved in PEFT format. | ||||||||||||||||||||||||||||||||||||
| When LoRA is disabled, the full model is saved. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| if self.lora_config is not None: | ||||||||||||||||||||||||||||||||||||
| from areal.experimental.engine.archon_lora_checkpoint import ( | ||||||||||||||||||||||||||||||||||||
| save_lora_adapter, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| save_lora_adapter(self, meta.path, meta.base_model_path) | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+688
to
+694
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if meta.weight_format == "hf": | ||||||||||||||||||||||||||||||||||||
| save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor) | ||||||||||||||||||||||||||||||||||||
| elif meta.weight_format == "dcp": | ||||||||||||||||||||||||||||||||||||
|
|
@@ -673,7 +704,20 @@ def save(self, meta: SaveLoadMeta): | |||||||||||||||||||||||||||||||||||
| save_optimizer_state(self, meta.path) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def load(self, meta: SaveLoadMeta): | ||||||||||||||||||||||||||||||||||||
| """Load model from HuggingFace or DCP format.""" | ||||||||||||||||||||||||||||||||||||
| """Load model from HuggingFace or DCP format. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| When LoRA is enabled and the checkpoint is a PEFT adapter, | ||||||||||||||||||||||||||||||||||||
| only adapter weights are loaded. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| from areal.experimental.engine.archon_lora_checkpoint import ( | ||||||||||||||||||||||||||||||||||||
| is_lora_adapter_checkpoint, | ||||||||||||||||||||||||||||||||||||
| load_lora_adapter, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): | ||||||||||||||||||||||||||||||||||||
| load_lora_adapter(self, meta.path) | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+712
to
+719
|
||||||||||||||||||||||||||||||||||||
| from areal.experimental.engine.archon_lora_checkpoint import ( | |
| is_lora_adapter_checkpoint, | |
| load_lora_adapter, | |
| ) | |
| if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): | |
| load_lora_adapter(self, meta.path) | |
| return | |
| if self.lora_config is not None: | |
| from areal.experimental.engine.archon_lora_checkpoint import ( | |
| is_lora_adapter_checkpoint, | |
| load_lora_adapter, | |
| ) | |
| if is_lora_adapter_checkpoint(meta.path): | |
| load_lora_adapter(self, meta.path) | |
| return |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,241 @@ | ||||||
| """LoRA adapter checkpoint I/O in PEFT format. | ||||||
|
|
||||||
| This module provides functions to save and load LoRA adapters in PEFT-compatible | ||||||
| format for HuggingFace ecosystem interoperability. | ||||||
|
|
||||||
| PEFT checkpoint structure: | ||||||
| adapter_checkpoint/ | ||||||
| ├── adapter_model.safetensors # LoRA weights only | ||||||
| └── adapter_config.json # PEFT configuration | ||||||
|
|
||||||
| Reference: peft/src/peft/utils/save_and_load.py | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import json | ||||||
| import os | ||||||
| from pathlib import Path | ||||||
| from typing import TYPE_CHECKING | ||||||
|
|
||||||
| import torch | ||||||
| import torch.distributed as dist | ||||||
| from safetensors.torch import load_file, save_file | ||||||
|
|
||||||
| from areal.experimental.models.archon.lora.adapter import get_adapter_params | ||||||
| from areal.utils import logging | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from areal.experimental.engine.archon_engine import ArchonEngine | ||||||
|
|
||||||
| logger = logging.getLogger("LoRACheckpoint") | ||||||
|
|
||||||
|
|
||||||
| def save_lora_adapter( | ||||||
| engine: "ArchonEngine", | ||||||
| path: str, | ||||||
| base_model_path: str | None = None, | ||||||
| ) -> None: | ||||||
| """Save LoRA adapter in PEFT format. | ||||||
|
|
||||||
| Creates two files: | ||||||
| - adapter_model.safetensors: LoRA weights (lora_a, lora_b) | ||||||
| - adapter_config.json: PEFT configuration | ||||||
|
|
||||||
| Args: | ||||||
| engine: ArchonEngine instance with LoRA-enabled model | ||||||
| path: Directory path to save adapter checkpoint | ||||||
| base_model_path: Optional path to base model (for config reference) | ||||||
|
|
||||||
| Raises: | ||||||
| RuntimeError: If LoRA is not enabled on engine | ||||||
| """ | ||||||
| if engine.lora_config is None: | ||||||
| raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine") | ||||||
|
Comment on lines
+53
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.") |
||||||
|
|
||||||
| if dist.is_initialized(): | ||||||
| rank = dist.get_rank() | ||||||
| else: | ||||||
| rank = 0 | ||||||
|
|
||||||
| if rank == 0: | ||||||
| os.makedirs(path, exist_ok=True) | ||||||
| logger.info(f"Saving LoRA adapter to {path}") | ||||||
|
|
||||||
| # Extract adapter parameters from model | ||||||
| adapter_params = get_adapter_params(engine.model) | ||||||
|
|
||||||
| if not adapter_params: | ||||||
| logger.warning("No adapter parameters found in model") | ||||||
| if rank == 0: | ||||||
| logger.warning("Creating empty adapter checkpoint") | ||||||
|
|
||||||
| # Convert to HF format using state dict adapter | ||||||
| archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()} | ||||||
| hf_state = engine.state_dict_adapter.to_hf(archon_state) | ||||||
|
|
||||||
| # Add PEFT prefix: base_model.model.{key} | ||||||
| peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()} | ||||||
|
Comment on lines
+66
to
+78
|
||||||
|
|
||||||
| # Save weights (only rank 0) | ||||||
| if rank == 0: | ||||||
| weights_path = os.path.join(path, "adapter_model.safetensors") | ||||||
| save_file(peft_state, weights_path) | ||||||
| logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}") | ||||||
|
|
||||||
| # Determine target modules from actual adapter parameters | ||||||
| target_modules = set() | ||||||
| for key in adapter_params: | ||||||
| parts = key.split(".") | ||||||
| for i, part in enumerate(parts): | ||||||
| if part in ("lora_a", "lora_b") and i > 0: | ||||||
| module_name = parts[i - 1] | ||||||
| target_modules.add(module_name) | ||||||
| break | ||||||
|
|
||||||
| # Create config copy with actual target modules | ||||||
| from dataclasses import replace | ||||||
|
|
||||||
| lora_config_for_save = replace( | ||||||
| engine.lora_config, target_modules=sorted(target_modules) | ||||||
| ) | ||||||
|
|
||||||
| # Generate adapter config using model-specific state dict adapter | ||||||
| adapter_config = engine.state_dict_adapter.create_peft_adapter_config( | ||||||
| lora_config=lora_config_for_save, | ||||||
| base_model_path=base_model_path, | ||||||
| ) | ||||||
|
|
||||||
| config_path = os.path.join(path, "adapter_config.json") | ||||||
| with open(config_path, "w") as f: | ||||||
| json.dump(adapter_config, f, indent=2) | ||||||
| logger.info(f"Saved adapter config to {config_path}") | ||||||
|
|
||||||
| # Synchronize all ranks | ||||||
| if dist.is_initialized(): | ||||||
| dist.barrier() | ||||||
|
||||||
| dist.barrier() | |
| dist.barrier(group=engine.cpu_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the save_lora_adapter function, this RuntimeError is appropriate for preventing the loading of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.
raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This FileNotFoundError is appropriate for handling missing adapter weights. Consider including the path in the error message to help the user quickly identify the missing file.
raise FileNotFoundError(
f"Adapter weights not found at {weights_path}. "
"Expected adapter_model.safetensors or adapter_model.bin"
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot
AI
Mar 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the save path issue: under FSDP2, expected_adapter_params[key] will be a DTensor (sharded). Directly assigning to param.data.copy_(value) on a DTensor won't properly distribute the loaded weights across the FSDP2 mesh. The existing load code in archon_checkpoint.py uses set_model_state_dict() to handle this correctly. The LoRA load path should use an equivalent mechanism to ensure proper shard distribution.
Copilot
AI
Mar 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as the save path: dist.barrier() should use engine.cpu_group to match the convention in archon_checkpoint.py (e.g., line 437, 450, 463).
| dist.barrier() | |
| dist.barrier(group=engine.cpu_group) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| """LoRA (Low-Rank Adaptation) infrastructure for Archon engine. | ||
|
|
||
| Following torchtune's design patterns for FSDP2 compatibility. | ||
| This module provides custom LoRALinear implementation and utilities | ||
| for parameter-efficient fine-tuning of large language models. | ||
| """ | ||
|
|
||
| from areal.experimental.models.archon.lora.adapter import ( | ||
| AdapterModule, | ||
| disable_adapter, | ||
| enable_adapter, | ||
| get_adapter_params, | ||
| get_adapter_state_dict, | ||
| set_trainable_params, | ||
| ) | ||
| from areal.experimental.models.archon.lora.lora_linear import LoRALinear | ||
|
|
||
| __all__ = [ | ||
| "LoRALinear", | ||
| "AdapterModule", | ||
| "get_adapter_params", | ||
| "get_adapter_state_dict", | ||
| "set_trainable_params", | ||
| "disable_adapter", | ||
| "enable_adapter", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good to see the alpha value being converted to float to ensure type consistency. However, consider adding a try-except block to handle potential
ValueErrorifconfig.lora_alphacannot be converted to a float. This will make the code more robust.