Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,25 @@ def __init__(self, config: TrainEngineConfig):
self._initialized = False
self.is_offload = False

# LoRA Configuration (extract from config if enabled)
self.lora_config = None
if hasattr(config, "use_lora") and config.use_lora:
from dataclasses import dataclass

@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
rank=config.lora_rank,
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)
Comment on lines +210 to +212
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good to see the alpha value being converted to float to ensure type consistency. However, consider adding a try-except block to handle potential ValueError if config.lora_alpha cannot be converted to a float. This will make the code more robust.

                rank=config.lora_rank,
                alpha=float(config.lora_alpha) if isinstance(config.lora_alpha, (int, float)) else config.lora_alpha,
                target_modules=config.target_modules if config.target_modules else [],

Comment on lines +197 to +212
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining the LoRAConfig dataclass inside the __init__ method body is an anti-pattern. This creates a new class object on every ArchonEngine instantiation, makes the class non-importable/non-reusable, and complicates type annotations elsewhere. The dataclass should be defined at module level or in a separate config module. Additionally, the enabled field is redundant since self.lora_config being non-None already indicates LoRA is enabled.

Copilot uses AI. Check for mistakes.

def create_process_group(
self,
parallel_strategy: ParallelStrategy | None = None,
Expand Down Expand Up @@ -661,7 +680,19 @@ def update_weights(self, meta: WeightUpdateMeta):
)

def save(self, meta: SaveLoadMeta):
"""Save model in HuggingFace or DCP format."""
"""Save model in HuggingFace or DCP format.

When LoRA is enabled, only the adapter weights are saved in PEFT format.
When LoRA is disabled, the full model is saved.
"""
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
save_lora_adapter,
)

save_lora_adapter(self, meta.path, meta.base_model_path)
return
Comment on lines +688 to +694
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When lora_config is not None, the save method unconditionally saves only the LoRA adapter and returns early, ignoring meta.weight_format and meta.with_optim. This means optimizer state is never saved during LoRA training, making it impossible to resume training from a checkpoint. Consider also saving optimizer state when meta.with_optim is True, similar to the non-LoRA code path.

Copilot uses AI. Check for mistakes.

if meta.weight_format == "hf":
save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
Expand All @@ -673,7 +704,20 @@ def save(self, meta: SaveLoadMeta):
save_optimizer_state(self, meta.path)

def load(self, meta: SaveLoadMeta):
"""Load model from HuggingFace or DCP format."""
"""Load model from HuggingFace or DCP format.

When LoRA is enabled and the checkpoint is a PEFT adapter,
only adapter weights are loaded.
"""
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)

if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return
Comment on lines +712 to +719
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA checkpoint imports are done unconditionally on every load() call, even when LoRA is not enabled. In save(), the import is correctly guarded inside if self.lora_config is not None. For consistency and to avoid unnecessary imports, move the import inside the if block, similar to save():

if self.lora_config is not None:
    from areal.experimental.engine.archon_lora_checkpoint import (
        is_lora_adapter_checkpoint,
        load_lora_adapter,
    )
    if is_lora_adapter_checkpoint(meta.path):
        load_lora_adapter(self, meta.path)
        return
Suggested change
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)
if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)
if is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return

Copilot uses AI. Check for mistakes.

if meta.weight_format == "hf":
load_model_from_hf(self, meta.path)
elif meta.weight_format == "dcp":
Expand Down
241 changes: 241 additions & 0 deletions areal/experimental/engine/archon_lora_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""LoRA adapter checkpoint I/O in PEFT format.

This module provides functions to save and load LoRA adapters in PEFT-compatible
format for HuggingFace ecosystem interoperability.

PEFT checkpoint structure:
adapter_checkpoint/
├── adapter_model.safetensors # LoRA weights only
└── adapter_config.json # PEFT configuration

Reference: peft/src/peft/utils/save_and_load.py
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist
from safetensors.torch import load_file, save_file

from areal.experimental.models.archon.lora.adapter import get_adapter_params
from areal.utils import logging

if TYPE_CHECKING:
from areal.experimental.engine.archon_engine import ArchonEngine

logger = logging.getLogger("LoRACheckpoint")


def save_lora_adapter(
engine: "ArchonEngine",
path: str,
base_model_path: str | None = None,
) -> None:
"""Save LoRA adapter in PEFT format.

Creates two files:
- adapter_model.safetensors: LoRA weights (lora_a, lora_b)
- adapter_config.json: PEFT configuration

Args:
engine: ArchonEngine instance with LoRA-enabled model
path: Directory path to save adapter checkpoint
base_model_path: Optional path to base model (for config reference)

Raises:
RuntimeError: If LoRA is not enabled on engine
"""
if engine.lora_config is None:
raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine")
Comment on lines +53 to +54
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This RuntimeError is appropriate for preventing the saving of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.

        raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")


if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0

if rank == 0:
os.makedirs(path, exist_ok=True)
logger.info(f"Saving LoRA adapter to {path}")

# Extract adapter parameters from model
adapter_params = get_adapter_params(engine.model)

