Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,25 @@ def __init__(self, config: TrainEngineConfig):
self._initialized = False
self.is_offload = False

# LoRA Configuration (extract from config if enabled)
self.lora_config = None
if hasattr(config, "use_lora") and config.use_lora:
from dataclasses import dataclass

@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
rank=config.lora_rank,
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)

def create_process_group(
self,
parallel_strategy: ParallelStrategy | None = None,
Expand Down Expand Up @@ -661,7 +680,19 @@ def update_weights(self, meta: WeightUpdateMeta):
)

def save(self, meta: SaveLoadMeta):
"""Save model in HuggingFace or DCP format."""
"""Save model in HuggingFace or DCP format.

When LoRA is enabled, only the adapter weights are saved in PEFT format.
When LoRA is disabled, the full model is saved.
"""
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
save_lora_adapter,
)

save_lora_adapter(self, meta.path, meta.base_model_path)
return

if meta.weight_format == "hf":
save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
Expand All @@ -673,7 +704,20 @@ def save(self, meta: SaveLoadMeta):
save_optimizer_state(self, meta.path)

def load(self, meta: SaveLoadMeta):
"""Load model from HuggingFace or DCP format."""
"""Load model from HuggingFace or DCP format.

When LoRA is enabled and the checkpoint is a PEFT adapter,
only adapter weights are loaded.
"""
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)

if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return

if meta.weight_format == "hf":
load_model_from_hf(self, meta.path)
elif meta.weight_format == "dcp":
Expand Down
241 changes: 241 additions & 0 deletions areal/experimental/engine/archon_lora_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""LoRA adapter checkpoint I/O in PEFT format.

This module provides functions to save and load LoRA adapters in PEFT-compatible
format for HuggingFace ecosystem interoperability.

PEFT checkpoint structure:
adapter_checkpoint/
├── adapter_model.safetensors # LoRA weights only
└── adapter_config.json # PEFT configuration

Reference: peft/src/peft/utils/save_and_load.py
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist
from safetensors.torch import load_file, save_file

from areal.experimental.models.archon.lora.adapter import get_adapter_params
from areal.utils import logging

if TYPE_CHECKING:
from areal.experimental.engine.archon_engine import ArchonEngine

logger = logging.getLogger("LoRACheckpoint")


def save_lora_adapter(
engine: "ArchonEngine",
path: str,
base_model_path: str | None = None,
) -> None:
"""Save LoRA adapter in PEFT format.

Creates two files:
- adapter_model.safetensors: LoRA weights (lora_a, lora_b)
- adapter_config.json: PEFT configuration

Args:
engine: ArchonEngine instance with LoRA-enabled model
path: Directory path to save adapter checkpoint
base_model_path: Optional path to base model (for config reference)

Raises:
RuntimeError: If LoRA is not enabled on engine
"""
if engine.lora_config is None:
raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine")

if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0

if rank == 0:
os.makedirs(path, exist_ok=True)
logger.info(f"Saving LoRA adapter to {path}")

# Extract adapter parameters from model
adapter_params = get_adapter_params(engine.model)

if not adapter_params:
logger.warning("No adapter parameters found in model")
if rank == 0:
logger.warning("Creating empty adapter checkpoint")

# Convert to HF format using state dict adapter
archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()}
hf_state = engine.state_dict_adapter.to_hf(archon_state)

# Add PEFT prefix: base_model.model.{key}
peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()}

# Save weights (only rank 0)
if rank == 0:
weights_path = os.path.join(path, "adapter_model.safetensors")
save_file(peft_state, weights_path)
logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}")

# Determine target modules from actual adapter parameters
target_modules = set()
for key in adapter_params:
parts = key.split(".")
for i, part in enumerate(parts):
if part in ("lora_a", "lora_b") and i > 0:
module_name = parts[i - 1]
target_modules.add(module_name)
break

# Create config copy with actual target modules
from dataclasses import replace

lora_config_for_save = replace(
engine.lora_config, target_modules=sorted(target_modules)
)

# Generate adapter config using model-specific state dict adapter
adapter_config = engine.state_dict_adapter.create_peft_adapter_config(
lora_config=lora_config_for_save,
base_model_path=base_model_path,
)

config_path = os.path.join(path, "adapter_config.json")
with open(config_path, "w") as f:
json.dump(adapter_config, f, indent=2)
logger.info(f"Saved adapter config to {config_path}")

# Synchronize all ranks
if dist.is_initialized():
dist.barrier()


def load_lora_adapter(
engine: "ArchonEngine",
path: str,
strict: bool = True,
) -> None:
"""Load LoRA adapter from PEFT format checkpoint.

Args:
engine: ArchonEngine instance with LoRA-enabled model
path: Directory path containing adapter checkpoint
strict: If True, raise error on missing/unexpected keys

Raises:
RuntimeError: If LoRA is not enabled on engine
FileNotFoundError: If adapter checkpoint files not found
ValueError: If strict=True and keys don't match
"""
if engine.lora_config is None:
raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine")

if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0

if rank == 0:
logger.info(f"Loading LoRA adapter from {path}")

