Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions areal/engine/fsdp_utils/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,26 @@ def get_grad_norm_fp32(
norm_type = float(norm_type)
total_norm = 0.0

if not grads_for_norm:
return 0.0

device = current_platform.current_device()

if not grads_for_norm:
# Still participate in all_reduce with zero contribution so that
# ranks with grads don't hang waiting for this rank.
total_norm_cuda = torch.tensor(0.0, dtype=torch.float, device=device)
reduce_op = (
dist.ReduceOp.MAX if norm_type == torch.inf else dist.ReduceOp.SUM
)
if data_parallel_group:
dist.all_reduce(total_norm_cuda, op=reduce_op, group=data_parallel_group)
if model_parallel_group is not None:
dist.all_reduce(
total_norm_cuda, op=reduce_op, group=model_parallel_group
)
total_norm = float(total_norm_cuda.item())
if norm_type != torch.inf and total_norm > 0:
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

if norm_type == torch.inf:
norms = [grad.abs().max() for grad in grads_for_norm]
total_norm = torch.max(torch.stack(norms)) if norms else 0.0
Expand Down
16 changes: 15 additions & 1 deletion areal/experimental/engine/archon_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,19 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
# Convert to HF format to match checkpoint keys
hf_state_dict = engine.state_dict_adapter.to_hf(state_dict)

# LoRA adapter parameters don't exist in the base HF checkpoint.
# Strip them before calling dcp.load() so it won't raise on missing keys.
lora_archon_keys: set[str] = set()
if engine.lora_config is not None:
lora_hf_keys = {
k for k in hf_state_dict if ".lora_A." in k or ".lora_B." in k
}
lora_archon_keys = {
k for k in state_dict if ".lora_a." in k or ".lora_b." in k
}
for k in lora_hf_keys:
del hf_state_dict[k]

# PP mode + weight tying fix: last stage needs embed_tokens weight for output layer
# When tie_word_embeddings=True, HF checkpoint only stores embed_tokens.weight,
# not lm_head.weight. In PP mode, last stage has output.weight but no tok_embeddings,
Expand Down Expand Up @@ -377,10 +390,11 @@ def load_model_from_hf(engine: ArchonEngine, path: str) -> None:
# Filter known expected missing keys
expected_missing = set()
for key in list(missing_keys):
# rotary_emb is computed at runtime, not stored in checkpoint
if "rotary_emb" in key:
expected_missing.add(key)
missing_keys -= expected_missing
# LoRA adapter keys are initialised separately, not loaded from base ckpt
missing_keys -= lora_archon_keys

if dist.get_rank() == 0:
if missing_keys:
Expand Down
165 changes: 162 additions & 3 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ def __init__(self, config: TrainEngineConfig):
self._initialized = False
self.is_offload = False

# LoRA Configuration (extract from config if enabled)
self.lora_config = None
if hasattr(config, "use_lora") and config.use_lora:
from dataclasses import dataclass

@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
rank=config.lora_rank,
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)

def create_process_group(
self,
parallel_strategy: ParallelStrategy | None = None,
Expand Down Expand Up @@ -319,6 +338,8 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
)

self._materialize_and_load_weights()
if self.lora_config is not None:
self._freeze_non_lora_params()
self._create_optimizer(ft_spec)

self.runner = create_runner(
Expand Down Expand Up @@ -486,6 +507,16 @@ def process_output(

self.forward_backward_batch(mb_list, process_output, forward_only=False)

if self.lora_config is not None:
from areal.experimental.models.archon.lora.lora_linear import (
sync_lora_grads,
)
sync_lora_grads(
self.model,
tp_group=self._tp_group,
dp_group=self.data_parallel_group,
)

return self.optimizer_step()

@torch.no_grad()
Expand Down Expand Up @@ -659,7 +690,19 @@ def update_weights(self, meta: WeightUpdateMeta):
)

def save(self, meta: SaveLoadMeta):
"""Save model in HuggingFace or DCP format."""
"""Save model in HuggingFace or DCP format.

When LoRA is enabled, only the adapter weights are saved in PEFT format.
When LoRA is disabled, the full model is saved.
"""
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
save_lora_adapter,
)

save_lora_adapter(self, meta.path, meta.base_model_path)
return