if not adapter_params:
logger.warning("No adapter parameters found in model")
if rank == 0:
logger.warning("Creating empty adapter checkpoint")

# Convert to HF format using state dict adapter
archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()}
hf_state = engine.state_dict_adapter.to_hf(archon_state)

# Add PEFT prefix: base_model.model.{key}
peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()}
Comment on lines +66 to +78
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under FSDP2, parameters returned by get_adapter_params(engine.model) will be DTensor (sharded across ranks). Calling .cpu().detach().clone() on a DTensor only captures the local shard, not the full tensor. This means each rank would save different fragments of the LoRA weights, and the rank-0 checkpoint would be incomplete/incorrect.

The existing checkpoint code in archon_checkpoint.py uses get_model_state_dict() with StateDictOptions to properly gather/unshard tensors before saving. The LoRA save path should do the same — e.g., use get_model_state_dict then filter to adapter keys, or use full_tensor() on DTensors before saving.

Copilot uses AI. Check for mistakes.

# Save weights (only rank 0)
if rank == 0:
weights_path = os.path.join(path, "adapter_model.safetensors")
save_file(peft_state, weights_path)
logger.info(f"Saved {len(peft_state)} adapter tensors to {weights_path}")

# Determine target modules from actual adapter parameters
target_modules = set()
for key in adapter_params:
parts = key.split(".")
for i, part in enumerate(parts):
if part in ("lora_a", "lora_b") and i > 0:
module_name = parts[i - 1]
target_modules.add(module_name)
break

# Create config copy with actual target modules
from dataclasses import replace

lora_config_for_save = replace(
engine.lora_config, target_modules=sorted(target_modules)
)

# Generate adapter config using model-specific state dict adapter
adapter_config = engine.state_dict_adapter.create_peft_adapter_config(
lora_config=lora_config_for_save,
base_model_path=base_model_path,
)

config_path = os.path.join(path, "adapter_config.json")
with open(config_path, "w") as f:
json.dump(adapter_config, f, indent=2)
logger.info(f"Saved adapter config to {config_path}")

# Synchronize all ranks
if dist.is_initialized():
dist.barrier()
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The barrier calls here use the default process group (dist.barrier()), but the existing checkpoint code in archon_checkpoint.py consistently uses dist.barrier(group=engine.cpu_group) to synchronize. Using different groups for barriers can cause hangs or incorrect synchronization in multi-group setups. Consider passing engine.cpu_group to dist.barrier() for consistency with the existing checkpoint code.

Suggested change
dist.barrier()
dist.barrier(group=engine.cpu_group)

Copilot uses AI. Check for mistakes.


def load_lora_adapter(
engine: "ArchonEngine",
path: str,
strict: bool = True,
) -> None:
"""Load LoRA adapter from PEFT format checkpoint.

Args:
engine: ArchonEngine instance with LoRA-enabled model
path: Directory path containing adapter checkpoint
strict: If True, raise error on missing/unexpected keys

Raises:
RuntimeError: If LoRA is not enabled on engine
FileNotFoundError: If adapter checkpoint files not found
ValueError: If strict=True and keys don't match
"""
if engine.lora_config is None:
raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine")
Comment on lines +136 to +137
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the save_lora_adapter function, this RuntimeError is appropriate for preventing the loading of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.

        raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")


if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0

if rank == 0:
logger.info(f"Loading LoRA adapter from {path}")

# Load adapter weights
weights_path = os.path.join(path, "adapter_model.safetensors")
if not os.path.exists(weights_path):
# Fallback to .bin format
weights_path = os.path.join(path, "adapter_model.bin")
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"Adapter weights not found at {path}. "
"Expected adapter_model.safetensors or adapter_model.bin"
)
Comment on lines +152 to +156
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This FileNotFoundError is appropriate for handling missing adapter weights. Consider including the path in the error message to help the user quickly identify the missing file.

            raise FileNotFoundError(
                f"Adapter weights not found at {weights_path}. "
                "Expected adapter_model.safetensors or adapter_model.bin"
            )

peft_state = torch.load(weights_path, map_location="cpu", weights_only=True)
else:
peft_state = load_file(weights_path)

if rank == 0:
logger.info(f"Loaded {len(peft_state)} adapter tensors from {weights_path}")

# Strip PEFT prefix: base_model.model.{key} -> {key}
hf_state = {}
for key, value in peft_state.items():
if key.startswith("base_model.model."):
hf_key = key.replace("base_model.model.", "", 1)
hf_state[hf_key] = value
else:
hf_state[key] = value

# Convert from HF format to Archon format
archon_state = engine.state_dict_adapter.from_hf(hf_state)

# Get expected adapter keys from model
expected_adapter_params = get_adapter_params(engine.model)
expected_keys = set(expected_adapter_params.keys())
loaded_keys = set(archon_state.keys())

missing_keys = expected_keys - loaded_keys
unexpected_keys = loaded_keys - expected_keys

