diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py new file mode 100644 index 00000000000..b5dff7b5663 --- /dev/null +++ b/tests/unit_tests/tools/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tests/unit_tests/tools/checkpoint/__init__.py b/tests/unit_tests/tools/checkpoint/__init__.py new file mode 100644 index 00000000000..b5dff7b5663 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py new file mode 100644 index 00000000000..4e080f8f0b3 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py @@ -0,0 +1,328 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-rank distributed round-trip test for gpt_hybrid_conversion. + +Each rank participates in a multi-rank DCP save of a synthetic GPT (or MoE +GPT) state dict; rank 0 then runs the converter and verifies the GPT->Hybrid-> +GPT round-trip exactly. + +This test is meant to be launched under SLURM/srun (or torchrun) with +WORLD_SIZE = TP * PP * FSDP * EP. The (tp, pp, fsdp, ep) values are passed +as flags purely as labels — the converter sees only the DCP-stored +``global_shape`` per tensor and is agnostic to *which* dimension(s) the +source was sharded along. The test value is in: + + 1. Coordinating a real multi-rank ``dcp.save`` (cross-node networking, + collective barriers, shared-filesystem write). + 2. Verifying the converter loads a multi-rank-written checkpoint and + round-trips it through both directions. + +Usage (under srun): + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + export WORLD_SIZE=$SLURM_NTASKS + export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -1) + export MASTER_PORT=29500 + python test_distributed_round_trip.py \\ + --tp 2 --pp 2 --fsdp 2 --ep 2 --label TP2-PP2-FSDP2-EP2 \\ + --output-root /lustre/.../scratch/dist_test +""" + +import argparse +import copy +import os +import shutil +import sys +import time +from collections import OrderedDict +from types import SimpleNamespace + +import torch +import torch.distributed as dist + +# Make the conversion tool and helpers importable. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + + +def _log(msg, rank, label=""): + prefix = f"[rank={rank}{(' ' + label) if label else ''}]" + print(f"{prefix} {msg}", flush=True) + + +def _build_state_dict(num_layers, hidden_size, num_moe_experts, vocab_size, dtype): + """Build a deterministic GPT(MoE) state dict identical on every rank. + + Determinism (via fixed seed) lets every rank produce the same tensors so + DCP's de-duplication writes a single coherent checkpoint. After load, we + re-derive the same tensors on rank 0 to verify round-trip. + """ + torch.manual_seed(0xC0FFEE) + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + + if num_moe_experts is None: + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + else: + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd + + +def _build_ckpt_args(num_layers, hidden_size, num_moe_experts): + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=4, + ffn_hidden_size=hidden_size * 4, + seq_length=256, + max_position_embeddings=256, + iteration=100, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + ) + + +def _init_process_group(init_file): + """Initialize via file:// rendezvous on a shared filesystem. + + RANK / WORLD_SIZE come from env (set by srun). MASTER_ADDR / MASTER_PORT + are not needed — every rank just opens the same shared file. This avoids + the SLURM CLI tools (e.g. scontrol) which are not always present inside + container images. + + The init file MUST NOT pre-exist; rank 0 cleans up any stale leftover. + """ + if dist.is_initialized(): + return + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + if rank == 0 and os.path.exists(init_file): + os.remove(init_file) + # Brief settle so other ranks don't race ahead of the cleanup. + time.sleep(1) + dist.init_process_group( + backend='gloo', # CPU-only synthetic save; no NCCL needed + init_method=f'file://{init_file}', + rank=rank, + world_size=world_size, + ) + + +def _verify_round_trip(original, recovered, label): + missing, mismatch = [], [] + for k, v in original.items(): + if k not in recovered: + missing.append(k) + continue + if not torch.equal(v, recovered[k].to(v.dtype)): + mismatch.append(k) + + leaked_ssm = [k for k in recovered if 'mixer.' in k] + + if missing or mismatch or leaked_ssm: + print(f"FAIL [{label}]:") + for k in missing[:5]: + print(f" MISSING: {k}") + for k in mismatch[:5]: + print(f" MISMATCH: {k}") + for k in leaked_ssm[:5]: + print(f" SSM LEAKED: {k}") + raise AssertionError( + f"{label}: {len(missing)} missing, {len(mismatch)} mismatched, " + f"{len(leaked_ssm)} SSM keys leaked" + ) + + print(f"PASS [{label}]: {len(original)} keys round-tripped exactly") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--pp', type=int, default=1) + parser.add_argument('--fsdp', type=int, default=1) + parser.add_argument('--ep', type=int, default=1) + parser.add_argument('--label', type=str, required=True) + parser.add_argument( + '--output-root', + type=str, + required=True, + help='Shared-filesystem path where this scenario writes its ' + 'checkpoints (must be visible from every rank).', + ) + parser.add_argument('--num-layers', type=int, default=3) + parser.add_argument('--hidden-size', type=int, default=64) + parser.add_argument('--vocab-size', type=int, default=512) + parser.add_argument( + '--num-moe-experts', + type=int, + default=None, + help='If set, use the MoE GPT state-dict layout ' + '(mlp.router + mlp.experts.local_experts.*).', + ) + parser.add_argument( + '--pattern', + type=str, + default=None, + help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for ' + 'dense or "M*EM*EM*E" when --num-moe-experts is set.', + ) + args = parser.parse_args() + + expected_world = args.tp * args.pp * args.fsdp * args.ep + pattern = args.pattern or ('M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-') + + # Shared init file lives on the same shared FS we use for checkpoints, so + # all ranks on all nodes see the same path. + init_file = os.path.join(args.output_root, f'_pg_init_{args.label}') + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.output_root, exist_ok=True) + _init_process_group(init_file) + rank = dist.get_rank() + world = dist.get_world_size() + if world != expected_world: + if rank == 0: + print(f"FAIL [{args.label}]: world={world} but tp*pp*fsdp*ep={expected_world}") + sys.exit(2) + + # Lazy imports after sys.path is set. + from dist_checkpoint_io import ( + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + from gpt_hybrid_conversion import main as conversion_main + + if rank == 0: + _log( + f"label={args.label} tp={args.tp} pp={args.pp} fsdp={args.fsdp} " + f"ep={args.ep} world={world} pattern={pattern} " + f"num_moe_experts={args.num_moe_experts}", + rank, + args.label, + ) + + # Each rank builds the same full state dict — DCP de-duplicates writes + # across ranks via its writer planner. + state_dict = _build_state_dict( + args.num_layers, args.hidden_size, args.num_moe_experts, args.vocab_size, torch.float32 + ) + ckpt_args = _build_ckpt_args(args.num_layers, args.hidden_size, args.num_moe_experts) + + scratch = os.path.join(args.output_root, args.label) + src_dir = os.path.join(scratch, 'gpt_src') + mid_dir = os.path.join(scratch, 'hybrid_mid') + dst_dir = os.path.join(scratch, 'gpt_dst') + iter_subdir = os.path.join(src_dir, 'iter_0000100') + + if rank == 0: + if os.path.isdir(scratch): + shutil.rmtree(scratch, ignore_errors=True) + os.makedirs(iter_subdir, exist_ok=True) + dist.barrier() + + # --- Multi-rank DCP write of the source GPT checkpoint --- + # dcp.save / dcp.load are both COLLECTIVE in the active process group, so + # every rank in this PG must participate in every save and every load. + # That includes the two conversion_main calls below, which internally call + # load_dist_checkpoint_full + save_dist_checkpoint_full once each. + # If a rank exits early its gloo socket closes and rank 0's reduce_scatter + # dies with "Connection closed by peer". + common_state = {'args': copy.deepcopy(ckpt_args), 'checkpoint_version': 3.0, 'iteration': 100} + save_dist_checkpoint_full( + state_dict, common_state, iter_subdir, model_prefix='model.', backend='torch_dist' + ) + if rank == 0: + write_latest_iteration_marker(iter_subdir, 100) + dist.barrier() + + # --- Convert GPT -> hybrid -> GPT (every rank participates collectively). + common_kwargs = dict( + hybrid_layer_pattern=pattern, + d_model=args.hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format='torch_dist', + ) + + # Silence non-rank-0 stdout to keep logs readable; collective behavior + # is unaffected. + if rank != 0: + sys.stdout = open(os.devnull, 'w') + + t0 = time.time() + conversion_main( + argparse.Namespace( + direction='gpt-to-hybrid', load_dir=src_dir, save_dir=mid_dir, **common_kwargs + ) + ) + dist.barrier() + conversion_main( + argparse.Namespace( + direction='hybrid-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs + ) + ) + dist.barrier() + dt = time.time() - t0 + + # Restore stdout for rank 0's verify message. + if rank != 0: + sys.stdout = sys.__stdout__ + + # --- Load final + (rank 0 only) verify --- + recovered, _, _, _, _ = load_dist_checkpoint_full(dst_dir) + dist.barrier() + + if rank == 0: + _log(f"conversion time: {dt:.2f}s", rank, args.label) + _verify_round_trip(state_dict, recovered, args.label) + shutil.rmtree(scratch, ignore_errors=True) + if os.path.exists(init_file): + os.remove(init_file) + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py new file mode 100644 index 00000000000..79ec3e5ab36 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py @@ -0,0 +1,809 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for the GPT <-> Hybrid checkpoint conversion tool. + +These tests validate: +- Hybrid layer pattern parsing +- Layer index mapping (GPT <-> Hybrid) +- State dict key renaming (final_layernorm <-> final_norm) +- Shared parameter copying (embeddings, output_layer) +- SSM parameter initialization shapes and dtypes +- Round-trip conversion: GPT -> Hybrid -> GPT preserves attention and MLP weights +- TP split dimension lookup +""" + +import argparse +import math +import os +import sys +import tempfile +from collections import OrderedDict + +import pytest +import torch + +# Add the tools/checkpoint directory to the path so we can import the module +sys.path.insert( + 0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint') +) + +from gpt_hybrid_conversion import ( + build_layer_index_mapping, + convert_gpt_to_hybrid, + convert_hybrid_to_gpt, + get_layer_num_from_key, + initialize_ssm_layer_params, + is_attention_param, + is_mlp_param, + is_ssm_param, + parse_hybrid_layer_pattern, + replace_layer_num, + validate_pattern_gpt_compatible, + validate_source_args_gpt_compatible, +) + +# --------------------------------------------------------------------------- +# Pattern parsing tests +# --------------------------------------------------------------------------- + + +class TestPatternParsing: + def test_simple_pattern(self): + result = parse_hybrid_layer_pattern("M*-M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_all_mamba(self): + result = parse_hybrid_layer_pattern("MMMM") + assert result == ['M', 'M', 'M', 'M'] + + def test_all_attention(self): + result = parse_hybrid_layer_pattern("****") + assert result == ['*', '*', '*', '*'] + + def test_with_mtp_separator(self): + # Should strip MTP patterns (only main pattern) + result = parse_hybrid_layer_pattern("M*-M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_pipe_separator(self): + # Should strip pipeline stage separators + result = parse_hybrid_layer_pattern("M*-|M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_both_separators(self): + result = parse_hybrid_layer_pattern("M*-|M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_mixed_layers(self): + result = parse_hybrid_layer_pattern("M*-EG") + assert result == ['M', '*', '-', 'E', 'G'] + + def test_invalid_symbol(self): + with pytest.raises(ValueError, match="Invalid layer symbol"): + parse_hybrid_layer_pattern("M*X") + + +# --------------------------------------------------------------------------- +# Layer index mapping tests +# --------------------------------------------------------------------------- + + +class TestLayerIndexMapping: + def test_gpt_to_hybrid_basic(self): + # Pattern: M*-M*- (2 attn at pos 1,4; 2 MLP at pos 2,5) + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + # 2 GPT layers -> attn at [1,4], MLP at [2,5] + assert attn_map == {0: 1, 1: 4} + assert mlp_map == {0: 2, 1: 5} + assert ssm_indices == [0, 3] + + def test_hybrid_to_gpt_basic(self): + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'hybrid-to-gpt') + # attn at mamba layer 1 -> GPT layer 0, attn at 4 -> GPT layer 1 + assert attn_map == {1: 0, 4: 1} + assert mlp_map == {2: 0, 5: 1} + assert ssm_indices == [0, 3] + + def test_alternating_pattern(self): + layer_types = ['*', '-', '*', '-', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + assert attn_map == {0: 0, 1: 2, 2: 4} + assert mlp_map == {0: 1, 1: 3, 2: 5} + assert ssm_indices == [] + + def test_mismatched_attn_mlp_count(self): + # 2 attn but 1 MLP -> should raise + layer_types = ['*', '*', '-', 'M'] + with pytest.raises(ValueError, match="must equal"): + build_layer_index_mapping(layer_types, 'gpt-to-hybrid') + + def test_unknown_direction(self): + with pytest.raises(ValueError, match="Unknown direction"): + build_layer_index_mapping(['*', '-'], 'invalid') + + +# --------------------------------------------------------------------------- +# Key helper tests +# --------------------------------------------------------------------------- + + +class TestKeyHelpers: + def test_get_layer_num(self): + assert get_layer_num_from_key('decoder.layers.5.mlp.linear_fc1.weight') == 5 + assert get_layer_num_from_key('decoder.layers.0.self_attention.linear_qkv.weight') == 0 + assert get_layer_num_from_key('decoder.layers.99.mixer.A_log') == 99 + assert get_layer_num_from_key('embedding.word_embeddings.weight') is None + + def test_replace_layer_num(self): + key = 'decoder.layers.3.mlp.linear_fc1.weight' + assert replace_layer_num(key, 3, 7) == 'decoder.layers.7.mlp.linear_fc1.weight' + + def test_is_attention_param(self): + assert is_attention_param('decoder.layers.0.self_attention.linear_qkv.weight') + assert is_attention_param('decoder.layers.0.input_layernorm.weight') + assert not is_attention_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_attention_param('decoder.layers.0.mixer.A_log') + + def test_is_mlp_param(self): + assert is_mlp_param('decoder.layers.0.mlp.linear_fc1.weight') + assert is_mlp_param('decoder.layers.0.pre_mlp_layernorm.weight') + assert not is_mlp_param('decoder.layers.0.self_attention.linear_qkv.weight') + + def test_is_ssm_param(self): + assert is_ssm_param('decoder.layers.0.mixer.A_log') + assert is_ssm_param('decoder.layers.0.mixer.in_proj.weight') + assert is_ssm_param('decoder.layers.0.mixer.conv1d.weight') + assert is_ssm_param('decoder.layers.0.mixer.D') + assert is_ssm_param('decoder.layers.0.mixer.dt_bias') + assert is_ssm_param('decoder.layers.0.mixer.norm.weight') + assert is_ssm_param('decoder.layers.0.mixer.out_proj.weight') + assert not is_ssm_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_ssm_param('decoder.layers.0.self_attention.linear_qkv.weight') + + +# --------------------------------------------------------------------------- +# SSM initialization tests +# --------------------------------------------------------------------------- + + +class TestSSMInitialization: + def test_shapes(self): + d_model = 256 + d_inner = 512 # 2 * d_model + d_state = 64 + n_groups = 4 + head_dim = 32 + n_heads = d_inner // head_dim # 16 + d_conv = 4 + conv_dim = d_inner + 2 * n_groups * d_state + + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=d_model, + mamba_d_inner=d_inner, + mamba_d_state=d_state, + mamba2_n_groups=n_groups, + mamba2_n_heads=n_heads, + mamba_head_dim=head_dim, + d_conv=d_conv, + dtype=torch.float32, + ) + + prefix = 'decoder.layers.0.mixer.' + + # in_proj: [2*d_inner + 2*n_groups*d_state + n_heads, d_model] + in_proj_out = 2 * d_inner + 2 * n_groups * d_state + n_heads + assert params[prefix + 'in_proj.weight'].shape == (in_proj_out, d_model) + + # in_proj layer norm weight + assert params[prefix + 'in_proj.layer_norm_weight'].shape == (d_model,) + + # conv1d: [conv_dim, 1, d_conv] + assert params[prefix + 'conv1d.weight'].shape == (conv_dim, 1, d_conv) + assert params[prefix + 'conv1d.bias'].shape == (conv_dim,) + + # A_log: [n_heads] + assert params[prefix + 'A_log'].shape == (n_heads,) + assert params[prefix + 'A_log'].dtype == torch.float32 + + # D: [n_heads] + assert params[prefix + 'D'].shape == (n_heads,) + assert params[prefix + 'D'].dtype == torch.float32 + + # dt_bias: [n_heads] + assert params[prefix + 'dt_bias'].shape == (n_heads,) + + # norm: [d_inner] + assert params[prefix + 'norm.weight'].shape == (d_inner,) + + # out_proj: [d_model, d_inner] + assert params[prefix + 'out_proj.weight'].shape == (d_model, d_inner) + + def test_A_log_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + A_log = params['decoder.layers.0.mixer.A_log'] + # A was uniform in (1, 16), so A_log should be in (log(1), log(16)) = (0, 2.77) + assert (A_log >= 0).all() + assert (A_log <= math.log(16) + 0.01).all() + + def test_D_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + D = params['decoder.layers.0.mixer.D'] + assert torch.allclose(D, torch.ones_like(D)) + + def test_conv1d_bias_zeros(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + bias = params['decoder.layers.0.mixer.conv1d.bias'] + assert torch.allclose(bias, torch.zeros_like(bias)) + + def test_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + norm = params['decoder.layers.0.mixer.norm.weight'] + assert torch.allclose(norm, torch.ones_like(norm)) + + def test_layer_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + ln = params['decoder.layers.0.mixer.in_proj.layer_norm_weight'] + assert torch.allclose(ln, torch.ones_like(ln)) + + def test_different_layer_idx(self): + params = initialize_ssm_layer_params( + layer_idx=7, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + assert 'decoder.layers.7.mixer.A_log' in params + assert 'decoder.layers.0.mixer.A_log' not in params + + +# --------------------------------------------------------------------------- +# Synthetic GPT checkpoint builder +# --------------------------------------------------------------------------- + + +def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): + """Create a minimal synthetic GPT state dict for testing.""" + state_dict = OrderedDict() + + # Embeddings + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, d_model, dtype=dtype) + + # Transformer layers + for i in range(num_layers): + prefix = f'decoder.layers.{i}.' + # Attention + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + d_model, d_model, dtype=dtype + ) + # MLP + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + d_model, 4 * d_model, dtype=dtype + ) + + # Final norm + state_dict['decoder.final_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + + # Output layer + state_dict['output_layer.weight'] = torch.randn(1000, d_model, dtype=dtype) + + return state_dict + + +# --------------------------------------------------------------------------- +# Full conversion tests +# --------------------------------------------------------------------------- + + +class TestGPTToHybridConversion: + def setup_method(self): + self.d_model = 64 + self.num_gpt_layers = 2 + self.pattern = "M*-M*-" # 6 total: 2 SSM, 2 attn, 2 MLP + self.gpt_state = make_synthetic_gpt_checkpoint(self.num_gpt_layers, self.d_model) + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def test_shared_params_preserved(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # Embeddings should be identical + assert torch.equal( + result['embedding.word_embeddings.weight'], + self.gpt_state['embedding.word_embeddings.weight'], + ) + # Output layer + assert torch.equal(result['output_layer.weight'], self.gpt_state['output_layer.weight']) + + def test_final_norm_renamed(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + assert 'decoder.final_norm.weight' in result + assert 'decoder.final_layernorm.weight' not in result + assert torch.equal( + result['decoder.final_norm.weight'], self.gpt_state['decoder.final_layernorm.weight'] + ) + + def test_attention_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # GPT layer 0 attn -> Mamba layer 1 (first '*' in M*-M*-) + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.0.self_attention.linear_qkv.weight'], + ) + # GPT layer 1 attn -> Mamba layer 4 (second '*') + assert torch.equal( + result['decoder.layers.4.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # GPT layer 0 MLP -> Mamba layer 2 (first '-') + assert torch.equal( + result['decoder.layers.2.mlp.linear_fc1.weight'], + self.gpt_state['decoder.layers.0.mlp.linear_fc1.weight'], + ) + # GPT layer 1 MLP -> Mamba layer 5 (second '-') + assert torch.equal( + result['decoder.layers.5.mlp.linear_fc2.weight'], + self.gpt_state['decoder.layers.1.mlp.linear_fc2.weight'], + ) + + def test_ssm_layers_initialized(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + # SSM layers at index 0 and 3 + for idx in [0, 3]: + prefix = f'decoder.layers.{idx}.mixer.' + assert prefix + 'A_log' in result + assert prefix + 'D' in result + assert prefix + 'dt_bias' in result + assert prefix + 'conv1d.weight' in result + assert prefix + 'conv1d.bias' in result + assert prefix + 'in_proj.weight' in result + assert prefix + 'norm.weight' in result + assert prefix + 'out_proj.weight' in result + + def test_layer_count_mismatch_raises(self): + # Pattern with 3 attn but only 2 GPT layers + layer_types = parse_hybrid_layer_pattern("M*-*-*-") + with pytest.raises(ValueError, match="layers"): + convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) + + +class TestHybridToGPTConversion: + def setup_method(self): + self.d_model = 64 + self.pattern = "M*-M*-" + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def _make_mamba_state(self): + """Build a synthetic Mamba state dict matching pattern M*-M*-.""" + state_dict = OrderedDict() + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, self.d_model) + state_dict['output_layer.weight'] = torch.randn(1000, self.d_model) + state_dict['decoder.final_norm.weight'] = torch.randn(self.d_model) + + layer_types = parse_hybrid_layer_pattern(self.pattern) + d_inner = self.d_model * 2 + n_heads = self.args.mamba2_n_heads + n_groups = self.args.mamba2_n_groups + d_state = self.args.mamba_d_state + + for i, lt in enumerate(layer_types): + prefix = f'decoder.layers.{i}.' + if lt == 'M': + # SSM params + ssm = initialize_ssm_layer_params( + i, self.d_model, d_inner, d_state, n_groups, n_heads, self.args.mamba2_head_dim + ) + state_dict.update(ssm) + elif lt == '*': + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * self.d_model, self.d_model + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + self.d_model, self.d_model + ) + elif lt == '-': + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * self.d_model, self.d_model + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + self.d_model, 4 * self.d_model + ) + + return state_dict + + def test_final_norm_renamed_back(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + assert 'decoder.final_layernorm.weight' in result + assert 'decoder.final_norm.weight' not in result + + def test_ssm_params_discarded(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # No SSM keys should remain + for key in result: + assert 'mixer.' not in key, f"SSM key not discarded: {key}" + + def test_attention_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 1 (first *) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + # Mamba layer 4 (second *) -> GPT layer 1 + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.4.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 2 (first -) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.mlp.linear_fc1.weight'], + mamba_state['decoder.layers.2.mlp.linear_fc1.weight'], + ) + + def test_gpt_layer_count(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) + + # Should have 2 GPT layers (layers 0 and 1) + layer_nums = set() + for key in result: + lnum = get_layer_num_from_key(key) + if lnum is not None: + layer_nums.add(lnum) + assert layer_nums == {0, 1} + + +# --------------------------------------------------------------------------- +# Round-trip test: GPT -> Hybrid -> GPT; using Mamba as the example below +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + def test_gpt_hybrid_gpt_preserves_weights(self): + """Converting GPT -> Hybrid -> GPT should preserve all attention & MLP weights.""" + d_model = 64 + num_layers = 2 + pattern = "M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + # GPT -> Hybrid + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) + + # Hybrid -> GPT + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) + + # Check all original GPT keys are preserved + for key in original_gpt: + # final_layernorm is renamed in the round trip + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key after round-trip: {key}" + assert torch.equal( + original_gpt[key], recovered_gpt[key] + ), f"Weight mismatch after round-trip for {key}" + + # Check final_layernorm was properly renamed back + assert torch.equal( + original_gpt['decoder.final_layernorm.weight'], + recovered_gpt['decoder.final_layernorm.weight'], + ) + + def test_round_trip_different_pattern(self): + """Test with a pattern that has more SSM layers.""" + d_model = 64 + num_layers = 3 + pattern = "M*-M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) + + for key in original_gpt: + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key: {key}" + assert torch.equal(original_gpt[key], recovered_gpt[key]), f"Mismatch for {key}" + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist tests +# --------------------------------------------------------------------------- + + +class TestPatternWhitelist: + """validate_pattern_gpt_compatible rejects hybrid patterns GPTModel can't express.""" + + def test_accepts_mamba_attn_mlp(self): + # Standard hybrid with equal attn/MLP counts. + layer_types = parse_hybrid_layer_pattern("M*-M*-M*-") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_pure_transformer_pattern(self): + layer_types = parse_hybrid_layer_pattern("*-*-*-") + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + + def test_accepts_pure_ssm_pattern(self): + # Pure-SSM models have no attention/MLP, so trivially GPT-compatible + # in the pattern sense (the GPT side would be empty). + layer_types = parse_hybrid_layer_pattern("MMMM") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_moe_pattern(self): + # MoE layers ('E') round-trip through the converter as long as every + # MLP-bearing position is the same kind. + layer_types = parse_hybrid_layer_pattern("M*EM*EM*E") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_accepts_pure_attn_moe_pattern(self): + # No SSM, alternating attn/MoE — i.e. a Mixtral-like GPT. + layer_types = parse_hybrid_layer_pattern("*E*E*E") + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + + def test_rejects_mixed_dense_and_moe(self): + # GPT layers must be uniform: '-' (dense) and 'E' (MoE) cannot both + # appear in the same pattern. + layer_types = parse_hybrid_layer_pattern("M*-M*E") + with pytest.raises(ValueError, match="uniform"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_rejects_gdn_symbol(self): + layer_types = parse_hybrid_layer_pattern("G*-*-") + with pytest.raises(ValueError, match="not GPT-compatible"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_rejects_unequal_attn_mlp(self): + layer_types = parse_hybrid_layer_pattern("M**-") # 2 attn, 1 MLP + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_unequal_attn_moe_also_rejected(self): + # Same uniformity check, but with MoE — 2 attn, 1 MoE. + layer_types = parse_hybrid_layer_pattern("M**E") + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') + + def test_error_lists_offending_symbols(self): + # 'G' is still rejected; the error message should mention it. + layer_types = parse_hybrid_layer_pattern("M*-G") + with pytest.raises(ValueError) as exc: + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') + assert 'G' in str(exc.value) + + +class TestSourceArgsWhitelist: + """validate_source_args_gpt_compatible rejects source checkpoints with + non-GPT-expressible features.""" + + def _ok_args(self, **overrides): + """Build a minimal args namespace that mimics a plain GPT/hybrid + training run. Any GPT-incompatible flags default to their + "off" value.""" + base = dict( + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + experimental_attention_variant=None, + linear_attention_freq=None, + heterogeneous_block_specs=False, + heterogeneous_layers_config_path=None, + heterogeneous_layers_config_encoded_json=None, + multi_latent_attention=False, + mtp_num_layers=None, + ) + base.update(overrides) + return argparse.Namespace(**base) + + def test_accepts_plain_gpt_args(self): + validate_source_args_gpt_compatible(self._ok_args(), 'gpt-to-hybrid') + + def test_none_args_is_noop(self): + # Dist checkpoints sometimes have no cached args blob. + validate_source_args_gpt_compatible(None, 'gpt-to-hybrid') + + def test_accepts_missing_optional_fields(self): + # Older checkpoints may not have every field; the validator should + # silently skip fields it doesn't find. + minimal = argparse.Namespace(num_moe_experts=None) + validate_source_args_gpt_compatible(minimal, 'hybrid-to-gpt') + + def test_accepts_moe_args(self): + # MoE keys live under decoder.layers..mlp.* and round-trip as-is. + validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-hybrid') + + def test_accepts_shared_expert_args(self): + # Shared experts also live under mlp.shared_experts.* and round-trip. + validate_source_args_gpt_compatible( + self._ok_args(num_moe_experts=8, moe_shared_expert_intermediate_size=4096), + 'gpt-to-hybrid', + ) + + def test_rejects_moe_layer_freq_list(self): + # Heterogeneous interleaving (some dense, some MoE) breaks GPT uniformity. + with pytest.raises(ValueError, match="interleaved"): + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-hybrid' + ) + + def test_accepts_moe_layer_freq_1(self): + validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-hybrid') + + def test_accepts_moe_layer_freq_all_ones_list(self): + # An all-1s list is uniform (every layer is the same kind) and accepted. + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 1, 1, 1]), 'gpt-to-hybrid' + ) + + def test_rejects_experimental_attention(self): + with pytest.raises(ValueError, match="experimental attention"): + validate_source_args_gpt_compatible( + self._ok_args(experimental_attention_variant='gated_delta_net'), 'gpt-to-hybrid' + ) + + def test_rejects_linear_attention(self): + with pytest.raises(ValueError, match="linear attention"): + validate_source_args_gpt_compatible( + self._ok_args(linear_attention_freq=4), 'gpt-to-hybrid' + ) + + def test_rejects_heterogeneous_block_specs(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_block_specs=True), 'hybrid-to-gpt' + ) + + def test_rejects_heterogeneous_config_path(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), 'gpt-to-hybrid' + ) + + def test_rejects_mla(self): + with pytest.raises(ValueError, match="Multi-Latent"): + validate_source_args_gpt_compatible( + self._ok_args(multi_latent_attention=True), 'gpt-to-hybrid' + ) + + def test_rejects_mtp(self): + with pytest.raises(ValueError, match="Multi-Token Prediction"): + validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-hybrid') + + def test_reports_multiple_reasons(self): + # Both heterogeneous moe_layer_freq and MLA set — both should be reported. + with pytest.raises(ValueError) as exc: + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), 'gpt-to-hybrid' + ) + msg = str(exc.value) + assert 'interleaved' in msg + assert 'Multi-Latent' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py new file mode 100644 index 00000000000..8102b7018ae --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py @@ -0,0 +1,418 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Parallelism-matrix integration tests for gpt_hybrid_conversion.py. + +The converter operates on dist (``torch_dist`` / ``fsdp_dtensor``) checkpoints +only — DCP's metadata stores each tensor's ``global_shape``, so the on-disk +TP / PP / FSDP layout is abstracted away from the conversion logic. We +synthesize a DCP checkpoint via a single-rank ``dcp.save`` and round-trip +GPT -> Hybrid -> GPT through the conversion CLI, asserting attention and MLP +weights match exactly. + +Each scenario is run as a distinct test to document the supported matrix and +catch regressions in dispatch logic. Designed to run on a single-GPU node via +SLURM (no torchrun needed). +""" + +import argparse +import copy +import os +import shutil +import sys +import tempfile +from collections import OrderedDict +from types import SimpleNamespace + +import pytest +import torch +import torch.distributed as dist + +# Make the conversion tool importable under both `python ` and `pytest`. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + +from gpt_hybrid_conversion import main as conversion_main + + +# These scenarios are SYNTHETIC and single-rank by design: each one writes a +# tiny synthetic DCP checkpoint and round-trips it through the converter on +# rank 0. They share the default torch.distributed process group with whatever +# harness launched pytest. When that default PG is multi-rank (e.g. Megatron's +# CI/CD initialises NCCL with world_size>1 before pytest collection), the +# dcp.save/dcp.load collectives stall: each rank has its own +# tempfile.mkdtemp() path and its own torch.randn() tensors, so the metadata +# coordination across ranks never converges and the NCCL watchdog kills the +# job after 10 minutes (see ProcessGroupNCCL ALLGATHER timeout). +# +# Multi-rank coverage lives in test_distributed_round_trip.py, which uses a +# fresh single-rank gloo subgroup per scenario via SLURM/srun in +# run_slurm_ckpt_convert_tests.sh. Skip these synthetic tests whenever the +# default PG is already multi-rank. +@pytest.fixture(autouse=True) +def _skip_when_multi_rank_pg(): + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + pytest.skip( + "Synthetic single-rank tests skipped under a multi-rank default " + "process group; multi-rank coverage is in " + "test_distributed_round_trip.py." + ) + + +# --------------------------------------------------------------------------- +# Synthetic-checkpoint helpers +# --------------------------------------------------------------------------- + + +def make_checkpoint_args( + num_layers=4, + hidden_size=128, + num_attention_heads=4, + seq_length=256, + max_position_embeddings=256, + iteration=100, + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, +): + """Build a minimal checkpoint 'args' namespace mirroring Megatron's. + + Set ``num_moe_experts`` to make the source/target a MoE GPT; the converter + will then pass the MoE config through unchanged so the round-trip stays + structurally consistent. + """ + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=hidden_size * 4, + seq_length=seq_length, + max_position_embeddings=max_position_embeddings, + iteration=iteration, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + moe_layer_freq=1, + ) + + +def make_gpt_state_dict( + num_layers, + hidden_size, + vocab_size=1024, + dtype=torch.float32, + num_moe_experts=None, + shared_expert_size=None, +): + """Create a minimal GPT state dict with the standard Megatron keys. + + Dense MLP layout (default): ``mlp.linear_fc1`` / ``mlp.linear_fc2``. + MoE layout (``num_moe_experts`` set): ``mlp.router`` plus N experts under + ``mlp.experts.local_experts..linear_fc{1,2}``, optionally a shared + expert under ``mlp.shared_experts.linear_fc{1,2}``. These are exactly the + keys Megatron writes for non-grouped-GEMM MoE — they all live under + ``decoder.layers..mlp.*`` so the converter ferries them through with no + MoE-specific code. + """ + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + + if num_moe_experts is None: + # Dense MLP + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + else: + # MoE: router + N experts (+ optional shared expert) + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + if shared_expert_size is not None: + sp = p + 'mlp.shared_experts.' + sd[sp + 'linear_fc1.weight'] = torch.randn( + shared_expert_size, hidden_size, dtype=dtype + ) + sd[sp + 'linear_fc2.weight'] = torch.randn( + hidden_size, shared_expert_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd + + +# --------------------------------------------------------------------------- +# Dist (torch_dist / fsdp_dtensor) fixture builders +# --------------------------------------------------------------------------- + + +def _save_dist_checkpoint( + root_dir, full_sd, ckpt_args, iteration=100, prefix='model.', backend='torch_dist' +): + """Write a full state dict as a single-rank DCP checkpoint. + + From the converter's POV, this is indistinguishable from a multi-rank + TP+PP+FSDP save: DCP stores each tensor's global shape in its metadata + and the read planner reassembles the full tensor regardless of how many + processes wrote it. + """ + from dist_checkpoint_io import ( + ensure_single_rank_process_group, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + + ensure_single_rank_process_group() + + iter_dir = os.path.join(root_dir, f'iter_{iteration:07d}') + common_state = { + 'args': copy.deepcopy(ckpt_args), + 'checkpoint_version': 3.0, + 'iteration': iteration, + } + save_dist_checkpoint_full(full_sd, common_state, iter_dir, model_prefix=prefix, backend=backend) + write_latest_iteration_marker(iter_dir, iteration) + + +def _load_converted_dist(ckpt_dir): + """Read a dist-format converted checkpoint back into a full state dict.""" + from dist_checkpoint_io import load_dist_checkpoint_full + + sd, common, prefix, backend, iteration = load_dist_checkpoint_full(ckpt_dir) + return sd, common.get('args', None) + + +# --------------------------------------------------------------------------- +# Core scenario runner +# --------------------------------------------------------------------------- + + +def _run_scenario( + label, + source_format, + target_format, + num_layers=4, + hidden_size=128, + pattern="M*-M*-M*-M*-", + source_prefix='model.', + num_moe_experts=None, + shared_expert_size=None, +): + """Build a GPT source ckpt, convert GPT->Hybrid->GPT, verify round-trip.""" + print(f"\n=== {label} ===") + print(f" source={source_format} (prefix='{source_prefix}')") + print(f" target={target_format}") + if num_moe_experts is not None: + print(f" MoE: num_experts={num_moe_experts} shared={shared_expert_size}") + + tmpdir = tempfile.mkdtemp(prefix=f'gpt_hybrid_{label.replace(" ", "_")}_') + try: + src_gpt_dir = os.path.join(tmpdir, 'gpt_src') + hybrid_dir = os.path.join(tmpdir, 'hybrid_mid') + dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') + + ckpt_args = make_checkpoint_args( + num_layers=num_layers, + hidden_size=hidden_size, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=shared_expert_size, + ) + gpt_sd = make_gpt_state_dict( + num_layers, + hidden_size, + num_moe_experts=num_moe_experts, + shared_expert_size=shared_expert_size, + ) + + _save_dist_checkpoint( + src_gpt_dir, gpt_sd, ckpt_args, prefix=source_prefix, backend=source_format + ) + + common_kwargs = dict( + hybrid_layer_pattern=pattern, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format=target_format, + ) + + # --- GPT -> Hybrid --- + conversion_main( + argparse.Namespace( + direction='gpt-to-hybrid', + load_dir=src_gpt_dir, + save_dir=hybrid_dir, + **common_kwargs, + ) + ) + + # --- Hybrid -> GPT --- + conversion_main( + argparse.Namespace( + direction='hybrid-to-gpt', + load_dir=hybrid_dir, + save_dir=dst_gpt_dir, + **common_kwargs, + ) + ) + + # --- Verify --- + recovered_sd, _ = _load_converted_dist(dst_gpt_dir) + # The hybrid->gpt step renames decoder.final_norm -> decoder.final_layernorm, + # mirroring the original GPT key. So recovered_sd should have the same + # keys and tensor values as gpt_sd. + + mismatches = [] + for key, original in gpt_sd.items(): + if key not in recovered_sd: + mismatches.append(f"MISSING: {key}") + continue + if not torch.equal(original, recovered_sd[key]): + max_diff = (original - recovered_sd[key]).abs().max().item() + mismatches.append(f"MISMATCH: {key} (max_diff={max_diff})") + + if mismatches: + for m in mismatches[:10]: + print(f" FAIL: {m}") + raise AssertionError(f"{label} failed with {len(mismatches)} weight mismatches") + + # SSM keys must be absent in the final GPT output. + assert not any('mixer.' in k for k in recovered_sd), ( + f"SSM keys leaked into final GPT output: " + f"{[k for k in recovered_sd if 'mixer.' in k][:5]}" + ) + + print(f"PASSED: {label}") + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Test cases — one per (source backend, target backend, pattern) combo +# --------------------------------------------------------------------------- + + +def test_torch_dist_roundtrip(): + _run_scenario("torch_dist roundtrip", 'torch_dist', 'torch_dist') + + +def test_fsdp_dtensor_roundtrip(): + _run_scenario("fsdp_dtensor roundtrip", 'fsdp_dtensor', 'fsdp_dtensor') + + +def test_fsdp_dtensor_prefix(): + """fsdp_dtensor backend uses the 'model.module.' key prefix — verify we + auto-detect and strip it correctly.""" + _run_scenario( + "fsdp_dtensor prefix", 'fsdp_dtensor', 'fsdp_dtensor', source_prefix='model.module.' + ) + + +def test_torch_dist_alternating_pattern(): + """Pure transformer pattern (no SSM) round-trips.""" + _run_scenario("torch_dist alternating", 'torch_dist', 'torch_dist', pattern="*-*-*-*-") + + +def test_torch_dist_dense_ssm_pattern(): + """Dense SSM pattern still round-trips on the attn/MLP layers.""" + _run_scenario("torch_dist dense SSM", 'torch_dist', 'torch_dist', pattern="MM*-MM*-MM*-MM*-") + + +def test_torch_dist_moe_roundtrip(): + """MoE GPT (Mixtral-style) round-trips through an 'E'-bearing pattern. + + Source has num_moe_experts=4 and writes mlp.router / mlp.experts.* keys. + The hybrid pattern 'M*EM*EM*E' has 3 'E' positions, one per source layer. + The converter should ferry the router + every per-expert tensor through + verbatim — no MoE-specific code path involved. + """ + _run_scenario( + "torch_dist MoE roundtrip", + 'torch_dist', + 'torch_dist', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + ) + + +def test_torch_dist_moe_with_shared_experts(): + """MoE + shared experts round-trip together (mlp.shared_experts.* keys).""" + _run_scenario( + "torch_dist MoE+shared", + 'torch_dist', + 'torch_dist', + num_layers=3, + hidden_size=64, + pattern="*E*E*E", + num_moe_experts=4, + shared_expert_size=64 * 2, + ) + + +def test_fsdp_dtensor_moe_roundtrip(): + """MoE round-trips through fsdp_dtensor (covers the 'model.module.' prefix + case combined with MoE keys).""" + _run_scenario( + "fsdp_dtensor MoE roundtrip", + 'fsdp_dtensor', + 'fsdp_dtensor', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + source_prefix='model.module.', + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == '__main__': + print("=" * 60) + print("GPT <-> Hybrid Conversion Parallelism Matrix Tests") + print("=" * 60) + + test_torch_dist_roundtrip() + test_fsdp_dtensor_roundtrip() + test_fsdp_dtensor_prefix() + test_torch_dist_alternating_pattern() + test_torch_dist_dense_ssm_pattern() + test_torch_dist_moe_roundtrip() + test_torch_dist_moe_with_shared_experts() + test_fsdp_dtensor_moe_roundtrip() + + print("=" * 60) + print("ALL PARALLELISM MATRIX TESTS PASSED") + print("=" * 60) diff --git a/tools/checkpoint/dist_checkpoint_io.py b/tools/checkpoint/dist_checkpoint_io.py new file mode 100644 index 00000000000..33f0814cc5a --- /dev/null +++ b/tools/checkpoint/dist_checkpoint_io.py @@ -0,0 +1,271 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Distributed checkpoint I/O helpers for structural model-conversion tools. + +Provides format detection, model-free full-tensor loading, and single-rank +saving for Megatron-LM distributed checkpoints (``torch_dist`` and +``fsdp_dtensor`` backends). This lets conversion tools operate on +TP+PP+FSDP-trained checkpoints without needing to instantiate the model. + +The key observation is that PyTorch DCP stores each logical parameter with a +``global_shape`` in its metadata, and the TP / PP / FSDP slicing is just an +on-disk layout detail handled by the read planner. Loading into a plain +``torch.empty(global_shape)`` state dict on rank 0 therefore yields fully +gathered tensors regardless of the parallelism the checkpoint was trained with. +""" + +import os +from collections import OrderedDict +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint import ( + DefaultLoadPlanner, + DefaultSavePlanner, + FileSystemReader, + FileSystemWriter, +) +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + TensorStorageMetadata, +) + +from megatron.core.dist_checkpointing.core import ( + CheckpointingConfig, + maybe_load_config, + save_config, +) +from megatron.core.dist_checkpointing.strategies.common import ( + load_common, + save_common, +) + + +FORMAT_TORCH_DIST = 'torch_dist' +FORMAT_FSDP_DTENSOR = 'fsdp_dtensor' +DIST_FORMATS = (FORMAT_TORCH_DIST, FORMAT_FSDP_DTENSOR) + +# Prefixes under which model weights may be keyed in a dist checkpoint. +_KNOWN_MODEL_PREFIXES = ('model.module.module.', 'model.module.', 'model.', '') +# Well-known bare-key suffixes we probe for when detecting the prefix. +_PROBE_SUFFIXES = ( + 'embedding.word_embeddings.weight', + 'decoder.layers.', + 'decoder.final_norm.', + 'decoder.final_layernorm.', + 'output_layer.weight', +) +# Keys that identify non-model state we drop during architecture conversion. +_NON_MODEL_TOP_LEVEL_PREFIXES = ( + 'optimizer.', + 'rng_state', + 'rerun_state_machine_state', +) + + +def resolve_checkpoint_subdir(load_dir): + """Return ``(ckpt_dir, iteration)``. + + Megatron writes checkpoints either flat or under ``iter_XXXXXXX/``. This + picks the right directory and reports the iteration when it can be + determined. + """ + if os.path.exists(os.path.join(load_dir, 'metadata.json')): + return load_dir, None + + latest_iter = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') + if os.path.exists(latest_iter): + with open(latest_iter, 'r') as f: + iteration = int(f.read().strip()) + iter_dir = os.path.join(load_dir, f'iter_{iteration:07d}') + if os.path.isdir(iter_dir): + return iter_dir, iteration + + return load_dir, None + + +def detect_checkpoint_format(load_dir): + """Return one of ``{'torch_dist', 'fsdp_dtensor'}``. + + Raises ``ValueError`` if the directory looks like the legacy + ``mp_rank_XX`` layout (no longer supported) or doesn't match any known + dist-checkpoint metadata. + """ + ckpt_dir, _ = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is not None: + return config.sharded_backend + + if os.path.isdir(ckpt_dir) and any( + name.startswith('mp_rank_') for name in os.listdir(ckpt_dir) + ): + raise ValueError( + f"{load_dir} looks like a legacy mp_rank_XX checkpoint. " + f"Legacy format is no longer supported — convert to torch_dist first." + ) + + raise ValueError(f"Unrecognized checkpoint format at {load_dir}") + + +def ensure_single_rank_process_group(): + """Initialize a 1-rank gloo process group if one isn't already up. + + DCP requires a default process group; this lets the conversion tool run + in a plain ``python`` invocation (no ``torchrun`` needed). + """ + if not dist.is_available(): + raise RuntimeError("torch.distributed is not available.") + if dist.is_initialized(): + return + os.environ.setdefault('MASTER_ADDR', '127.0.0.1') + os.environ.setdefault('MASTER_PORT', '29500') + os.environ.setdefault('RANK', '0') + os.environ.setdefault('WORLD_SIZE', '1') + os.environ.setdefault('LOCAL_RANK', '0') + dist.init_process_group(backend='gloo', rank=0, world_size=1) + + +def detect_model_prefix(keys): + """Return the prefix under which model weights live in ``keys``. + + Looks for a recognizable suffix (``embedding.word_embeddings.weight``, + ``decoder.layers.``, etc.) and returns the matching prefix. Falls back + to ``''`` if nothing obvious is found. + """ + keys = list(keys) + for prefix in _KNOWN_MODEL_PREFIXES: + for suffix in _PROBE_SUFFIXES: + probe = prefix + suffix + for key in keys: + if key.startswith(probe): + return prefix + return '' + + +def _is_non_model_key(bare_key): + if bare_key.startswith(_NON_MODEL_TOP_LEVEL_PREFIXES): + return True + # _extra_state blobs are TE per-module state; they are tied to a specific + # TP/parallelism configuration and aren't meaningful after a structural + # model conversion, so we drop them. + if '_extra_state' in bare_key: + return True + return False + + +def load_dist_checkpoint_full(load_dir): + """Load a dist checkpoint and return fully-gathered model weights. + + Returns: + model_state_dict (OrderedDict[str, torch.Tensor]): bare keys, full + tensors on CPU. Optimizer state, RNG state, and ``_extra_state`` + blobs are filtered out. + common_state (dict): contents of ``common.pt`` (e.g. ``args``). + model_prefix (str): the prefix we stripped (re-apply on save). + backend (str): ``'torch_dist'`` or ``'fsdp_dtensor'``. + iteration (int or None): iteration number if discoverable. + """ + ensure_single_rank_process_group() + + ckpt_dir, iteration = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is None: + raise ValueError( + f"{load_dir} is not a distributed checkpoint (no metadata.json)" + ) + backend = config.sharded_backend + + reader = FileSystemReader(ckpt_dir) + metadata = reader.read_metadata() + + model_prefix = detect_model_prefix(metadata.state_dict_metadata.keys()) + + raw_state_dict = {} + for key, md in metadata.state_dict_metadata.items(): + if not isinstance(md, TensorStorageMetadata): + continue + if model_prefix and not key.startswith(model_prefix): + continue + bare_key = key[len(model_prefix):] if model_prefix else key + if _is_non_model_key(bare_key): + continue + raw_state_dict[key] = torch.empty( + md.size, dtype=md.properties.dtype, device='cpu' + ) + + if not raw_state_dict: + raise ValueError( + f"No model tensors found in {ckpt_dir} (detected prefix " + f"'{model_prefix}', backend '{backend}')." + ) + + dcp.load(raw_state_dict, storage_reader=reader, planner=DefaultLoadPlanner()) + + model_state_dict = OrderedDict() + for key, tensor in raw_state_dict.items(): + bare_key = key[len(model_prefix):] if model_prefix else key + model_state_dict[bare_key] = tensor + + common_state = {} + try: + common_state = load_common(ckpt_dir) + except Exception: + pass + + return model_state_dict, common_state, model_prefix, backend, iteration + + +def save_dist_checkpoint_full( + model_state_dict, + common_state, + save_dir, + model_prefix='model.', + backend=FORMAT_TORCH_DIST, +): + """Save a fully-gathered state dict as a distributed checkpoint. + + The output is written as a single-rank, fully-replicated DCP checkpoint + plus ``common.pt`` and ``metadata.json``. A downstream Megatron training + job reads it back through ``dist_checkpointing.load()`` with its own + sharded_state_dict template — TP+PP+FSDP resharding happens transparently + on load, since the on-disk tensors carry their full logical shape. + """ + ensure_single_rank_process_group() + + os.makedirs(save_dir, exist_ok=True) + + raw_state_dict = OrderedDict() + for bare_key, tensor in model_state_dict.items(): + full_key = f"{model_prefix}{bare_key}" if model_prefix else bare_key + raw_state_dict[full_key] = tensor.contiguous() if tensor.is_contiguous() else tensor.contiguous() + + writer = FileSystemWriter(save_dir) + dcp.save( + state_dict=raw_state_dict, + storage_writer=writer, + planner=DefaultSavePlanner(), + ) + + if common_state: + save_common(common_state, save_dir) + + if dist.get_rank() == 0: + save_config(CheckpointingConfig(sharded_backend=backend), save_dir) + dist.barrier() + + +def write_latest_iteration_marker(save_dir, iteration): + """Mirror the legacy ``latest_checkpointed_iteration.txt`` convention. + + When ``save_dir`` points at a top-level checkpoint root with an + ``iter_XXXXXXX/`` subdirectory, the tracker file lets Megatron auto-find + the latest iteration on load. + """ + parent = os.path.dirname(save_dir.rstrip('/')) or save_dir + if os.path.basename(save_dir.rstrip('/')).startswith('iter_'): + tracker = os.path.join(parent, 'latest_checkpointed_iteration.txt') + with open(tracker, 'w') as f: + f.write(str(iteration)) diff --git a/tools/checkpoint/gpt_hybrid_conversion.py b/tools/checkpoint/gpt_hybrid_conversion.py new file mode 100644 index 00000000000..94236daeed9 --- /dev/null +++ b/tools/checkpoint/gpt_hybrid_conversion.py @@ -0,0 +1,901 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +GPT <-> Hybrid Checkpoint Conversion Tool +========================================= + +Directly converts checkpoints between GPTModel (homogeneous Transformer) and +HybridModel (hybrid Mamba+Transformer) without going through HuggingFace as an +intermediary. + +Supported directions: + gpt-to-hybrid : Convert a GPT checkpoint to Hybrid format. + hybrid-to-gpt : Convert a Hybrid checkpoint to GPT format. + +How the hybrid layer pattern maps GPT layers (gpt-to-hybrid): + - Each GPT layer contains both attention and MLP sub-layers. + - The target hybrid model's hybrid_layer_pattern specifies per-layer types: + M = Mamba SSM layer + * = Attention-only layer + - = MLP-only layer (dense) + E = MoE MLP-only layer (router + experts; supports EP) + G = GDN layer (not currently mapped) + - GPT layer i's attention params map to the i-th '*' layer in the pattern. + - GPT layer i's MLP/MoE params map to the i-th MLP-bearing position + ('-' or 'E') in the pattern. Dense ('-') and MoE ('E') cannot be mixed: + GPT layers are uniform. + - The number of '*' positions and MLP-bearing positions must each equal + the number of GPT layers. + - Mamba SSM ('M') layers have no GPT equivalent and are initialized from + scratch using standard Mamba initialization. + +How MoE / Expert Parallelism (EP) works through the converter: + - GPTModel can run with MoE (Mixtral-style: every layer has a router and + N local experts). State-dict keys live under + `decoder.layers..mlp.{router,experts,shared_experts}.*`. + - Hybrid 'E' layers use the same key naming, so MoE tensors round-trip + verbatim — no expert collapsing, no router init, no per-expert work. + - EP-sharded checkpoints load through DCP transparently because each + tensor's `global_shape` is in the metadata, regardless of how many + EP / TP / PP / FSDP ranks wrote it. + - Use a pattern like 'M*EM*EM*E' to pair Mamba/Attn/MoE-MLP per stage. + +What happens to SSM parameters: + gpt-to-hybrid: SSM layers (M) are initialized from scratch: + - A_log: log(uniform(1, 16)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones + - out_proj.weight: kaiming_uniform(a=sqrt(5)) + - norm.weight: ones + hybrid-to-gpt: SSM layers are discarded with a warning. + +Supported checkpoint formats: + - torch_dist : Megatron distributed checkpoint (TP + PP + FSDP). + - fsdp_dtensor : FSDP DTensor export (TP + PP + FSDP). + + PyTorch DCP gathers TP/PP/FSDP shards via the checkpoint's global-shape + metadata, so no explicit TP/PP/DP config is needed on input. The input + format is auto-detected; the output format defaults to the input format. + + The legacy ``mp_rank_XX/model_optim_rng.pt`` layout is not supported — + convert old checkpoints to ``torch_dist`` first. + +GPT compatibility whitelist (safeguard): + GPTModel is a strict homogeneous transformer (self-attention + MLP per + layer, standard linear_qkv / linear_fc1 / linear_fc2 state-dict keys). + The converter fails fast if either side uses features that GPTModel + cannot express. + + Rejected pattern symbols: 'G' (GDN), 'D' (DS-attention), 'E' (MoE). + Allowed: 'M' (Mamba SSM), '*' (attention), '-' (MLP). The number of + '*' and '-' layers must be equal. + + Rejected source-args features (checked against the args stored in the + source checkpoint): + - num_moe_experts / moe_shared_expert_intermediate_size / moe_layer_freq + - experimental_attention_variant (gated_delta_net, dsa, ...) + - linear_attention_freq + - heterogeneous_block_specs / heterogeneous_layers_config_path + - multi_latent_attention (MLA) + - mtp_num_layers (Multi-Token Prediction) + + See `validate_pattern_gpt_compatible` and + `validate_source_args_gpt_compatible` for the exact rules. + +Example commands: + # GPT -> Hybrid (TP+PP+FSDP dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction gpt-to-hybrid \\ + --load-dir /path/to/gpt-dist-checkpoint \\ + --save-dir /path/to/hybrid-dist-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 + + # Hybrid -> GPT (dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction hybrid-to-gpt \\ + --load-dir /path/to/hybrid-dist-checkpoint \\ + --save-dir /path/to/gpt-dist-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 +""" + +import argparse +import copy +import math +import os +import re +from collections import OrderedDict + +import torch + +from dist_checkpoint_io import ( + DIST_FORMATS, + FORMAT_TORCH_DIST, + detect_checkpoint_format, + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, +) + + +# --------------------------------------------------------------------------- +# Hybrid layer pattern parsing (standalone, no Megatron imports needed) +# --------------------------------------------------------------------------- + +VALID_LAYER_SYMBOLS = {'M', 'G', '*', '-', 'E'} + +# Layer symbols GPTModel can emit or absorb: +# '*' : standard self-attention layer (MHA / GQA / MQA) +# '-' : standard (optionally gated) dense MLP layer +# 'E' : MoE MLP layer. Both sides keep the keys under +# decoder.layers..mlp.{router,experts,shared_experts}.* so MoE +# tensors round-trip verbatim (see convert_gpt_to_hybrid and +# convert_hybrid_to_gpt — `is_mlp_param` already matches `mlp.*`). +# SSM ('M') has no GPT equivalent and is initialized from scratch / +# discarded (see convert_gpt_to_hybrid / convert_hybrid_to_gpt). +# 'G' (GDN) and 'D' (DS-attention) are not currently mapped — they would +# need separate key-naming work. Reject for now. +GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-', 'E'} + + +def parse_hybrid_layer_pattern(pattern): + """Parse a hybrid layer pattern string into a list of layer types. + + Strips MTP separators (/) and pipeline stage separators (|), returning only + the main decoder pattern as a list of single-character layer types. + + Returns: + list[str]: e.g. ['M', '*', '-', 'M', '*', '-'] + """ + # Take only the main pattern (before first '/') + main_pattern = pattern.split('/')[0] + # Remove pipeline stage separators + main_pattern = main_pattern.replace('|', '') + layer_types = list(main_pattern) + for ch in layer_types: + if ch not in VALID_LAYER_SYMBOLS: + raise ValueError( + f"Invalid layer symbol '{ch}' in pattern. " + f"Valid symbols: {VALID_LAYER_SYMBOLS}" + ) + return layer_types + + +# Pattern symbols that pair to a GPT-side MLP block. Both dense ('-') and MoE +# ('E') keep their state-dict keys under `decoder.layers..mlp.*`, so they +# round-trip identically. The pattern uniformity check +# (validate_pattern_gpt_compatible) ensures '-' and 'E' don't appear together, +# which would mean GPT layers aren't uniform. +_MLP_BEARING_SYMBOLS = ('-', 'E') + + +def build_layer_index_mapping(layer_types, direction): + """Build mapping between GPT layer indices and hybrid-model layer indices. + + For gpt-to-hybrid: + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[gpt_layer_i] = hybrid_layer_j (j is the index of the i-th '*') + - mlp_map[gpt_layer_i] = hybrid_layer_k (k is the index of the i-th + MLP-bearing position; either '-' or 'E') + + For hybrid-to-gpt: + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[hybrid_attn_idx] = gpt_layer_i + - mlp_map[hybrid_mlp_idx] = gpt_layer_i + """ + attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] + mlp_indices = [i for i, t in enumerate(layer_types) if t in _MLP_BEARING_SYMBOLS] + ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] + + if direction == 'gpt-to-hybrid': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For gpt-to-hybrid, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." + ) + attn_map = {i: attn_indices[i] for i in range(len(attn_indices))} + mlp_map = {i: mlp_indices[i] for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + elif direction == 'hybrid-to-gpt': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For hybrid-to-gpt, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." + ) + attn_map = {attn_indices[i]: i for i in range(len(attn_indices))} + mlp_map = {mlp_indices[i]: i for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + else: + raise ValueError(f"Unknown direction: {direction}") + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist +# --------------------------------------------------------------------------- +# +# GPTModel is a *uniform* transformer: every decoder layer is the same kind. +# It can run with dense MLP or MoE MLP — both keep keys under +# decoder.layers..mlp.* — so MoE checkpoints round-trip through the +# converter as long as both sides share the same kind on every layer. +# The helpers below reject any hybrid layout or source-args combination that +# violates uniformity (and would therefore silently produce a corrupt target). +# +# Pattern-level rules (checked on the parsed hybrid_layer_pattern): +# * only 'M', '*', '-', 'E' are allowed (no 'G' GDN, no 'D' DS-attention) +# * MLP-bearing symbols must be uniform: '-' and 'E' cannot both appear +# (that would imply GPT has both dense and MoE layers — heterogeneous) +# * '*' count must equal '-'+'E' count (one-to-one GPT attn<->MLP pairing) +# +# Args-level rules (checked against the training args stored in the source +# checkpoint): reject anything that makes GPT layers heterogeneous OR uses +# attention variants the converter doesn't currently key-translate: +# * moe_layer_freq != 1 (interleaved dense/MoE layers) +# * experimental_attention_variant (gated_delta_net, dsa, ...) +# * linear_attention_freq (interleaved linear-attention) +# * heterogeneous_block_specs / heterogeneous_layers_config_* +# (Nemotron-NAS per-layer specs) +# * multi_latent_attention (MLA: different QKV key layout) +# * mtp_num_layers (Multi-Token Prediction head) +# +# Notably NOT rejected (they round-trip via mlp.* / self_attention.* keys): +# * num_moe_experts (MoE on every layer) +# * moe_shared_expert_intermediate_size (shared experts on every layer) +# +# All rejected configurations raise ValueError early, before any tensors +# are touched. + +# Source-args field name -> (predicate-that-means-"reject", human reason). +# Predicates are applied with getattr(args, field, None); missing fields +# are treated as "absent" and pass. +_GPT_COMPAT_REJECT_FIELDS = ( + ( + 'moe_layer_freq', + # moe_layer_freq is None or 1 when every layer is the same kind (all + # dense or all MoE). A value > 1 or a list with mixed entries means + # GPT has interleaved dense/MoE layers — heterogeneous, can't pair + # one-to-one with a uniform hybrid pattern. + lambda v: ( + v is not None + and not (isinstance(v, int) and v == 1) + and not (isinstance(v, str) and v.strip() in ('', '1')) + and not (isinstance(v, (list, tuple)) and all(x == 1 for x in v)) + ), + 'interleaved dense/MoE layers (moe_layer_freq)', + ), + ( + 'experimental_attention_variant', + lambda v: v is not None and v != '', + 'experimental attention variant (gated_delta_net / dsa / ...)', + ), + ( + 'linear_attention_freq', + lambda v: v is not None, + 'linear attention layers (linear_attention_freq)', + ), + ( + 'heterogeneous_block_specs', + lambda v: bool(v), + 'heterogeneous per-layer block specs', + ), + ( + 'heterogeneous_layers_config_path', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS)', + ), + ( + 'heterogeneous_layers_config_encoded_json', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS, inline JSON)', + ), + ( + 'multi_latent_attention', + lambda v: bool(v), + 'Multi-Latent Attention (MLA)', + ), + ( + 'mtp_num_layers', + lambda v: v is not None and v > 0, + 'Multi-Token Prediction head (mtp_num_layers)', + ), +) + + +def validate_pattern_gpt_compatible(layer_types, direction): + """Raise ValueError if the hybrid pattern cannot round-trip with GPTModel. + + Args: + layer_types: list of layer-type chars from parse_hybrid_layer_pattern(). + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt' (for error messages). + + Rules: + * Allowed symbols: 'M', '*', '-', 'E'. 'G' (GDN) and 'D' (DS-attention) + are not currently key-translated. + * MLP-bearing symbols must be uniform: '-' (dense) and 'E' (MoE) cannot + both appear, because that would imply GPT has both dense and MoE + layers — the GPT side must be uniform. + * The number of attention positions must equal the number of + MLP-bearing positions: every GPT layer pairs one attention with one + MLP/MoE. + """ + bad = sorted({c for c in layer_types if c not in GPT_COMPATIBLE_PATTERN_SYMBOLS}) + if bad: + raise ValueError( + f"Hybrid layer pattern contains symbols {bad} that are not " + f"GPT-compatible (allowed: {sorted(GPT_COMPATIBLE_PATTERN_SYMBOLS)}). " + f"'G' (GDN) and 'D' (DS-attention) are not currently key-translated " + f"and cannot be {direction}-converted." + ) + + mlp_kinds_present = {t for t in layer_types if t in _MLP_BEARING_SYMBOLS} + if len(mlp_kinds_present) > 1: + raise ValueError( + f"Hybrid layer pattern mixes '-' (dense MLP) and 'E' (MoE) " + f"positions. GPTModel layers must be uniform — either all GPT " + f"layers are dense MLP, or all are MoE. Use only one of '-' or " + f"'E' in the pattern." + ) + + n_attn = sum(1 for t in layer_types if t == '*') + n_mlp = sum(1 for t in layer_types if t in _MLP_BEARING_SYMBOLS) + if n_attn != n_mlp: + raise ValueError( + f"GPT-compatible hybrid patterns must pair every attention layer " + f"('*') with one MLP/MoE layer ('-' or 'E'). Got {n_attn} '*' " + f"and {n_mlp} MLP-bearing layers in the pattern." + ) + + +def validate_source_args_gpt_compatible(source_args, direction): + """Raise ValueError if the source checkpoint uses features GPTModel can't express. + + Args: + source_args: argparse.Namespace (or any attribute-bag) loaded from the + source checkpoint; may be None, in which case this check is a no-op + (dist checkpoints without a cached args blob). + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt'. + + Rejects MoE, MLA, MTP, linear / experimental attention, and heterogeneous + per-layer specs. See the module header for the full list. + """ + if source_args is None: + return + + rejected = [] + for field, predicate, reason in _GPT_COMPAT_REJECT_FIELDS: + if not hasattr(source_args, field): + continue + value = getattr(source_args, field) + try: + if predicate(value): + rejected.append(f" - {reason}: {field}={value!r}") + except Exception: + # Defensive: never let the validator crash on an unexpected + # value type — treat it as "cannot verify, pass". + continue + + if rejected: + joined = "\n".join(rejected) + raise ValueError( + f"Source checkpoint is not GPT-compatible for {direction} " + f"conversion. The following features have no GPTModel equivalent " + f"and would produce a corrupt target checkpoint:\n{joined}\n" + f"Remove these features from the model (or use a different " + f"conversion tool) before running gpt_hybrid_conversion." + ) + + +# --------------------------------------------------------------------------- +# SSM parameter initialization (for gpt-to-hybrid) +# --------------------------------------------------------------------------- + +def initialize_ssm_layer_params( + layer_idx, + d_model, + mamba_d_inner, + mamba_d_state, + mamba2_n_groups, + mamba2_n_heads, + mamba_head_dim, + d_conv=4, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + A_init_range=(1, 16), + init_method_std=0.02, + dtype=torch.float32, +): + """Initialize parameters for a single Mamba SSM layer from scratch. + + Follows the initialization logic from MambaMixer.__init__: + - A_log: log(uniform(A_init_range)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones(nheads) + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones(d_model) + - out_proj.weight: kaiming_uniform(a=sqrt(5)) or normal(0, std) + - norm.weight: ones(d_inner) + + Returns: + dict: {param_suffix: tensor} for one SSM layer + """ + prefix = f'decoder.layers.{layer_idx}.mixer.' + + nheads = mamba2_n_heads + conv_dim = mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + in_proj_out_dim = 2 * mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + nheads + + params = OrderedDict() + + # in_proj (ColumnParallelLinear) + in_proj_weight = torch.empty(in_proj_out_dim, d_model, dtype=dtype) + torch.nn.init.kaiming_uniform_(in_proj_weight, a=math.sqrt(5)) + params[prefix + 'in_proj.weight'] = in_proj_weight + + # in_proj layer norm weight (fused into ColumnParallelLinear in TE) + params[prefix + 'in_proj.layer_norm_weight'] = torch.ones(d_model, dtype=dtype) + + # conv1d + conv_weight = torch.empty(conv_dim, 1, d_conv, dtype=dtype) + torch.nn.init.kaiming_uniform_(conv_weight, a=math.sqrt(5)) + params[prefix + 'conv1d.weight'] = conv_weight + params[prefix + 'conv1d.bias'] = torch.zeros(conv_dim, dtype=dtype) + + # A_log (kept in fp32) + A = torch.empty(nheads, dtype=torch.float32) + A.uniform_(*A_init_range) + params[prefix + 'A_log'] = torch.log(A) + + # D (kept in fp32) + params[prefix + 'D'] = torch.ones(nheads, dtype=torch.float32) + + # dt_bias + dt = torch.exp( + torch.rand(nheads, dtype=dtype) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + params[prefix + 'dt_bias'] = inv_dt + + # norm (RMSNorm) + params[prefix + 'norm.weight'] = torch.ones(mamba_d_inner, dtype=dtype) + + # out_proj (RowParallelLinear) + out_proj_weight = torch.empty(d_model, mamba_d_inner, dtype=dtype) + torch.nn.init.kaiming_uniform_(out_proj_weight, a=math.sqrt(5)) + params[prefix + 'out_proj.weight'] = out_proj_weight + + return params + + +# --------------------------------------------------------------------------- +# Key name helpers +# --------------------------------------------------------------------------- + +def get_layer_num_from_key(key): + """Extract the layer number from a state dict key like 'decoder.layers.5.mlp...'""" + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return int(match.group(1)) + return None + + +def replace_layer_num(key, old_num, new_num): + """Replace the layer number in a state dict key.""" + return key.replace(f'decoder.layers.{old_num}.', f'decoder.layers.{new_num}.', 1) + + +def is_attention_param(key): + """Check if a key belongs to an attention sub-layer.""" + return 'self_attention.' in key or 'input_layernorm.' in key + + +def is_mlp_param(key): + """Check if a key belongs to an MLP sub-layer.""" + return ('mlp.' in key or 'pre_mlp_layernorm.' in key) and 'self_attention' not in key + + +def is_ssm_param(key): + """Check if a key belongs to a Mamba SSM mixer sub-layer.""" + ssm_markers = ['mixer.in_proj', 'mixer.conv1d', 'mixer.A_log', 'mixer.D', + 'mixer.dt_bias', 'mixer.norm', 'mixer.out_proj', + 'mixer.x_proj', 'mixer.dt_proj'] + return any(m in key for m in ssm_markers) + + +def is_layer_norm_for_ssm(key): + """Check if a key is the input layer norm for an SSM layer. + + In hybrid models, SSM layers can have their own input_layernorm or the + norm can be fused into in_proj.layer_norm_weight. + """ + return 'in_proj.layer_norm_weight' in key + + +# --------------------------------------------------------------------------- +# Core conversion: GPT -> Hybrid +# --------------------------------------------------------------------------- + +def convert_gpt_to_hybrid(full_model, layer_types, args): + """Convert a GPT state dict to a Hybrid state dict. + + Args: + full_model: OrderedDict with globally-indexed GPT state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: Hybrid state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'gpt-to-hybrid' + ) + num_gpt_layers = len(attn_map) + + # Validate GPT layer count + gpt_layer_nums = set() + for key in full_model: + lnum = get_layer_num_from_key(key) + if lnum is not None: + gpt_layer_nums.add(lnum) + + if len(gpt_layer_nums) != num_gpt_layers: + raise ValueError( + f"GPT checkpoint has {len(gpt_layer_nums)} layers, but the pattern " + f"has {num_gpt_layers} attention ('*') and {num_gpt_layers} MLP ('-') " + f"layers. These must match." + ) + + target = OrderedDict() + dtype = None + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if dtype is None and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32): + dtype = tensor.dtype + + if 'decoder.layers.' in key: + continue + + # Rename final_layernorm -> final_norm + if 'decoder.final_layernorm' in key: + new_key = key.replace('decoder.final_layernorm', 'decoder.final_norm') + target[new_key] = tensor + else: + target[key] = tensor + + if dtype is None: + dtype = torch.float32 + + # Map attention and MLP params + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_attention_param(key): + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key): + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + # (any other layer params get copied as-is with their own mapping, + # but for pure GPT there should only be attention + MLP) + + # Initialize SSM layers from scratch + print(f" Initializing {len(ssm_indices)} SSM layers from scratch...") + for layer_idx in ssm_indices: + ssm_params = initialize_ssm_layer_params( + layer_idx=layer_idx, + d_model=args.d_model, + mamba_d_inner=args.mamba_d_inner, + mamba_d_state=args.mamba_d_state, + mamba2_n_groups=args.mamba2_n_groups, + mamba2_n_heads=args.mamba2_n_heads, + mamba_head_dim=args.mamba2_head_dim, + d_conv=getattr(args, 'd_conv', 4), + init_method_std=getattr(args, 'init_method_std', 0.02), + dtype=dtype, + ) + target.update(ssm_params) + + # Sort by layer index for consistent ordering + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Core conversion: Hybrid -> GPT +# --------------------------------------------------------------------------- + +def convert_hybrid_to_gpt(full_model, layer_types, args): + """Convert a Hybrid state dict to a GPT state dict. + + Args: + full_model: OrderedDict with globally-indexed Hybrid state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: GPT state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'hybrid-to-gpt' + ) + num_gpt_layers = len(attn_map) + + target = OrderedDict() + discarded_ssm_keys = [] + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if 'decoder.layers.' in key: + continue + + # Rename final_norm -> final_layernorm + if 'decoder.final_norm' in key: + new_key = key.replace('decoder.final_norm', 'decoder.final_layernorm') + target[new_key] = tensor + else: + target[key] = tensor + + # Map attention and MLP params, discard SSM + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_ssm_param(key) or is_layer_norm_for_ssm(key): + # Discard SSM params + if lnum in ssm_indices: + discarded_ssm_keys.append(key) + continue + + if is_attention_param(key) and lnum in attn_map: + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key) and lnum in mlp_map: + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif lnum in ssm_indices: + # Any remaining SSM-layer param not caught above + discarded_ssm_keys.append(key) + + if discarded_ssm_keys: + print(f"\n WARNING: Discarded {len(discarded_ssm_keys)} SSM parameter tensors " + f"from {len(ssm_indices)} SSM layers (no GPT equivalent).") + print(f" First few discarded keys: {discarded_ssm_keys[:5]}") + + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Sorting helper +# --------------------------------------------------------------------------- + +def _sort_state_dict(state_dict): + """Sort state dict keys so that layer-indexed keys are in order.""" + def sort_key(item): + key = item[0] + # Extract layer number if present + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return (1, int(match.group(1)), key) + # Non-layer keys: embeddings first, output_layer last + if 'embedding' in key: + return (0, 0, key) + if 'output_layer' in key: + return (2, 0, key) + if 'decoder.final' in key: + return (1, 999999, key) + return (0, 1, key) + + return OrderedDict(sorted(state_dict.items(), key=sort_key)) + + +# --------------------------------------------------------------------------- +# Format-aware save +# --------------------------------------------------------------------------- + +def _save_dist_full(target_state_dict, common_state, model_prefix, backend, + args, iteration): + """Save a fully-gathered state dict in dist-ckpt format. + + The on-disk tensors carry their full logical shape, so downstream Megatron + training reads them back with any TP+PP+FSDP configuration. + """ + if iteration is None: + out_iter = 0 if args.reset_iterations else 0 + iter_dir = args.save_dir + else: + out_iter = 0 if args.reset_iterations else iteration + iter_dir = os.path.join(args.save_dir, f'iter_{out_iter:07d}') + + # Update common state args to reflect target model structure. + common_state = copy.deepcopy(common_state) if common_state else {} + if 'args' in common_state and common_state['args'] is not None: + ckpt_args = common_state['args'] + ckpt_args.num_layers = args.target_num_layers + if hasattr(ckpt_args, 'hybrid_layer_pattern'): + if args.direction == 'gpt-to-hybrid': + ckpt_args.hybrid_layer_pattern = args.hybrid_layer_pattern + else: + ckpt_args.hybrid_layer_pattern = None + if args.reset_iterations: + for attr in ('iteration', 'consumed_valid_samples', + 'consumed_train_samples', 'train_iters', 'train_samples'): + if hasattr(ckpt_args, attr): + setattr(ckpt_args, attr, 0) + if args.reset_iterations and 'iteration' in common_state: + common_state['iteration'] = 0 + + print(f" Writing dist checkpoint to {iter_dir} " + f"(backend={backend}, prefix='{model_prefix}')...") + save_dist_checkpoint_full( + target_state_dict, common_state, iter_dir, + model_prefix=model_prefix, backend=backend, + ) + write_latest_iteration_marker(iter_dir, out_iter) + + +def main(args): + print("\n====RUNNING GPT <-> Hybrid CHECKPOINT CONVERSION====\n") + print(f" Direction: {args.direction}") + print(f" Source: {args.load_dir}") + print(f" Target: {args.save_dir}") + print(f" Hybrid layer pattern: {args.hybrid_layer_pattern}") + + # Compute derived Mamba dimensions + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + # Parse hybrid layer pattern + layer_types = parse_hybrid_layer_pattern(args.hybrid_layer_pattern) + total_hybrid_layers = len(layer_types) + attn_count = sum(1 for t in layer_types if t == '*') + mlp_count = sum(1 for t in layer_types if t == '-') + ssm_count = sum(1 for t in layer_types if t == 'M') + print(f"\n Pattern: {len(layer_types)} total layers " + f"({attn_count} attn, {mlp_count} MLP, {ssm_count} SSM, " + f"{len(layer_types) - attn_count - mlp_count - ssm_count} other)") + + # Pattern-level GPT compatibility whitelist (fails fast, pre-load). + validate_pattern_gpt_compatible(layer_types, args.direction) + + # 1. Resolve input format + input_format = getattr(args, 'input_format', 'auto') + if input_format == 'auto': + input_format = detect_checkpoint_format(args.load_dir) + output_format = getattr(args, 'output_format', 'auto') + if output_format == 'auto': + output_format = input_format + print(f"\n Input format: {input_format}") + print(f" Output format: {output_format}") + + if input_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported input format: {input_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) + if output_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported output format: {output_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) + + # 2. Load source checkpoint into a fully-gathered state dict + print("\n[Step 1] Loading source checkpoint...") + full_model, common_state, model_prefix, dist_backend, iteration = ( + load_dist_checkpoint_full(args.load_dir) + ) + print(f" Source: dist backend={dist_backend}, prefix='{model_prefix}', " + f"iteration={iteration}, params={len(full_model)}") + + # Args-level GPT compatibility whitelist: reject MoE, MLA, MTP, linear / + # experimental attention, heterogeneous block specs, etc. See module header. + source_args = common_state.get('args') if common_state else None + validate_source_args_gpt_compatible(source_args, args.direction) + + # 3. Convert + print(f"\n[Step 2] Converting ({args.direction})...") + if args.direction == 'gpt-to-hybrid': + target_state_dict = convert_gpt_to_hybrid(full_model, layer_types, args) + args.target_num_layers = total_hybrid_layers + elif args.direction == 'hybrid-to-gpt': + target_state_dict = convert_hybrid_to_gpt(full_model, layer_types, args) + args.target_num_layers = attn_count + else: + raise ValueError(f"Unknown direction: {args.direction}") + print(f" Target model: {len(target_state_dict)} parameters") + + # 4. Save + print(f"\n[Step 3] Saving to {args.save_dir}...") + _save_dist_full( + target_state_dict, common_state, model_prefix, output_format, + args, iteration, + ) + + print("\n====CONVERSION COMPLETE====\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert checkpoints between GPTModel and HybridModel formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--direction', type=str, required=True, + choices=['gpt-to-hybrid', 'hybrid-to-gpt'], + help='Conversion direction.', + ) + parser.add_argument('--load-dir', type=str, required=True, + help='Path to source checkpoint directory.') + parser.add_argument('--save-dir', type=str, required=True, + help='Path to target checkpoint directory.') + parser.add_argument('--hybrid-layer-pattern', type=str, required=True, + help='Hybrid layer pattern string, e.g. "M*-M*-M*-M*-".') + + parser.add_argument( + '--input-format', type=str, default='auto', + choices=('auto',) + DIST_FORMATS, + help='Source checkpoint format. "auto" detects from metadata.json.', + ) + parser.add_argument( + '--output-format', type=str, default='auto', + choices=('auto',) + DIST_FORMATS, + help='Target checkpoint format. "auto" matches the input format. ' + 'Dist formats (torch_dist / fsdp_dtensor) transparently support ' + 'TP+PP+FSDP training checkpoints.', + ) + + # Model architecture params + parser.add_argument('--d-model', type=int, default=4096, + help='Model hidden dimension.') + parser.add_argument('--mamba-version', type=int, default=2, + choices=[1, 2], help='Mamba SSM version.') + parser.add_argument('--mamba-d-state', type=int, default=128, + help='Mamba state dimension.') + parser.add_argument('--mamba2-n-groups', type=int, default=8, + help='Number of groups (Mamba v2).') + parser.add_argument('--mamba2-head-dim', type=int, default=64, + help='Head dimension (Mamba v2).') + parser.add_argument('--d-conv', type=int, default=4, + help='Causal convolution kernel size.') + + # Initialization params + parser.add_argument('--init-method-std', type=float, default=0.02, + help='Std for initializing new Mamba SSM params.') + + # Checkpoint control + parser.add_argument('--reset-iterations', action='store_true', + help='Zero out the training iteration count.') + + args = parser.parse_args() + main(args)