if meta.weight_format == "hf":
save_model_to_hf(self, meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
Expand All @@ -671,7 +714,20 @@ def save(self, meta: SaveLoadMeta):
save_optimizer_state(self, meta.path)

def load(self, meta: SaveLoadMeta):
"""Load model from HuggingFace or DCP format."""
"""Load model from HuggingFace or DCP format.

When LoRA is enabled and the checkpoint is a PEFT adapter,
only adapter weights are loaded.
"""
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)

if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return

if meta.weight_format == "hf":
load_model_from_hf(self, meta.path)
elif meta.weight_format == "dcp":
Expand Down Expand Up @@ -790,6 +846,7 @@ def _apply_pipeline_parallelism(
reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy,
ac_config=ac_config,
enable_compile=enable_compile,
apply_lora_fn=self._apply_lora if self.lora_config is not None else None,
)

# Delete original model to free memory
Expand Down Expand Up @@ -828,6 +885,7 @@ def _apply_parallelism(
reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy,
ac_config=ac_config,
enable_compile=enable_compile,
apply_lora_fn=self._apply_lora if self.lora_config is not None else None,
)
self.model_parts = [self.model]

Expand Down Expand Up @@ -929,8 +987,109 @@ def _create_state_dict_adapter(self) -> BaseStateDictAdapter | None:
self.model_config, hf_assets_path=self.config.path
)

def _apply_lora(self, module: nn.Module | None = None) -> None:
from areal.experimental.models.archon.lora import (
LoRALinear,
get_adapter_params,
)

assert self.lora_config is not None
module = self.model if module is None else module

target_modules = set(self.lora_config.target_modules)
apply_to_all_linears = "all-linear" in target_modules
peft_name_map = (
self.state_dict_adapter.to_peft_module_map
if self.state_dict_adapter is not None
else {}
)
replaced_modules: list[str] = []

def replace_linear_modules(parent_module: nn.Module, prefix: str = "") -> None:
for child_name, child in list(parent_module.named_children()):
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
if isinstance(child, nn.Linear):
peft_name = peft_name_map.get(child_name)
if (
not apply_to_all_linears
and child_name not in target_modules
and peft_name not in target_modules
):
continue

lora_mod = LoRALinear.from_linear(
child,
rank=self.lora_config.rank,
alpha=self.lora_config.alpha,
)
lora_mod._debug_name = child_prefix
setattr(parent_module, child_name, lora_mod)
replaced_modules.append(child_prefix)
continue

replace_linear_modules(child, child_prefix)

replace_linear_modules(module)

adapter_params = get_adapter_params(module)

if replaced_modules:
self.logger.info(
f"Applied LoRA to {len(replaced_modules)} linear modules and created "
f"{len(adapter_params)} adapter parameters"
)

def _freeze_non_lora_params(self) -> None:
from areal.experimental.models.archon.lora import (
LoRALinear,
get_adapter_params,
set_trainable_params,
)

adapter_param_count = 0
for model in self.model_parts:
# LoRA weights are plain tensors created on meta device during
# model structure creation. FSDP2 only materialises
# nn.Parameters, so we must move LoRA tensors ourselves.
for module in model.modules():
if isinstance(module, LoRALinear):
module.materialize_lora(self.device)

adapter_params = get_adapter_params(model)
if not adapter_params:
continue

# Re-initialize lora_a (kaiming) and lora_b (zeros) so the
# initial LoRA contribution is exactly zero.
with torch.no_grad():
for name, tensor in adapter_params.items():
if "lora_b" in name:
nn.init.zeros_(tensor)
elif "lora_a" in name:
nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))

adapter_param_count += len(adapter_params)
set_trainable_params(model, set(adapter_params.keys()))

if adapter_param_count == 0:
raise RuntimeError(
"LoRA is enabled but no adapter parameters were found after weight loading."
)

self.logger.info(
f"Froze base weights and kept {adapter_param_count} adapter parameters trainable"
)

def _get_all_parameters(self) -> list[nn.Parameter]:
return [p for m in self.model_parts for p in m.parameters()]
params = [p for m in self.model_parts for p in m.parameters()]
if self.lora_config is not None:
from areal.experimental.models.archon.lora import LoRALinear

for m in self.model_parts:
for module in m.modules():
if isinstance(module, LoRALinear):
params.extend(module.lora_parameters())
return params

def _get_model_name_parameters(self) -> Iterator[tuple[str, nn.Parameter]]:
for m in self.model_parts:
Expand Down
Loading
Loading