if missing_keys or unexpected_keys:
if strict:
error_msg = []
if missing_keys:
error_msg.append(f"Missing keys: {sorted(missing_keys)[:5]}...")
if unexpected_keys:
error_msg.append(f"Unexpected keys: {sorted(unexpected_keys)[:5]}...")
raise ValueError(
"Adapter checkpoint keys don't match model. " + " ".join(error_msg)
)
Comment on lines +191 to +193
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This ValueError is appropriate for handling key mismatches. Consider including the path in the error message to help the user quickly identify the incorrect checkpoint.

            raise ValueError(
                f"Adapter checkpoint keys don't match model at {path}. " + " ".join(error_msg)
            )

else:
if missing_keys and rank == 0:
logger.warning(
f"Missing {len(missing_keys)} adapter keys: "
f"{sorted(missing_keys)[:5]}..."
)
if unexpected_keys and rank == 0:
logger.warning(
f"Unexpected {len(unexpected_keys)} adapter keys: "
f"{sorted(unexpected_keys)[:5]}..."
)

# Load adapter weights into model
loaded_count = 0
for key, value in archon_state.items():
if key in expected_adapter_params:
param = expected_adapter_params[key]
value = value.to(device=param.device, dtype=param.dtype)
param.data.copy_(value)
loaded_count += 1
Comment on lines +206 to +213
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the save path issue: under FSDP2, expected_adapter_params[key] will be a DTensor (sharded). Directly assigning to param.data.copy_(value) on a DTensor won't properly distribute the loaded weights across the FSDP2 mesh. The existing load code in archon_checkpoint.py uses set_model_state_dict() to handle this correctly. The LoRA load path should use an equivalent mechanism to ensure proper shard distribution.

Copilot uses AI. Check for mistakes.

if rank == 0:
logger.info(f"Loaded {loaded_count} adapter parameters into model")

if dist.is_initialized():
dist.barrier()
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the save path: dist.barrier() should use engine.cpu_group to match the convention in archon_checkpoint.py (e.g., line 437, 450, 463).

Suggested change
dist.barrier()
dist.barrier(group=engine.cpu_group)

Copilot uses AI. Check for mistakes.


def is_lora_adapter_checkpoint(path: str) -> bool:
"""Check if path contains a PEFT LoRA adapter checkpoint.

Args:
path: Directory path to check

Returns:
True if path contains adapter_config.json with peft_type="LORA"
"""
config_path = Path(path) / "adapter_config.json"

if not config_path.exists():
return False

try:
with open(config_path) as f:
config = json.load(f)
return config.get("peft_type") == "LORA"
except (OSError, json.JSONDecodeError):
return False
40 changes: 40 additions & 0 deletions areal/experimental/models/archon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __init__(
self.to_hf_map: dict[str, str] = {}
self.hf_assets_path = hf_assets_path
self.fqn_to_index_mapping = None
# Model-specific mapping from Archon module names to PEFT module names
# Subclasses should define this mapping for LoRA adapter config generation
self.to_peft_module_map: dict[str, str] = {}

if hf_assets_path:
self._load_safetensors_index(hf_assets_path)
Expand Down Expand Up @@ -115,6 +118,43 @@ def convert_single_to_hf(
self, name: str, tensor: torch.Tensor
) -> list[tuple[str, torch.Tensor]]: ...

def create_peft_adapter_config(
self, lora_config: Any, base_model_path: str | None = None
) -> dict:
"""Generate adapter_config.json for PEFT format checkpoint.

Args:
lora_config: LoRA configuration object with rank, alpha, target_modules
base_model_path: Optional path to base model

Returns:
Dictionary containing PEFT adapter configuration
"""
# Convert Archon module names to PEFT names using model-specific mapping
peft_target_modules = [
self.to_peft_module_map.get(name, name)
for name in lora_config.target_modules
]

return {
"peft_type": "LORA",
"auto_mapping": None,
"base_model_name_or_path": base_model_path or "",
"revision": None,
"task_type": "CAUSAL_LM",
"inference_mode": False,
"r": lora_config.rank,
"lora_alpha": int(lora_config.alpha),
"lora_dropout": 0.0,
"target_modules": peft_target_modules,
"fan_in_fan_out": False,
"bias": "none",
"modules_to_save": None,
"init_lora_weights": True,
"layers_to_transform": None,
"layers_pattern": None,
}


class BaseArchonModel(nn.Module, ABC):
"""Base class for Archon models."""
Expand Down
26 changes: 26 additions & 0 deletions areal/experimental/models/archon/lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""LoRA (Low-Rank Adaptation) infrastructure for Archon engine.

Following torchtune's design patterns for FSDP2 compatibility.
This module provides custom LoRALinear implementation and utilities
for parameter-efficient fine-tuning of large language models.
"""

from areal.experimental.models.archon.lora.adapter import (
AdapterModule,
disable_adapter,
enable_adapter,
get_adapter_params,
get_adapter_state_dict,
set_trainable_params,
)
from areal.experimental.models.archon.lora.lora_linear import LoRALinear

__all__ = [
"LoRALinear",
"AdapterModule",
"get_adapter_params",
"get_adapter_state_dict",
"set_trainable_params",
"disable_adapter",
"enable_adapter",
]
Loading
Loading