# Load adapter weights
weights_path = os.path.join(path, "adapter_model.safetensors")
if not os.path.exists(weights_path):
# Fallback to .bin format
weights_path = os.path.join(path, "adapter_model.bin")
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"Adapter weights not found at {path}. "
"Expected adapter_model.safetensors or adapter_model.bin"
)
peft_state = torch.load(weights_path, map_location="cpu", weights_only=True)
else:
peft_state = load_file(weights_path)

if rank == 0:
logger.info(f"Loaded {len(peft_state)} adapter tensors from {weights_path}")

# Strip PEFT prefix: base_model.model.{key} -> {key}
hf_state = {}
for key, value in peft_state.items():
if key.startswith("base_model.model."):
hf_key = key.replace("base_model.model.", "", 1)
hf_state[hf_key] = value
else:
hf_state[key] = value

# Convert from HF format to Archon format
archon_state = engine.state_dict_adapter.from_hf(hf_state)

# Get expected adapter keys from model
expected_adapter_params = get_adapter_params(engine.model)
expected_keys = set(expected_adapter_params.keys())
loaded_keys = set(archon_state.keys())

missing_keys = expected_keys - loaded_keys
unexpected_keys = loaded_keys - expected_keys

if missing_keys or unexpected_keys:
if strict:
error_msg = []
if missing_keys:
error_msg.append(f"Missing keys: {sorted(missing_keys)[:5]}...")
if unexpected_keys:
error_msg.append(f"Unexpected keys: {sorted(unexpected_keys)[:5]}...")
raise ValueError(
"Adapter checkpoint keys don't match model. " + " ".join(error_msg)
)
else:
if missing_keys and rank == 0:
logger.warning(
f"Missing {len(missing_keys)} adapter keys: "
f"{sorted(missing_keys)[:5]}..."
)
if unexpected_keys and rank == 0:
logger.warning(
f"Unexpected {len(unexpected_keys)} adapter keys: "
f"{sorted(unexpected_keys)[:5]}..."
)

# Load adapter weights into model
loaded_count = 0
for key, value in archon_state.items():
if key in expected_adapter_params:
param = expected_adapter_params[key]
value = value.to(device=param.device, dtype=param.dtype)
param.data.copy_(value)
loaded_count += 1

if rank == 0:
logger.info(f"Loaded {loaded_count} adapter parameters into model")

if dist.is_initialized():
dist.barrier()


def is_lora_adapter_checkpoint(path: str) -> bool:
"""Check if path contains a PEFT LoRA adapter checkpoint.

Args:
path: Directory path to check

Returns:
True if path contains adapter_config.json with peft_type="LORA"
"""
config_path = Path(path) / "adapter_config.json"

if not config_path.exists():
return False

try:
with open(config_path) as f:
config = json.load(f)
return config.get("peft_type") == "LORA"
except (OSError, json.JSONDecodeError):
return False
40 changes: 40 additions & 0 deletions areal/experimental/models/archon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __init__(
self.to_hf_map: dict[str, str] = {}
self.hf_assets_path = hf_assets_path
self.fqn_to_index_mapping = None
# Model-specific mapping from Archon module names to PEFT module names
# Subclasses should define this mapping for LoRA adapter config generation
self.to_peft_module_map: dict[str, str] = {}

if hf_assets_path:
self._load_safetensors_index(hf_assets_path)
Expand Down Expand Up @@ -115,6 +118,43 @@ def convert_single_to_hf(
self, name: str, tensor: torch.Tensor
) -> list[tuple[str, torch.Tensor]]: ...

def create_peft_adapter_config(
self, lora_config: Any, base_model_path: str | None = None
) -> dict:
"""Generate adapter_config.json for PEFT format checkpoint.

Args:
lora_config: LoRA configuration object with rank, alpha, target_modules
base_model_path: Optional path to base model

Returns:
Dictionary containing PEFT adapter configuration
"""
# Convert Archon module names to PEFT names using model-specific mapping
peft_target_modules = [
self.to_peft_module_map.get(name, name)
for name in lora_config.target_modules
]

return {
"peft_type": "LORA",
"auto_mapping": None,
"base_model_name_or_path": base_model_path or "",
"revision": None,
"task_type": "CAUSAL_LM",
"inference_mode": False,
"r": lora_config.rank,
"lora_alpha": int(lora_config.alpha),
"lora_dropout": 0.0,
"target_modules": peft_target_modules,
"fan_in_fan_out": False,
"bias": "none",
"modules_to_save": None,
"init_lora_weights": True,
"layers_to_transform": None,
"layers_pattern": None,
}


class BaseArchonModel(nn.Module, ABC):
"""Base class for Archon models."""
Expand Down
26 changes: 26 additions & 0 deletions areal/experimental/models/archon/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""LoRA (Low-Rank Adaptation) infrastructure for Archon engine.

Following torchtune's design patterns for FSDP2 compatibility.
This module provides custom LoRALinear implementation and utilities
for parameter-efficient fine-tuning of large language models.
"""

from areal.experimental.models.archon.lora.adapter import (
AdapterModule,
disable_adapter,
enable_adapter,
get_adapter_params,
get_adapter_state_dict,
set_trainable_params,
)
from areal.experimental.models.archon.lora.lora_linear import LoRALinear

__all__ = [
"LoRALinear",
"AdapterModule",
"get_adapter_params",
"get_adapter_state_dict",
"set_trainable_params",
"disable_adapter",
"enable_adapter",
]
Loading
Loading