From 0e7f4005dec18c5e8e1617865cb5a1d5c2031f0f Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Mon, 27 Apr 2026 08:10:34 -0700 Subject: [PATCH 01/10] Add the support for checkpoint conversion between GPT_model and Hybrid_model Signed-off-by: guihong-nv --- hybrid_conversion.py | 398 +++++ tests/unit_tests/tools/__init__.py | 0 tests/unit_tests/tools/checkpoint/__init__.py | 0 .../tools/checkpoint/run_slurm_tests.sh | 120 ++ .../checkpoint/test_gpt_mamba_conversion.py | 946 +++++++++++ .../test_gpt_mamba_conversion_integration.py | 467 ++++++ .../test_gpt_mamba_conversion_parallelism.py | 360 +++++ tools/checkpoint/dist_checkpoint_io.py | 264 ++++ tools/checkpoint/gpt_mamba_conversion.py | 1392 +++++++++++++++++ 9 files changed, 3947 insertions(+) create mode 100644 hybrid_conversion.py create mode 100644 tests/unit_tests/tools/__init__.py create mode 100644 tests/unit_tests/tools/checkpoint/__init__.py create mode 100755 tests/unit_tests/tools/checkpoint/run_slurm_tests.sh create mode 100644 tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py create mode 100644 tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py create mode 100644 tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py create mode 100644 tools/checkpoint/dist_checkpoint_io.py create mode 100644 tools/checkpoint/gpt_mamba_conversion.py diff --git a/hybrid_conversion.py b/hybrid_conversion.py new file mode 100644 index 00000000000..da384e31ced --- /dev/null +++ b/hybrid_conversion.py @@ -0,0 +1,398 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion. +# This functionality should be integrated with the megatron core checkpoint loader/saver. + + +import copy +import os +import re +import shutil +from collections import OrderedDict + +import torch +import argparse + + +tp_split_dim = { + 'word_embeddings.weight': 0, + 'norm.weight': -1, + 'final_norm.weight': -1, + 'output_layer.weight': 0, + # mamba1/2 + 'A_log': 0, + 'D': 0, + 'dt_bias': 0, + 'in_proj.weight': 0, + 'conv1d.weight': 0, + 'conv1d.bias': 0, + 'x_proj.weight': 1, + 'dt_proj.weight': 0, + 'dt_proj.bias': 0, + 'out_proj.weight': 1, + 'mixer.norm.weight': 0, + # mlp + 'linear_fc1.layer_norm_weight': -1, + 'linear_fc1.weight': 0, + 'linear_fc2.weight': 1, + # attention + 'self_attention.linear_proj.weight': 1, + 'self_attention.linear_qkv.layer_norm_weight': -1, + 'self_attention.linear_qkv.weight': 0, +} + + +def get_split_dim(tensor_name): + # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish + if 'norm.weight' in tensor_name: + if 'mixer.norm.weight' in tensor_name: + return tp_split_dim['mixer.norm.weight'] + else: + return tp_split_dim['norm.weight'] + + for key in tp_split_dim.keys(): + if key in tensor_name: + return tp_split_dim[key] + raise Exception("Unknown tensor name {}".format(tensor_name)) + + +def combine_tp_tensors(params, key, dim, tensors): + tp_size = len(tensors) + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + xs = []; zs = [] + for tensor in tensors: + x, z = torch.split(tensor, [params.mamba_d_inner//tp_size, + params.mamba_d_inner//tp_size], dim=dim) + xs.append(x); zs.append(z) + return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + xs = []; zs = []; Bs = []; Cs = []; dts = [] + for tensor in tensors: + x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size, + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * args.mamba_d_state, + (params.mamba2_n_groups // tp_size) * args.mamba_d_state, + params.mamba2_n_heads // tp_size], dim=dim) + xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt) + + for ii in range(len(Bs)): + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1])) + B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim) + + return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + xs = []; Bs = []; Cs = [] + for tensor in tensors: + x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim) + xs.append(x); Bs.append(B); Cs.append(C) + + for ii in range(len(Bs)): + if 'weight' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1])) + elif 'bias' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state)) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + + return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) + + else: + return torch.cat(tensors, dim=dim) + + +def split_tensor_for_tp(params, key, dim, tensor): + tp_size = params.target_tp_size + tensor_sliced = [] + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + for (x, z) in zip(x_sliced, z_sliced): + tensor_sliced.append(torch.cat((x, z), dim=dim)) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_heads], dim=dim) + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1])) + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + dt_sliced = torch.chunk(dt, tp_size, dim=dim) + + tensor_sliced = [] + for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced): + tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim)) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + x, B, C = torch.split(tensor, [params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state], dim=dim) + if 'weight' in key: + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1])) + elif 'bias' in key: + B = torch.reshape(B, (-1, params.mamba_d_state)) + C = torch.reshape(C, (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + + tensor_sliced = [] + for (x, B, C) in zip(x_sliced, B_sliced, C_sliced): + tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim)) + + else: + tensor_sliced = torch.chunk(tensor, tp_size, dim=dim) + + return tensor_sliced + + +def finalize_checkpoint(sample_model, model, params, verbose=False): + # make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model') + reset_iterations = params.reset_iterations + + # checkpoint 'args' + model['args'] = copy.deepcopy(sample_model['args']) + model['args'].tensor_model_parallel_size = params.target_tp_size + model['args'].pipeline_model_parallel_size = params.target_pp_size + if reset_iterations: + model['args'].iteration = 0 + model['args'].consumed_valid_samples = 0 + model['args'].consumed_train_samples = 0 + model['args'].train_iters = 0 + model['args'].train_samples = 0 + + # checkpoint 'checkpoint_version' + model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version']) + + # checkpoint 'iteration' + model['iteration'] = copy.deepcopy(sample_model['iteration']) + if reset_iterations: + model['iteration'] = 0 + + # checkpoint 'optimizer' + # ignore + + # checkpoint 'opt_param_scheduler' + if 'opt_param_scheduler' in sample_model.keys(): + model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler']) + + # checkpoint 'rng_state' + model['rng_state'] = copy.deepcopy(sample_model['rng_state']) + + # report on argument difference + if verbose: + original_args = sample_model['args'].__dict__ + final_args = model['args'].__dict__ + for key in original_args: + if key in final_args: + if final_args[key] != original_args[key]: + print("KEY MISMATCH: {}".format(key)) + print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key])) + else: + print("KEY MISSING from final: {}, value {}".format(key, original_args[key])) + print("") + for key in final_args: + if key not in original_args: + print("KEY ADDED to final: {}, value {}".format(key, final_args[key])) + + return model + + +def main(args): + print("\n====RUNNING CHECKPOINT CONVERSION====\n") + + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + # get the latest iteration + tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + raise Exception("Invalid iteration found in latest_checkpointed_iteration.txt!") + out_iteration = iteration if not args.reset_iterations else 0 + + # get model directory and model parallel ranks + input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration)) + input_sub_models = os.listdir(input_model_dir) + # input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group())) + + # load one of the model parallel ranks to get arguments + sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt") + sample_model = torch.load(sample_model_file) + print(f"Sample model {sample_model_file} is loaded.\n") + + # input tensor and pipeline parallel size + input_tp_rank = sample_model['args'].tensor_model_parallel_size + input_pp_rank = sample_model['args'].pipeline_model_parallel_size + num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank + + # construct full model + full_model = OrderedDict() + for pp in range(input_pp_rank): + print("[INFO] Processing input pipeline rank {}".format(pp)) + tp_models = [] + for tp in range(input_tp_rank): + dir_name = "mp_rank_{:02d}".format(tp) + if input_pp_rank > 1: + dir_name += "_{:03d}".format(pp) + model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt") + + tp_models.append(torch.load(model_file)) + print(f"Model {model_file} is loaded.") + + if input_tp_rank > 1: + combined_tp_model = OrderedDict() + for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()): + if "_extra_state" in key: + combined_tp_model[key] = original_tensor + continue + + split_dim = get_split_dim(key) + original_shape = list(original_tensor.shape) + combined_shape = copy.deepcopy(original_shape) + combined_shape[split_dim] *= input_tp_rank + # print("{}, {}, {}".format(ii, key, split_dim)) + + if split_dim != -1: + # slice together model + # print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape)) + combined_tensor = combine_tp_tensors(args, key, split_dim, + [tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)]) + combined_tp_model[key] = combined_tensor + else: + # copy model + combined_tp_model[key] = original_tensor + else: + combined_tp_model = tp_models[0]['model'] + # print("Combined tp model: {}".format(combined_tp_model.keys())) + + for ii, (key, original_tensor) in enumerate(combined_tp_model.items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1) + except: + new_key = key + full_model[new_key] = original_tensor + # print("Combined model: {}".format(full_model.keys())) + print("\n[INFO] Loaded combined model\n") + + # sort by layer + # full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1])) + + # create new split model + pp_offset = 0 + num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size + + for pp in range(args.target_pp_size): + print("[INFO] Processing output pipeline rank {}".format(pp)) + tp_models = [] + for ii in range(args.target_tp_size): + tp_models.append({'model': OrderedDict()}) + + for ii, (key, original_tensor) in enumerate(full_model.items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + if layer_num >= num_layers_per_pipeline_rank * (pp+1): + break + new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1) + except Exception: + new_key = key + + if ii < pp_offset: + continue + else: + pp_offset += 1 + + if "_extra_state" in new_key: + # copy + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = original_tensor + continue + + split_dim = get_split_dim(new_key) + original_shape = list(original_tensor.shape) + v0 = original_shape[split_dim] + split_size = v0 // args.target_tp_size + split_shape = copy.deepcopy(original_shape) + split_shape[split_dim] = split_size + # print("{}, {}, {}".format(ii, new_key, split_dim)) + + if split_dim != -1: + # split model + # print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape)) + tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor) + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = tensor_sliced[jj] + else: + # copy model + for jj in range(args.target_tp_size): + tp_models[jj]['model'][new_key] = original_tensor + # print(tp_models[0]['model'].keys()) + + for tp in range(args.target_tp_size): + dir_name = "mp_rank_{:02d}".format(tp) + if args.target_pp_size > 1: + dir_name += "_{:03d}".format(pp) + + model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False) + + save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name) + os.makedirs(save_dir, exist_ok=True) + model_file = os.path.join(save_dir, "model_optim_rng.pt") + torch.save(model, model_file) + print(f"Model {model_file} is saved.") + + # shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')) + tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'w') as f: + f.write(str(out_iteration)) + + +if __name__ == "__main__": + # example run command: + # python hybrid_conversion.py + # --load-dir mamba2-840m-test/checkpoints/ + # --save-dir mamba2-840m-test-conversion/checkpoints/ + # --target-pp-size 1 + # --target-tp-size 1 + + parser = argparse.ArgumentParser() + parser.add_argument('--load-dir', type=str) + parser.add_argument('--save-dir', type=str) + parser.add_argument('--target-tp-size', type=int, default=1) + parser.add_argument('--target-pp-size', type=int, default=1) + parser.add_argument('--reset-iterations', action='store_true') + + parser.add_argument('--d-model', type=int, default=4096) + parser.add_argument('--mamba-version', type=int, default=2) + parser.add_argument('--mamba-d-state', type=int, default=128) + parser.add_argument('--mamba2-n-groups', type=int, default=8) + parser.add_argument('--mamba2-head-dim', type=int, default=64) + + args = parser.parse_args() + + main(args) diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit_tests/tools/checkpoint/__init__.py b/tests/unit_tests/tools/checkpoint/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh b/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh new file mode 100755 index 00000000000..d5b9e504703 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# Run GPT <-> Mamba checkpoint conversion tests on SLURM. +# +# Covers: +# Phase 1 - Unit tests (pattern parsing, key mapping, SSM init, round-trip, +# and the new GPT-compatibility whitelist) +# Phase 2 - Integration tests (legacy TP=1/PP=1 on-disk round-trip) +# Phase 3 - Parallelism matrix (TP / PP / FSDP and all combinations, +# across legacy and torch_dist / fsdp_dtensor formats; +# hybrid patterns exercised: pure-attention, M*-, M*-M*-, +# alternating, and pure-SSM) +# +# Single-node mode (default) exercises the full matrix on one GPU. +# Multi-node mode launches the same pytest invocation on N nodes to verify +# the converter is deterministic across nodes and that dist-checkpoint load +# works from a shared filesystem. +# +# Usage: +# bash run_slurm_tests.sh # single-node, default repo path +# NODES=2 bash run_slurm_tests.sh # 2 nodes +# MEGATRON_LM_DIR=/path bash run_slurm_tests.sh +# +# Environment knobs: +# MEGATRON_LM_DIR Path to the Megatron-LM checkout (default: this repo root) +# CONTAINER_IMAGE Container image (default: nemo:26.04) +# NODES Number of nodes (default: 1) +# GPUS_PER_NODE GPUs per node (default: 1) +# PARTITION SLURM partition (default: batch) +# ACCOUNT SLURM account (default: coreai_dlalgo_genai) +# TIME SLURM time limit (default: 00:45:00) + +set -euo pipefail + +# Default to the repo that contains this script. +_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +_DEFAULT_REPO="$(cd "${_SCRIPT_DIR}/../../../.." && pwd)" + +CONTAINER_IMAGE="${CONTAINER_IMAGE:-nvcr.io/nvidia/nemo:26.02}" +MEGATRON_LM_DIR="${MEGATRON_LM_DIR:-${_DEFAULT_REPO}}" +NODES="${NODES:-1}" +GPUS_PER_NODE="${GPUS_PER_NODE:-1}" +PARTITION="${PARTITION:-batch}" +ACCOUNT="${ACCOUNT:-coreai_dlalgo_mcore}" +TIME="${TIME:-00:45:00}" + +LOG_DIR="${MEGATRON_LM_DIR}/logs" +mkdir -p "${LOG_DIR}" + +echo "======================================================" +echo "GPT <-> Mamba Conversion Tests" +echo " Repo : ${MEGATRON_LM_DIR}" +echo " Container : ${CONTAINER_IMAGE}" +echo " Nodes : ${NODES}" +echo " GPUs per node : ${GPUS_PER_NODE}" +echo " Partition : ${PARTITION}" +echo " Account : ${ACCOUNT}" +echo " Time limit : ${TIME}" +echo " Logs : ${LOG_DIR}" +echo "======================================================" + +# The conversion unit tests are CPU-only; the parallelism matrix only needs +# one GPU per test process (it exercises the sharding logic, not kernels). +# On multi-node runs we invoke pytest on every task so each node independently +# validates its view of the checkpoint — i.e. the *same* shared +# checkpoint must round-trip from every node. +srun \ + --job-name=gpt-mamba-conv-test \ + --nodes="${NODES}" \ + --ntasks-per-node=1 \ + --gpus-per-node="${GPUS_PER_NODE}" \ + --cpus-per-gpu=16 \ + --time="${TIME}" \ + --partition="${PARTITION}" \ + --account="${ACCOUNT}" \ + --output="${LOG_DIR}/gpt_mamba_conv_test_%j_%t.out" \ + --error="${LOG_DIR}/gpt_mamba_conv_test_%j_%t.err" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MEGATRON_LM_DIR}:/opt/megatron-lm" \ + --container-workdir="/opt/megatron-lm" \ + bash -c ' + set -euo pipefail + export PYTHONPATH=/opt/megatron-lm:${PYTHONPATH:-} + + RANK="${SLURM_PROCID:-0}" + NODE="${SLURMD_NODENAME:-local}" + echo "[node=${NODE} rank=${RANK}] Python : $(python --version 2>&1)" + echo "[node=${NODE} rank=${RANK}] torch : $(python -c "import torch; print(torch.__version__)")" + echo "[node=${NODE} rank=${RANK}] cuda : $(python -c "import torch; print(torch.cuda.is_available(), torch.cuda.device_count())")" + echo "------------------------------------------------------" + + echo "" + echo "=== [node=${NODE}] Phase 1: Unit tests ===" + echo "" + python -m pytest -vs \ + tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py + + echo "" + echo "=== [node=${NODE}] Phase 1b: GPT-compatibility whitelist tests ===" + echo "" + # Run the whitelist classes in isolation so a regression in the + # safeguard is easy to spot in CI logs. + python -m pytest -vs \ + tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py::TestPatternWhitelist \ + tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py::TestSourceArgsWhitelist + + echo "" + echo "=== [node=${NODE}] Phase 2: Integration (legacy TP=1/PP=1) ===" + echo "" + python tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py + + echo "" + echo "=== [node=${NODE}] Phase 3: Parallelism matrix (TP/PP/FSDP/combos) ===" + echo "" + python -m pytest -vs \ + tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py + ' + +echo "======================================================" +echo "Test complete. Logs: ${LOG_DIR}" +echo "======================================================" diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py new file mode 100644 index 00000000000..5ab5a435bdd --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py @@ -0,0 +1,946 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for the GPT <-> Mamba checkpoint conversion tool. + +These tests validate: +- Hybrid layer pattern parsing +- Layer index mapping (GPT <-> Mamba) +- State dict key renaming (final_layernorm <-> final_norm) +- Shared parameter copying (embeddings, output_layer) +- SSM parameter initialization shapes and dtypes +- Round-trip conversion: GPT -> Mamba -> GPT preserves attention and MLP weights +- TP split dimension lookup +""" + +import argparse +import math +import os +import sys +import tempfile +from collections import OrderedDict + +import pytest +import torch + +# Add the tools/checkpoint directory to the path so we can import the module +sys.path.insert( + 0, + os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint'), +) + +from gpt_mamba_conversion import ( + build_layer_index_mapping, + combine_tp_tensors, + convert_gpt_to_mamba, + convert_mamba_to_gpt, + get_layer_num_from_key, + get_split_dim, + initialize_ssm_layer_params, + is_attention_param, + is_mlp_param, + is_ssm_param, + parse_hybrid_layer_pattern, + replace_layer_num, + split_tensor_for_tp, + validate_pattern_gpt_compatible, + validate_source_args_gpt_compatible, +) + + +# --------------------------------------------------------------------------- +# Pattern parsing tests +# --------------------------------------------------------------------------- + +class TestPatternParsing: + def test_simple_pattern(self): + result = parse_hybrid_layer_pattern("M*-M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_all_mamba(self): + result = parse_hybrid_layer_pattern("MMMM") + assert result == ['M', 'M', 'M', 'M'] + + def test_all_attention(self): + result = parse_hybrid_layer_pattern("****") + assert result == ['*', '*', '*', '*'] + + def test_with_mtp_separator(self): + # Should strip MTP patterns (only main pattern) + result = parse_hybrid_layer_pattern("M*-M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_pipe_separator(self): + # Should strip pipeline stage separators + result = parse_hybrid_layer_pattern("M*-|M*-") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_with_both_separators(self): + result = parse_hybrid_layer_pattern("M*-|M*-/MM/MM") + assert result == ['M', '*', '-', 'M', '*', '-'] + + def test_mixed_layers(self): + result = parse_hybrid_layer_pattern("M*-EG") + assert result == ['M', '*', '-', 'E', 'G'] + + def test_invalid_symbol(self): + with pytest.raises(ValueError, match="Invalid layer symbol"): + parse_hybrid_layer_pattern("M*X") + + +# --------------------------------------------------------------------------- +# Layer index mapping tests +# --------------------------------------------------------------------------- + +class TestLayerIndexMapping: + def test_gpt_to_mamba_basic(self): + # Pattern: M*-M*- (2 attn at pos 1,4; 2 MLP at pos 2,5) + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'gpt-to-mamba' + ) + # 2 GPT layers -> attn at [1,4], MLP at [2,5] + assert attn_map == {0: 1, 1: 4} + assert mlp_map == {0: 2, 1: 5} + assert ssm_indices == [0, 3] + + def test_mamba_to_gpt_basic(self): + layer_types = ['M', '*', '-', 'M', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'mamba-to-gpt' + ) + # attn at mamba layer 1 -> GPT layer 0, attn at 4 -> GPT layer 1 + assert attn_map == {1: 0, 4: 1} + assert mlp_map == {2: 0, 5: 1} + assert ssm_indices == [0, 3] + + def test_alternating_pattern(self): + layer_types = ['*', '-', '*', '-', '*', '-'] + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'gpt-to-mamba' + ) + assert attn_map == {0: 0, 1: 2, 2: 4} + assert mlp_map == {0: 1, 1: 3, 2: 5} + assert ssm_indices == [] + + def test_mismatched_attn_mlp_count(self): + # 2 attn but 1 MLP -> should raise + layer_types = ['*', '*', '-', 'M'] + with pytest.raises(ValueError, match="must equal"): + build_layer_index_mapping(layer_types, 'gpt-to-mamba') + + def test_unknown_direction(self): + with pytest.raises(ValueError, match="Unknown direction"): + build_layer_index_mapping(['*', '-'], 'invalid') + + +# --------------------------------------------------------------------------- +# Key helper tests +# --------------------------------------------------------------------------- + +class TestKeyHelpers: + def test_get_layer_num(self): + assert get_layer_num_from_key('decoder.layers.5.mlp.linear_fc1.weight') == 5 + assert get_layer_num_from_key('decoder.layers.0.self_attention.linear_qkv.weight') == 0 + assert get_layer_num_from_key('decoder.layers.99.mixer.A_log') == 99 + assert get_layer_num_from_key('embedding.word_embeddings.weight') is None + + def test_replace_layer_num(self): + key = 'decoder.layers.3.mlp.linear_fc1.weight' + assert replace_layer_num(key, 3, 7) == 'decoder.layers.7.mlp.linear_fc1.weight' + + def test_is_attention_param(self): + assert is_attention_param('decoder.layers.0.self_attention.linear_qkv.weight') + assert is_attention_param('decoder.layers.0.input_layernorm.weight') + assert not is_attention_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_attention_param('decoder.layers.0.mixer.A_log') + + def test_is_mlp_param(self): + assert is_mlp_param('decoder.layers.0.mlp.linear_fc1.weight') + assert is_mlp_param('decoder.layers.0.pre_mlp_layernorm.weight') + assert not is_mlp_param('decoder.layers.0.self_attention.linear_qkv.weight') + + def test_is_ssm_param(self): + assert is_ssm_param('decoder.layers.0.mixer.A_log') + assert is_ssm_param('decoder.layers.0.mixer.in_proj.weight') + assert is_ssm_param('decoder.layers.0.mixer.conv1d.weight') + assert is_ssm_param('decoder.layers.0.mixer.D') + assert is_ssm_param('decoder.layers.0.mixer.dt_bias') + assert is_ssm_param('decoder.layers.0.mixer.norm.weight') + assert is_ssm_param('decoder.layers.0.mixer.out_proj.weight') + assert not is_ssm_param('decoder.layers.0.mlp.linear_fc1.weight') + assert not is_ssm_param('decoder.layers.0.self_attention.linear_qkv.weight') + + +# --------------------------------------------------------------------------- +# TP split dim tests +# --------------------------------------------------------------------------- + +class TestTPSplitDim: + def test_embedding_split(self): + assert get_split_dim('embedding.word_embeddings.weight') == 0 + + def test_output_layer_split(self): + assert get_split_dim('output_layer.weight') == 0 + + def test_norm_replicated(self): + assert get_split_dim('decoder.layers.0.input_layernorm.weight') == -1 + + def test_final_layernorm(self): + assert get_split_dim('decoder.final_layernorm.weight') == -1 + + def test_final_norm(self): + assert get_split_dim('decoder.final_norm.weight') == -1 + + def test_qkv_weight_column(self): + assert get_split_dim('decoder.layers.0.self_attention.linear_qkv.weight') == 0 + + def test_proj_weight_row(self): + assert get_split_dim('decoder.layers.0.self_attention.linear_proj.weight') == 1 + + def test_mlp_fc1_column(self): + assert get_split_dim('decoder.layers.0.mlp.linear_fc1.weight') == 0 + + def test_mlp_fc2_row(self): + assert get_split_dim('decoder.layers.0.mlp.linear_fc2.weight') == 1 + + def test_mamba_mixer_norm(self): + assert get_split_dim('decoder.layers.0.mixer.norm.weight') == 0 + + def test_mamba_A_log(self): + assert get_split_dim('decoder.layers.0.mixer.A_log') == 0 + + def test_mamba_out_proj(self): + assert get_split_dim('decoder.layers.0.mixer.out_proj.weight') == 1 + + def test_unknown_key(self): + with pytest.raises(ValueError, match="Unknown tensor name"): + get_split_dim('decoder.layers.0.some_unknown_param') + + +# --------------------------------------------------------------------------- +# SSM initialization tests +# --------------------------------------------------------------------------- + +class TestSSMInitialization: + def test_shapes(self): + d_model = 256 + d_inner = 512 # 2 * d_model + d_state = 64 + n_groups = 4 + head_dim = 32 + n_heads = d_inner // head_dim # 16 + d_conv = 4 + conv_dim = d_inner + 2 * n_groups * d_state + + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=d_model, + mamba_d_inner=d_inner, + mamba_d_state=d_state, + mamba2_n_groups=n_groups, + mamba2_n_heads=n_heads, + mamba_head_dim=head_dim, + d_conv=d_conv, + dtype=torch.float32, + ) + + prefix = 'decoder.layers.0.mixer.' + + # in_proj: [2*d_inner + 2*n_groups*d_state + n_heads, d_model] + in_proj_out = 2 * d_inner + 2 * n_groups * d_state + n_heads + assert params[prefix + 'in_proj.weight'].shape == (in_proj_out, d_model) + + # in_proj layer norm weight + assert params[prefix + 'in_proj.layer_norm_weight'].shape == (d_model,) + + # conv1d: [conv_dim, 1, d_conv] + assert params[prefix + 'conv1d.weight'].shape == (conv_dim, 1, d_conv) + assert params[prefix + 'conv1d.bias'].shape == (conv_dim,) + + # A_log: [n_heads] + assert params[prefix + 'A_log'].shape == (n_heads,) + assert params[prefix + 'A_log'].dtype == torch.float32 + + # D: [n_heads] + assert params[prefix + 'D'].shape == (n_heads,) + assert params[prefix + 'D'].dtype == torch.float32 + + # dt_bias: [n_heads] + assert params[prefix + 'dt_bias'].shape == (n_heads,) + + # norm: [d_inner] + assert params[prefix + 'norm.weight'].shape == (d_inner,) + + # out_proj: [d_model, d_inner] + assert params[prefix + 'out_proj.weight'].shape == (d_model, d_inner) + + def test_A_log_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + A_log = params['decoder.layers.0.mixer.A_log'] + # A was uniform in (1, 16), so A_log should be in (log(1), log(16)) = (0, 2.77) + assert (A_log >= 0).all() + assert (A_log <= math.log(16) + 0.01).all() + + def test_D_values(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + D = params['decoder.layers.0.mixer.D'] + assert torch.allclose(D, torch.ones_like(D)) + + def test_conv1d_bias_zeros(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + bias = params['decoder.layers.0.mixer.conv1d.bias'] + assert torch.allclose(bias, torch.zeros_like(bias)) + + def test_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + norm = params['decoder.layers.0.mixer.norm.weight'] + assert torch.allclose(norm, torch.ones_like(norm)) + + def test_layer_norm_weight_ones(self): + params = initialize_ssm_layer_params( + layer_idx=0, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + ln = params['decoder.layers.0.mixer.in_proj.layer_norm_weight'] + assert torch.allclose(ln, torch.ones_like(ln)) + + def test_different_layer_idx(self): + params = initialize_ssm_layer_params( + layer_idx=7, + d_model=64, + mamba_d_inner=128, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=4, + mamba_head_dim=32, + ) + assert 'decoder.layers.7.mixer.A_log' in params + assert 'decoder.layers.0.mixer.A_log' not in params + + +# --------------------------------------------------------------------------- +# Synthetic GPT checkpoint builder +# --------------------------------------------------------------------------- + +def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): + """Create a minimal synthetic GPT state dict for testing.""" + state_dict = OrderedDict() + + # Embeddings + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, d_model, dtype=dtype) + + # Transformer layers + for i in range(num_layers): + prefix = f'decoder.layers.{i}.' + # Attention + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + d_model, d_model, dtype=dtype + ) + # MLP + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * d_model, d_model, dtype=dtype + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + d_model, 4 * d_model, dtype=dtype + ) + + # Final norm + state_dict['decoder.final_layernorm.weight'] = torch.randn(d_model, dtype=dtype) + + # Output layer + state_dict['output_layer.weight'] = torch.randn(1000, d_model, dtype=dtype) + + return state_dict + + +# --------------------------------------------------------------------------- +# Full conversion tests +# --------------------------------------------------------------------------- + +class TestGPTToMambaConversion: + def setup_method(self): + self.d_model = 64 + self.num_gpt_layers = 2 + self.pattern = "M*-M*-" # 6 total: 2 SSM, 2 attn, 2 MLP + self.gpt_state = make_synthetic_gpt_checkpoint( + self.num_gpt_layers, self.d_model + ) + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def test_shared_params_preserved(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + # Embeddings should be identical + assert torch.equal( + result['embedding.word_embeddings.weight'], + self.gpt_state['embedding.word_embeddings.weight'], + ) + # Output layer + assert torch.equal( + result['output_layer.weight'], + self.gpt_state['output_layer.weight'], + ) + + def test_final_norm_renamed(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + assert 'decoder.final_norm.weight' in result + assert 'decoder.final_layernorm.weight' not in result + assert torch.equal( + result['decoder.final_norm.weight'], + self.gpt_state['decoder.final_layernorm.weight'], + ) + + def test_attention_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + # GPT layer 0 attn -> Mamba layer 1 (first '*' in M*-M*-) + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.0.self_attention.linear_qkv.weight'], + ) + # GPT layer 1 attn -> Mamba layer 4 (second '*') + assert torch.equal( + result['decoder.layers.4.self_attention.linear_qkv.weight'], + self.gpt_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + # GPT layer 0 MLP -> Mamba layer 2 (first '-') + assert torch.equal( + result['decoder.layers.2.mlp.linear_fc1.weight'], + self.gpt_state['decoder.layers.0.mlp.linear_fc1.weight'], + ) + # GPT layer 1 MLP -> Mamba layer 5 (second '-') + assert torch.equal( + result['decoder.layers.5.mlp.linear_fc2.weight'], + self.gpt_state['decoder.layers.1.mlp.linear_fc2.weight'], + ) + + def test_ssm_layers_initialized(self): + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + # SSM layers at index 0 and 3 + for idx in [0, 3]: + prefix = f'decoder.layers.{idx}.mixer.' + assert prefix + 'A_log' in result + assert prefix + 'D' in result + assert prefix + 'dt_bias' in result + assert prefix + 'conv1d.weight' in result + assert prefix + 'conv1d.bias' in result + assert prefix + 'in_proj.weight' in result + assert prefix + 'norm.weight' in result + assert prefix + 'out_proj.weight' in result + + def test_layer_count_mismatch_raises(self): + # Pattern with 3 attn but only 2 GPT layers + layer_types = parse_hybrid_layer_pattern("M*-*-*-") + with pytest.raises(ValueError, match="layers"): + convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + + +class TestMambaToGPTConversion: + def setup_method(self): + self.d_model = 64 + self.pattern = "M*-M*-" + self.args = argparse.Namespace( + d_model=self.d_model, + mamba_d_inner=self.d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(self.d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + def _make_mamba_state(self): + """Build a synthetic Mamba state dict matching pattern M*-M*-.""" + state_dict = OrderedDict() + state_dict['embedding.word_embeddings.weight'] = torch.randn(1000, self.d_model) + state_dict['output_layer.weight'] = torch.randn(1000, self.d_model) + state_dict['decoder.final_norm.weight'] = torch.randn(self.d_model) + + layer_types = parse_hybrid_layer_pattern(self.pattern) + d_inner = self.d_model * 2 + n_heads = self.args.mamba2_n_heads + n_groups = self.args.mamba2_n_groups + d_state = self.args.mamba_d_state + + for i, lt in enumerate(layer_types): + prefix = f'decoder.layers.{i}.' + if lt == 'M': + # SSM params + ssm = initialize_ssm_layer_params( + i, self.d_model, d_inner, d_state, n_groups, n_heads, + self.args.mamba2_head_dim, + ) + state_dict.update(ssm) + elif lt == '*': + state_dict[prefix + 'input_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * self.d_model, self.d_model + ) + state_dict[prefix + 'self_attention.linear_proj.weight'] = torch.randn( + self.d_model, self.d_model + ) + elif lt == '-': + state_dict[prefix + 'pre_mlp_layernorm.weight'] = torch.randn(self.d_model) + state_dict[prefix + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * self.d_model, self.d_model + ) + state_dict[prefix + 'mlp.linear_fc2.weight'] = torch.randn( + self.d_model, 4 * self.d_model + ) + + return state_dict + + def test_final_norm_renamed_back(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + + assert 'decoder.final_layernorm.weight' in result + assert 'decoder.final_norm.weight' not in result + + def test_ssm_params_discarded(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + + # No SSM keys should remain + for key in result: + assert 'mixer.' not in key, f"SSM key not discarded: {key}" + + def test_attention_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 1 (first *) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.1.self_attention.linear_qkv.weight'], + ) + # Mamba layer 4 (second *) -> GPT layer 1 + assert torch.equal( + result['decoder.layers.1.self_attention.linear_qkv.weight'], + mamba_state['decoder.layers.4.self_attention.linear_qkv.weight'], + ) + + def test_mlp_params_mapped(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + + # Mamba layer 2 (first -) -> GPT layer 0 + assert torch.equal( + result['decoder.layers.0.mlp.linear_fc1.weight'], + mamba_state['decoder.layers.2.mlp.linear_fc1.weight'], + ) + + def test_gpt_layer_count(self): + mamba_state = self._make_mamba_state() + layer_types = parse_hybrid_layer_pattern(self.pattern) + result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + + # Should have 2 GPT layers (layers 0 and 1) + layer_nums = set() + for key in result: + lnum = get_layer_num_from_key(key) + if lnum is not None: + layer_nums.add(lnum) + assert layer_nums == {0, 1} + + +# --------------------------------------------------------------------------- +# Round-trip test: GPT -> Mamba -> GPT +# --------------------------------------------------------------------------- + +class TestRoundTrip: + def test_gpt_mamba_gpt_preserves_weights(self): + """Converting GPT -> Mamba -> GPT should preserve all attention & MLP weights.""" + d_model = 64 + num_layers = 2 + pattern = "M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + # GPT -> Mamba + mamba_state = convert_gpt_to_mamba(original_gpt, layer_types, args) + + # Mamba -> GPT + recovered_gpt = convert_mamba_to_gpt(mamba_state, layer_types, args) + + # Check all original GPT keys are preserved + for key in original_gpt: + # final_layernorm is renamed in the round trip + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key after round-trip: {key}" + assert torch.equal(original_gpt[key], recovered_gpt[key]), ( + f"Weight mismatch after round-trip for {key}" + ) + + # Check final_layernorm was properly renamed back + assert torch.equal( + original_gpt['decoder.final_layernorm.weight'], + recovered_gpt['decoder.final_layernorm.weight'], + ) + + def test_round_trip_different_pattern(self): + """Test with a pattern that has more SSM layers.""" + d_model = 64 + num_layers = 3 + pattern = "M*-M*-M*-" + + args = argparse.Namespace( + d_model=d_model, + mamba_d_inner=d_model * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=(d_model * 2) // 32, + mamba2_head_dim=32, + mamba_version=2, + d_conv=4, + init_method_std=0.02, + ) + + original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) + layer_types = parse_hybrid_layer_pattern(pattern) + + mamba_state = convert_gpt_to_mamba(original_gpt, layer_types, args) + recovered_gpt = convert_mamba_to_gpt(mamba_state, layer_types, args) + + for key in original_gpt: + if 'final_layernorm' in key: + continue + assert key in recovered_gpt, f"Missing key: {key}" + assert torch.equal(original_gpt[key], recovered_gpt[key]), ( + f"Mismatch for {key}" + ) + + +# --------------------------------------------------------------------------- +# TP combine/split round-trip tests +# --------------------------------------------------------------------------- + +class TestTPCombineSplit: + def test_simple_tensor_roundtrip(self): + params = argparse.Namespace( + mamba_d_inner=256, mamba_d_state=64, mamba2_n_groups=4, + mamba2_n_heads=8, mamba_version=2, target_tp_size=2, + ) + tensor = torch.randn(512, 128) + key = 'decoder.layers.0.mlp.linear_fc1.weight' + dim = 0 + + # Split into 2 + slices = split_tensor_for_tp(params, key, dim, tensor) + assert len(slices) == 2 + assert slices[0].shape == (256, 128) + + # Combine back + combined = combine_tp_tensors(params, key, dim, slices) + assert torch.equal(combined, tensor) + + def test_mamba_in_proj_v2_roundtrip(self): + d_inner = 256 + n_groups = 4 + d_state = 64 + n_heads = 8 + d_model = 128 + tp_size = 2 + + params = argparse.Namespace( + mamba_d_inner=d_inner, mamba_d_state=d_state, + mamba2_n_groups=n_groups, mamba2_n_heads=n_heads, + mamba_version=2, target_tp_size=tp_size, + ) + + # Full in_proj: [2*d_inner + 2*n_groups*d_state + n_heads, d_model] + out_dim = 2 * d_inner + 2 * n_groups * d_state + n_heads + tensor = torch.randn(out_dim, d_model) + key = 'decoder.layers.0.mixer.in_proj.weight' + dim = 0 + + slices = split_tensor_for_tp(params, key, dim, tensor) + assert len(slices) == tp_size + + combined = combine_tp_tensors(params, key, dim, slices) + assert torch.allclose(combined, tensor), "Mamba v2 in_proj round-trip failed" + + def test_mamba_conv1d_weight_v2_roundtrip(self): + d_inner = 256 + n_groups = 4 + d_state = 64 + d_conv = 4 + tp_size = 2 + + params = argparse.Namespace( + mamba_d_inner=d_inner, mamba_d_state=d_state, + mamba2_n_groups=n_groups, mamba2_n_heads=8, + mamba_version=2, target_tp_size=tp_size, + ) + + conv_dim = d_inner + 2 * n_groups * d_state + tensor = torch.randn(conv_dim, 1, d_conv) + key = 'decoder.layers.0.mixer.conv1d.weight' + dim = 0 + + slices = split_tensor_for_tp(params, key, dim, tensor) + combined = combine_tp_tensors(params, key, dim, slices) + assert torch.allclose(combined, tensor), "conv1d weight round-trip failed" + + def test_mamba_conv1d_bias_v2_roundtrip(self): + d_inner = 256 + n_groups = 4 + d_state = 64 + tp_size = 2 + + params = argparse.Namespace( + mamba_d_inner=d_inner, mamba_d_state=d_state, + mamba2_n_groups=n_groups, mamba2_n_heads=8, + mamba_version=2, target_tp_size=tp_size, + ) + + conv_dim = d_inner + 2 * n_groups * d_state + tensor = torch.randn(conv_dim) + key = 'decoder.layers.0.mixer.conv1d.bias' + dim = 0 + + slices = split_tensor_for_tp(params, key, dim, tensor) + combined = combine_tp_tensors(params, key, dim, slices) + assert torch.allclose(combined, tensor), "conv1d bias round-trip failed" + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist tests +# --------------------------------------------------------------------------- + +class TestPatternWhitelist: + """validate_pattern_gpt_compatible rejects hybrid patterns GPTModel can't express.""" + + def test_accepts_mamba_attn_mlp(self): + # Standard hybrid with equal attn/MLP counts. + layer_types = parse_hybrid_layer_pattern("M*-M*-M*-") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_accepts_pure_transformer_pattern(self): + layer_types = parse_hybrid_layer_pattern("*-*-*-") + validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + + def test_accepts_pure_ssm_pattern(self): + # Pure-SSM models have no attention/MLP, so trivially GPT-compatible + # in the pattern sense (the GPT side would be empty). + layer_types = parse_hybrid_layer_pattern("MMMM") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_rejects_moe_symbol(self): + layer_types = parse_hybrid_layer_pattern("M*-E") + with pytest.raises(ValueError, match="not GPT-compatible"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_rejects_gdn_symbol(self): + layer_types = parse_hybrid_layer_pattern("G*-*-") + with pytest.raises(ValueError, match="not GPT-compatible"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_rejects_unequal_attn_mlp(self): + layer_types = parse_hybrid_layer_pattern("M**-") # 2 attn, 1 MLP + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_error_lists_offending_symbols(self): + layer_types = parse_hybrid_layer_pattern("M*-EG") + with pytest.raises(ValueError) as exc: + validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + msg = str(exc.value) + assert 'E' in msg + assert 'G' in msg + + +class TestSourceArgsWhitelist: + """validate_source_args_gpt_compatible rejects source checkpoints with + non-GPT-expressible features.""" + + def _ok_args(self, **overrides): + """Build a minimal args namespace that mimics a plain GPT/hybrid + training run. Any GPT-incompatible flags default to their + "off" value.""" + base = dict( + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + experimental_attention_variant=None, + linear_attention_freq=None, + heterogeneous_block_specs=False, + heterogeneous_layers_config_path=None, + heterogeneous_layers_config_encoded_json=None, + multi_latent_attention=False, + mtp_num_layers=None, + ) + base.update(overrides) + return argparse.Namespace(**base) + + def test_accepts_plain_gpt_args(self): + validate_source_args_gpt_compatible(self._ok_args(), 'gpt-to-mamba') + + def test_none_args_is_noop(self): + # Dist checkpoints sometimes have no cached args blob. + validate_source_args_gpt_compatible(None, 'gpt-to-mamba') + + def test_accepts_missing_optional_fields(self): + # Older checkpoints may not have every field; the validator should + # silently skip fields it doesn't find. + minimal = argparse.Namespace(num_moe_experts=None) + validate_source_args_gpt_compatible(minimal, 'mamba-to-gpt') + + def test_rejects_moe(self): + with pytest.raises(ValueError, match="MoE"): + validate_source_args_gpt_compatible( + self._ok_args(num_moe_experts=8), 'gpt-to-mamba' + ) + + def test_rejects_shared_expert(self): + with pytest.raises(ValueError, match="shared expert"): + validate_source_args_gpt_compatible( + self._ok_args(moe_shared_expert_intermediate_size=4096), + 'gpt-to-mamba', + ) + + def test_rejects_moe_layer_freq_list(self): + with pytest.raises(ValueError, match="MoE layers"): + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 0, 1, 0]), + 'gpt-to-mamba', + ) + + def test_accepts_moe_layer_freq_1(self): + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=1), 'gpt-to-mamba' + ) + + def test_rejects_experimental_attention(self): + with pytest.raises(ValueError, match="experimental attention"): + validate_source_args_gpt_compatible( + self._ok_args(experimental_attention_variant='gated_delta_net'), + 'gpt-to-mamba', + ) + + def test_rejects_linear_attention(self): + with pytest.raises(ValueError, match="linear attention"): + validate_source_args_gpt_compatible( + self._ok_args(linear_attention_freq=4), 'gpt-to-mamba' + ) + + def test_rejects_heterogeneous_block_specs(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_block_specs=True), + 'mamba-to-gpt', + ) + + def test_rejects_heterogeneous_config_path(self): + with pytest.raises(ValueError, match="heterogeneous"): + validate_source_args_gpt_compatible( + self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), + 'gpt-to-mamba', + ) + + def test_rejects_mla(self): + with pytest.raises(ValueError, match="Multi-Latent"): + validate_source_args_gpt_compatible( + self._ok_args(multi_latent_attention=True), 'gpt-to-mamba' + ) + + def test_rejects_mtp(self): + with pytest.raises(ValueError, match="Multi-Token Prediction"): + validate_source_args_gpt_compatible( + self._ok_args(mtp_num_layers=2), 'gpt-to-mamba' + ) + + def test_reports_multiple_reasons(self): + # Both MoE and MLA set: the error should surface both. + with pytest.raises(ValueError) as exc: + validate_source_args_gpt_compatible( + self._ok_args(num_moe_experts=8, multi_latent_attention=True), + 'gpt-to-mamba', + ) + msg = str(exc.value) + assert 'MoE' in msg + assert 'Multi-Latent' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py new file mode 100644 index 00000000000..1dc499e13b9 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py @@ -0,0 +1,467 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Integration tests for gpt_mamba_conversion.py. + +Creates minimal synthetic GPT checkpoints on disk, runs the full conversion +pipeline (load -> combine TP -> stitch PP -> convert -> split TP/PP -> save), +and verifies: + - Shapes, dtypes, and key names in the output checkpoint. + - Round-trip GPT -> Mamba -> GPT preserves attention and MLP weights exactly. + +Designed to run on a single-GPU node via SLURM (no distributed launch needed). +""" + +import argparse +import copy +import os +import shutil +import sys +import tempfile +from collections import OrderedDict +from types import SimpleNamespace + +import torch + +# Ensure the conversion tool is importable +sys.path.insert( + 0, + os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint'), +) + +from gpt_mamba_conversion import ( + get_checkpoint_iteration, + initialize_ssm_layer_params, + main as conversion_main, + parse_hybrid_layer_pattern, +) + + +# --------------------------------------------------------------------------- +# Helpers: create a minimal on-disk GPT checkpoint +# --------------------------------------------------------------------------- + +def make_checkpoint_args( + num_layers=4, + hidden_size=128, + num_attention_heads=4, + seq_length=256, + max_position_embeddings=256, + tp_size=1, + pp_size=1, + iteration=100, +): + """Build a minimal checkpoint 'args' namespace mirroring Megatron's.""" + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=hidden_size * 4, + seq_length=seq_length, + max_position_embeddings=max_position_embeddings, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + iteration=iteration, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + ) + + +def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.float32): + """Create a minimal GPT model state dict.""" + sd = OrderedDict() + + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + # attention + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + # MLP + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + return sd + + +def write_checkpoint_to_disk(root_dir, state_dict, ckpt_args, iteration=100): + """Write a single-rank (TP=1, PP=1) checkpoint to disk in Megatron format. + + Directory layout: + root_dir/ + latest_checkpointed_iteration.txt -> "100" + iter_0000100/ + mp_rank_00/ + model_optim_rng.pt + """ + iter_dir = os.path.join(root_dir, f'iter_{iteration:07d}', 'mp_rank_00') + os.makedirs(iter_dir, exist_ok=True) + + checkpoint = { + 'model': state_dict, + 'args': copy.deepcopy(ckpt_args), + 'checkpoint_version': 3.0, + 'iteration': iteration, + 'rng_state': [ + { + 'random_rng_state': [0] * 625, + 'np_rng_state': ('MT19937', [0] * 625, 0, 0, 0.0), + 'torch_rng_state': torch.ByteTensor(8), + 'cuda_rng_state': torch.ByteTensor(8), + 'rng_tracker_states': {}, + } + ], + } + + torch.save(checkpoint, os.path.join(iter_dir, 'model_optim_rng.pt')) + + with open(os.path.join(root_dir, 'latest_checkpointed_iteration.txt'), 'w') as f: + f.write(str(iteration)) + + return root_dir + + +def load_converted_state_dict(ckpt_dir): + """Load the state dict from a converted checkpoint (TP=1, PP=1).""" + iteration = get_checkpoint_iteration(ckpt_dir) + model_file = os.path.join( + ckpt_dir, f'iter_{iteration:07d}', 'mp_rank_00', 'model_optim_rng.pt' + ) + checkpoint = torch.load(model_file, map_location='cpu', weights_only=False) + return checkpoint['model'], checkpoint['args'] + + +# --------------------------------------------------------------------------- +# Test 1: GPT -> Mamba shapes, dtypes, and key names +# --------------------------------------------------------------------------- + +def test_gpt_to_mamba_shapes_and_keys(): + """Create a 4-layer GPT ckpt, convert to Mamba with M*-M*-M*-M*-, verify output.""" + print("\n=== Test 1: GPT -> Mamba shapes, dtypes, and key names ===") + + num_layers = 4 + hidden_size = 128 + d_state = 16 + n_groups = 2 + head_dim = 32 + d_inner = hidden_size * 2 + n_heads = d_inner // head_dim + pattern = "M*-M*-M*-M*-" # 12 layers: 4 SSM, 4 attn, 4 MLP + + tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_test_') + try: + src_dir = os.path.join(tmpdir, 'gpt_src') + dst_dir = os.path.join(tmpdir, 'mamba_dst') + + ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) + gpt_sd = make_gpt_state_dict(num_layers, hidden_size) + write_checkpoint_to_disk(src_dir, gpt_sd, ckpt_args) + + # Run conversion + args = argparse.Namespace( + direction='gpt-to-mamba', + load_dir=src_dir, + save_dir=dst_dir, + hybrid_layer_pattern=pattern, + target_tp_size=1, + target_pp_size=1, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=d_state, + mamba2_n_groups=n_groups, + mamba2_head_dim=head_dim, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + ) + conversion_main(args) + + # Load and verify + mamba_sd, mamba_args = load_converted_state_dict(dst_dir) + + layer_types = parse_hybrid_layer_pattern(pattern) + total_layers = len(layer_types) + + # 1) Check total layer count in args + assert mamba_args.num_layers == total_layers, ( + f"Expected num_layers={total_layers}, got {mamba_args.num_layers}" + ) + + # 2) Check key names + assert 'decoder.final_norm.weight' in mamba_sd, "Missing decoder.final_norm.weight" + assert 'decoder.final_layernorm.weight' not in mamba_sd, "Old final_layernorm key present" + assert 'embedding.word_embeddings.weight' in mamba_sd + assert 'output_layer.weight' in mamba_sd + + # 3) Check SSM layer params exist with correct shapes + ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] + conv_dim = d_inner + 2 * n_groups * d_state + in_proj_out = 2 * d_inner + 2 * n_groups * d_state + n_heads + + for idx in ssm_indices: + prefix = f'decoder.layers.{idx}.mixer.' + assert prefix + 'A_log' in mamba_sd, f"Missing {prefix}A_log" + assert mamba_sd[prefix + 'A_log'].shape == (n_heads,) + assert mamba_sd[prefix + 'A_log'].dtype == torch.float32 + + assert prefix + 'D' in mamba_sd + assert mamba_sd[prefix + 'D'].shape == (n_heads,) + + assert prefix + 'dt_bias' in mamba_sd + assert mamba_sd[prefix + 'dt_bias'].shape == (n_heads,) + + assert prefix + 'conv1d.weight' in mamba_sd + assert mamba_sd[prefix + 'conv1d.weight'].shape == (conv_dim, 1, 4) + + assert prefix + 'conv1d.bias' in mamba_sd + assert mamba_sd[prefix + 'conv1d.bias'].shape == (conv_dim,) + + assert prefix + 'in_proj.weight' in mamba_sd + assert mamba_sd[prefix + 'in_proj.weight'].shape == (in_proj_out, hidden_size) + + assert prefix + 'norm.weight' in mamba_sd + assert mamba_sd[prefix + 'norm.weight'].shape == (d_inner,) + + assert prefix + 'out_proj.weight' in mamba_sd + assert mamba_sd[prefix + 'out_proj.weight'].shape == (hidden_size, d_inner) + + # 4) Check attention layer params exist at correct indices + attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] + for idx in attn_indices: + prefix = f'decoder.layers.{idx}.' + assert prefix + 'self_attention.linear_qkv.weight' in mamba_sd + assert mamba_sd[prefix + 'self_attention.linear_qkv.weight'].shape == ( + 3 * hidden_size, hidden_size + ) + + # 5) Check MLP layer params exist at correct indices + mlp_indices = [i for i, t in enumerate(layer_types) if t == '-'] + for idx in mlp_indices: + prefix = f'decoder.layers.{idx}.' + assert prefix + 'mlp.linear_fc1.weight' in mamba_sd + assert mamba_sd[prefix + 'mlp.linear_fc1.weight'].shape == ( + 4 * hidden_size, hidden_size + ) + + print("PASSED: All shapes, dtypes, and key names verified.\n") + + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Test 2: Round-trip GPT -> Mamba -> GPT weight preservation +# --------------------------------------------------------------------------- + +def test_roundtrip_weight_preservation(): + """Convert GPT -> Mamba -> GPT and verify attention/MLP weights match exactly.""" + print("\n=== Test 2: Round-trip GPT -> Mamba -> GPT weight preservation ===") + + num_layers = 2 + hidden_size = 64 + pattern = "M*-M*-" + + tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_rt_test_') + try: + src_gpt_dir = os.path.join(tmpdir, 'gpt_src') + mamba_dir = os.path.join(tmpdir, 'mamba_mid') + dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') + + # Create and save source GPT checkpoint + ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) + gpt_sd = make_gpt_state_dict(num_layers, hidden_size) + write_checkpoint_to_disk(src_gpt_dir, gpt_sd, ckpt_args) + + common_args = dict( + hybrid_layer_pattern=pattern, + target_tp_size=1, + target_pp_size=1, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + ) + + # Step 1: GPT -> Mamba + conversion_main(argparse.Namespace( + direction='gpt-to-mamba', + load_dir=src_gpt_dir, + save_dir=mamba_dir, + **common_args, + )) + + # Step 2: Mamba -> GPT + conversion_main(argparse.Namespace( + direction='mamba-to-gpt', + load_dir=mamba_dir, + save_dir=dst_gpt_dir, + **common_args, + )) + + # Load and compare + recovered_sd, recovered_args = load_converted_state_dict(dst_gpt_dir) + + assert recovered_args.num_layers == num_layers, ( + f"Expected num_layers={num_layers}, got {recovered_args.num_layers}" + ) + + # Compare every key in the original + mismatches = [] + for key, original_tensor in gpt_sd.items(): + # final_layernorm is renamed in the round trip + if key not in recovered_sd: + mismatches.append(f"MISSING: {key}") + continue + if not torch.equal(original_tensor, recovered_sd[key]): + max_diff = (original_tensor - recovered_sd[key]).abs().max().item() + mismatches.append(f"MISMATCH: {key} (max_diff={max_diff})") + + if mismatches: + for m in mismatches: + print(f" FAIL: {m}") + raise AssertionError( + f"Round-trip failed with {len(mismatches)} mismatches:\n" + + "\n".join(mismatches) + ) + + print("PASSED: All attention and MLP weights preserved exactly.\n") + + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Test 3: Verify that Mamba -> GPT discards SSM params cleanly +# --------------------------------------------------------------------------- + +def test_mamba_to_gpt_discards_ssm(): + """Convert Mamba -> GPT and verify no SSM keys leak through.""" + print("\n=== Test 3: Mamba -> GPT discards SSM params ===") + + hidden_size = 64 + pattern = "M*-M*-" + d_inner = hidden_size * 2 + d_state = 16 + n_groups = 2 + head_dim = 32 + n_heads = d_inner // head_dim + + tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_discard_test_') + try: + # Build a Mamba-style state dict + mamba_sd = OrderedDict() + mamba_sd['embedding.word_embeddings.weight'] = torch.randn(512, hidden_size) + mamba_sd['output_layer.weight'] = torch.randn(512, hidden_size) + mamba_sd['decoder.final_norm.weight'] = torch.randn(hidden_size) + + layer_types = parse_hybrid_layer_pattern(pattern) + for i, lt in enumerate(layer_types): + p = f'decoder.layers.{i}.' + if lt == 'M': + ssm = initialize_ssm_layer_params( + i, hidden_size, d_inner, d_state, n_groups, n_heads, head_dim + ) + mamba_sd.update(ssm) + elif lt == '*': + mamba_sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size) + mamba_sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size + ) + mamba_sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size + ) + elif lt == '-': + mamba_sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size) + mamba_sd[p + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size + ) + mamba_sd[p + 'mlp.linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size + ) + + # Write to disk + src_dir = os.path.join(tmpdir, 'mamba_src') + dst_dir = os.path.join(tmpdir, 'gpt_dst') + ckpt_args = make_checkpoint_args( + num_layers=len(layer_types), hidden_size=hidden_size + ) + write_checkpoint_to_disk(src_dir, mamba_sd, ckpt_args) + + # Convert + conversion_main(argparse.Namespace( + direction='mamba-to-gpt', + load_dir=src_dir, + save_dir=dst_dir, + hybrid_layer_pattern=pattern, + target_tp_size=1, + target_pp_size=1, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=d_state, + mamba2_n_groups=n_groups, + mamba2_head_dim=head_dim, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + )) + + gpt_sd, gpt_args = load_converted_state_dict(dst_dir) + + # Verify no SSM keys + ssm_keys = [k for k in gpt_sd if 'mixer.' in k] + assert len(ssm_keys) == 0, f"SSM keys leaked: {ssm_keys}" + + # Verify correct GPT layer count + assert gpt_args.num_layers == 2 + + # Verify final_layernorm renamed back + assert 'decoder.final_layernorm.weight' in gpt_sd + assert 'decoder.final_norm.weight' not in gpt_sd + + print("PASSED: No SSM keys in GPT output, norms renamed correctly.\n") + + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == '__main__': + print("=" * 60) + print("GPT <-> Mamba Conversion Integration Tests") + print("=" * 60) + + test_gpt_to_mamba_shapes_and_keys() + test_roundtrip_weight_preservation() + test_mamba_to_gpt_discards_ssm() + + print("=" * 60) + print("ALL INTEGRATION TESTS PASSED") + print("=" * 60) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py new file mode 100644 index 00000000000..04ecb08a532 --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -0,0 +1,360 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Parallelism-matrix integration tests for gpt_mamba_conversion.py. + +Covers every combination the user cares about: + + source format exercises + TP TP=2, PP=1 legacy TP-combine + TP-split + PP TP=1, PP=2 legacy PP-stitch + FSDP world=1 dist (torch_dist) DCP load + DCP save + TP+PP TP=2, PP=2 legacy TP+PP both paths + TP+FSDP world=1 dist DCP load + DCP save + PP+FSDP world=1 dist DCP load + DCP save + TP+PP+FSDP world=1 dist DCP load + DCP save + +Legacy configs synthesize ``mp_rank_XX[_YYY]/model_optim_rng.pt`` shards by +re-using the converter's own save routine (which implements the exact TP-split +and PP-stitch layout Megatron produces). Dist configs synthesize a DCP +checkpoint via a single-rank ``torch.distributed.checkpoint.save``; at the +converter level the TP/PP/FSDP sharding layout of a dist checkpoint is +abstracted away by DCP's global-shape metadata, so one save code path +exercises every ``*+FSDP`` combination. Each config is run as a distinct test +to document the matrix and catch regressions in the dispatch logic. + +Designed to run on a single-GPU node via SLURM; no torchrun needed. +""" + +import argparse +import copy +import os +import shutil +import sys +import tempfile +from collections import OrderedDict +from types import SimpleNamespace + +import torch + +# Make the conversion tool and the sibling integration-test helpers importable +# under both `python ` and `pytest` (pytest doesn't put the test file's +# directory on sys.path). +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + +from gpt_mamba_conversion import ( + combine_tp_shards, + convert_gpt_to_mamba, + get_checkpoint_iteration, + load_checkpoint_shards, + main as conversion_main, + parse_hybrid_layer_pattern, + save_checkpoint_shards, + stitch_pp_shards, +) + +from test_gpt_mamba_conversion_integration import ( + make_checkpoint_args, + make_gpt_state_dict, +) + + +# --------------------------------------------------------------------------- +# Legacy (mp_rank_XX) fixture builders +# --------------------------------------------------------------------------- + +def _save_legacy_sharded(root_dir, full_sd, ckpt_args, tp_size, pp_size, + hybrid_layer_pattern='', + hidden_size=128, + iteration=100): + """Write a full state dict to disk as a sharded legacy checkpoint. + + We delegate to ``save_checkpoint_shards`` so the on-disk layout matches + exactly what Megatron training would produce at the given TP/PP. + """ + # save_checkpoint_shards expects a "sample_model" shape that mirrors a + # single rank's on-disk file. Any args object with the target fields works. + ckpt_args = copy.deepcopy(ckpt_args) + ckpt_args.tensor_model_parallel_size = tp_size + ckpt_args.pipeline_model_parallel_size = pp_size + sample_model = { + 'args': ckpt_args, + 'checkpoint_version': 3.0, + 'iteration': iteration, + 'rng_state': [], + } + params = SimpleNamespace( + target_tp_size=tp_size, + target_pp_size=pp_size, + target_num_layers=ckpt_args.num_layers, + reset_iterations=False, + # Mamba-only TP-split args; irrelevant for pure GPT shards but required. + mamba_version=2, + mamba_d_inner=hidden_size * 2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_n_heads=hidden_size * 2 // 32, + ) + save_checkpoint_shards(full_sd, sample_model, params, root_dir, iteration) + + +# --------------------------------------------------------------------------- +# Dist (torch_dist / fsdp_dtensor) fixture builders +# --------------------------------------------------------------------------- + +def _save_dist_checkpoint(root_dir, full_sd, ckpt_args, iteration=100, + prefix='model.', backend='torch_dist'): + """Write a full state dict as a single-rank DCP checkpoint. + + From the converter's POV, this is indistinguishable from a multi-rank + TP+PP+FSDP save: DCP stores each tensor's global shape in its metadata + and the read planner reassembles the full tensor regardless of how many + processes wrote it. + """ + from dist_checkpoint_io import ( + ensure_single_rank_process_group, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + + ensure_single_rank_process_group() + + iter_dir = os.path.join(root_dir, f'iter_{iteration:07d}') + common_state = { + 'args': copy.deepcopy(ckpt_args), + 'checkpoint_version': 3.0, + 'iteration': iteration, + } + save_dist_checkpoint_full( + full_sd, common_state, iter_dir, + model_prefix=prefix, backend=backend, + ) + write_latest_iteration_marker(iter_dir, iteration) + + +# --------------------------------------------------------------------------- +# Output readers +# --------------------------------------------------------------------------- + +def _load_converted_dist(ckpt_dir): + """Read a dist-format converted checkpoint back into a full state dict.""" + from dist_checkpoint_io import load_dist_checkpoint_full + sd, common, prefix, backend, iteration = load_dist_checkpoint_full(ckpt_dir) + return sd, common.get('args', None) + + +def _load_converted_legacy_full(ckpt_dir): + """Read a legacy TP+PP-sharded converted checkpoint into a full state dict. + + Peeks the first shard to discover TP/PP sizes and total layers, then reuses + the converter's own load / TP-combine / PP-stitch routines. + """ + iteration = get_checkpoint_iteration(ckpt_dir) + model_dir = os.path.join(ckpt_dir, f'iter_{iteration:07d}') + first_shard = sorted(os.listdir(model_dir))[0] + sample = torch.load( + os.path.join(model_dir, first_shard, 'model_optim_rng.pt'), + map_location='cpu', weights_only=False, + ) + tp_size = sample['args'].tensor_model_parallel_size + pp_size = sample['args'].pipeline_model_parallel_size + num_layers = sample['args'].num_layers + num_layers_per_pp_rank = num_layers // pp_size + + all_shards, sample = load_checkpoint_shards( + ckpt_dir, iteration, tp_size, pp_size, + ) + # combine_tp_tensors only touches mamba-specific branches for mamba keys; + # any hidden_size-consistent defaults work for GPT-only outputs. + combine_params = SimpleNamespace( + mamba_version=2, + mamba_d_inner=0, + mamba_d_state=0, + mamba2_n_groups=0, + mamba2_n_heads=0, + ) + combined_pp = [combine_tp_shards(all_shards[pp], combine_params) + for pp in range(pp_size)] + full = stitch_pp_shards(combined_pp, num_layers_per_pp_rank) + return full, sample['args'] + + +def _load_converted(ckpt_dir, output_format): + if output_format == 'legacy': + return _load_converted_legacy_full(ckpt_dir) + return _load_converted_dist(ckpt_dir) + + +# --------------------------------------------------------------------------- +# Core scenario runner +# --------------------------------------------------------------------------- + +def _run_scenario( + label, + source_format, + source_tp, + source_pp, + target_format, + target_tp=1, + target_pp=1, + num_layers=4, + hidden_size=128, + pattern="M*-M*-M*-M*-", + source_prefix='model.', +): + """Build a GPT source ckpt, convert GPT->Mamba->GPT, verify round-trip.""" + print(f"\n=== {label} ===") + print(f" source={source_format} (tp={source_tp}, pp={source_pp}, prefix='{source_prefix}')") + print(f" target={target_format} (tp={target_tp}, pp={target_pp})") + + tmpdir = tempfile.mkdtemp(prefix=f'gpt_mamba_{label.replace(" ", "_")}_') + try: + src_gpt_dir = os.path.join(tmpdir, 'gpt_src') + mamba_dir = os.path.join(tmpdir, 'mamba_mid') + dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') + + # --- Build source --- + ckpt_args = make_checkpoint_args( + num_layers=num_layers, hidden_size=hidden_size, + tp_size=source_tp, pp_size=source_pp, + ) + gpt_sd = make_gpt_state_dict(num_layers, hidden_size) + + if source_format == 'legacy': + _save_legacy_sharded( + src_gpt_dir, gpt_sd, ckpt_args, source_tp, source_pp, + hidden_size=hidden_size, + ) + else: + _save_dist_checkpoint( + src_gpt_dir, gpt_sd, ckpt_args, + prefix=source_prefix, backend=source_format, + ) + + common_kwargs = dict( + hybrid_layer_pattern=pattern, + target_tp_size=target_tp, + target_pp_size=target_pp, + d_model=hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format=target_format, + ) + + # --- GPT -> Mamba --- + conversion_main(argparse.Namespace( + direction='gpt-to-mamba', + load_dir=src_gpt_dir, + save_dir=mamba_dir, + **common_kwargs, + )) + + # --- Mamba -> GPT --- + conversion_main(argparse.Namespace( + direction='mamba-to-gpt', + load_dir=mamba_dir, + save_dir=dst_gpt_dir, + **common_kwargs, + )) + + # --- Verify --- + recovered_sd, recovered_args = _load_converted(dst_gpt_dir, target_format) + layer_types = parse_hybrid_layer_pattern(pattern) + + mismatches = [] + for key, original in gpt_sd.items(): + if key not in recovered_sd: + mismatches.append(f"MISSING: {key}") + continue + if not torch.equal(original, recovered_sd[key]): + max_diff = (original - recovered_sd[key]).abs().max().item() + mismatches.append(f"MISMATCH: {key} (max_diff={max_diff})") + + if mismatches: + for m in mismatches[:10]: + print(f" FAIL: {m}") + raise AssertionError( + f"{label} failed with {len(mismatches)} weight mismatches" + ) + + # SSM keys must be absent in the final GPT output. + assert not any('mixer.' in k for k in recovered_sd), \ + f"SSM keys leaked into final GPT output: " \ + f"{[k for k in recovered_sd if 'mixer.' in k][:5]}" + + print(f"PASSED: {label}") + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Test cases — one per parallelism combo +# --------------------------------------------------------------------------- + +def test_tp_only_legacy(): + _run_scenario("TP only (legacy)", 'legacy', 2, 1, 'legacy', target_tp=2, target_pp=1) + + +def test_pp_only_legacy(): + _run_scenario("PP only (legacy)", 'legacy', 1, 2, 'legacy', target_tp=1, target_pp=2) + + +def test_tp_pp_legacy(): + _run_scenario("TP+PP (legacy)", 'legacy', 2, 2, 'legacy', target_tp=2, target_pp=2) + + +def test_fsdp_only_dist(): + _run_scenario("FSDP only (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') + + +def test_tp_fsdp_dist(): + _run_scenario("TP + FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') + + +def test_pp_fsdp_dist(): + _run_scenario("PP + FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') + + +def test_tp_pp_fsdp_dist(): + _run_scenario("TP+PP+FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') + + +def test_fsdp_dtensor_prefix(): + """fsdp_dtensor backend uses the 'model.module.' key prefix — verify we + auto-detect and strip it correctly.""" + _run_scenario( + "FSDP dtensor prefix", 'fsdp_dtensor', 1, 1, 'fsdp_dtensor', + source_prefix='model.module.', + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == '__main__': + print("=" * 60) + print("GPT <-> Mamba Conversion Parallelism Matrix Tests") + print("=" * 60) + + test_tp_only_legacy() + test_pp_only_legacy() + test_tp_pp_legacy() + test_fsdp_only_dist() + test_tp_fsdp_dist() + test_pp_fsdp_dist() + test_tp_pp_fsdp_dist() + test_fsdp_dtensor_prefix() + + print("=" * 60) + print("ALL PARALLELISM MATRIX TESTS PASSED") + print("=" * 60) diff --git a/tools/checkpoint/dist_checkpoint_io.py b/tools/checkpoint/dist_checkpoint_io.py new file mode 100644 index 00000000000..1fece3a3f91 --- /dev/null +++ b/tools/checkpoint/dist_checkpoint_io.py @@ -0,0 +1,264 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Distributed checkpoint I/O helpers for structural model-conversion tools. + +Provides format detection, model-free full-tensor loading, and single-rank +saving for Megatron-LM distributed checkpoints (``torch_dist`` and +``fsdp_dtensor`` backends). This lets conversion tools operate on +TP+PP+FSDP-trained checkpoints without needing to instantiate the model. + +The key observation is that PyTorch DCP stores each logical parameter with a +``global_shape`` in its metadata, and the TP / PP / FSDP slicing is just an +on-disk layout detail handled by the read planner. Loading into a plain +``torch.empty(global_shape)`` state dict on rank 0 therefore yields fully +gathered tensors regardless of the parallelism the checkpoint was trained with. +""" + +import os +from collections import OrderedDict +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint import ( + DefaultLoadPlanner, + DefaultSavePlanner, + FileSystemReader, + FileSystemWriter, +) +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + TensorStorageMetadata, +) + +from megatron.core.dist_checkpointing.core import ( + CheckpointingConfig, + maybe_load_config, + save_config, +) +from megatron.core.dist_checkpointing.strategies.common import ( + load_common, + save_common, +) + + +FORMAT_LEGACY = 'legacy' +FORMAT_TORCH_DIST = 'torch_dist' +FORMAT_FSDP_DTENSOR = 'fsdp_dtensor' +DIST_FORMATS = (FORMAT_TORCH_DIST, FORMAT_FSDP_DTENSOR) + +# Prefixes under which model weights may be keyed in a dist checkpoint. +_KNOWN_MODEL_PREFIXES = ('model.module.module.', 'model.module.', 'model.', '') +# Well-known bare-key suffixes we probe for when detecting the prefix. +_PROBE_SUFFIXES = ( + 'embedding.word_embeddings.weight', + 'decoder.layers.', + 'decoder.final_norm.', + 'decoder.final_layernorm.', + 'output_layer.weight', +) +# Keys that identify non-model state we drop during architecture conversion. +_NON_MODEL_TOP_LEVEL_PREFIXES = ( + 'optimizer.', + 'rng_state', + 'rerun_state_machine_state', +) + + +def resolve_checkpoint_subdir(load_dir): + """Return ``(ckpt_dir, iteration)``. + + Megatron writes checkpoints either flat or under ``iter_XXXXXXX/``. This + picks the right directory and reports the iteration when it can be + determined. + """ + if os.path.exists(os.path.join(load_dir, 'metadata.json')): + return load_dir, None + + latest_iter = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') + if os.path.exists(latest_iter): + with open(latest_iter, 'r') as f: + iteration = int(f.read().strip()) + iter_dir = os.path.join(load_dir, f'iter_{iteration:07d}') + if os.path.isdir(iter_dir): + return iter_dir, iteration + + return load_dir, None + + +def detect_checkpoint_format(load_dir): + """Return one of ``{'legacy', 'torch_dist', 'fsdp_dtensor'}``.""" + ckpt_dir, _ = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is not None: + return config.sharded_backend + + if os.path.isdir(ckpt_dir) and any( + name.startswith('mp_rank_') for name in os.listdir(ckpt_dir) + ): + return FORMAT_LEGACY + + raise ValueError(f"Unrecognized checkpoint format at {load_dir}") + + +def ensure_single_rank_process_group(): + """Initialize a 1-rank gloo process group if one isn't already up. + + DCP requires a default process group; this lets the conversion tool run + in a plain ``python`` invocation (no ``torchrun`` needed). + """ + if not dist.is_available(): + raise RuntimeError("torch.distributed is not available.") + if dist.is_initialized(): + return + os.environ.setdefault('MASTER_ADDR', '127.0.0.1') + os.environ.setdefault('MASTER_PORT', '29500') + os.environ.setdefault('RANK', '0') + os.environ.setdefault('WORLD_SIZE', '1') + os.environ.setdefault('LOCAL_RANK', '0') + dist.init_process_group(backend='gloo', rank=0, world_size=1) + + +def detect_model_prefix(keys): + """Return the prefix under which model weights live in ``keys``. + + Looks for a recognizable suffix (``embedding.word_embeddings.weight``, + ``decoder.layers.``, etc.) and returns the matching prefix. Falls back + to ``''`` if nothing obvious is found. + """ + keys = list(keys) + for prefix in _KNOWN_MODEL_PREFIXES: + for suffix in _PROBE_SUFFIXES: + probe = prefix + suffix + for key in keys: + if key.startswith(probe): + return prefix + return '' + + +def _is_non_model_key(bare_key): + if bare_key.startswith(_NON_MODEL_TOP_LEVEL_PREFIXES): + return True + # _extra_state blobs are TE per-module state; they are tied to a specific + # TP/parallelism configuration and aren't meaningful after a structural + # model conversion, so we drop them. + if '_extra_state' in bare_key: + return True + return False + + +def load_dist_checkpoint_full(load_dir): + """Load a dist checkpoint and return fully-gathered model weights. + + Returns: + model_state_dict (OrderedDict[str, torch.Tensor]): bare keys, full + tensors on CPU. Optimizer state, RNG state, and ``_extra_state`` + blobs are filtered out. + common_state (dict): contents of ``common.pt`` (e.g. ``args``). + model_prefix (str): the prefix we stripped (re-apply on save). + backend (str): ``'torch_dist'`` or ``'fsdp_dtensor'``. + iteration (int or None): iteration number if discoverable. + """ + ensure_single_rank_process_group() + + ckpt_dir, iteration = resolve_checkpoint_subdir(load_dir) + config = maybe_load_config(ckpt_dir) + if config is None: + raise ValueError( + f"{load_dir} is not a distributed checkpoint (no metadata.json)" + ) + backend = config.sharded_backend + + reader = FileSystemReader(ckpt_dir) + metadata = reader.read_metadata() + + model_prefix = detect_model_prefix(metadata.state_dict_metadata.keys()) + + raw_state_dict = {} + for key, md in metadata.state_dict_metadata.items(): + if not isinstance(md, TensorStorageMetadata): + continue + if model_prefix and not key.startswith(model_prefix): + continue + bare_key = key[len(model_prefix):] if model_prefix else key + if _is_non_model_key(bare_key): + continue + raw_state_dict[key] = torch.empty( + md.size, dtype=md.properties.dtype, device='cpu' + ) + + if not raw_state_dict: + raise ValueError( + f"No model tensors found in {ckpt_dir} (detected prefix " + f"'{model_prefix}', backend '{backend}')." + ) + + dcp.load(raw_state_dict, storage_reader=reader, planner=DefaultLoadPlanner()) + + model_state_dict = OrderedDict() + for key, tensor in raw_state_dict.items(): + bare_key = key[len(model_prefix):] if model_prefix else key + model_state_dict[bare_key] = tensor + + common_state = {} + try: + common_state = load_common(ckpt_dir) + except Exception: + pass + + return model_state_dict, common_state, model_prefix, backend, iteration + + +def save_dist_checkpoint_full( + model_state_dict, + common_state, + save_dir, + model_prefix='model.', + backend=FORMAT_TORCH_DIST, +): + """Save a fully-gathered state dict as a distributed checkpoint. + + The output is written as a single-rank, fully-replicated DCP checkpoint + plus ``common.pt`` and ``metadata.json``. A downstream Megatron training + job reads it back through ``dist_checkpointing.load()`` with its own + sharded_state_dict template — TP+PP+FSDP resharding happens transparently + on load, since the on-disk tensors carry their full logical shape. + """ + ensure_single_rank_process_group() + + os.makedirs(save_dir, exist_ok=True) + + raw_state_dict = OrderedDict() + for bare_key, tensor in model_state_dict.items(): + full_key = f"{model_prefix}{bare_key}" if model_prefix else bare_key + raw_state_dict[full_key] = tensor.contiguous() if tensor.is_contiguous() else tensor.contiguous() + + writer = FileSystemWriter(save_dir) + dcp.save( + state_dict=raw_state_dict, + storage_writer=writer, + planner=DefaultSavePlanner(), + ) + + if common_state: + save_common(common_state, save_dir) + + if dist.get_rank() == 0: + save_config(CheckpointingConfig(sharded_backend=backend), save_dir) + dist.barrier() + + +def write_latest_iteration_marker(save_dir, iteration): + """Mirror the legacy ``latest_checkpointed_iteration.txt`` convention. + + When ``save_dir`` points at a top-level checkpoint root with an + ``iter_XXXXXXX/`` subdirectory, the tracker file lets Megatron auto-find + the latest iteration on load. + """ + parent = os.path.dirname(save_dir.rstrip('/')) or save_dir + if os.path.basename(save_dir.rstrip('/')).startswith('iter_'): + tracker = os.path.join(parent, 'latest_checkpointed_iteration.txt') + with open(tracker, 'w') as f: + f.write(str(iteration)) diff --git a/tools/checkpoint/gpt_mamba_conversion.py b/tools/checkpoint/gpt_mamba_conversion.py new file mode 100644 index 00000000000..ff3d5851b4c --- /dev/null +++ b/tools/checkpoint/gpt_mamba_conversion.py @@ -0,0 +1,1392 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +GPT <-> Mamba Checkpoint Conversion Tool +========================================= + +Directly converts checkpoints between GPTModel (homogeneous Transformer) and +MambaModel (hybrid Mamba+Transformer) without going through HuggingFace as an +intermediary. + +Supported directions: + gpt-to-mamba : Convert a GPT checkpoint to Mamba hybrid format. + mamba-to-gpt : Convert a Mamba hybrid checkpoint to GPT format. + +How the hybrid layer pattern maps GPT layers (gpt-to-mamba): + - Each GPT layer contains both attention and MLP sub-layers. + - The target Mamba model's hybrid_layer_pattern specifies per-layer types: + M = Mamba SSM layer + * = Attention-only layer + - = MLP-only layer + G = GDN layer + E = MoE layer + - GPT layer i's attention params map to the i-th '*' layer in the pattern. + - GPT layer i's MLP params map to the i-th '-' layer in the pattern. + - The number of '*' and '-' layers in the pattern must both equal the number + of GPT layers. + - Mamba SSM ('M') layers have no GPT equivalent and are initialized from + scratch using standard Mamba initialization. + +What happens to SSM parameters: + gpt-to-mamba: SSM layers (M) are initialized from scratch: + - A_log: log(uniform(1, 16)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones + - out_proj.weight: kaiming_uniform(a=sqrt(5)) + - norm.weight: ones + mamba-to-gpt: SSM layers are discarded with a warning. + +Supported checkpoint formats: + - legacy : mp_rank_XX[_YYY]/model_optim_rng.pt (TP + PP, no FSDP). + - torch_dist : Megatron distributed checkpoint (TP + PP + FSDP). + - fsdp_dtensor : FSDP DTensor export (TP + PP + FSDP). + + For distributed formats, PyTorch DCP gathers TP/PP/FSDP shards via the + checkpoint's global-shape metadata, so no explicit TP/PP/DP config is + needed on input. The input format is auto-detected; the output format + defaults to the input format. + +GPT compatibility whitelist (safeguard): + GPTModel is a strict homogeneous transformer (self-attention + MLP per + layer, standard linear_qkv / linear_fc1 / linear_fc2 state-dict keys). + The converter fails fast if either side uses features that GPTModel + cannot express. + + Rejected pattern symbols: 'G' (GDN), 'D' (DS-attention), 'E' (MoE). + Allowed: 'M' (Mamba SSM), '*' (attention), '-' (MLP). The number of + '*' and '-' layers must be equal. + + Rejected source-args features (checked against the args stored in the + source checkpoint): + - num_moe_experts / moe_shared_expert_intermediate_size / moe_layer_freq + - experimental_attention_variant (gated_delta_net, dsa, ...) + - linear_attention_freq + - heterogeneous_block_specs / heterogeneous_layers_config_path + - multi_latent_attention (MLA) + - mtp_num_layers (Multi-Token Prediction) + + See `validate_pattern_gpt_compatible` and + `validate_source_args_gpt_compatible` for the exact rules. + +Example commands: + # GPT -> Mamba (legacy TP+PP checkpoint) + python tools/checkpoint/gpt_mamba_conversion.py \\ + --direction gpt-to-mamba \\ + --load-dir /path/to/gpt-checkpoint \\ + --save-dir /path/to/mamba-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --target-tp-size 1 \\ + --target-pp-size 1 \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 + + # GPT -> Mamba (TP+PP+FSDP dist checkpoint) + python tools/checkpoint/gpt_mamba_conversion.py \\ + --direction gpt-to-mamba \\ + --load-dir /path/to/gpt-dist-checkpoint \\ + --save-dir /path/to/mamba-dist-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 + + # Mamba -> GPT (legacy) + python tools/checkpoint/gpt_mamba_conversion.py \\ + --direction mamba-to-gpt \\ + --load-dir /path/to/mamba-checkpoint \\ + --save-dir /path/to/gpt-checkpoint \\ + --hybrid-layer-pattern "M*-M*-M*-M*-" \\ + --target-tp-size 1 \\ + --target-pp-size 1 \\ + --d-model 4096 \\ + --mamba-d-state 128 \\ + --mamba2-n-groups 8 \\ + --mamba2-head-dim 64 +""" + +import argparse +import copy +import math +import os +import re +from collections import OrderedDict + +import torch + +from dist_checkpoint_io import ( + DIST_FORMATS, + FORMAT_LEGACY, + FORMAT_TORCH_DIST, + detect_checkpoint_format, + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, +) + + +# --------------------------------------------------------------------------- +# TP split-dim mapping (reused from hybrid_conversion.py) +# --------------------------------------------------------------------------- + +# Maps parameter-name substrings to the tensor dimension along which they are +# sharded across TP ranks. -1 means "replicated" (not sharded). +TP_SPLIT_DIM = { + # embeddings / output + 'word_embeddings.weight': 0, + 'output_layer.weight': 0, + # norms (replicated) + 'norm.weight': -1, + 'final_norm.weight': -1, + 'final_layernorm.weight': -1, + 'final_layernorm.bias': -1, + # mamba SSM params + 'A_log': 0, + 'D': 0, + 'dt_bias': 0, + 'in_proj.weight': 0, + 'conv1d.weight': 0, + 'conv1d.bias': 0, + 'x_proj.weight': 1, + 'dt_proj.weight': 0, + 'dt_proj.bias': 0, + 'out_proj.weight': 1, + 'mixer.norm.weight': 0, + # MLP (transformer-style) + 'linear_fc1.layer_norm_weight': -1, + 'linear_fc1.weight': 0, + 'linear_fc2.weight': 1, + # attention (transformer-style) + 'self_attention.linear_proj.weight': 1, + 'self_attention.linear_qkv.layer_norm_weight': -1, + 'self_attention.linear_qkv.weight': 0, + # standalone layer norms (used in non-TE / "local" transformer impl) + 'input_layernorm.weight': -1, + 'input_layernorm.bias': -1, + 'pre_mlp_layernorm.weight': -1, + 'pre_mlp_layernorm.bias': -1, + # TE-fused layer norms in Mamba in_proj + 'in_proj.layer_norm_weight': -1, + 'in_proj.layer_norm_bias': -1, +} + + +def get_split_dim(tensor_name): + """Determine the TP-split dimension for a given parameter name.""" + # Disambiguate mixer.norm.weight vs generic norm.weight + if 'norm.weight' in tensor_name: + if 'mixer.norm.weight' in tensor_name: + return TP_SPLIT_DIM['mixer.norm.weight'] + elif 'final_norm.weight' in tensor_name: + return TP_SPLIT_DIM['final_norm.weight'] + elif 'final_layernorm.weight' in tensor_name: + return TP_SPLIT_DIM['final_layernorm.weight'] + elif 'layer_norm_weight' in tensor_name: + # TE-fused layer norm weights + for key in TP_SPLIT_DIM: + if key in tensor_name: + return TP_SPLIT_DIM[key] + return -1 + else: + return TP_SPLIT_DIM['norm.weight'] + + for key in TP_SPLIT_DIM: + if key in tensor_name: + return TP_SPLIT_DIM[key] + raise ValueError(f"Unknown tensor name for TP splitting: {tensor_name}") + + +# --------------------------------------------------------------------------- +# TP combine / split (reused from hybrid_conversion.py) +# --------------------------------------------------------------------------- + +def combine_tp_tensors(params, key, dim, tensors): + """Combine TP-sharded tensors back into one full tensor. + + Handles special Mamba v2 in_proj and conv1d interleaved layouts. + """ + tp_size = len(tensors) + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + xs, zs = [], [] + for tensor in tensors: + x, z = torch.split( + tensor, + [params.mamba_d_inner // tp_size, params.mamba_d_inner // tp_size], + dim=dim, + ) + xs.append(x) + zs.append(z) + return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + xs, zs, Bs, Cs, dts = [], [], [], [], [] + for tensor in tensors: + x, z, B, C, dt = torch.split( + tensor, + [ + params.mamba_d_inner // tp_size, + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + params.mamba2_n_heads // tp_size, + ], + dim=dim, + ) + xs.append(x) + zs.append(z) + Bs.append(B) + Cs.append(C) + dts.append(dt) + + for ii in range(len(Bs)): + Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state, Bs[ii].shape[-1]) + Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state, Cs[ii].shape[-1]) + B = torch.cat(Bs, dim=dim) + C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + z = torch.cat(zs, dim=dim) + dt = torch.cat(dts, dim=dim) + return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + xs, Bs, Cs = [], [], [] + for tensor in tensors: + x, B, C = torch.split( + tensor, + [ + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + ], + dim=dim, + ) + xs.append(x) + Bs.append(B) + Cs.append(C) + + for ii in range(len(Bs)): + if 'weight' in key: + Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1]) + Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1]) + elif 'bias' in key: + Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state) + Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state) + else: + raise ValueError(f"Unknown conv1d key: {key}") + B = torch.cat(Bs, dim=dim) + C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) + + else: + return torch.cat(tensors, dim=dim) + + +def split_tensor_for_tp(params, key, dim, tensor): + """Split a full tensor into TP shards. + + Handles special Mamba v2 in_proj and conv1d interleaved layouts. + """ + tp_size = params.target_tp_size + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + x, z = torch.split( + tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim + ) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + return [torch.cat((xi, zi), dim=dim) for xi, zi in zip(x_sliced, z_sliced)] + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + x, z, B, C, dt = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_heads, + ], + dim=dim, + ) + B = B.reshape(-1, params.mamba_d_state, B.shape[-1]) + C = C.reshape(-1, params.mamba_d_state, C.shape[-1]) + x_s = torch.chunk(x, tp_size, dim=dim) + z_s = torch.chunk(z, tp_size, dim=dim) + B_s = torch.chunk(B, tp_size, dim=dim) + C_s = torch.chunk(C, tp_size, dim=dim) + dt_s = torch.chunk(dt, tp_size, dim=dim) + return [ + torch.cat((xi, zi, Bi.flatten(0, 1), Ci.flatten(0, 1), dti), dim=dim) + for xi, zi, Bi, Ci, dti in zip(x_s, z_s, B_s, C_s, dt_s) + ] + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + x, B, C = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + ], + dim=dim, + ) + if 'weight' in key: + B = B.reshape(-1, params.mamba_d_state, B.shape[-2], B.shape[-1]) + C = C.reshape(-1, params.mamba_d_state, C.shape[-2], C.shape[-1]) + elif 'bias' in key: + B = B.reshape(-1, params.mamba_d_state) + C = C.reshape(-1, params.mamba_d_state) + else: + raise ValueError(f"Unknown conv1d key: {key}") + + x_s = torch.chunk(x, tp_size, dim=dim) + B_s = torch.chunk(B, tp_size, dim=dim) + C_s = torch.chunk(C, tp_size, dim=dim) + return [ + torch.cat((xi, Bi.flatten(0, 1), Ci.flatten(0, 1)), dim=dim) + for xi, Bi, Ci in zip(x_s, B_s, C_s) + ] + + else: + return list(torch.chunk(tensor, tp_size, dim=dim)) + + +# --------------------------------------------------------------------------- +# Hybrid layer pattern parsing (standalone, no Megatron imports needed) +# --------------------------------------------------------------------------- + +VALID_LAYER_SYMBOLS = {'M', 'G', '*', '-', 'E'} + +# Layer symbols GPTModel can emit or absorb: +# '*' : standard self-attention layer (MHA / GQA / MQA) +# '-' : standard (optionally gated) MLP layer +# SSM ('M') has no GPT equivalent and is initialized from scratch / +# discarded (see convert_gpt_to_mamba / convert_mamba_to_gpt). +# Everything else is an architecture feature GPTModel does NOT +# produce: GDN ('G'), DS-attention ('D'), MoE ('E'). If the hybrid +# model contains any of those, we cannot faithfully translate. +GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-'} + + +def parse_hybrid_layer_pattern(pattern): + """Parse a hybrid layer pattern string into a list of layer types. + + Strips MTP separators (/) and pipeline stage separators (|), returning only + the main decoder pattern as a list of single-character layer types. + + Returns: + list[str]: e.g. ['M', '*', '-', 'M', '*', '-'] + """ + # Take only the main pattern (before first '/') + main_pattern = pattern.split('/')[0] + # Remove pipeline stage separators + main_pattern = main_pattern.replace('|', '') + layer_types = list(main_pattern) + for ch in layer_types: + if ch not in VALID_LAYER_SYMBOLS: + raise ValueError( + f"Invalid layer symbol '{ch}' in pattern. " + f"Valid symbols: {VALID_LAYER_SYMBOLS}" + ) + return layer_types + + +def build_layer_index_mapping(layer_types, direction): + """Build mapping between GPT layer indices and Mamba layer indices. + + For gpt-to-mamba: + Returns (attn_map, mlp_map) where: + - attn_map[gpt_layer_i] = mamba_layer_j (j is the index of the i-th '*') + - mlp_map[gpt_layer_i] = mamba_layer_k (k is the index of the i-th '-') + + For mamba-to-gpt: + Returns (attn_map, mlp_map) where: + - attn_map[mamba_attn_idx] = gpt_layer_i + - mlp_map[mamba_mlp_idx] = gpt_layer_i + """ + attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] + mlp_indices = [i for i, t in enumerate(layer_types) if t == '-'] + ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] + + if direction == 'gpt-to-mamba': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For gpt-to-mamba, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP layers ({len(mlp_indices)}) in the pattern." + ) + # attn_map: gpt_layer_i -> mamba_layer_j + attn_map = {i: attn_indices[i] for i in range(len(attn_indices))} + mlp_map = {i: mlp_indices[i] for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + elif direction == 'mamba-to-gpt': + if len(attn_indices) != len(mlp_indices): + raise ValueError( + f"For mamba-to-gpt, the number of attention layers ({len(attn_indices)}) " + f"must equal the number of MLP layers ({len(mlp_indices)}) in the pattern." + ) + # attn_map: mamba_layer_idx -> gpt_layer_i + attn_map = {attn_indices[i]: i for i in range(len(attn_indices))} + mlp_map = {mlp_indices[i]: i for i in range(len(mlp_indices))} + return attn_map, mlp_map, ssm_indices + + else: + raise ValueError(f"Unknown direction: {direction}") + + +# --------------------------------------------------------------------------- +# GPT compatibility whitelist +# --------------------------------------------------------------------------- +# +# GPTModel is a strict homogeneous transformer: every decoder layer is a +# (self-attention + MLP) pair with standard linear_qkv / linear_fc1 / +# linear_fc2 state-dict naming. The hybrid <-> GPT converter is only safe +# when the hybrid side agrees with that shape. The helpers below act as a +# safeguard: they reject any hybrid layout or source-args combination that +# would silently produce a broken checkpoint. +# +# Pattern-level rules (checked on the parsed hybrid_layer_pattern): +# * only 'M', '*', '-' are allowed (no 'G' GDN, no 'D' DS-attention, +# no 'E' MoE) +# * '*' count must equal '-' count (one-to-one GPT attention<->MLP pairing) +# +# Args-level rules (checked against the training args stored in the source +# checkpoint): reject anything that would make GPTModel's layer shape +# inapplicable to either side: +# * num_moe_experts (MoE routing, different keys) +# * moe_shared_expert_intermediate_size (shared-expert branch) +# * moe_layer_freq (MoE-every-N layer insertion) +# * experimental_attention_variant (gated_delta_net, dsa, ...) +# * linear_attention_freq (linear-attention layers) +# * heterogeneous_block_specs / heterogeneous_layers_config_path +# (Nemotron-NAS per-layer specs) +# * multi_latent_attention (MLA: different QKV layout) +# * mtp_num_layers (Multi-Token Prediction head) +# +# All rejected configurations raise ValueError early, before any tensors +# are touched. + +# Source-args field name -> (predicate-that-means-"reject", human reason). +# Predicates are applied with getattr(args, field, None); missing fields +# are treated as "absent" and pass. +_GPT_COMPAT_REJECT_FIELDS = ( + ( + 'num_moe_experts', + lambda v: v is not None and v > 0, + 'MoE routing (num_moe_experts)', + ), + ( + 'moe_shared_expert_intermediate_size', + lambda v: v is not None and v > 0, + 'MoE shared experts (moe_shared_expert_intermediate_size)', + ), + ( + 'moe_layer_freq', + # moe_layer_freq is None or 1 for non-MoE models; a list or a value + # > 1 means interleaved MoE layers. + lambda v: ( + v is not None + and not (isinstance(v, int) and v == 1) + and not (isinstance(v, str) and v.strip() in ('', '1')) + ), + 'interleaved MoE layers (moe_layer_freq)', + ), + ( + 'experimental_attention_variant', + lambda v: v is not None and v != '', + 'experimental attention variant (gated_delta_net / dsa / ...)', + ), + ( + 'linear_attention_freq', + lambda v: v is not None, + 'linear attention layers (linear_attention_freq)', + ), + ( + 'heterogeneous_block_specs', + lambda v: bool(v), + 'heterogeneous per-layer block specs', + ), + ( + 'heterogeneous_layers_config_path', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS)', + ), + ( + 'heterogeneous_layers_config_encoded_json', + lambda v: v is not None and v != '', + 'heterogeneous layers config (Nemotron-NAS, inline JSON)', + ), + ( + 'multi_latent_attention', + lambda v: bool(v), + 'Multi-Latent Attention (MLA)', + ), + ( + 'mtp_num_layers', + lambda v: v is not None and v > 0, + 'Multi-Token Prediction head (mtp_num_layers)', + ), +) + + +def validate_pattern_gpt_compatible(layer_types, direction): + """Raise ValueError if the hybrid pattern cannot round-trip with GPTModel. + + Args: + layer_types: list of layer-type chars from parse_hybrid_layer_pattern(). + direction: 'gpt-to-mamba' or 'mamba-to-gpt' (for error messages). + + Rules: + * Allowed symbols are M / * / - only. G, D, E are rejected because + they denote layer kinds (GDN, DS-attention, MoE) that GPTModel + cannot emit or absorb. + * The number of '*' and '-' layers must match: every GPT layer pairs + one attention with one MLP. + """ + bad = sorted({c for c in layer_types if c not in GPT_COMPATIBLE_PATTERN_SYMBOLS}) + if bad: + raise ValueError( + f"Hybrid layer pattern contains symbols {bad} that are not " + f"GPT-compatible (allowed: {sorted(GPT_COMPATIBLE_PATTERN_SYMBOLS)}). " + f"GPTModel only supports standard attention ('*') and MLP ('-') " + f"layers; 'G' (GDN), 'D' (DS-attention), and 'E' (MoE) have no " + f"GPT equivalent and cannot be {direction}-converted." + ) + + n_attn = sum(1 for t in layer_types if t == '*') + n_mlp = sum(1 for t in layer_types if t == '-') + if n_attn != n_mlp: + raise ValueError( + f"GPT-compatible hybrid patterns must pair every attention layer " + f"('*') with one MLP layer ('-'). Got {n_attn} '*' and {n_mlp} '-' " + f"in the pattern." + ) + + +def validate_source_args_gpt_compatible(source_args, direction): + """Raise ValueError if the source checkpoint uses features GPTModel can't express. + + Args: + source_args: argparse.Namespace (or any attribute-bag) loaded from the + source checkpoint; may be None, in which case this check is a no-op + (dist checkpoints without a cached args blob). + direction: 'gpt-to-mamba' or 'mamba-to-gpt'. + + Rejects MoE, MLA, MTP, linear / experimental attention, and heterogeneous + per-layer specs. See the module header for the full list. + """ + if source_args is None: + return + + rejected = [] + for field, predicate, reason in _GPT_COMPAT_REJECT_FIELDS: + if not hasattr(source_args, field): + continue + value = getattr(source_args, field) + try: + if predicate(value): + rejected.append(f" - {reason}: {field}={value!r}") + except Exception: + # Defensive: never let the validator crash on an unexpected + # value type — treat it as "cannot verify, pass". + continue + + if rejected: + joined = "\n".join(rejected) + raise ValueError( + f"Source checkpoint is not GPT-compatible for {direction} " + f"conversion. The following features have no GPTModel equivalent " + f"and would produce a corrupt target checkpoint:\n{joined}\n" + f"Remove these features from the model (or use a different " + f"conversion tool) before running gpt_mamba_conversion." + ) + + +# --------------------------------------------------------------------------- +# SSM parameter initialization (for gpt-to-mamba) +# --------------------------------------------------------------------------- + +def initialize_ssm_layer_params( + layer_idx, + d_model, + mamba_d_inner, + mamba_d_state, + mamba2_n_groups, + mamba2_n_heads, + mamba_head_dim, + d_conv=4, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + A_init_range=(1, 16), + init_method_std=0.02, + dtype=torch.float32, +): + """Initialize parameters for a single Mamba SSM layer from scratch. + + Follows the initialization logic from MambaMixer.__init__: + - A_log: log(uniform(A_init_range)) + - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) + - D: ones(nheads) + - conv1d.weight: kaiming_uniform(a=sqrt(5)) + - conv1d.bias: zeros + - in_proj.weight: kaiming_uniform(a=sqrt(5)) + - in_proj.layer_norm_weight: ones(d_model) + - out_proj.weight: kaiming_uniform(a=sqrt(5)) or normal(0, std) + - norm.weight: ones(d_inner) + + Returns: + dict: {param_suffix: tensor} for one SSM layer + """ + prefix = f'decoder.layers.{layer_idx}.mixer.' + + nheads = mamba2_n_heads + conv_dim = mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + in_proj_out_dim = 2 * mamba_d_inner + 2 * mamba2_n_groups * mamba_d_state + nheads + + params = OrderedDict() + + # in_proj (ColumnParallelLinear) + in_proj_weight = torch.empty(in_proj_out_dim, d_model, dtype=dtype) + torch.nn.init.kaiming_uniform_(in_proj_weight, a=math.sqrt(5)) + params[prefix + 'in_proj.weight'] = in_proj_weight + + # in_proj layer norm weight (fused into ColumnParallelLinear in TE) + params[prefix + 'in_proj.layer_norm_weight'] = torch.ones(d_model, dtype=dtype) + + # conv1d + conv_weight = torch.empty(conv_dim, 1, d_conv, dtype=dtype) + torch.nn.init.kaiming_uniform_(conv_weight, a=math.sqrt(5)) + params[prefix + 'conv1d.weight'] = conv_weight + params[prefix + 'conv1d.bias'] = torch.zeros(conv_dim, dtype=dtype) + + # A_log (kept in fp32) + A = torch.empty(nheads, dtype=torch.float32) + A.uniform_(*A_init_range) + params[prefix + 'A_log'] = torch.log(A) + + # D (kept in fp32) + params[prefix + 'D'] = torch.ones(nheads, dtype=torch.float32) + + # dt_bias + dt = torch.exp( + torch.rand(nheads, dtype=dtype) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + params[prefix + 'dt_bias'] = inv_dt + + # norm (RMSNorm) + params[prefix + 'norm.weight'] = torch.ones(mamba_d_inner, dtype=dtype) + + # out_proj (RowParallelLinear) + out_proj_weight = torch.empty(d_model, mamba_d_inner, dtype=dtype) + torch.nn.init.kaiming_uniform_(out_proj_weight, a=math.sqrt(5)) + params[prefix + 'out_proj.weight'] = out_proj_weight + + return params + + +# --------------------------------------------------------------------------- +# Checkpoint I/O helpers (patterns from hybrid_conversion.py) +# --------------------------------------------------------------------------- + +def get_checkpoint_iteration(load_dir): + """Read the latest iteration number from a checkpoint directory.""" + tracker_file = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_file, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + raise ValueError( + f"Invalid iteration in {tracker_file}: '{metastring}'" + ) + return iteration + + +def load_checkpoint_shards(load_dir, iteration, input_tp_size, input_pp_size): + """Load all TP/PP shards of a checkpoint. + + Returns: + list[list[dict]]: models[pp_rank][tp_rank] = checkpoint dict + dict: sample_model (first shard, for metadata) + """ + model_dir = os.path.join(load_dir, f'iter_{iteration:07d}') + sample_model = None + all_shards = [] + + for pp in range(input_pp_size): + tp_shards = [] + for tp in range(input_tp_size): + dir_name = f"mp_rank_{tp:02d}" + if input_pp_size > 1: + dir_name += f"_{pp:03d}" + model_file = os.path.join(model_dir, dir_name, "model_optim_rng.pt") + checkpoint = torch.load(model_file, map_location='cpu', weights_only=False) + tp_shards.append(checkpoint) + if sample_model is None: + sample_model = checkpoint + print(f" Loaded {model_file}") + all_shards.append(tp_shards) + + return all_shards, sample_model + + +def combine_tp_shards(tp_models, params): + """Combine TP-sharded models into a single state dict with full tensors.""" + input_tp_size = len(tp_models) + if input_tp_size == 1: + return OrderedDict(tp_models[0]['model']) + + combined = OrderedDict() + for key, original_tensor in tp_models[0]['model'].items(): + if '_extra_state' in key: + combined[key] = original_tensor + continue + + split_dim = get_split_dim(key) + if split_dim != -1: + tensors = [tp_models[j]['model'][key].cpu() for j in range(input_tp_size)] + combined[key] = combine_tp_tensors(params, key, split_dim, tensors) + else: + combined[key] = original_tensor + + return combined + + +def stitch_pp_shards(all_combined_shards, num_layers_per_pp_rank): + """Stitch PP shards into one flat model with globally-indexed layers.""" + full_model = OrderedDict() + + for pp, combined_shard in enumerate(all_combined_shards): + for key, tensor in combined_shard.items(): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + new_key = key.replace( + str(layer_num), + str(layer_num + pp * num_layers_per_pp_rank), + 1, + ) + except (IndexError, ValueError): + new_key = key + full_model[new_key] = tensor + + return full_model + + +def finalize_checkpoint(sample_model, model, params, verbose=False): + """Finalize checkpoint metadata from a sample source checkpoint.""" + reset_iterations = params.reset_iterations + + model['args'] = copy.deepcopy(sample_model['args']) + model['args'].tensor_model_parallel_size = params.target_tp_size + model['args'].pipeline_model_parallel_size = params.target_pp_size + if reset_iterations: + model['args'].iteration = 0 + model['args'].consumed_valid_samples = 0 + model['args'].consumed_train_samples = 0 + model['args'].train_iters = 0 + model['args'].train_samples = 0 + + model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version']) + + model['iteration'] = copy.deepcopy(sample_model['iteration']) + if reset_iterations: + model['iteration'] = 0 + + if 'opt_param_scheduler' in sample_model: + model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler']) + + model['rng_state'] = copy.deepcopy(sample_model['rng_state']) + + if verbose: + original_args = sample_model['args'].__dict__ + final_args = model['args'].__dict__ + for key in original_args: + if key in final_args: + if final_args[key] != original_args[key]: + print(f" ARG MISMATCH: {key}") + print(f" original: {original_args[key]}") + print(f" final: {final_args[key]}") + else: + print(f" ARG MISSING from final: {key} = {original_args[key]}") + for key in final_args: + if key not in original_args: + print(f" ARG ADDED to final: {key} = {final_args[key]}") + + return model + + +def save_checkpoint_shards(target_state_dicts, sample_model, params, save_dir, iteration): + """Split and save checkpoint for target TP/PP configuration. + + Args: + target_state_dicts: OrderedDict with globally-indexed layer keys (full tensors). + sample_model: Source checkpoint dict for metadata. + params: argparse namespace with target_tp_size, target_pp_size, etc. + save_dir: Output directory. + iteration: Iteration number to write. + """ + total_layers = params.target_num_layers + num_layers_per_pp_rank = total_layers // params.target_pp_size + + out_iteration = iteration if not params.reset_iterations else 0 + + pp_offset = 0 + # Build a list of (key, tensor) for iteration + all_items = list(target_state_dicts.items()) + + for pp in range(params.target_pp_size): + print(f" Saving PP rank {pp}") + tp_models = [{'model': OrderedDict()} for _ in range(params.target_tp_size)] + + for idx in range(pp_offset, len(all_items)): + key, tensor = all_items[idx] + + # Determine if this key belongs to this PP rank + try: + layer_num = int(re.findall(r'\d+', key)[0]) + if layer_num >= num_layers_per_pp_rank * (pp + 1): + break + new_key = key.replace( + str(layer_num), + str(layer_num - pp * num_layers_per_pp_rank), + 1, + ) + except (IndexError, ValueError): + new_key = key + + pp_offset += 1 + + if '_extra_state' in new_key: + for j in range(params.target_tp_size): + tp_models[j]['model'][new_key] = tensor + continue + + split_dim = get_split_dim(new_key) + if split_dim != -1: + slices = split_tensor_for_tp(params, new_key, split_dim, tensor) + for j in range(params.target_tp_size): + tp_models[j]['model'][new_key] = slices[j] + else: + for j in range(params.target_tp_size): + tp_models[j]['model'][new_key] = tensor + + for tp in range(params.target_tp_size): + dir_name = f"mp_rank_{tp:02d}" + if params.target_pp_size > 1: + dir_name += f"_{pp:03d}" + + model = finalize_checkpoint(sample_model, tp_models[tp], params, verbose=False) + + out_dir = os.path.join(save_dir, f'iter_{out_iteration:07d}', dir_name) + os.makedirs(out_dir, exist_ok=True) + model_file = os.path.join(out_dir, "model_optim_rng.pt") + torch.save(model, model_file) + print(f" Saved {model_file}") + + # Write iteration tracker + tracker_file = os.path.join(save_dir, 'latest_checkpointed_iteration.txt') + with open(tracker_file, 'w') as f: + f.write(str(out_iteration)) + + +# --------------------------------------------------------------------------- +# Key name helpers +# --------------------------------------------------------------------------- + +def get_layer_num_from_key(key): + """Extract the layer number from a state dict key like 'decoder.layers.5.mlp...'""" + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return int(match.group(1)) + return None + + +def replace_layer_num(key, old_num, new_num): + """Replace the layer number in a state dict key.""" + return key.replace(f'decoder.layers.{old_num}.', f'decoder.layers.{new_num}.', 1) + + +def is_attention_param(key): + """Check if a key belongs to an attention sub-layer.""" + return 'self_attention.' in key or 'input_layernorm.' in key + + +def is_mlp_param(key): + """Check if a key belongs to an MLP sub-layer.""" + return ('mlp.' in key or 'pre_mlp_layernorm.' in key) and 'self_attention' not in key + + +def is_ssm_param(key): + """Check if a key belongs to a Mamba SSM mixer sub-layer.""" + ssm_markers = ['mixer.in_proj', 'mixer.conv1d', 'mixer.A_log', 'mixer.D', + 'mixer.dt_bias', 'mixer.norm', 'mixer.out_proj', + 'mixer.x_proj', 'mixer.dt_proj'] + return any(m in key for m in ssm_markers) + + +def is_layer_norm_for_ssm(key): + """Check if a key is the input layer norm for an SSM layer. + + In hybrid models, SSM layers can have their own input_layernorm or the + norm can be fused into in_proj.layer_norm_weight. + """ + return 'in_proj.layer_norm_weight' in key + + +# --------------------------------------------------------------------------- +# Core conversion: GPT -> Mamba +# --------------------------------------------------------------------------- + +def convert_gpt_to_mamba(full_model, layer_types, args): + """Convert a GPT state dict to a Mamba hybrid state dict. + + Args: + full_model: OrderedDict with globally-indexed GPT state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: Mamba state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'gpt-to-mamba' + ) + num_gpt_layers = len(attn_map) + + # Validate GPT layer count + gpt_layer_nums = set() + for key in full_model: + lnum = get_layer_num_from_key(key) + if lnum is not None: + gpt_layer_nums.add(lnum) + + if len(gpt_layer_nums) != num_gpt_layers: + raise ValueError( + f"GPT checkpoint has {len(gpt_layer_nums)} layers, but the pattern " + f"has {num_gpt_layers} attention ('*') and {num_gpt_layers} MLP ('-') " + f"layers. These must match." + ) + + target = OrderedDict() + dtype = None + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if dtype is None and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32): + dtype = tensor.dtype + + if 'decoder.layers.' in key: + continue + + # Rename final_layernorm -> final_norm + if 'decoder.final_layernorm' in key: + new_key = key.replace('decoder.final_layernorm', 'decoder.final_norm') + target[new_key] = tensor + else: + target[key] = tensor + + if dtype is None: + dtype = torch.float32 + + # Map attention and MLP params + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_attention_param(key): + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key): + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + # (any other layer params get copied as-is with their own mapping, + # but for pure GPT there should only be attention + MLP) + + # Initialize SSM layers from scratch + print(f" Initializing {len(ssm_indices)} SSM layers from scratch...") + for layer_idx in ssm_indices: + ssm_params = initialize_ssm_layer_params( + layer_idx=layer_idx, + d_model=args.d_model, + mamba_d_inner=args.mamba_d_inner, + mamba_d_state=args.mamba_d_state, + mamba2_n_groups=args.mamba2_n_groups, + mamba2_n_heads=args.mamba2_n_heads, + mamba_head_dim=args.mamba2_head_dim, + d_conv=getattr(args, 'd_conv', 4), + init_method_std=getattr(args, 'init_method_std', 0.02), + dtype=dtype, + ) + target.update(ssm_params) + + # Sort by layer index for consistent ordering + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Core conversion: Mamba -> GPT +# --------------------------------------------------------------------------- + +def convert_mamba_to_gpt(full_model, layer_types, args): + """Convert a Mamba hybrid state dict to a GPT state dict. + + Args: + full_model: OrderedDict with globally-indexed Mamba state dict keys. + layer_types: list of layer type chars from hybrid_layer_pattern. + args: Parsed CLI arguments. + + Returns: + OrderedDict: GPT state dict with globally-indexed keys. + """ + attn_map, mlp_map, ssm_indices = build_layer_index_mapping( + layer_types, 'mamba-to-gpt' + ) + num_gpt_layers = len(attn_map) + + target = OrderedDict() + discarded_ssm_keys = [] + + # Copy / rename non-layer params + for key, tensor in full_model.items(): + if 'decoder.layers.' in key: + continue + + # Rename final_norm -> final_layernorm + if 'decoder.final_norm' in key: + new_key = key.replace('decoder.final_norm', 'decoder.final_layernorm') + target[new_key] = tensor + else: + target[key] = tensor + + # Map attention and MLP params, discard SSM + for key, tensor in full_model.items(): + lnum = get_layer_num_from_key(key) + if lnum is None: + continue + + if is_ssm_param(key) or is_layer_norm_for_ssm(key): + # Discard SSM params + if lnum in ssm_indices: + discarded_ssm_keys.append(key) + continue + + if is_attention_param(key) and lnum in attn_map: + target_layer = attn_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif is_mlp_param(key) and lnum in mlp_map: + target_layer = mlp_map[lnum] + new_key = replace_layer_num(key, lnum, target_layer) + target[new_key] = tensor + + elif lnum in ssm_indices: + # Any remaining SSM-layer param not caught above + discarded_ssm_keys.append(key) + + if discarded_ssm_keys: + print(f"\n WARNING: Discarded {len(discarded_ssm_keys)} SSM parameter tensors " + f"from {len(ssm_indices)} Mamba layers (no GPT equivalent).") + print(f" First few discarded keys: {discarded_ssm_keys[:5]}") + + target = _sort_state_dict(target) + + return target + + +# --------------------------------------------------------------------------- +# Sorting helper +# --------------------------------------------------------------------------- + +def _sort_state_dict(state_dict): + """Sort state dict keys so that layer-indexed keys are in order.""" + def sort_key(item): + key = item[0] + # Extract layer number if present + match = re.search(r'decoder\.layers\.(\d+)\.', key) + if match: + return (1, int(match.group(1)), key) + # Non-layer keys: embeddings first, output_layer last + if 'embedding' in key: + return (0, 0, key) + if 'output_layer' in key: + return (2, 0, key) + if 'decoder.final' in key: + return (1, 999999, key) + return (0, 1, key) + + return OrderedDict(sorted(state_dict.items(), key=sort_key)) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Format-aware load / save +# --------------------------------------------------------------------------- + +def _load_legacy_full(args): + """Load a legacy mp_rank_XX checkpoint and return a full (TP+PP gathered) + state dict plus a sample shard for metadata. + + Returns: + full_model (OrderedDict): globally-indexed, TP-combined state dict. + sample_model (dict): one source shard (for args/iteration/etc.). + iteration (int): source iteration. + """ + iteration = get_checkpoint_iteration(args.load_dir) + print(f" Iteration: {iteration}") + + model_dir = os.path.join(args.load_dir, f'iter_{iteration:07d}') + sub_models = os.listdir(model_dir) + sample_file = os.path.join(model_dir, sub_models[0], "model_optim_rng.pt") + sample_model = torch.load(sample_file, map_location='cpu', weights_only=False) + + input_tp_size = sample_model['args'].tensor_model_parallel_size + input_pp_size = sample_model['args'].pipeline_model_parallel_size + input_num_layers = sample_model['args'].num_layers + num_layers_per_pp_rank = input_num_layers // input_pp_size + + print(f" Source: TP={input_tp_size}, PP={input_pp_size}, " + f"num_layers={input_num_layers}") + + all_shards, sample_model = load_checkpoint_shards( + args.load_dir, iteration, input_tp_size, input_pp_size + ) + + print(" Combining TP shards into full tensors...") + combined_pp_shards = [] + for pp in range(input_pp_size): + combined = combine_tp_shards(all_shards[pp], args) + combined_pp_shards.append(combined) + + print(" Stitching PP shards into flat model...") + full_model = stitch_pp_shards(combined_pp_shards, num_layers_per_pp_rank) + print(f" Full model: {len(full_model)} parameters") + + return full_model, sample_model, iteration + + +def _save_dist_full(target_state_dict, common_state, model_prefix, backend, + args, iteration): + """Save a fully-gathered state dict in dist-ckpt format. + + The on-disk tensors carry their full logical shape, so downstream Megatron + training reads them back with any TP+PP+FSDP configuration. + """ + if iteration is None: + out_iter = 0 if args.reset_iterations else 0 + iter_dir = args.save_dir + else: + out_iter = 0 if args.reset_iterations else iteration + iter_dir = os.path.join(args.save_dir, f'iter_{out_iter:07d}') + + # Update common state args to reflect target model structure. + common_state = copy.deepcopy(common_state) if common_state else {} + if 'args' in common_state and common_state['args'] is not None: + ckpt_args = common_state['args'] + ckpt_args.num_layers = args.target_num_layers + if hasattr(ckpt_args, 'hybrid_layer_pattern'): + if args.direction == 'gpt-to-mamba': + ckpt_args.hybrid_layer_pattern = args.hybrid_layer_pattern + else: + ckpt_args.hybrid_layer_pattern = None + if args.reset_iterations: + for attr in ('iteration', 'consumed_valid_samples', + 'consumed_train_samples', 'train_iters', 'train_samples'): + if hasattr(ckpt_args, attr): + setattr(ckpt_args, attr, 0) + if args.reset_iterations and 'iteration' in common_state: + common_state['iteration'] = 0 + + print(f" Writing dist checkpoint to {iter_dir} " + f"(backend={backend}, prefix='{model_prefix}')...") + save_dist_checkpoint_full( + target_state_dict, common_state, iter_dir, + model_prefix=model_prefix, backend=backend, + ) + write_latest_iteration_marker(iter_dir, out_iter) + + +def main(args): + print("\n====RUNNING GPT <-> MAMBA CHECKPOINT CONVERSION====\n") + print(f" Direction: {args.direction}") + print(f" Source: {args.load_dir}") + print(f" Target: {args.save_dir}") + print(f" Hybrid layer pattern: {args.hybrid_layer_pattern}") + + # Compute derived Mamba dimensions + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + # Parse hybrid layer pattern + layer_types = parse_hybrid_layer_pattern(args.hybrid_layer_pattern) + total_mamba_layers = len(layer_types) + attn_count = sum(1 for t in layer_types if t == '*') + mlp_count = sum(1 for t in layer_types if t == '-') + ssm_count = sum(1 for t in layer_types if t == 'M') + print(f"\n Pattern: {len(layer_types)} total layers " + f"({attn_count} attn, {mlp_count} MLP, {ssm_count} SSM, " + f"{len(layer_types) - attn_count - mlp_count - ssm_count} other)") + + # Pattern-level GPT compatibility whitelist (fails fast, pre-load). + validate_pattern_gpt_compatible(layer_types, args.direction) + + # 1. Resolve input format + input_format = getattr(args, 'input_format', 'auto') + if input_format == 'auto': + input_format = detect_checkpoint_format(args.load_dir) + output_format = getattr(args, 'output_format', 'auto') + if output_format == 'auto': + output_format = input_format + print(f"\n Input format: {input_format}") + print(f" Output format: {output_format}") + if output_format == FORMAT_LEGACY: + print(f" Target TP size: {args.target_tp_size}") + print(f" Target PP size: {args.target_pp_size}") + + # 2. Load source checkpoint into a fully-gathered state dict + print("\n[Step 1] Loading source checkpoint...") + sample_model = None + common_state = {} + model_prefix = 'model.' + dist_backend = FORMAT_TORCH_DIST + + if input_format == FORMAT_LEGACY: + full_model, sample_model, iteration = _load_legacy_full(args) + elif input_format in DIST_FORMATS: + full_model, common_state, model_prefix, dist_backend, iteration = ( + load_dist_checkpoint_full(args.load_dir) + ) + print(f" Source: dist backend={dist_backend}, prefix='{model_prefix}', " + f"iteration={iteration}, params={len(full_model)}") + else: + raise ValueError(f"Unsupported input format: {input_format}") + + # Args-level GPT compatibility whitelist: reject MoE, MLA, MTP, linear / + # experimental attention, heterogeneous block specs, etc. See module header. + source_args = None + if sample_model is not None and 'args' in sample_model: + source_args = sample_model['args'] + elif common_state and 'args' in common_state: + source_args = common_state['args'] + validate_source_args_gpt_compatible(source_args, args.direction) + + # 3. Convert + print(f"\n[Step 2] Converting ({args.direction})...") + if args.direction == 'gpt-to-mamba': + target_state_dict = convert_gpt_to_mamba(full_model, layer_types, args) + args.target_num_layers = total_mamba_layers + elif args.direction == 'mamba-to-gpt': + target_state_dict = convert_mamba_to_gpt(full_model, layer_types, args) + args.target_num_layers = attn_count + else: + raise ValueError(f"Unknown direction: {args.direction}") + print(f" Target model: {len(target_state_dict)} parameters") + + # 4. Save + print(f"\n[Step 3] Saving to {args.save_dir}...") + if output_format == FORMAT_LEGACY: + if sample_model is None: + raise ValueError( + "Legacy output requires a legacy source checkpoint for metadata. " + "Use --output-format torch_dist when loading a dist checkpoint." + ) + sample_model['args'].num_layers = args.target_num_layers + save_checkpoint_shards( + target_state_dict, sample_model, args, args.save_dir, + iteration if iteration is not None else 0, + ) + elif output_format in DIST_FORMATS: + _save_dist_full( + target_state_dict, common_state, model_prefix, output_format, + args, iteration, + ) + else: + raise ValueError(f"Unsupported output format: {output_format}") + + print("\n====CONVERSION COMPLETE====\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert checkpoints between GPTModel and MambaModel formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + '--direction', type=str, required=True, + choices=['gpt-to-mamba', 'mamba-to-gpt'], + help='Conversion direction.', + ) + parser.add_argument('--load-dir', type=str, required=True, + help='Path to source checkpoint directory.') + parser.add_argument('--save-dir', type=str, required=True, + help='Path to target checkpoint directory.') + parser.add_argument('--hybrid-layer-pattern', type=str, required=True, + help='Hybrid layer pattern string, e.g. "M*-M*-M*-M*-".') + parser.add_argument('--target-tp-size', type=int, default=1, + help='Target tensor parallel size (legacy output only; ' + 'dist formats are saved fully-replicated and ' + 'resharded at training load time).') + parser.add_argument('--target-pp-size', type=int, default=1, + help='Target pipeline parallel size (legacy output only).') + + parser.add_argument( + '--input-format', type=str, default='auto', + choices=['auto', FORMAT_LEGACY, FORMAT_TORCH_DIST, 'fsdp_dtensor'], + help='Source checkpoint format. "auto" detects from metadata.json / ' + 'mp_rank_XX layout.', + ) + parser.add_argument( + '--output-format', type=str, default='auto', + choices=['auto', FORMAT_LEGACY, FORMAT_TORCH_DIST, 'fsdp_dtensor'], + help='Target checkpoint format. "auto" matches the input format. ' + 'Dist formats (torch_dist / fsdp_dtensor) transparently support ' + 'TP+PP+FSDP training checkpoints.', + ) + + # Model architecture params + parser.add_argument('--d-model', type=int, default=4096, + help='Model hidden dimension.') + parser.add_argument('--mamba-version', type=int, default=2, + choices=[1, 2], help='Mamba SSM version.') + parser.add_argument('--mamba-d-state', type=int, default=128, + help='Mamba state dimension.') + parser.add_argument('--mamba2-n-groups', type=int, default=8, + help='Number of groups (Mamba v2).') + parser.add_argument('--mamba2-head-dim', type=int, default=64, + help='Head dimension (Mamba v2).') + parser.add_argument('--d-conv', type=int, default=4, + help='Causal convolution kernel size.') + + # Initialization params + parser.add_argument('--init-method-std', type=float, default=0.02, + help='Std for initializing new Mamba SSM params.') + + # Checkpoint control + parser.add_argument('--reset-iterations', action='store_true', + help='Zero out the training iteration count.') + + args = parser.parse_args() + main(args) From fef2eb05061ebc94f85e65a90cbf808e7e6a8ff2 Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Mon, 27 Apr 2026 08:28:31 -0700 Subject: [PATCH 02/10] Remove the support for GPT legacy model Signed-off-by: guihong-nv --- .../tools/checkpoint/run_slurm_tests.sh | 120 ---- .../checkpoint/test_gpt_mamba_conversion.py | 142 ----- .../test_gpt_mamba_conversion_integration.py | 467 -------------- .../test_gpt_mamba_conversion_parallelism.py | 279 +++------ tools/checkpoint/dist_checkpoint_io.py | 13 +- tools/checkpoint/gpt_mamba_conversion.py | 589 +----------------- 6 files changed, 147 insertions(+), 1463 deletions(-) delete mode 100755 tests/unit_tests/tools/checkpoint/run_slurm_tests.sh delete mode 100644 tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py diff --git a/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh b/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh deleted file mode 100755 index d5b9e504703..00000000000 --- a/tests/unit_tests/tools/checkpoint/run_slurm_tests.sh +++ /dev/null @@ -1,120 +0,0 @@ -#!/bin/bash -# Run GPT <-> Mamba checkpoint conversion tests on SLURM. -# -# Covers: -# Phase 1 - Unit tests (pattern parsing, key mapping, SSM init, round-trip, -# and the new GPT-compatibility whitelist) -# Phase 2 - Integration tests (legacy TP=1/PP=1 on-disk round-trip) -# Phase 3 - Parallelism matrix (TP / PP / FSDP and all combinations, -# across legacy and torch_dist / fsdp_dtensor formats; -# hybrid patterns exercised: pure-attention, M*-, M*-M*-, -# alternating, and pure-SSM) -# -# Single-node mode (default) exercises the full matrix on one GPU. -# Multi-node mode launches the same pytest invocation on N nodes to verify -# the converter is deterministic across nodes and that dist-checkpoint load -# works from a shared filesystem. -# -# Usage: -# bash run_slurm_tests.sh # single-node, default repo path -# NODES=2 bash run_slurm_tests.sh # 2 nodes -# MEGATRON_LM_DIR=/path bash run_slurm_tests.sh -# -# Environment knobs: -# MEGATRON_LM_DIR Path to the Megatron-LM checkout (default: this repo root) -# CONTAINER_IMAGE Container image (default: nemo:26.04) -# NODES Number of nodes (default: 1) -# GPUS_PER_NODE GPUs per node (default: 1) -# PARTITION SLURM partition (default: batch) -# ACCOUNT SLURM account (default: coreai_dlalgo_genai) -# TIME SLURM time limit (default: 00:45:00) - -set -euo pipefail - -# Default to the repo that contains this script. -_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -_DEFAULT_REPO="$(cd "${_SCRIPT_DIR}/../../../.." && pwd)" - -CONTAINER_IMAGE="${CONTAINER_IMAGE:-nvcr.io/nvidia/nemo:26.02}" -MEGATRON_LM_DIR="${MEGATRON_LM_DIR:-${_DEFAULT_REPO}}" -NODES="${NODES:-1}" -GPUS_PER_NODE="${GPUS_PER_NODE:-1}" -PARTITION="${PARTITION:-batch}" -ACCOUNT="${ACCOUNT:-coreai_dlalgo_mcore}" -TIME="${TIME:-00:45:00}" - -LOG_DIR="${MEGATRON_LM_DIR}/logs" -mkdir -p "${LOG_DIR}" - -echo "======================================================" -echo "GPT <-> Mamba Conversion Tests" -echo " Repo : ${MEGATRON_LM_DIR}" -echo " Container : ${CONTAINER_IMAGE}" -echo " Nodes : ${NODES}" -echo " GPUs per node : ${GPUS_PER_NODE}" -echo " Partition : ${PARTITION}" -echo " Account : ${ACCOUNT}" -echo " Time limit : ${TIME}" -echo " Logs : ${LOG_DIR}" -echo "======================================================" - -# The conversion unit tests are CPU-only; the parallelism matrix only needs -# one GPU per test process (it exercises the sharding logic, not kernels). -# On multi-node runs we invoke pytest on every task so each node independently -# validates its view of the checkpoint — i.e. the *same* shared -# checkpoint must round-trip from every node. -srun \ - --job-name=gpt-mamba-conv-test \ - --nodes="${NODES}" \ - --ntasks-per-node=1 \ - --gpus-per-node="${GPUS_PER_NODE}" \ - --cpus-per-gpu=16 \ - --time="${TIME}" \ - --partition="${PARTITION}" \ - --account="${ACCOUNT}" \ - --output="${LOG_DIR}/gpt_mamba_conv_test_%j_%t.out" \ - --error="${LOG_DIR}/gpt_mamba_conv_test_%j_%t.err" \ - --container-image="${CONTAINER_IMAGE}" \ - --container-mounts="${MEGATRON_LM_DIR}:/opt/megatron-lm" \ - --container-workdir="/opt/megatron-lm" \ - bash -c ' - set -euo pipefail - export PYTHONPATH=/opt/megatron-lm:${PYTHONPATH:-} - - RANK="${SLURM_PROCID:-0}" - NODE="${SLURMD_NODENAME:-local}" - echo "[node=${NODE} rank=${RANK}] Python : $(python --version 2>&1)" - echo "[node=${NODE} rank=${RANK}] torch : $(python -c "import torch; print(torch.__version__)")" - echo "[node=${NODE} rank=${RANK}] cuda : $(python -c "import torch; print(torch.cuda.is_available(), torch.cuda.device_count())")" - echo "------------------------------------------------------" - - echo "" - echo "=== [node=${NODE}] Phase 1: Unit tests ===" - echo "" - python -m pytest -vs \ - tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py - - echo "" - echo "=== [node=${NODE}] Phase 1b: GPT-compatibility whitelist tests ===" - echo "" - # Run the whitelist classes in isolation so a regression in the - # safeguard is easy to spot in CI logs. - python -m pytest -vs \ - tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py::TestPatternWhitelist \ - tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py::TestSourceArgsWhitelist - - echo "" - echo "=== [node=${NODE}] Phase 2: Integration (legacy TP=1/PP=1) ===" - echo "" - python tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py - - echo "" - echo "=== [node=${NODE}] Phase 3: Parallelism matrix (TP/PP/FSDP/combos) ===" - echo "" - python -m pytest -vs \ - tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py - ' - -echo "======================================================" -echo "Test complete. Logs: ${LOG_DIR}" -echo "======================================================" diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py index 5ab5a435bdd..2b3aca046be 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py @@ -31,18 +31,15 @@ from gpt_mamba_conversion import ( build_layer_index_mapping, - combine_tp_tensors, convert_gpt_to_mamba, convert_mamba_to_gpt, get_layer_num_from_key, - get_split_dim, initialize_ssm_layer_params, is_attention_param, is_mlp_param, is_ssm_param, parse_hybrid_layer_pattern, replace_layer_num, - split_tensor_for_tp, validate_pattern_gpt_compatible, validate_source_args_gpt_compatible, ) @@ -172,52 +169,6 @@ def test_is_ssm_param(self): assert not is_ssm_param('decoder.layers.0.self_attention.linear_qkv.weight') -# --------------------------------------------------------------------------- -# TP split dim tests -# --------------------------------------------------------------------------- - -class TestTPSplitDim: - def test_embedding_split(self): - assert get_split_dim('embedding.word_embeddings.weight') == 0 - - def test_output_layer_split(self): - assert get_split_dim('output_layer.weight') == 0 - - def test_norm_replicated(self): - assert get_split_dim('decoder.layers.0.input_layernorm.weight') == -1 - - def test_final_layernorm(self): - assert get_split_dim('decoder.final_layernorm.weight') == -1 - - def test_final_norm(self): - assert get_split_dim('decoder.final_norm.weight') == -1 - - def test_qkv_weight_column(self): - assert get_split_dim('decoder.layers.0.self_attention.linear_qkv.weight') == 0 - - def test_proj_weight_row(self): - assert get_split_dim('decoder.layers.0.self_attention.linear_proj.weight') == 1 - - def test_mlp_fc1_column(self): - assert get_split_dim('decoder.layers.0.mlp.linear_fc1.weight') == 0 - - def test_mlp_fc2_row(self): - assert get_split_dim('decoder.layers.0.mlp.linear_fc2.weight') == 1 - - def test_mamba_mixer_norm(self): - assert get_split_dim('decoder.layers.0.mixer.norm.weight') == 0 - - def test_mamba_A_log(self): - assert get_split_dim('decoder.layers.0.mixer.A_log') == 0 - - def test_mamba_out_proj(self): - assert get_split_dim('decoder.layers.0.mixer.out_proj.weight') == 1 - - def test_unknown_key(self): - with pytest.raises(ValueError, match="Unknown tensor name"): - get_split_dim('decoder.layers.0.some_unknown_param') - - # --------------------------------------------------------------------------- # SSM initialization tests # --------------------------------------------------------------------------- @@ -695,99 +646,6 @@ def test_round_trip_different_pattern(self): ) -# --------------------------------------------------------------------------- -# TP combine/split round-trip tests -# --------------------------------------------------------------------------- - -class TestTPCombineSplit: - def test_simple_tensor_roundtrip(self): - params = argparse.Namespace( - mamba_d_inner=256, mamba_d_state=64, mamba2_n_groups=4, - mamba2_n_heads=8, mamba_version=2, target_tp_size=2, - ) - tensor = torch.randn(512, 128) - key = 'decoder.layers.0.mlp.linear_fc1.weight' - dim = 0 - - # Split into 2 - slices = split_tensor_for_tp(params, key, dim, tensor) - assert len(slices) == 2 - assert slices[0].shape == (256, 128) - - # Combine back - combined = combine_tp_tensors(params, key, dim, slices) - assert torch.equal(combined, tensor) - - def test_mamba_in_proj_v2_roundtrip(self): - d_inner = 256 - n_groups = 4 - d_state = 64 - n_heads = 8 - d_model = 128 - tp_size = 2 - - params = argparse.Namespace( - mamba_d_inner=d_inner, mamba_d_state=d_state, - mamba2_n_groups=n_groups, mamba2_n_heads=n_heads, - mamba_version=2, target_tp_size=tp_size, - ) - - # Full in_proj: [2*d_inner + 2*n_groups*d_state + n_heads, d_model] - out_dim = 2 * d_inner + 2 * n_groups * d_state + n_heads - tensor = torch.randn(out_dim, d_model) - key = 'decoder.layers.0.mixer.in_proj.weight' - dim = 0 - - slices = split_tensor_for_tp(params, key, dim, tensor) - assert len(slices) == tp_size - - combined = combine_tp_tensors(params, key, dim, slices) - assert torch.allclose(combined, tensor), "Mamba v2 in_proj round-trip failed" - - def test_mamba_conv1d_weight_v2_roundtrip(self): - d_inner = 256 - n_groups = 4 - d_state = 64 - d_conv = 4 - tp_size = 2 - - params = argparse.Namespace( - mamba_d_inner=d_inner, mamba_d_state=d_state, - mamba2_n_groups=n_groups, mamba2_n_heads=8, - mamba_version=2, target_tp_size=tp_size, - ) - - conv_dim = d_inner + 2 * n_groups * d_state - tensor = torch.randn(conv_dim, 1, d_conv) - key = 'decoder.layers.0.mixer.conv1d.weight' - dim = 0 - - slices = split_tensor_for_tp(params, key, dim, tensor) - combined = combine_tp_tensors(params, key, dim, slices) - assert torch.allclose(combined, tensor), "conv1d weight round-trip failed" - - def test_mamba_conv1d_bias_v2_roundtrip(self): - d_inner = 256 - n_groups = 4 - d_state = 64 - tp_size = 2 - - params = argparse.Namespace( - mamba_d_inner=d_inner, mamba_d_state=d_state, - mamba2_n_groups=n_groups, mamba2_n_heads=8, - mamba_version=2, target_tp_size=tp_size, - ) - - conv_dim = d_inner + 2 * n_groups * d_state - tensor = torch.randn(conv_dim) - key = 'decoder.layers.0.mixer.conv1d.bias' - dim = 0 - - slices = split_tensor_for_tp(params, key, dim, tensor) - combined = combine_tp_tensors(params, key, dim, slices) - assert torch.allclose(combined, tensor), "conv1d bias round-trip failed" - - # --------------------------------------------------------------------------- # GPT compatibility whitelist tests # --------------------------------------------------------------------------- diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py deleted file mode 100644 index 1dc499e13b9..00000000000 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_integration.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -""" -Integration tests for gpt_mamba_conversion.py. - -Creates minimal synthetic GPT checkpoints on disk, runs the full conversion -pipeline (load -> combine TP -> stitch PP -> convert -> split TP/PP -> save), -and verifies: - - Shapes, dtypes, and key names in the output checkpoint. - - Round-trip GPT -> Mamba -> GPT preserves attention and MLP weights exactly. - -Designed to run on a single-GPU node via SLURM (no distributed launch needed). -""" - -import argparse -import copy -import os -import shutil -import sys -import tempfile -from collections import OrderedDict -from types import SimpleNamespace - -import torch - -# Ensure the conversion tool is importable -sys.path.insert( - 0, - os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint'), -) - -from gpt_mamba_conversion import ( - get_checkpoint_iteration, - initialize_ssm_layer_params, - main as conversion_main, - parse_hybrid_layer_pattern, -) - - -# --------------------------------------------------------------------------- -# Helpers: create a minimal on-disk GPT checkpoint -# --------------------------------------------------------------------------- - -def make_checkpoint_args( - num_layers=4, - hidden_size=128, - num_attention_heads=4, - seq_length=256, - max_position_embeddings=256, - tp_size=1, - pp_size=1, - iteration=100, -): - """Build a minimal checkpoint 'args' namespace mirroring Megatron's.""" - return SimpleNamespace( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - ffn_hidden_size=hidden_size * 4, - seq_length=seq_length, - max_position_embeddings=max_position_embeddings, - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - iteration=iteration, - consumed_train_samples=0, - consumed_valid_samples=0, - train_iters=1000, - train_samples=0, - tokenizer_type='GPT2BPETokenizer', - position_embedding_type='rope', - params_dtype=torch.float32, - fp16=False, - bf16=False, - ) - - -def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.float32): - """Create a minimal GPT model state dict.""" - sd = OrderedDict() - - sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) - - for i in range(num_layers): - p = f'decoder.layers.{i}.' - # attention - sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) - sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( - 3 * hidden_size, hidden_size, dtype=dtype - ) - sd[p + 'self_attention.linear_proj.weight'] = torch.randn( - hidden_size, hidden_size, dtype=dtype - ) - # MLP - sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) - sd[p + 'mlp.linear_fc1.weight'] = torch.randn( - 4 * hidden_size, hidden_size, dtype=dtype - ) - sd[p + 'mlp.linear_fc2.weight'] = torch.randn( - hidden_size, 4 * hidden_size, dtype=dtype - ) - - sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) - sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) - - return sd - - -def write_checkpoint_to_disk(root_dir, state_dict, ckpt_args, iteration=100): - """Write a single-rank (TP=1, PP=1) checkpoint to disk in Megatron format. - - Directory layout: - root_dir/ - latest_checkpointed_iteration.txt -> "100" - iter_0000100/ - mp_rank_00/ - model_optim_rng.pt - """ - iter_dir = os.path.join(root_dir, f'iter_{iteration:07d}', 'mp_rank_00') - os.makedirs(iter_dir, exist_ok=True) - - checkpoint = { - 'model': state_dict, - 'args': copy.deepcopy(ckpt_args), - 'checkpoint_version': 3.0, - 'iteration': iteration, - 'rng_state': [ - { - 'random_rng_state': [0] * 625, - 'np_rng_state': ('MT19937', [0] * 625, 0, 0, 0.0), - 'torch_rng_state': torch.ByteTensor(8), - 'cuda_rng_state': torch.ByteTensor(8), - 'rng_tracker_states': {}, - } - ], - } - - torch.save(checkpoint, os.path.join(iter_dir, 'model_optim_rng.pt')) - - with open(os.path.join(root_dir, 'latest_checkpointed_iteration.txt'), 'w') as f: - f.write(str(iteration)) - - return root_dir - - -def load_converted_state_dict(ckpt_dir): - """Load the state dict from a converted checkpoint (TP=1, PP=1).""" - iteration = get_checkpoint_iteration(ckpt_dir) - model_file = os.path.join( - ckpt_dir, f'iter_{iteration:07d}', 'mp_rank_00', 'model_optim_rng.pt' - ) - checkpoint = torch.load(model_file, map_location='cpu', weights_only=False) - return checkpoint['model'], checkpoint['args'] - - -# --------------------------------------------------------------------------- -# Test 1: GPT -> Mamba shapes, dtypes, and key names -# --------------------------------------------------------------------------- - -def test_gpt_to_mamba_shapes_and_keys(): - """Create a 4-layer GPT ckpt, convert to Mamba with M*-M*-M*-M*-, verify output.""" - print("\n=== Test 1: GPT -> Mamba shapes, dtypes, and key names ===") - - num_layers = 4 - hidden_size = 128 - d_state = 16 - n_groups = 2 - head_dim = 32 - d_inner = hidden_size * 2 - n_heads = d_inner // head_dim - pattern = "M*-M*-M*-M*-" # 12 layers: 4 SSM, 4 attn, 4 MLP - - tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_test_') - try: - src_dir = os.path.join(tmpdir, 'gpt_src') - dst_dir = os.path.join(tmpdir, 'mamba_dst') - - ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) - gpt_sd = make_gpt_state_dict(num_layers, hidden_size) - write_checkpoint_to_disk(src_dir, gpt_sd, ckpt_args) - - # Run conversion - args = argparse.Namespace( - direction='gpt-to-mamba', - load_dir=src_dir, - save_dir=dst_dir, - hybrid_layer_pattern=pattern, - target_tp_size=1, - target_pp_size=1, - d_model=hidden_size, - mamba_version=2, - mamba_d_state=d_state, - mamba2_n_groups=n_groups, - mamba2_head_dim=head_dim, - d_conv=4, - init_method_std=0.02, - reset_iterations=False, - ) - conversion_main(args) - - # Load and verify - mamba_sd, mamba_args = load_converted_state_dict(dst_dir) - - layer_types = parse_hybrid_layer_pattern(pattern) - total_layers = len(layer_types) - - # 1) Check total layer count in args - assert mamba_args.num_layers == total_layers, ( - f"Expected num_layers={total_layers}, got {mamba_args.num_layers}" - ) - - # 2) Check key names - assert 'decoder.final_norm.weight' in mamba_sd, "Missing decoder.final_norm.weight" - assert 'decoder.final_layernorm.weight' not in mamba_sd, "Old final_layernorm key present" - assert 'embedding.word_embeddings.weight' in mamba_sd - assert 'output_layer.weight' in mamba_sd - - # 3) Check SSM layer params exist with correct shapes - ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] - conv_dim = d_inner + 2 * n_groups * d_state - in_proj_out = 2 * d_inner + 2 * n_groups * d_state + n_heads - - for idx in ssm_indices: - prefix = f'decoder.layers.{idx}.mixer.' - assert prefix + 'A_log' in mamba_sd, f"Missing {prefix}A_log" - assert mamba_sd[prefix + 'A_log'].shape == (n_heads,) - assert mamba_sd[prefix + 'A_log'].dtype == torch.float32 - - assert prefix + 'D' in mamba_sd - assert mamba_sd[prefix + 'D'].shape == (n_heads,) - - assert prefix + 'dt_bias' in mamba_sd - assert mamba_sd[prefix + 'dt_bias'].shape == (n_heads,) - - assert prefix + 'conv1d.weight' in mamba_sd - assert mamba_sd[prefix + 'conv1d.weight'].shape == (conv_dim, 1, 4) - - assert prefix + 'conv1d.bias' in mamba_sd - assert mamba_sd[prefix + 'conv1d.bias'].shape == (conv_dim,) - - assert prefix + 'in_proj.weight' in mamba_sd - assert mamba_sd[prefix + 'in_proj.weight'].shape == (in_proj_out, hidden_size) - - assert prefix + 'norm.weight' in mamba_sd - assert mamba_sd[prefix + 'norm.weight'].shape == (d_inner,) - - assert prefix + 'out_proj.weight' in mamba_sd - assert mamba_sd[prefix + 'out_proj.weight'].shape == (hidden_size, d_inner) - - # 4) Check attention layer params exist at correct indices - attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] - for idx in attn_indices: - prefix = f'decoder.layers.{idx}.' - assert prefix + 'self_attention.linear_qkv.weight' in mamba_sd - assert mamba_sd[prefix + 'self_attention.linear_qkv.weight'].shape == ( - 3 * hidden_size, hidden_size - ) - - # 5) Check MLP layer params exist at correct indices - mlp_indices = [i for i, t in enumerate(layer_types) if t == '-'] - for idx in mlp_indices: - prefix = f'decoder.layers.{idx}.' - assert prefix + 'mlp.linear_fc1.weight' in mamba_sd - assert mamba_sd[prefix + 'mlp.linear_fc1.weight'].shape == ( - 4 * hidden_size, hidden_size - ) - - print("PASSED: All shapes, dtypes, and key names verified.\n") - - finally: - shutil.rmtree(tmpdir, ignore_errors=True) - - -# --------------------------------------------------------------------------- -# Test 2: Round-trip GPT -> Mamba -> GPT weight preservation -# --------------------------------------------------------------------------- - -def test_roundtrip_weight_preservation(): - """Convert GPT -> Mamba -> GPT and verify attention/MLP weights match exactly.""" - print("\n=== Test 2: Round-trip GPT -> Mamba -> GPT weight preservation ===") - - num_layers = 2 - hidden_size = 64 - pattern = "M*-M*-" - - tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_rt_test_') - try: - src_gpt_dir = os.path.join(tmpdir, 'gpt_src') - mamba_dir = os.path.join(tmpdir, 'mamba_mid') - dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') - - # Create and save source GPT checkpoint - ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) - gpt_sd = make_gpt_state_dict(num_layers, hidden_size) - write_checkpoint_to_disk(src_gpt_dir, gpt_sd, ckpt_args) - - common_args = dict( - hybrid_layer_pattern=pattern, - target_tp_size=1, - target_pp_size=1, - d_model=hidden_size, - mamba_version=2, - mamba_d_state=16, - mamba2_n_groups=2, - mamba2_head_dim=32, - d_conv=4, - init_method_std=0.02, - reset_iterations=False, - ) - - # Step 1: GPT -> Mamba - conversion_main(argparse.Namespace( - direction='gpt-to-mamba', - load_dir=src_gpt_dir, - save_dir=mamba_dir, - **common_args, - )) - - # Step 2: Mamba -> GPT - conversion_main(argparse.Namespace( - direction='mamba-to-gpt', - load_dir=mamba_dir, - save_dir=dst_gpt_dir, - **common_args, - )) - - # Load and compare - recovered_sd, recovered_args = load_converted_state_dict(dst_gpt_dir) - - assert recovered_args.num_layers == num_layers, ( - f"Expected num_layers={num_layers}, got {recovered_args.num_layers}" - ) - - # Compare every key in the original - mismatches = [] - for key, original_tensor in gpt_sd.items(): - # final_layernorm is renamed in the round trip - if key not in recovered_sd: - mismatches.append(f"MISSING: {key}") - continue - if not torch.equal(original_tensor, recovered_sd[key]): - max_diff = (original_tensor - recovered_sd[key]).abs().max().item() - mismatches.append(f"MISMATCH: {key} (max_diff={max_diff})") - - if mismatches: - for m in mismatches: - print(f" FAIL: {m}") - raise AssertionError( - f"Round-trip failed with {len(mismatches)} mismatches:\n" - + "\n".join(mismatches) - ) - - print("PASSED: All attention and MLP weights preserved exactly.\n") - - finally: - shutil.rmtree(tmpdir, ignore_errors=True) - - -# --------------------------------------------------------------------------- -# Test 3: Verify that Mamba -> GPT discards SSM params cleanly -# --------------------------------------------------------------------------- - -def test_mamba_to_gpt_discards_ssm(): - """Convert Mamba -> GPT and verify no SSM keys leak through.""" - print("\n=== Test 3: Mamba -> GPT discards SSM params ===") - - hidden_size = 64 - pattern = "M*-M*-" - d_inner = hidden_size * 2 - d_state = 16 - n_groups = 2 - head_dim = 32 - n_heads = d_inner // head_dim - - tmpdir = tempfile.mkdtemp(prefix='gpt_mamba_discard_test_') - try: - # Build a Mamba-style state dict - mamba_sd = OrderedDict() - mamba_sd['embedding.word_embeddings.weight'] = torch.randn(512, hidden_size) - mamba_sd['output_layer.weight'] = torch.randn(512, hidden_size) - mamba_sd['decoder.final_norm.weight'] = torch.randn(hidden_size) - - layer_types = parse_hybrid_layer_pattern(pattern) - for i, lt in enumerate(layer_types): - p = f'decoder.layers.{i}.' - if lt == 'M': - ssm = initialize_ssm_layer_params( - i, hidden_size, d_inner, d_state, n_groups, n_heads, head_dim - ) - mamba_sd.update(ssm) - elif lt == '*': - mamba_sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size) - mamba_sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( - 3 * hidden_size, hidden_size - ) - mamba_sd[p + 'self_attention.linear_proj.weight'] = torch.randn( - hidden_size, hidden_size - ) - elif lt == '-': - mamba_sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size) - mamba_sd[p + 'mlp.linear_fc1.weight'] = torch.randn( - 4 * hidden_size, hidden_size - ) - mamba_sd[p + 'mlp.linear_fc2.weight'] = torch.randn( - hidden_size, 4 * hidden_size - ) - - # Write to disk - src_dir = os.path.join(tmpdir, 'mamba_src') - dst_dir = os.path.join(tmpdir, 'gpt_dst') - ckpt_args = make_checkpoint_args( - num_layers=len(layer_types), hidden_size=hidden_size - ) - write_checkpoint_to_disk(src_dir, mamba_sd, ckpt_args) - - # Convert - conversion_main(argparse.Namespace( - direction='mamba-to-gpt', - load_dir=src_dir, - save_dir=dst_dir, - hybrid_layer_pattern=pattern, - target_tp_size=1, - target_pp_size=1, - d_model=hidden_size, - mamba_version=2, - mamba_d_state=d_state, - mamba2_n_groups=n_groups, - mamba2_head_dim=head_dim, - d_conv=4, - init_method_std=0.02, - reset_iterations=False, - )) - - gpt_sd, gpt_args = load_converted_state_dict(dst_dir) - - # Verify no SSM keys - ssm_keys = [k for k in gpt_sd if 'mixer.' in k] - assert len(ssm_keys) == 0, f"SSM keys leaked: {ssm_keys}" - - # Verify correct GPT layer count - assert gpt_args.num_layers == 2 - - # Verify final_layernorm renamed back - assert 'decoder.final_layernorm.weight' in gpt_sd - assert 'decoder.final_norm.weight' not in gpt_sd - - print("PASSED: No SSM keys in GPT output, norms renamed correctly.\n") - - finally: - shutil.rmtree(tmpdir, ignore_errors=True) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -if __name__ == '__main__': - print("=" * 60) - print("GPT <-> Mamba Conversion Integration Tests") - print("=" * 60) - - test_gpt_to_mamba_shapes_and_keys() - test_roundtrip_weight_preservation() - test_mamba_to_gpt_discards_ssm() - - print("=" * 60) - print("ALL INTEGRATION TESTS PASSED") - print("=" * 60) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index 04ecb08a532..d3a082afd16 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -3,27 +3,16 @@ """ Parallelism-matrix integration tests for gpt_mamba_conversion.py. -Covers every combination the user cares about: - - source format exercises - TP TP=2, PP=1 legacy TP-combine + TP-split - PP TP=1, PP=2 legacy PP-stitch - FSDP world=1 dist (torch_dist) DCP load + DCP save - TP+PP TP=2, PP=2 legacy TP+PP both paths - TP+FSDP world=1 dist DCP load + DCP save - PP+FSDP world=1 dist DCP load + DCP save - TP+PP+FSDP world=1 dist DCP load + DCP save - -Legacy configs synthesize ``mp_rank_XX[_YYY]/model_optim_rng.pt`` shards by -re-using the converter's own save routine (which implements the exact TP-split -and PP-stitch layout Megatron produces). Dist configs synthesize a DCP -checkpoint via a single-rank ``torch.distributed.checkpoint.save``; at the -converter level the TP/PP/FSDP sharding layout of a dist checkpoint is -abstracted away by DCP's global-shape metadata, so one save code path -exercises every ``*+FSDP`` combination. Each config is run as a distinct test -to document the matrix and catch regressions in the dispatch logic. - -Designed to run on a single-GPU node via SLURM; no torchrun needed. +The converter operates on dist (``torch_dist`` / ``fsdp_dtensor``) checkpoints +only — DCP's metadata stores each tensor's ``global_shape``, so the on-disk +TP / PP / FSDP layout is abstracted away from the conversion logic. We +synthesize a DCP checkpoint via a single-rank ``dcp.save`` and round-trip +GPT -> Mamba -> GPT through the conversion CLI, asserting attention and MLP +weights match exactly. + +Each scenario is run as a distinct test to document the supported matrix and +catch regressions in dispatch logic. Designed to run on a single-GPU node via +SLURM (no torchrun needed). """ import argparse @@ -37,68 +26,73 @@ import torch -# Make the conversion tool and the sibling integration-test helpers importable -# under both `python ` and `pytest` (pytest doesn't put the test file's -# directory on sys.path). +# Make the conversion tool importable under both `python ` and `pytest`. _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) sys.path.insert(0, _THIS_DIR) -from gpt_mamba_conversion import ( - combine_tp_shards, - convert_gpt_to_mamba, - get_checkpoint_iteration, - load_checkpoint_shards, - main as conversion_main, - parse_hybrid_layer_pattern, - save_checkpoint_shards, - stitch_pp_shards, -) - -from test_gpt_mamba_conversion_integration import ( - make_checkpoint_args, - make_gpt_state_dict, -) +from gpt_mamba_conversion import main as conversion_main # --------------------------------------------------------------------------- -# Legacy (mp_rank_XX) fixture builders +# Synthetic-checkpoint helpers # --------------------------------------------------------------------------- -def _save_legacy_sharded(root_dir, full_sd, ckpt_args, tp_size, pp_size, - hybrid_layer_pattern='', - hidden_size=128, - iteration=100): - """Write a full state dict to disk as a sharded legacy checkpoint. - - We delegate to ``save_checkpoint_shards`` so the on-disk layout matches - exactly what Megatron training would produce at the given TP/PP. - """ - # save_checkpoint_shards expects a "sample_model" shape that mirrors a - # single rank's on-disk file. Any args object with the target fields works. - ckpt_args = copy.deepcopy(ckpt_args) - ckpt_args.tensor_model_parallel_size = tp_size - ckpt_args.pipeline_model_parallel_size = pp_size - sample_model = { - 'args': ckpt_args, - 'checkpoint_version': 3.0, - 'iteration': iteration, - 'rng_state': [], - } - params = SimpleNamespace( - target_tp_size=tp_size, - target_pp_size=pp_size, - target_num_layers=ckpt_args.num_layers, - reset_iterations=False, - # Mamba-only TP-split args; irrelevant for pure GPT shards but required. - mamba_version=2, - mamba_d_inner=hidden_size * 2, - mamba_d_state=16, - mamba2_n_groups=2, - mamba2_n_heads=hidden_size * 2 // 32, +def make_checkpoint_args( + num_layers=4, + hidden_size=128, + num_attention_heads=4, + seq_length=256, + max_position_embeddings=256, + iteration=100, +): + """Build a minimal checkpoint 'args' namespace mirroring Megatron's.""" + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=hidden_size * 4, + seq_length=seq_length, + max_position_embeddings=max_position_embeddings, + iteration=iteration, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, ) - save_checkpoint_shards(full_sd, sample_model, params, root_dir, iteration) + + +def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.float32): + """Create a minimal GPT model state dict with the standard Megatron keys.""" + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd # --------------------------------------------------------------------------- @@ -135,10 +129,6 @@ def _save_dist_checkpoint(root_dir, full_sd, ckpt_args, iteration=100, write_latest_iteration_marker(iter_dir, iteration) -# --------------------------------------------------------------------------- -# Output readers -# --------------------------------------------------------------------------- - def _load_converted_dist(ckpt_dir): """Read a dist-format converted checkpoint back into a full state dict.""" from dist_checkpoint_io import load_dist_checkpoint_full @@ -146,48 +136,6 @@ def _load_converted_dist(ckpt_dir): return sd, common.get('args', None) -def _load_converted_legacy_full(ckpt_dir): - """Read a legacy TP+PP-sharded converted checkpoint into a full state dict. - - Peeks the first shard to discover TP/PP sizes and total layers, then reuses - the converter's own load / TP-combine / PP-stitch routines. - """ - iteration = get_checkpoint_iteration(ckpt_dir) - model_dir = os.path.join(ckpt_dir, f'iter_{iteration:07d}') - first_shard = sorted(os.listdir(model_dir))[0] - sample = torch.load( - os.path.join(model_dir, first_shard, 'model_optim_rng.pt'), - map_location='cpu', weights_only=False, - ) - tp_size = sample['args'].tensor_model_parallel_size - pp_size = sample['args'].pipeline_model_parallel_size - num_layers = sample['args'].num_layers - num_layers_per_pp_rank = num_layers // pp_size - - all_shards, sample = load_checkpoint_shards( - ckpt_dir, iteration, tp_size, pp_size, - ) - # combine_tp_tensors only touches mamba-specific branches for mamba keys; - # any hidden_size-consistent defaults work for GPT-only outputs. - combine_params = SimpleNamespace( - mamba_version=2, - mamba_d_inner=0, - mamba_d_state=0, - mamba2_n_groups=0, - mamba2_n_heads=0, - ) - combined_pp = [combine_tp_shards(all_shards[pp], combine_params) - for pp in range(pp_size)] - full = stitch_pp_shards(combined_pp, num_layers_per_pp_rank) - return full, sample['args'] - - -def _load_converted(ckpt_dir, output_format): - if output_format == 'legacy': - return _load_converted_legacy_full(ckpt_dir) - return _load_converted_dist(ckpt_dir) - - # --------------------------------------------------------------------------- # Core scenario runner # --------------------------------------------------------------------------- @@ -195,11 +143,7 @@ def _load_converted(ckpt_dir, output_format): def _run_scenario( label, source_format, - source_tp, - source_pp, target_format, - target_tp=1, - target_pp=1, num_layers=4, hidden_size=128, pattern="M*-M*-M*-M*-", @@ -207,8 +151,8 @@ def _run_scenario( ): """Build a GPT source ckpt, convert GPT->Mamba->GPT, verify round-trip.""" print(f"\n=== {label} ===") - print(f" source={source_format} (tp={source_tp}, pp={source_pp}, prefix='{source_prefix}')") - print(f" target={target_format} (tp={target_tp}, pp={target_pp})") + print(f" source={source_format} (prefix='{source_prefix}')") + print(f" target={target_format}") tmpdir = tempfile.mkdtemp(prefix=f'gpt_mamba_{label.replace(" ", "_")}_') try: @@ -216,28 +160,16 @@ def _run_scenario( mamba_dir = os.path.join(tmpdir, 'mamba_mid') dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') - # --- Build source --- - ckpt_args = make_checkpoint_args( - num_layers=num_layers, hidden_size=hidden_size, - tp_size=source_tp, pp_size=source_pp, - ) + ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) gpt_sd = make_gpt_state_dict(num_layers, hidden_size) - if source_format == 'legacy': - _save_legacy_sharded( - src_gpt_dir, gpt_sd, ckpt_args, source_tp, source_pp, - hidden_size=hidden_size, - ) - else: - _save_dist_checkpoint( - src_gpt_dir, gpt_sd, ckpt_args, - prefix=source_prefix, backend=source_format, - ) + _save_dist_checkpoint( + src_gpt_dir, gpt_sd, ckpt_args, + prefix=source_prefix, backend=source_format, + ) common_kwargs = dict( hybrid_layer_pattern=pattern, - target_tp_size=target_tp, - target_pp_size=target_pp, d_model=hidden_size, mamba_version=2, mamba_d_state=16, @@ -267,8 +199,10 @@ def _run_scenario( )) # --- Verify --- - recovered_sd, recovered_args = _load_converted(dst_gpt_dir, target_format) - layer_types = parse_hybrid_layer_pattern(pattern) + recovered_sd, _ = _load_converted_dist(dst_gpt_dir) + # The mamba->gpt step renames decoder.final_norm -> decoder.final_layernorm, + # mirroring the original GPT key. So recovered_sd should have the same + # keys and tensor values as gpt_sd. mismatches = [] for key, original in gpt_sd.items(): @@ -297,46 +231,42 @@ def _run_scenario( # --------------------------------------------------------------------------- -# Test cases — one per parallelism combo +# Test cases — one per (source backend, target backend, pattern) combo # --------------------------------------------------------------------------- -def test_tp_only_legacy(): - _run_scenario("TP only (legacy)", 'legacy', 2, 1, 'legacy', target_tp=2, target_pp=1) - - -def test_pp_only_legacy(): - _run_scenario("PP only (legacy)", 'legacy', 1, 2, 'legacy', target_tp=1, target_pp=2) - - -def test_tp_pp_legacy(): - _run_scenario("TP+PP (legacy)", 'legacy', 2, 2, 'legacy', target_tp=2, target_pp=2) - +def test_torch_dist_roundtrip(): + _run_scenario("torch_dist roundtrip", 'torch_dist', 'torch_dist') -def test_fsdp_only_dist(): - _run_scenario("FSDP only (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') - -def test_tp_fsdp_dist(): - _run_scenario("TP + FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') - - -def test_pp_fsdp_dist(): - _run_scenario("PP + FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') - - -def test_tp_pp_fsdp_dist(): - _run_scenario("TP+PP+FSDP (torch_dist)", 'torch_dist', 1, 1, 'torch_dist') +def test_fsdp_dtensor_roundtrip(): + _run_scenario("fsdp_dtensor roundtrip", 'fsdp_dtensor', 'fsdp_dtensor') def test_fsdp_dtensor_prefix(): """fsdp_dtensor backend uses the 'model.module.' key prefix — verify we auto-detect and strip it correctly.""" _run_scenario( - "FSDP dtensor prefix", 'fsdp_dtensor', 1, 1, 'fsdp_dtensor', + "fsdp_dtensor prefix", 'fsdp_dtensor', 'fsdp_dtensor', source_prefix='model.module.', ) +def test_torch_dist_alternating_pattern(): + """Pure transformer pattern (no SSM) round-trips.""" + _run_scenario( + "torch_dist alternating", 'torch_dist', 'torch_dist', + pattern="*-*-*-*-", + ) + + +def test_torch_dist_dense_ssm_pattern(): + """Dense SSM pattern still round-trips on the attn/MLP layers.""" + _run_scenario( + "torch_dist dense SSM", 'torch_dist', 'torch_dist', + pattern="MM*-MM*-MM*-MM*-", + ) + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -346,14 +276,11 @@ def test_fsdp_dtensor_prefix(): print("GPT <-> Mamba Conversion Parallelism Matrix Tests") print("=" * 60) - test_tp_only_legacy() - test_pp_only_legacy() - test_tp_pp_legacy() - test_fsdp_only_dist() - test_tp_fsdp_dist() - test_pp_fsdp_dist() - test_tp_pp_fsdp_dist() + test_torch_dist_roundtrip() + test_fsdp_dtensor_roundtrip() test_fsdp_dtensor_prefix() + test_torch_dist_alternating_pattern() + test_torch_dist_dense_ssm_pattern() print("=" * 60) print("ALL PARALLELISM MATRIX TESTS PASSED") diff --git a/tools/checkpoint/dist_checkpoint_io.py b/tools/checkpoint/dist_checkpoint_io.py index 1fece3a3f91..33f0814cc5a 100644 --- a/tools/checkpoint/dist_checkpoint_io.py +++ b/tools/checkpoint/dist_checkpoint_io.py @@ -44,7 +44,6 @@ ) -FORMAT_LEGACY = 'legacy' FORMAT_TORCH_DIST = 'torch_dist' FORMAT_FSDP_DTENSOR = 'fsdp_dtensor' DIST_FORMATS = (FORMAT_TORCH_DIST, FORMAT_FSDP_DTENSOR) @@ -89,7 +88,12 @@ def resolve_checkpoint_subdir(load_dir): def detect_checkpoint_format(load_dir): - """Return one of ``{'legacy', 'torch_dist', 'fsdp_dtensor'}``.""" + """Return one of ``{'torch_dist', 'fsdp_dtensor'}``. + + Raises ``ValueError`` if the directory looks like the legacy + ``mp_rank_XX`` layout (no longer supported) or doesn't match any known + dist-checkpoint metadata. + """ ckpt_dir, _ = resolve_checkpoint_subdir(load_dir) config = maybe_load_config(ckpt_dir) if config is not None: @@ -98,7 +102,10 @@ def detect_checkpoint_format(load_dir): if os.path.isdir(ckpt_dir) and any( name.startswith('mp_rank_') for name in os.listdir(ckpt_dir) ): - return FORMAT_LEGACY + raise ValueError( + f"{load_dir} looks like a legacy mp_rank_XX checkpoint. " + f"Legacy format is no longer supported — convert to torch_dist first." + ) raise ValueError(f"Unrecognized checkpoint format at {load_dir}") diff --git a/tools/checkpoint/gpt_mamba_conversion.py b/tools/checkpoint/gpt_mamba_conversion.py index ff3d5851b4c..77780125ae8 100644 --- a/tools/checkpoint/gpt_mamba_conversion.py +++ b/tools/checkpoint/gpt_mamba_conversion.py @@ -41,14 +41,15 @@ mamba-to-gpt: SSM layers are discarded with a warning. Supported checkpoint formats: - - legacy : mp_rank_XX[_YYY]/model_optim_rng.pt (TP + PP, no FSDP). - torch_dist : Megatron distributed checkpoint (TP + PP + FSDP). - fsdp_dtensor : FSDP DTensor export (TP + PP + FSDP). - For distributed formats, PyTorch DCP gathers TP/PP/FSDP shards via the - checkpoint's global-shape metadata, so no explicit TP/PP/DP config is - needed on input. The input format is auto-detected; the output format - defaults to the input format. + PyTorch DCP gathers TP/PP/FSDP shards via the checkpoint's global-shape + metadata, so no explicit TP/PP/DP config is needed on input. The input + format is auto-detected; the output format defaults to the input format. + + The legacy ``mp_rank_XX/model_optim_rng.pt`` layout is not supported — + convert old checkpoints to ``torch_dist`` first. GPT compatibility whitelist (safeguard): GPTModel is a strict homogeneous transformer (self-attention + MLP per @@ -73,19 +74,6 @@ `validate_source_args_gpt_compatible` for the exact rules. Example commands: - # GPT -> Mamba (legacy TP+PP checkpoint) - python tools/checkpoint/gpt_mamba_conversion.py \\ - --direction gpt-to-mamba \\ - --load-dir /path/to/gpt-checkpoint \\ - --save-dir /path/to/mamba-checkpoint \\ - --hybrid-layer-pattern "M*-M*-M*-M*-" \\ - --target-tp-size 1 \\ - --target-pp-size 1 \\ - --d-model 4096 \\ - --mamba-d-state 128 \\ - --mamba2-n-groups 8 \\ - --mamba2-head-dim 64 - # GPT -> Mamba (TP+PP+FSDP dist checkpoint) python tools/checkpoint/gpt_mamba_conversion.py \\ --direction gpt-to-mamba \\ @@ -97,14 +85,12 @@ --mamba2-n-groups 8 \\ --mamba2-head-dim 64 - # Mamba -> GPT (legacy) + # Mamba -> GPT (dist checkpoint) python tools/checkpoint/gpt_mamba_conversion.py \\ --direction mamba-to-gpt \\ - --load-dir /path/to/mamba-checkpoint \\ - --save-dir /path/to/gpt-checkpoint \\ + --load-dir /path/to/mamba-dist-checkpoint \\ + --save-dir /path/to/gpt-dist-checkpoint \\ --hybrid-layer-pattern "M*-M*-M*-M*-" \\ - --target-tp-size 1 \\ - --target-pp-size 1 \\ --d-model 4096 \\ --mamba-d-state 128 \\ --mamba2-n-groups 8 \\ @@ -122,7 +108,6 @@ from dist_checkpoint_io import ( DIST_FORMATS, - FORMAT_LEGACY, FORMAT_TORCH_DIST, detect_checkpoint_format, load_dist_checkpoint_full, @@ -131,234 +116,6 @@ ) -# --------------------------------------------------------------------------- -# TP split-dim mapping (reused from hybrid_conversion.py) -# --------------------------------------------------------------------------- - -# Maps parameter-name substrings to the tensor dimension along which they are -# sharded across TP ranks. -1 means "replicated" (not sharded). -TP_SPLIT_DIM = { - # embeddings / output - 'word_embeddings.weight': 0, - 'output_layer.weight': 0, - # norms (replicated) - 'norm.weight': -1, - 'final_norm.weight': -1, - 'final_layernorm.weight': -1, - 'final_layernorm.bias': -1, - # mamba SSM params - 'A_log': 0, - 'D': 0, - 'dt_bias': 0, - 'in_proj.weight': 0, - 'conv1d.weight': 0, - 'conv1d.bias': 0, - 'x_proj.weight': 1, - 'dt_proj.weight': 0, - 'dt_proj.bias': 0, - 'out_proj.weight': 1, - 'mixer.norm.weight': 0, - # MLP (transformer-style) - 'linear_fc1.layer_norm_weight': -1, - 'linear_fc1.weight': 0, - 'linear_fc2.weight': 1, - # attention (transformer-style) - 'self_attention.linear_proj.weight': 1, - 'self_attention.linear_qkv.layer_norm_weight': -1, - 'self_attention.linear_qkv.weight': 0, - # standalone layer norms (used in non-TE / "local" transformer impl) - 'input_layernorm.weight': -1, - 'input_layernorm.bias': -1, - 'pre_mlp_layernorm.weight': -1, - 'pre_mlp_layernorm.bias': -1, - # TE-fused layer norms in Mamba in_proj - 'in_proj.layer_norm_weight': -1, - 'in_proj.layer_norm_bias': -1, -} - - -def get_split_dim(tensor_name): - """Determine the TP-split dimension for a given parameter name.""" - # Disambiguate mixer.norm.weight vs generic norm.weight - if 'norm.weight' in tensor_name: - if 'mixer.norm.weight' in tensor_name: - return TP_SPLIT_DIM['mixer.norm.weight'] - elif 'final_norm.weight' in tensor_name: - return TP_SPLIT_DIM['final_norm.weight'] - elif 'final_layernorm.weight' in tensor_name: - return TP_SPLIT_DIM['final_layernorm.weight'] - elif 'layer_norm_weight' in tensor_name: - # TE-fused layer norm weights - for key in TP_SPLIT_DIM: - if key in tensor_name: - return TP_SPLIT_DIM[key] - return -1 - else: - return TP_SPLIT_DIM['norm.weight'] - - for key in TP_SPLIT_DIM: - if key in tensor_name: - return TP_SPLIT_DIM[key] - raise ValueError(f"Unknown tensor name for TP splitting: {tensor_name}") - - -# --------------------------------------------------------------------------- -# TP combine / split (reused from hybrid_conversion.py) -# --------------------------------------------------------------------------- - -def combine_tp_tensors(params, key, dim, tensors): - """Combine TP-sharded tensors back into one full tensor. - - Handles special Mamba v2 in_proj and conv1d interleaved layouts. - """ - tp_size = len(tensors) - - if 'mixer.in_proj.weight' in key and params.mamba_version == 1: - xs, zs = [], [] - for tensor in tensors: - x, z = torch.split( - tensor, - [params.mamba_d_inner // tp_size, params.mamba_d_inner // tp_size], - dim=dim, - ) - xs.append(x) - zs.append(z) - return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) - - elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: - xs, zs, Bs, Cs, dts = [], [], [], [], [] - for tensor in tensors: - x, z, B, C, dt = torch.split( - tensor, - [ - params.mamba_d_inner // tp_size, - params.mamba_d_inner // tp_size, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state, - params.mamba2_n_heads // tp_size, - ], - dim=dim, - ) - xs.append(x) - zs.append(z) - Bs.append(B) - Cs.append(C) - dts.append(dt) - - for ii in range(len(Bs)): - Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state, Bs[ii].shape[-1]) - Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state, Cs[ii].shape[-1]) - B = torch.cat(Bs, dim=dim) - C = torch.cat(Cs, dim=dim) - x = torch.cat(xs, dim=dim) - z = torch.cat(zs, dim=dim) - dt = torch.cat(dts, dim=dim) - return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) - - elif 'mixer.conv1d' in key and params.mamba_version == 2: - xs, Bs, Cs = [], [], [] - for tensor in tensors: - x, B, C = torch.split( - tensor, - [ - params.mamba_d_inner // tp_size, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state, - ], - dim=dim, - ) - xs.append(x) - Bs.append(B) - Cs.append(C) - - for ii in range(len(Bs)): - if 'weight' in key: - Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1]) - Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1]) - elif 'bias' in key: - Bs[ii] = Bs[ii].reshape(-1, params.mamba_d_state) - Cs[ii] = Cs[ii].reshape(-1, params.mamba_d_state) - else: - raise ValueError(f"Unknown conv1d key: {key}") - B = torch.cat(Bs, dim=dim) - C = torch.cat(Cs, dim=dim) - x = torch.cat(xs, dim=dim) - return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) - - else: - return torch.cat(tensors, dim=dim) - - -def split_tensor_for_tp(params, key, dim, tensor): - """Split a full tensor into TP shards. - - Handles special Mamba v2 in_proj and conv1d interleaved layouts. - """ - tp_size = params.target_tp_size - - if 'mixer.in_proj.weight' in key and params.mamba_version == 1: - x, z = torch.split( - tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim - ) - x_sliced = torch.chunk(x, tp_size, dim=dim) - z_sliced = torch.chunk(z, tp_size, dim=dim) - return [torch.cat((xi, zi), dim=dim) for xi, zi in zip(x_sliced, z_sliced)] - - elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: - x, z, B, C, dt = torch.split( - tensor, - [ - params.mamba_d_inner, - params.mamba_d_inner, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_heads, - ], - dim=dim, - ) - B = B.reshape(-1, params.mamba_d_state, B.shape[-1]) - C = C.reshape(-1, params.mamba_d_state, C.shape[-1]) - x_s = torch.chunk(x, tp_size, dim=dim) - z_s = torch.chunk(z, tp_size, dim=dim) - B_s = torch.chunk(B, tp_size, dim=dim) - C_s = torch.chunk(C, tp_size, dim=dim) - dt_s = torch.chunk(dt, tp_size, dim=dim) - return [ - torch.cat((xi, zi, Bi.flatten(0, 1), Ci.flatten(0, 1), dti), dim=dim) - for xi, zi, Bi, Ci, dti in zip(x_s, z_s, B_s, C_s, dt_s) - ] - - elif 'mixer.conv1d' in key and params.mamba_version == 2: - x, B, C = torch.split( - tensor, - [ - params.mamba_d_inner, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_groups * params.mamba_d_state, - ], - dim=dim, - ) - if 'weight' in key: - B = B.reshape(-1, params.mamba_d_state, B.shape[-2], B.shape[-1]) - C = C.reshape(-1, params.mamba_d_state, C.shape[-2], C.shape[-1]) - elif 'bias' in key: - B = B.reshape(-1, params.mamba_d_state) - C = C.reshape(-1, params.mamba_d_state) - else: - raise ValueError(f"Unknown conv1d key: {key}") - - x_s = torch.chunk(x, tp_size, dim=dim) - B_s = torch.chunk(B, tp_size, dim=dim) - C_s = torch.chunk(C, tp_size, dim=dim) - return [ - torch.cat((xi, Bi.flatten(0, 1), Ci.flatten(0, 1)), dim=dim) - for xi, Bi, Ci in zip(x_s, B_s, C_s) - ] - - else: - return list(torch.chunk(tensor, tp_size, dim=dim)) - - # --------------------------------------------------------------------------- # Hybrid layer pattern parsing (standalone, no Megatron imports needed) # --------------------------------------------------------------------------- @@ -696,211 +453,6 @@ def initialize_ssm_layer_params( return params -# --------------------------------------------------------------------------- -# Checkpoint I/O helpers (patterns from hybrid_conversion.py) -# --------------------------------------------------------------------------- - -def get_checkpoint_iteration(load_dir): - """Read the latest iteration number from a checkpoint directory.""" - tracker_file = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') - with open(tracker_file, 'r') as f: - metastring = f.read().strip() - try: - iteration = int(metastring) - except ValueError: - raise ValueError( - f"Invalid iteration in {tracker_file}: '{metastring}'" - ) - return iteration - - -def load_checkpoint_shards(load_dir, iteration, input_tp_size, input_pp_size): - """Load all TP/PP shards of a checkpoint. - - Returns: - list[list[dict]]: models[pp_rank][tp_rank] = checkpoint dict - dict: sample_model (first shard, for metadata) - """ - model_dir = os.path.join(load_dir, f'iter_{iteration:07d}') - sample_model = None - all_shards = [] - - for pp in range(input_pp_size): - tp_shards = [] - for tp in range(input_tp_size): - dir_name = f"mp_rank_{tp:02d}" - if input_pp_size > 1: - dir_name += f"_{pp:03d}" - model_file = os.path.join(model_dir, dir_name, "model_optim_rng.pt") - checkpoint = torch.load(model_file, map_location='cpu', weights_only=False) - tp_shards.append(checkpoint) - if sample_model is None: - sample_model = checkpoint - print(f" Loaded {model_file}") - all_shards.append(tp_shards) - - return all_shards, sample_model - - -def combine_tp_shards(tp_models, params): - """Combine TP-sharded models into a single state dict with full tensors.""" - input_tp_size = len(tp_models) - if input_tp_size == 1: - return OrderedDict(tp_models[0]['model']) - - combined = OrderedDict() - for key, original_tensor in tp_models[0]['model'].items(): - if '_extra_state' in key: - combined[key] = original_tensor - continue - - split_dim = get_split_dim(key) - if split_dim != -1: - tensors = [tp_models[j]['model'][key].cpu() for j in range(input_tp_size)] - combined[key] = combine_tp_tensors(params, key, split_dim, tensors) - else: - combined[key] = original_tensor - - return combined - - -def stitch_pp_shards(all_combined_shards, num_layers_per_pp_rank): - """Stitch PP shards into one flat model with globally-indexed layers.""" - full_model = OrderedDict() - - for pp, combined_shard in enumerate(all_combined_shards): - for key, tensor in combined_shard.items(): - try: - layer_num = int(re.findall(r'\d+', key)[0]) - new_key = key.replace( - str(layer_num), - str(layer_num + pp * num_layers_per_pp_rank), - 1, - ) - except (IndexError, ValueError): - new_key = key - full_model[new_key] = tensor - - return full_model - - -def finalize_checkpoint(sample_model, model, params, verbose=False): - """Finalize checkpoint metadata from a sample source checkpoint.""" - reset_iterations = params.reset_iterations - - model['args'] = copy.deepcopy(sample_model['args']) - model['args'].tensor_model_parallel_size = params.target_tp_size - model['args'].pipeline_model_parallel_size = params.target_pp_size - if reset_iterations: - model['args'].iteration = 0 - model['args'].consumed_valid_samples = 0 - model['args'].consumed_train_samples = 0 - model['args'].train_iters = 0 - model['args'].train_samples = 0 - - model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version']) - - model['iteration'] = copy.deepcopy(sample_model['iteration']) - if reset_iterations: - model['iteration'] = 0 - - if 'opt_param_scheduler' in sample_model: - model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler']) - - model['rng_state'] = copy.deepcopy(sample_model['rng_state']) - - if verbose: - original_args = sample_model['args'].__dict__ - final_args = model['args'].__dict__ - for key in original_args: - if key in final_args: - if final_args[key] != original_args[key]: - print(f" ARG MISMATCH: {key}") - print(f" original: {original_args[key]}") - print(f" final: {final_args[key]}") - else: - print(f" ARG MISSING from final: {key} = {original_args[key]}") - for key in final_args: - if key not in original_args: - print(f" ARG ADDED to final: {key} = {final_args[key]}") - - return model - - -def save_checkpoint_shards(target_state_dicts, sample_model, params, save_dir, iteration): - """Split and save checkpoint for target TP/PP configuration. - - Args: - target_state_dicts: OrderedDict with globally-indexed layer keys (full tensors). - sample_model: Source checkpoint dict for metadata. - params: argparse namespace with target_tp_size, target_pp_size, etc. - save_dir: Output directory. - iteration: Iteration number to write. - """ - total_layers = params.target_num_layers - num_layers_per_pp_rank = total_layers // params.target_pp_size - - out_iteration = iteration if not params.reset_iterations else 0 - - pp_offset = 0 - # Build a list of (key, tensor) for iteration - all_items = list(target_state_dicts.items()) - - for pp in range(params.target_pp_size): - print(f" Saving PP rank {pp}") - tp_models = [{'model': OrderedDict()} for _ in range(params.target_tp_size)] - - for idx in range(pp_offset, len(all_items)): - key, tensor = all_items[idx] - - # Determine if this key belongs to this PP rank - try: - layer_num = int(re.findall(r'\d+', key)[0]) - if layer_num >= num_layers_per_pp_rank * (pp + 1): - break - new_key = key.replace( - str(layer_num), - str(layer_num - pp * num_layers_per_pp_rank), - 1, - ) - except (IndexError, ValueError): - new_key = key - - pp_offset += 1 - - if '_extra_state' in new_key: - for j in range(params.target_tp_size): - tp_models[j]['model'][new_key] = tensor - continue - - split_dim = get_split_dim(new_key) - if split_dim != -1: - slices = split_tensor_for_tp(params, new_key, split_dim, tensor) - for j in range(params.target_tp_size): - tp_models[j]['model'][new_key] = slices[j] - else: - for j in range(params.target_tp_size): - tp_models[j]['model'][new_key] = tensor - - for tp in range(params.target_tp_size): - dir_name = f"mp_rank_{tp:02d}" - if params.target_pp_size > 1: - dir_name += f"_{pp:03d}" - - model = finalize_checkpoint(sample_model, tp_models[tp], params, verbose=False) - - out_dir = os.path.join(save_dir, f'iter_{out_iteration:07d}', dir_name) - os.makedirs(out_dir, exist_ok=True) - model_file = os.path.join(out_dir, "model_optim_rng.pt") - torch.save(model, model_file) - print(f" Saved {model_file}") - - # Write iteration tracker - tracker_file = os.path.join(save_dir, 'latest_checkpointed_iteration.txt') - with open(tracker_file, 'w') as f: - f.write(str(out_iteration)) - - # --------------------------------------------------------------------------- # Key name helpers # --------------------------------------------------------------------------- @@ -1138,55 +690,9 @@ def sort_key(item): # --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -# --------------------------------------------------------------------------- -# Format-aware load / save +# Format-aware save # --------------------------------------------------------------------------- -def _load_legacy_full(args): - """Load a legacy mp_rank_XX checkpoint and return a full (TP+PP gathered) - state dict plus a sample shard for metadata. - - Returns: - full_model (OrderedDict): globally-indexed, TP-combined state dict. - sample_model (dict): one source shard (for args/iteration/etc.). - iteration (int): source iteration. - """ - iteration = get_checkpoint_iteration(args.load_dir) - print(f" Iteration: {iteration}") - - model_dir = os.path.join(args.load_dir, f'iter_{iteration:07d}') - sub_models = os.listdir(model_dir) - sample_file = os.path.join(model_dir, sub_models[0], "model_optim_rng.pt") - sample_model = torch.load(sample_file, map_location='cpu', weights_only=False) - - input_tp_size = sample_model['args'].tensor_model_parallel_size - input_pp_size = sample_model['args'].pipeline_model_parallel_size - input_num_layers = sample_model['args'].num_layers - num_layers_per_pp_rank = input_num_layers // input_pp_size - - print(f" Source: TP={input_tp_size}, PP={input_pp_size}, " - f"num_layers={input_num_layers}") - - all_shards, sample_model = load_checkpoint_shards( - args.load_dir, iteration, input_tp_size, input_pp_size - ) - - print(" Combining TP shards into full tensors...") - combined_pp_shards = [] - for pp in range(input_pp_size): - combined = combine_tp_shards(all_shards[pp], args) - combined_pp_shards.append(combined) - - print(" Stitching PP shards into flat model...") - full_model = stitch_pp_shards(combined_pp_shards, num_layers_per_pp_rank) - print(f" Full model: {len(full_model)} parameters") - - return full_model, sample_model, iteration - - def _save_dist_full(target_state_dict, common_state, model_prefix, backend, args, iteration): """Save a fully-gathered state dict in dist-ckpt format. @@ -1261,35 +767,29 @@ def main(args): output_format = input_format print(f"\n Input format: {input_format}") print(f" Output format: {output_format}") - if output_format == FORMAT_LEGACY: - print(f" Target TP size: {args.target_tp_size}") - print(f" Target PP size: {args.target_pp_size}") + + if input_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported input format: {input_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) + if output_format not in DIST_FORMATS: + raise ValueError( + f"Unsupported output format: {output_format}. " + f"Only dist formats are supported: {DIST_FORMATS}." + ) # 2. Load source checkpoint into a fully-gathered state dict print("\n[Step 1] Loading source checkpoint...") - sample_model = None - common_state = {} - model_prefix = 'model.' - dist_backend = FORMAT_TORCH_DIST - - if input_format == FORMAT_LEGACY: - full_model, sample_model, iteration = _load_legacy_full(args) - elif input_format in DIST_FORMATS: - full_model, common_state, model_prefix, dist_backend, iteration = ( - load_dist_checkpoint_full(args.load_dir) - ) - print(f" Source: dist backend={dist_backend}, prefix='{model_prefix}', " - f"iteration={iteration}, params={len(full_model)}") - else: - raise ValueError(f"Unsupported input format: {input_format}") + full_model, common_state, model_prefix, dist_backend, iteration = ( + load_dist_checkpoint_full(args.load_dir) + ) + print(f" Source: dist backend={dist_backend}, prefix='{model_prefix}', " + f"iteration={iteration}, params={len(full_model)}") # Args-level GPT compatibility whitelist: reject MoE, MLA, MTP, linear / # experimental attention, heterogeneous block specs, etc. See module header. - source_args = None - if sample_model is not None and 'args' in sample_model: - source_args = sample_model['args'] - elif common_state and 'args' in common_state: - source_args = common_state['args'] + source_args = common_state.get('args') if common_state else None validate_source_args_gpt_compatible(source_args, args.direction) # 3. Convert @@ -1306,24 +806,10 @@ def main(args): # 4. Save print(f"\n[Step 3] Saving to {args.save_dir}...") - if output_format == FORMAT_LEGACY: - if sample_model is None: - raise ValueError( - "Legacy output requires a legacy source checkpoint for metadata. " - "Use --output-format torch_dist when loading a dist checkpoint." - ) - sample_model['args'].num_layers = args.target_num_layers - save_checkpoint_shards( - target_state_dict, sample_model, args, args.save_dir, - iteration if iteration is not None else 0, - ) - elif output_format in DIST_FORMATS: - _save_dist_full( - target_state_dict, common_state, model_prefix, output_format, - args, iteration, - ) - else: - raise ValueError(f"Unsupported output format: {output_format}") + _save_dist_full( + target_state_dict, common_state, model_prefix, output_format, + args, iteration, + ) print("\n====CONVERSION COMPLETE====\n") @@ -1345,22 +831,15 @@ def main(args): help='Path to target checkpoint directory.') parser.add_argument('--hybrid-layer-pattern', type=str, required=True, help='Hybrid layer pattern string, e.g. "M*-M*-M*-M*-".') - parser.add_argument('--target-tp-size', type=int, default=1, - help='Target tensor parallel size (legacy output only; ' - 'dist formats are saved fully-replicated and ' - 'resharded at training load time).') - parser.add_argument('--target-pp-size', type=int, default=1, - help='Target pipeline parallel size (legacy output only).') parser.add_argument( '--input-format', type=str, default='auto', - choices=['auto', FORMAT_LEGACY, FORMAT_TORCH_DIST, 'fsdp_dtensor'], - help='Source checkpoint format. "auto" detects from metadata.json / ' - 'mp_rank_XX layout.', + choices=('auto',) + DIST_FORMATS, + help='Source checkpoint format. "auto" detects from metadata.json.', ) parser.add_argument( '--output-format', type=str, default='auto', - choices=['auto', FORMAT_LEGACY, FORMAT_TORCH_DIST, 'fsdp_dtensor'], + choices=('auto',) + DIST_FORMATS, help='Target checkpoint format. "auto" matches the input format. ' 'Dist formats (torch_dist / fsdp_dtensor) transparently support ' 'TP+PP+FSDP training checkpoints.', From 366d5f84dadbe0e3b780784c93bab1999a7a06f5 Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Mon, 27 Apr 2026 08:45:22 -0700 Subject: [PATCH 03/10] Auto-format the code Signed-off-by: guihong-nv --- .../checkpoint/test_gpt_mamba_conversion.py | 79 +++++++------------ .../test_gpt_mamba_conversion_parallelism.py | 71 +++++++---------- 2 files changed, 58 insertions(+), 92 deletions(-) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py index 2b3aca046be..c1f1396d4eb 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py @@ -25,8 +25,7 @@ # Add the tools/checkpoint directory to the path so we can import the module sys.path.insert( - 0, - os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint'), + 0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint') ) from gpt_mamba_conversion import ( @@ -44,11 +43,11 @@ validate_source_args_gpt_compatible, ) - # --------------------------------------------------------------------------- # Pattern parsing tests # --------------------------------------------------------------------------- + class TestPatternParsing: def test_simple_pattern(self): result = parse_hybrid_layer_pattern("M*-M*-") @@ -89,13 +88,12 @@ def test_invalid_symbol(self): # Layer index mapping tests # --------------------------------------------------------------------------- + class TestLayerIndexMapping: def test_gpt_to_mamba_basic(self): # Pattern: M*-M*- (2 attn at pos 1,4; 2 MLP at pos 2,5) layer_types = ['M', '*', '-', 'M', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping( - layer_types, 'gpt-to-mamba' - ) + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-mamba') # 2 GPT layers -> attn at [1,4], MLP at [2,5] assert attn_map == {0: 1, 1: 4} assert mlp_map == {0: 2, 1: 5} @@ -103,9 +101,7 @@ def test_gpt_to_mamba_basic(self): def test_mamba_to_gpt_basic(self): layer_types = ['M', '*', '-', 'M', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping( - layer_types, 'mamba-to-gpt' - ) + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'mamba-to-gpt') # attn at mamba layer 1 -> GPT layer 0, attn at 4 -> GPT layer 1 assert attn_map == {1: 0, 4: 1} assert mlp_map == {2: 0, 5: 1} @@ -113,9 +109,7 @@ def test_mamba_to_gpt_basic(self): def test_alternating_pattern(self): layer_types = ['*', '-', '*', '-', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping( - layer_types, 'gpt-to-mamba' - ) + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-mamba') assert attn_map == {0: 0, 1: 2, 2: 4} assert mlp_map == {0: 1, 1: 3, 2: 5} assert ssm_indices == [] @@ -135,6 +129,7 @@ def test_unknown_direction(self): # Key helper tests # --------------------------------------------------------------------------- + class TestKeyHelpers: def test_get_layer_num(self): assert get_layer_num_from_key('decoder.layers.5.mlp.linear_fc1.weight') == 5 @@ -173,6 +168,7 @@ def test_is_ssm_param(self): # SSM initialization tests # --------------------------------------------------------------------------- + class TestSSMInitialization: def test_shapes(self): d_model = 256 @@ -311,6 +307,7 @@ def test_different_layer_idx(self): # Synthetic GPT checkpoint builder # --------------------------------------------------------------------------- + def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): """Create a minimal synthetic GPT state dict for testing.""" state_dict = OrderedDict() @@ -351,14 +348,13 @@ def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): # Full conversion tests # --------------------------------------------------------------------------- + class TestGPTToMambaConversion: def setup_method(self): self.d_model = 64 self.num_gpt_layers = 2 self.pattern = "M*-M*-" # 6 total: 2 SSM, 2 attn, 2 MLP - self.gpt_state = make_synthetic_gpt_checkpoint( - self.num_gpt_layers, self.d_model - ) + self.gpt_state = make_synthetic_gpt_checkpoint(self.num_gpt_layers, self.d_model) self.args = argparse.Namespace( d_model=self.d_model, mamba_d_inner=self.d_model * 2, @@ -381,10 +377,7 @@ def test_shared_params_preserved(self): self.gpt_state['embedding.word_embeddings.weight'], ) # Output layer - assert torch.equal( - result['output_layer.weight'], - self.gpt_state['output_layer.weight'], - ) + assert torch.equal(result['output_layer.weight'], self.gpt_state['output_layer.weight']) def test_final_norm_renamed(self): layer_types = parse_hybrid_layer_pattern(self.pattern) @@ -393,8 +386,7 @@ def test_final_norm_renamed(self): assert 'decoder.final_norm.weight' in result assert 'decoder.final_layernorm.weight' not in result assert torch.equal( - result['decoder.final_norm.weight'], - self.gpt_state['decoder.final_layernorm.weight'], + result['decoder.final_norm.weight'], self.gpt_state['decoder.final_layernorm.weight'] ) def test_attention_params_mapped(self): @@ -484,8 +476,7 @@ def _make_mamba_state(self): if lt == 'M': # SSM params ssm = initialize_ssm_layer_params( - i, self.d_model, d_inner, d_state, n_groups, n_heads, - self.args.mamba2_head_dim, + i, self.d_model, d_inner, d_state, n_groups, n_heads, self.args.mamba2_head_dim ) state_dict.update(ssm) elif lt == '*': @@ -569,6 +560,7 @@ def test_gpt_layer_count(self): # Round-trip test: GPT -> Mamba -> GPT # --------------------------------------------------------------------------- + class TestRoundTrip: def test_gpt_mamba_gpt_preserves_weights(self): """Converting GPT -> Mamba -> GPT should preserve all attention & MLP weights.""" @@ -603,9 +595,9 @@ def test_gpt_mamba_gpt_preserves_weights(self): if 'final_layernorm' in key: continue assert key in recovered_gpt, f"Missing key after round-trip: {key}" - assert torch.equal(original_gpt[key], recovered_gpt[key]), ( - f"Weight mismatch after round-trip for {key}" - ) + assert torch.equal( + original_gpt[key], recovered_gpt[key] + ), f"Weight mismatch after round-trip for {key}" # Check final_layernorm was properly renamed back assert torch.equal( @@ -641,15 +633,14 @@ def test_round_trip_different_pattern(self): if 'final_layernorm' in key: continue assert key in recovered_gpt, f"Missing key: {key}" - assert torch.equal(original_gpt[key], recovered_gpt[key]), ( - f"Mismatch for {key}" - ) + assert torch.equal(original_gpt[key], recovered_gpt[key]), f"Mismatch for {key}" # --------------------------------------------------------------------------- # GPT compatibility whitelist tests # --------------------------------------------------------------------------- + class TestPatternWhitelist: """validate_pattern_gpt_compatible rejects hybrid patterns GPTModel can't express.""" @@ -730,34 +721,27 @@ def test_accepts_missing_optional_fields(self): def test_rejects_moe(self): with pytest.raises(ValueError, match="MoE"): - validate_source_args_gpt_compatible( - self._ok_args(num_moe_experts=8), 'gpt-to-mamba' - ) + validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-mamba') def test_rejects_shared_expert(self): with pytest.raises(ValueError, match="shared expert"): validate_source_args_gpt_compatible( - self._ok_args(moe_shared_expert_intermediate_size=4096), - 'gpt-to-mamba', + self._ok_args(moe_shared_expert_intermediate_size=4096), 'gpt-to-mamba' ) def test_rejects_moe_layer_freq_list(self): with pytest.raises(ValueError, match="MoE layers"): validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=[1, 0, 1, 0]), - 'gpt-to-mamba', + self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-mamba' ) def test_accepts_moe_layer_freq_1(self): - validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=1), 'gpt-to-mamba' - ) + validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-mamba') def test_rejects_experimental_attention(self): with pytest.raises(ValueError, match="experimental attention"): validate_source_args_gpt_compatible( - self._ok_args(experimental_attention_variant='gated_delta_net'), - 'gpt-to-mamba', + self._ok_args(experimental_attention_variant='gated_delta_net'), 'gpt-to-mamba' ) def test_rejects_linear_attention(self): @@ -769,15 +753,13 @@ def test_rejects_linear_attention(self): def test_rejects_heterogeneous_block_specs(self): with pytest.raises(ValueError, match="heterogeneous"): validate_source_args_gpt_compatible( - self._ok_args(heterogeneous_block_specs=True), - 'mamba-to-gpt', + self._ok_args(heterogeneous_block_specs=True), 'mamba-to-gpt' ) def test_rejects_heterogeneous_config_path(self): with pytest.raises(ValueError, match="heterogeneous"): validate_source_args_gpt_compatible( - self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), - 'gpt-to-mamba', + self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), 'gpt-to-mamba' ) def test_rejects_mla(self): @@ -788,16 +770,13 @@ def test_rejects_mla(self): def test_rejects_mtp(self): with pytest.raises(ValueError, match="Multi-Token Prediction"): - validate_source_args_gpt_compatible( - self._ok_args(mtp_num_layers=2), 'gpt-to-mamba' - ) + validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-mamba') def test_reports_multiple_reasons(self): # Both MoE and MLA set: the error should surface both. with pytest.raises(ValueError) as exc: validate_source_args_gpt_compatible( - self._ok_args(num_moe_experts=8, multi_latent_attention=True), - 'gpt-to-mamba', + self._ok_args(num_moe_experts=8, multi_latent_attention=True), 'gpt-to-mamba' ) msg = str(exc.value) assert 'MoE' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index d3a082afd16..191a26cef55 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -34,11 +34,11 @@ from gpt_mamba_conversion import main as conversion_main - # --------------------------------------------------------------------------- # Synthetic-checkpoint helpers # --------------------------------------------------------------------------- + def make_checkpoint_args( num_layers=4, hidden_size=128, @@ -83,12 +83,8 @@ def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.fl hidden_size, hidden_size, dtype=dtype ) sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) - sd[p + 'mlp.linear_fc1.weight'] = torch.randn( - 4 * hidden_size, hidden_size, dtype=dtype - ) - sd[p + 'mlp.linear_fc2.weight'] = torch.randn( - hidden_size, 4 * hidden_size, dtype=dtype - ) + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) @@ -99,8 +95,10 @@ def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.fl # Dist (torch_dist / fsdp_dtensor) fixture builders # --------------------------------------------------------------------------- -def _save_dist_checkpoint(root_dir, full_sd, ckpt_args, iteration=100, - prefix='model.', backend='torch_dist'): + +def _save_dist_checkpoint( + root_dir, full_sd, ckpt_args, iteration=100, prefix='model.', backend='torch_dist' +): """Write a full state dict as a single-rank DCP checkpoint. From the converter's POV, this is indistinguishable from a multi-rank @@ -122,16 +120,14 @@ def _save_dist_checkpoint(root_dir, full_sd, ckpt_args, iteration=100, 'checkpoint_version': 3.0, 'iteration': iteration, } - save_dist_checkpoint_full( - full_sd, common_state, iter_dir, - model_prefix=prefix, backend=backend, - ) + save_dist_checkpoint_full(full_sd, common_state, iter_dir, model_prefix=prefix, backend=backend) write_latest_iteration_marker(iter_dir, iteration) def _load_converted_dist(ckpt_dir): """Read a dist-format converted checkpoint back into a full state dict.""" from dist_checkpoint_io import load_dist_checkpoint_full + sd, common, prefix, backend, iteration = load_dist_checkpoint_full(ckpt_dir) return sd, common.get('args', None) @@ -140,6 +136,7 @@ def _load_converted_dist(ckpt_dir): # Core scenario runner # --------------------------------------------------------------------------- + def _run_scenario( label, source_format, @@ -164,8 +161,7 @@ def _run_scenario( gpt_sd = make_gpt_state_dict(num_layers, hidden_size) _save_dist_checkpoint( - src_gpt_dir, gpt_sd, ckpt_args, - prefix=source_prefix, backend=source_format, + src_gpt_dir, gpt_sd, ckpt_args, prefix=source_prefix, backend=source_format ) common_kwargs = dict( @@ -183,20 +179,18 @@ def _run_scenario( ) # --- GPT -> Mamba --- - conversion_main(argparse.Namespace( - direction='gpt-to-mamba', - load_dir=src_gpt_dir, - save_dir=mamba_dir, - **common_kwargs, - )) + conversion_main( + argparse.Namespace( + direction='gpt-to-mamba', load_dir=src_gpt_dir, save_dir=mamba_dir, **common_kwargs + ) + ) # --- Mamba -> GPT --- - conversion_main(argparse.Namespace( - direction='mamba-to-gpt', - load_dir=mamba_dir, - save_dir=dst_gpt_dir, - **common_kwargs, - )) + conversion_main( + argparse.Namespace( + direction='mamba-to-gpt', load_dir=mamba_dir, save_dir=dst_gpt_dir, **common_kwargs + ) + ) # --- Verify --- recovered_sd, _ = _load_converted_dist(dst_gpt_dir) @@ -216,14 +210,13 @@ def _run_scenario( if mismatches: for m in mismatches[:10]: print(f" FAIL: {m}") - raise AssertionError( - f"{label} failed with {len(mismatches)} weight mismatches" - ) + raise AssertionError(f"{label} failed with {len(mismatches)} weight mismatches") # SSM keys must be absent in the final GPT output. - assert not any('mixer.' in k for k in recovered_sd), \ - f"SSM keys leaked into final GPT output: " \ + assert not any('mixer.' in k for k in recovered_sd), ( + f"SSM keys leaked into final GPT output: " f"{[k for k in recovered_sd if 'mixer.' in k][:5]}" + ) print(f"PASSED: {label}") finally: @@ -234,6 +227,7 @@ def _run_scenario( # Test cases — one per (source backend, target backend, pattern) combo # --------------------------------------------------------------------------- + def test_torch_dist_roundtrip(): _run_scenario("torch_dist roundtrip", 'torch_dist', 'torch_dist') @@ -246,25 +240,18 @@ def test_fsdp_dtensor_prefix(): """fsdp_dtensor backend uses the 'model.module.' key prefix — verify we auto-detect and strip it correctly.""" _run_scenario( - "fsdp_dtensor prefix", 'fsdp_dtensor', 'fsdp_dtensor', - source_prefix='model.module.', + "fsdp_dtensor prefix", 'fsdp_dtensor', 'fsdp_dtensor', source_prefix='model.module.' ) def test_torch_dist_alternating_pattern(): """Pure transformer pattern (no SSM) round-trips.""" - _run_scenario( - "torch_dist alternating", 'torch_dist', 'torch_dist', - pattern="*-*-*-*-", - ) + _run_scenario("torch_dist alternating", 'torch_dist', 'torch_dist', pattern="*-*-*-*-") def test_torch_dist_dense_ssm_pattern(): """Dense SSM pattern still round-trips on the attn/MLP layers.""" - _run_scenario( - "torch_dist dense SSM", 'torch_dist', 'torch_dist', - pattern="MM*-MM*-MM*-MM*-", - ) + _run_scenario("torch_dist dense SSM", 'torch_dist', 'torch_dist', pattern="MM*-MM*-MM*-MM*-") # --------------------------------------------------------------------------- From cfe49698b3e5558b44b72889527636c5e38651de Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Tue, 28 Apr 2026 18:27:46 -0700 Subject: [PATCH 04/10] Add the EP support and test cases Signed-off-by: guihong-nv --- .../checkpoint/test_distributed_round_trip.py | 332 ++++++++++++++++++ .../checkpoint/test_gpt_mamba_conversion.py | 69 +++- .../test_gpt_mamba_conversion_parallelism.py | 124 ++++++- tools/checkpoint/gpt_mamba_conversion.py | 160 +++++---- 4 files changed, 594 insertions(+), 91 deletions(-) create mode 100644 tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py diff --git a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py new file mode 100644 index 00000000000..1b1f2c3080d --- /dev/null +++ b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py @@ -0,0 +1,332 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-rank distributed round-trip test for gpt_mamba_conversion. + +Each rank participates in a multi-rank DCP save of a synthetic GPT (or MoE +GPT) state dict; rank 0 then runs the converter and verifies the GPT->Mamba-> +GPT round-trip exactly. + +This test is meant to be launched under SLURM/srun (or torchrun) with +WORLD_SIZE = TP * PP * FSDP * EP. The (tp, pp, fsdp, ep) values are passed +as flags purely as labels — the converter sees only the DCP-stored +``global_shape`` per tensor and is agnostic to *which* dimension(s) the +source was sharded along. The test value is in: + + 1. Coordinating a real multi-rank ``dcp.save`` (cross-node networking, + collective barriers, shared-filesystem write). + 2. Verifying the converter loads a multi-rank-written checkpoint and + round-trips it through both directions. + +Usage (under srun): + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + export WORLD_SIZE=$SLURM_NTASKS + export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -1) + export MASTER_PORT=29500 + python test_distributed_round_trip.py \\ + --tp 2 --pp 2 --fsdp 2 --ep 2 --label TP2-PP2-FSDP2-EP2 \\ + --output-root /lustre/.../scratch/dist_test +""" + +import argparse +import copy +import os +import shutil +import sys +import time +from collections import OrderedDict +from types import SimpleNamespace + +import torch +import torch.distributed as dist + +# Make the conversion tool and helpers importable. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..') +sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) +sys.path.insert(0, _THIS_DIR) + + +def _log(msg, rank, label=""): + prefix = f"[rank={rank}{(' ' + label) if label else ''}]" + print(f"{prefix} {msg}", flush=True) + + +def _build_state_dict(num_layers, hidden_size, num_moe_experts, vocab_size, dtype): + """Build a deterministic GPT(MoE) state dict identical on every rank. + + Determinism (via fixed seed) lets every rank produce the same tensors so + DCP's de-duplication writes a single coherent checkpoint. After load, we + re-derive the same tensors on rank 0 to verify round-trip. + """ + torch.manual_seed(0xC0FFEE) + sd = OrderedDict() + sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + + for i in range(num_layers): + p = f'decoder.layers.{i}.' + sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd[p + 'self_attention.linear_qkv.weight'] = torch.randn( + 3 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'self_attention.linear_proj.weight'] = torch.randn( + hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + + if num_moe_experts is None: + sd[p + 'mlp.linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + else: + sd[p + 'mlp.router.weight'] = torch.randn( + num_moe_experts, hidden_size, dtype=dtype + ) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + + sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) + sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) + return sd + + +def _build_ckpt_args(num_layers, hidden_size, num_moe_experts): + return SimpleNamespace( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=4, + ffn_hidden_size=hidden_size * 4, + seq_length=256, + max_position_embeddings=256, + iteration=100, + consumed_train_samples=0, + consumed_valid_samples=0, + train_iters=1000, + train_samples=0, + tokenizer_type='GPT2BPETokenizer', + position_embedding_type='rope', + params_dtype=torch.float32, + fp16=False, + bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=None, + moe_layer_freq=1, + ) + + +def _init_process_group(init_file): + """Initialize via file:// rendezvous on a shared filesystem. + + RANK / WORLD_SIZE come from env (set by srun). MASTER_ADDR / MASTER_PORT + are not needed — every rank just opens the same shared file. This avoids + the SLURM CLI tools (e.g. scontrol) which are not always present inside + container images. + + The init file MUST NOT pre-exist; rank 0 cleans up any stale leftover. + """ + if dist.is_initialized(): + return + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + if rank == 0 and os.path.exists(init_file): + os.remove(init_file) + # Brief settle so other ranks don't race ahead of the cleanup. + time.sleep(1) + dist.init_process_group( + backend='gloo', # CPU-only synthetic save; no NCCL needed + init_method=f'file://{init_file}', + rank=rank, + world_size=world_size, + ) + + +def _verify_round_trip(original, recovered, label): + missing, mismatch = [], [] + for k, v in original.items(): + if k not in recovered: + missing.append(k) + continue + if not torch.equal(v, recovered[k].to(v.dtype)): + mismatch.append(k) + + leaked_ssm = [k for k in recovered if 'mixer.' in k] + + if missing or mismatch or leaked_ssm: + print(f"FAIL [{label}]:") + for k in missing[:5]: + print(f" MISSING: {k}") + for k in mismatch[:5]: + print(f" MISMATCH: {k}") + for k in leaked_ssm[:5]: + print(f" SSM LEAKED: {k}") + raise AssertionError( + f"{label}: {len(missing)} missing, {len(mismatch)} mismatched, " + f"{len(leaked_ssm)} SSM keys leaked" + ) + + print(f"PASS [{label}]: {len(original)} keys round-tripped exactly") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--pp', type=int, default=1) + parser.add_argument('--fsdp', type=int, default=1) + parser.add_argument('--ep', type=int, default=1) + parser.add_argument('--label', type=str, required=True) + parser.add_argument('--output-root', type=str, required=True, + help='Shared-filesystem path where this scenario writes its ' + 'checkpoints (must be visible from every rank).') + parser.add_argument('--num-layers', type=int, default=3) + parser.add_argument('--hidden-size', type=int, default=64) + parser.add_argument('--vocab-size', type=int, default=512) + parser.add_argument('--num-moe-experts', type=int, default=None, + help='If set, use the MoE GPT state-dict layout ' + '(mlp.router + mlp.experts.local_experts.*).') + parser.add_argument('--pattern', type=str, default=None, + help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for ' + 'dense or "M*EM*EM*E" when --num-moe-experts is set.') + args = parser.parse_args() + + expected_world = args.tp * args.pp * args.fsdp * args.ep + pattern = args.pattern or ( + 'M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-' + ) + + # Shared init file lives on the same shared FS we use for checkpoints, so + # all ranks on all nodes see the same path. + init_file = os.path.join(args.output_root, f'_pg_init_{args.label}') + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.output_root, exist_ok=True) + _init_process_group(init_file) + rank = dist.get_rank() + world = dist.get_world_size() + if world != expected_world: + if rank == 0: + print(f"FAIL [{args.label}]: world={world} but tp*pp*fsdp*ep={expected_world}") + sys.exit(2) + + # Lazy imports after sys.path is set. + from gpt_mamba_conversion import main as conversion_main + from dist_checkpoint_io import ( + load_dist_checkpoint_full, + save_dist_checkpoint_full, + write_latest_iteration_marker, + ) + + if rank == 0: + _log( + f"label={args.label} tp={args.tp} pp={args.pp} fsdp={args.fsdp} " + f"ep={args.ep} world={world} pattern={pattern} " + f"num_moe_experts={args.num_moe_experts}", + rank, + args.label, + ) + + # Each rank builds the same full state dict — DCP de-duplicates writes + # across ranks via its writer planner. + state_dict = _build_state_dict( + args.num_layers, args.hidden_size, args.num_moe_experts, + args.vocab_size, torch.float32, + ) + ckpt_args = _build_ckpt_args( + args.num_layers, args.hidden_size, args.num_moe_experts, + ) + + scratch = os.path.join(args.output_root, args.label) + src_dir = os.path.join(scratch, 'gpt_src') + mid_dir = os.path.join(scratch, 'mamba_mid') + dst_dir = os.path.join(scratch, 'gpt_dst') + iter_subdir = os.path.join(src_dir, 'iter_0000100') + + if rank == 0: + if os.path.isdir(scratch): + shutil.rmtree(scratch, ignore_errors=True) + os.makedirs(iter_subdir, exist_ok=True) + dist.barrier() + + # --- Multi-rank DCP write of the source GPT checkpoint --- + # dcp.save / dcp.load are both COLLECTIVE in the active process group, so + # every rank in this PG must participate in every save and every load. + # That includes the two conversion_main calls below, which internally call + # load_dist_checkpoint_full + save_dist_checkpoint_full once each. + # If a rank exits early its gloo socket closes and rank 0's reduce_scatter + # dies with "Connection closed by peer". + common_state = { + 'args': copy.deepcopy(ckpt_args), + 'checkpoint_version': 3.0, + 'iteration': 100, + } + save_dist_checkpoint_full( + state_dict, common_state, iter_subdir, + model_prefix='model.', backend='torch_dist', + ) + if rank == 0: + write_latest_iteration_marker(iter_subdir, 100) + dist.barrier() + + # --- Convert GPT -> hybrid -> GPT (every rank participates collectively). + common_kwargs = dict( + hybrid_layer_pattern=pattern, + d_model=args.hidden_size, + mamba_version=2, + mamba_d_state=16, + mamba2_n_groups=2, + mamba2_head_dim=32, + d_conv=4, + init_method_std=0.02, + reset_iterations=False, + input_format='auto', + output_format='torch_dist', + ) + + # Silence non-rank-0 stdout to keep logs readable; collective behavior + # is unaffected. + if rank != 0: + sys.stdout = open(os.devnull, 'w') + + t0 = time.time() + conversion_main(argparse.Namespace( + direction='gpt-to-mamba', + load_dir=src_dir, save_dir=mid_dir, + **common_kwargs, + )) + dist.barrier() + conversion_main(argparse.Namespace( + direction='mamba-to-gpt', + load_dir=mid_dir, save_dir=dst_dir, + **common_kwargs, + )) + dist.barrier() + dt = time.time() - t0 + + # Restore stdout for rank 0's verify message. + if rank != 0: + sys.stdout = sys.__stdout__ + + # --- Load final + (rank 0 only) verify --- + recovered, _, _, _, _ = load_dist_checkpoint_full(dst_dir) + dist.barrier() + + if rank == 0: + _log(f"conversion time: {dt:.2f}s", rank, args.label) + _verify_round_trip(state_dict, recovered, args.label) + shutil.rmtree(scratch, ignore_errors=True) + if os.path.exists(init_file): + os.remove(init_file) + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py index c1f1396d4eb..ebecb197da6 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py @@ -659,9 +659,22 @@ def test_accepts_pure_ssm_pattern(self): layer_types = parse_hybrid_layer_pattern("MMMM") validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') - def test_rejects_moe_symbol(self): - layer_types = parse_hybrid_layer_pattern("M*-E") - with pytest.raises(ValueError, match="not GPT-compatible"): + def test_accepts_moe_pattern(self): + # MoE layers ('E') round-trip through the converter as long as every + # MLP-bearing position is the same kind. + layer_types = parse_hybrid_layer_pattern("M*EM*EM*E") + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + + def test_accepts_pure_attn_moe_pattern(self): + # No SSM, alternating attn/MoE — i.e. a Mixtral-like GPT. + layer_types = parse_hybrid_layer_pattern("*E*E*E") + validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + + def test_rejects_mixed_dense_and_moe(self): + # GPT layers must be uniform: '-' (dense) and 'E' (MoE) cannot both + # appear in the same pattern. + layer_types = parse_hybrid_layer_pattern("M*-M*E") + with pytest.raises(ValueError, match="uniform"): validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') def test_rejects_gdn_symbol(self): @@ -674,13 +687,18 @@ def test_rejects_unequal_attn_mlp(self): with pytest.raises(ValueError, match="pair every attention"): validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + def test_unequal_attn_moe_also_rejected(self): + # Same uniformity check, but with MoE — 2 attn, 1 MoE. + layer_types = parse_hybrid_layer_pattern("M**E") + with pytest.raises(ValueError, match="pair every attention"): + validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + def test_error_lists_offending_symbols(self): - layer_types = parse_hybrid_layer_pattern("M*-EG") + # 'G' is still rejected; the error message should mention it. + layer_types = parse_hybrid_layer_pattern("M*-G") with pytest.raises(ValueError) as exc: validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') - msg = str(exc.value) - assert 'E' in msg - assert 'G' in msg + assert 'G' in str(exc.value) class TestSourceArgsWhitelist: @@ -719,18 +737,24 @@ def test_accepts_missing_optional_fields(self): minimal = argparse.Namespace(num_moe_experts=None) validate_source_args_gpt_compatible(minimal, 'mamba-to-gpt') - def test_rejects_moe(self): - with pytest.raises(ValueError, match="MoE"): - validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-mamba') + def test_accepts_moe_args(self): + # MoE keys live under decoder.layers..mlp.* and round-trip as-is. + validate_source_args_gpt_compatible( + self._ok_args(num_moe_experts=8), 'gpt-to-mamba' + ) - def test_rejects_shared_expert(self): - with pytest.raises(ValueError, match="shared expert"): - validate_source_args_gpt_compatible( - self._ok_args(moe_shared_expert_intermediate_size=4096), 'gpt-to-mamba' - ) + def test_accepts_shared_expert_args(self): + # Shared experts also live under mlp.shared_experts.* and round-trip. + validate_source_args_gpt_compatible( + self._ok_args( + num_moe_experts=8, moe_shared_expert_intermediate_size=4096 + ), + 'gpt-to-mamba', + ) def test_rejects_moe_layer_freq_list(self): - with pytest.raises(ValueError, match="MoE layers"): + # Heterogeneous interleaving (some dense, some MoE) breaks GPT uniformity. + with pytest.raises(ValueError, match="interleaved"): validate_source_args_gpt_compatible( self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-mamba' ) @@ -738,6 +762,12 @@ def test_rejects_moe_layer_freq_list(self): def test_accepts_moe_layer_freq_1(self): validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-mamba') + def test_accepts_moe_layer_freq_all_ones_list(self): + # An all-1s list is uniform (every layer is the same kind) and accepted. + validate_source_args_gpt_compatible( + self._ok_args(moe_layer_freq=[1, 1, 1, 1]), 'gpt-to-mamba' + ) + def test_rejects_experimental_attention(self): with pytest.raises(ValueError, match="experimental attention"): validate_source_args_gpt_compatible( @@ -773,11 +803,12 @@ def test_rejects_mtp(self): validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-mamba') def test_reports_multiple_reasons(self): - # Both MoE and MLA set: the error should surface both. + # Both heterogeneous moe_layer_freq and MLA set — both should be reported. with pytest.raises(ValueError) as exc: validate_source_args_gpt_compatible( - self._ok_args(num_moe_experts=8, multi_latent_attention=True), 'gpt-to-mamba' + self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), + 'gpt-to-mamba', ) msg = str(exc.value) - assert 'MoE' in msg + assert 'interleaved' in msg assert 'Multi-Latent' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index 191a26cef55..1fb576a7bd4 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -46,8 +46,15 @@ def make_checkpoint_args( seq_length=256, max_position_embeddings=256, iteration=100, + num_moe_experts=None, + moe_shared_expert_intermediate_size=None, ): - """Build a minimal checkpoint 'args' namespace mirroring Megatron's.""" + """Build a minimal checkpoint 'args' namespace mirroring Megatron's. + + Set ``num_moe_experts`` to make the source/target a MoE GPT; the converter + will then pass the MoE config through unchanged so the round-trip stays + structurally consistent. + """ return SimpleNamespace( num_layers=num_layers, hidden_size=hidden_size, @@ -65,11 +72,30 @@ def make_checkpoint_args( params_dtype=torch.float32, fp16=False, bf16=False, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + moe_layer_freq=1, ) -def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.float32): - """Create a minimal GPT model state dict with the standard Megatron keys.""" +def make_gpt_state_dict( + num_layers, + hidden_size, + vocab_size=1024, + dtype=torch.float32, + num_moe_experts=None, + shared_expert_size=None, +): + """Create a minimal GPT state dict with the standard Megatron keys. + + Dense MLP layout (default): ``mlp.linear_fc1`` / ``mlp.linear_fc2``. + MoE layout (``num_moe_experts`` set): ``mlp.router`` plus N experts under + ``mlp.experts.local_experts..linear_fc{1,2}``, optionally a shared + expert under ``mlp.shared_experts.linear_fc{1,2}``. These are exactly the + keys Megatron writes for non-grouped-GEMM MoE — they all live under + ``decoder.layers..mlp.*`` so the converter ferries them through with no + MoE-specific code. + """ sd = OrderedDict() sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) @@ -83,8 +109,30 @@ def make_gpt_state_dict(num_layers, hidden_size, vocab_size=1024, dtype=torch.fl hidden_size, hidden_size, dtype=dtype ) sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) - sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) - sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + + if num_moe_experts is None: + # Dense MLP + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) + else: + # MoE: router + N experts (+ optional shared expert) + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) + for j in range(num_moe_experts): + ep = p + f'mlp.experts.local_experts.{j}.' + sd[ep + 'linear_fc1.weight'] = torch.randn( + 4 * hidden_size, hidden_size, dtype=dtype + ) + sd[ep + 'linear_fc2.weight'] = torch.randn( + hidden_size, 4 * hidden_size, dtype=dtype + ) + if shared_expert_size is not None: + sp = p + 'mlp.shared_experts.' + sd[sp + 'linear_fc1.weight'] = torch.randn( + shared_expert_size, hidden_size, dtype=dtype + ) + sd[sp + 'linear_fc2.weight'] = torch.randn( + hidden_size, shared_expert_size, dtype=dtype + ) sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype) @@ -145,11 +193,15 @@ def _run_scenario( hidden_size=128, pattern="M*-M*-M*-M*-", source_prefix='model.', + num_moe_experts=None, + shared_expert_size=None, ): """Build a GPT source ckpt, convert GPT->Mamba->GPT, verify round-trip.""" print(f"\n=== {label} ===") print(f" source={source_format} (prefix='{source_prefix}')") print(f" target={target_format}") + if num_moe_experts is not None: + print(f" MoE: num_experts={num_moe_experts} shared={shared_expert_size}") tmpdir = tempfile.mkdtemp(prefix=f'gpt_mamba_{label.replace(" ", "_")}_') try: @@ -157,8 +209,17 @@ def _run_scenario( mamba_dir = os.path.join(tmpdir, 'mamba_mid') dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') - ckpt_args = make_checkpoint_args(num_layers=num_layers, hidden_size=hidden_size) - gpt_sd = make_gpt_state_dict(num_layers, hidden_size) + ckpt_args = make_checkpoint_args( + num_layers=num_layers, + hidden_size=hidden_size, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=shared_expert_size, + ) + gpt_sd = make_gpt_state_dict( + num_layers, hidden_size, + num_moe_experts=num_moe_experts, + shared_expert_size=shared_expert_size, + ) _save_dist_checkpoint( src_gpt_dir, gpt_sd, ckpt_args, prefix=source_prefix, backend=source_format @@ -254,6 +315,52 @@ def test_torch_dist_dense_ssm_pattern(): _run_scenario("torch_dist dense SSM", 'torch_dist', 'torch_dist', pattern="MM*-MM*-MM*-MM*-") +def test_torch_dist_moe_roundtrip(): + """MoE GPT (Mixtral-style) round-trips through an 'E'-bearing pattern. + + Source has num_moe_experts=4 and writes mlp.router / mlp.experts.* keys. + The hybrid pattern 'M*EM*EM*E' has 3 'E' positions, one per source layer. + The converter should ferry the router + every per-expert tensor through + verbatim — no MoE-specific code path involved. + """ + _run_scenario( + "torch_dist MoE roundtrip", + 'torch_dist', + 'torch_dist', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + ) + + +def test_torch_dist_moe_with_shared_experts(): + """MoE + shared experts round-trip together (mlp.shared_experts.* keys).""" + _run_scenario( + "torch_dist MoE+shared", + 'torch_dist', + 'torch_dist', + num_layers=3, + hidden_size=64, + pattern="*E*E*E", + num_moe_experts=4, + shared_expert_size=64 * 2, + ) + + +def test_fsdp_dtensor_moe_roundtrip(): + """MoE round-trips through fsdp_dtensor (covers the 'model.module.' prefix + case combined with MoE keys).""" + _run_scenario( + "fsdp_dtensor MoE roundtrip", + 'fsdp_dtensor', + 'fsdp_dtensor', + num_layers=3, + pattern="M*EM*EM*E", + num_moe_experts=4, + source_prefix='model.module.', + ) + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- @@ -268,6 +375,9 @@ def test_torch_dist_dense_ssm_pattern(): test_fsdp_dtensor_prefix() test_torch_dist_alternating_pattern() test_torch_dist_dense_ssm_pattern() + test_torch_dist_moe_roundtrip() + test_torch_dist_moe_with_shared_experts() + test_fsdp_dtensor_moe_roundtrip() print("=" * 60) print("ALL PARALLELISM MATRIX TESTS PASSED") diff --git a/tools/checkpoint/gpt_mamba_conversion.py b/tools/checkpoint/gpt_mamba_conversion.py index 77780125ae8..a65a14a76e7 100644 --- a/tools/checkpoint/gpt_mamba_conversion.py +++ b/tools/checkpoint/gpt_mamba_conversion.py @@ -14,19 +14,32 @@ How the hybrid layer pattern maps GPT layers (gpt-to-mamba): - Each GPT layer contains both attention and MLP sub-layers. - - The target Mamba model's hybrid_layer_pattern specifies per-layer types: + - The target hybrid model's hybrid_layer_pattern specifies per-layer types: M = Mamba SSM layer * = Attention-only layer - - = MLP-only layer - G = GDN layer - E = MoE layer + - = MLP-only layer (dense) + E = MoE MLP-only layer (router + experts; supports EP) + G = GDN layer (not currently mapped) - GPT layer i's attention params map to the i-th '*' layer in the pattern. - - GPT layer i's MLP params map to the i-th '-' layer in the pattern. - - The number of '*' and '-' layers in the pattern must both equal the number - of GPT layers. + - GPT layer i's MLP/MoE params map to the i-th MLP-bearing position + ('-' or 'E') in the pattern. Dense ('-') and MoE ('E') cannot be mixed: + GPT layers are uniform. + - The number of '*' positions and MLP-bearing positions must each equal + the number of GPT layers. - Mamba SSM ('M') layers have no GPT equivalent and are initialized from scratch using standard Mamba initialization. +How MoE / Expert Parallelism (EP) works through the converter: + - GPTModel can run with MoE (Mixtral-style: every layer has a router and + N local experts). State-dict keys live under + `decoder.layers..mlp.{router,experts,shared_experts}.*`. + - Hybrid 'E' layers use the same key naming, so MoE tensors round-trip + verbatim — no expert collapsing, no router init, no per-expert work. + - EP-sharded checkpoints load through DCP transparently because each + tensor's `global_shape` is in the metadata, regardless of how many + EP / TP / PP / FSDP ranks wrote it. + - Use a pattern like 'M*EM*EM*E' to pair Mamba/Attn/MoE-MLP per stage. + What happens to SSM parameters: gpt-to-mamba: SSM layers (M) are initialized from scratch: - A_log: log(uniform(1, 16)) @@ -124,13 +137,16 @@ # Layer symbols GPTModel can emit or absorb: # '*' : standard self-attention layer (MHA / GQA / MQA) -# '-' : standard (optionally gated) MLP layer +# '-' : standard (optionally gated) dense MLP layer +# 'E' : MoE MLP layer. Both sides keep the keys under +# decoder.layers..mlp.{router,experts,shared_experts}.* so MoE +# tensors round-trip verbatim (see convert_gpt_to_mamba and +# convert_mamba_to_gpt — `is_mlp_param` already matches `mlp.*`). # SSM ('M') has no GPT equivalent and is initialized from scratch / # discarded (see convert_gpt_to_mamba / convert_mamba_to_gpt). -# Everything else is an architecture feature GPTModel does NOT -# produce: GDN ('G'), DS-attention ('D'), MoE ('E'). If the hybrid -# model contains any of those, we cannot faithfully translate. -GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-'} +# 'G' (GDN) and 'D' (DS-attention) are not currently mapped — they would +# need separate key-naming work. Reject for now. +GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-', 'E'} def parse_hybrid_layer_pattern(pattern): @@ -156,30 +172,38 @@ def parse_hybrid_layer_pattern(pattern): return layer_types +# Pattern symbols that pair to a GPT-side MLP block. Both dense ('-') and MoE +# ('E') keep their state-dict keys under `decoder.layers..mlp.*`, so they +# round-trip identically. The pattern uniformity check +# (validate_pattern_gpt_compatible) ensures '-' and 'E' don't appear together, +# which would mean GPT layers aren't uniform. +_MLP_BEARING_SYMBOLS = ('-', 'E') + + def build_layer_index_mapping(layer_types, direction): - """Build mapping between GPT layer indices and Mamba layer indices. + """Build mapping between GPT layer indices and hybrid-model layer indices. For gpt-to-mamba: - Returns (attn_map, mlp_map) where: - - attn_map[gpt_layer_i] = mamba_layer_j (j is the index of the i-th '*') - - mlp_map[gpt_layer_i] = mamba_layer_k (k is the index of the i-th '-') + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[gpt_layer_i] = hybrid_layer_j (j is the index of the i-th '*') + - mlp_map[gpt_layer_i] = hybrid_layer_k (k is the index of the i-th + MLP-bearing position; either '-' or 'E') For mamba-to-gpt: - Returns (attn_map, mlp_map) where: - - attn_map[mamba_attn_idx] = gpt_layer_i - - mlp_map[mamba_mlp_idx] = gpt_layer_i + Returns (attn_map, mlp_map, ssm_indices) where: + - attn_map[hybrid_attn_idx] = gpt_layer_i + - mlp_map[hybrid_mlp_idx] = gpt_layer_i """ attn_indices = [i for i, t in enumerate(layer_types) if t == '*'] - mlp_indices = [i for i, t in enumerate(layer_types) if t == '-'] + mlp_indices = [i for i, t in enumerate(layer_types) if t in _MLP_BEARING_SYMBOLS] ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] if direction == 'gpt-to-mamba': if len(attn_indices) != len(mlp_indices): raise ValueError( f"For gpt-to-mamba, the number of attention layers ({len(attn_indices)}) " - f"must equal the number of MLP layers ({len(mlp_indices)}) in the pattern." + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." ) - # attn_map: gpt_layer_i -> mamba_layer_j attn_map = {i: attn_indices[i] for i in range(len(attn_indices))} mlp_map = {i: mlp_indices[i] for i in range(len(mlp_indices))} return attn_map, mlp_map, ssm_indices @@ -188,9 +212,8 @@ def build_layer_index_mapping(layer_types, direction): if len(attn_indices) != len(mlp_indices): raise ValueError( f"For mamba-to-gpt, the number of attention layers ({len(attn_indices)}) " - f"must equal the number of MLP layers ({len(mlp_indices)}) in the pattern." + f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." ) - # attn_map: mamba_layer_idx -> gpt_layer_i attn_map = {attn_indices[i]: i for i in range(len(attn_indices))} mlp_map = {mlp_indices[i]: i for i in range(len(mlp_indices))} return attn_map, mlp_map, ssm_indices @@ -203,31 +226,34 @@ def build_layer_index_mapping(layer_types, direction): # GPT compatibility whitelist # --------------------------------------------------------------------------- # -# GPTModel is a strict homogeneous transformer: every decoder layer is a -# (self-attention + MLP) pair with standard linear_qkv / linear_fc1 / -# linear_fc2 state-dict naming. The hybrid <-> GPT converter is only safe -# when the hybrid side agrees with that shape. The helpers below act as a -# safeguard: they reject any hybrid layout or source-args combination that -# would silently produce a broken checkpoint. +# GPTModel is a *uniform* transformer: every decoder layer is the same kind. +# It can run with dense MLP or MoE MLP — both keep keys under +# decoder.layers..mlp.* — so MoE checkpoints round-trip through the +# converter as long as both sides share the same kind on every layer. +# The helpers below reject any hybrid layout or source-args combination that +# violates uniformity (and would therefore silently produce a corrupt target). # # Pattern-level rules (checked on the parsed hybrid_layer_pattern): -# * only 'M', '*', '-' are allowed (no 'G' GDN, no 'D' DS-attention, -# no 'E' MoE) -# * '*' count must equal '-' count (one-to-one GPT attention<->MLP pairing) +# * only 'M', '*', '-', 'E' are allowed (no 'G' GDN, no 'D' DS-attention) +# * MLP-bearing symbols must be uniform: '-' and 'E' cannot both appear +# (that would imply GPT has both dense and MoE layers — heterogeneous) +# * '*' count must equal '-'+'E' count (one-to-one GPT attn<->MLP pairing) # # Args-level rules (checked against the training args stored in the source -# checkpoint): reject anything that would make GPTModel's layer shape -# inapplicable to either side: -# * num_moe_experts (MoE routing, different keys) -# * moe_shared_expert_intermediate_size (shared-expert branch) -# * moe_layer_freq (MoE-every-N layer insertion) +# checkpoint): reject anything that makes GPT layers heterogeneous OR uses +# attention variants the converter doesn't currently key-translate: +# * moe_layer_freq != 1 (interleaved dense/MoE layers) # * experimental_attention_variant (gated_delta_net, dsa, ...) -# * linear_attention_freq (linear-attention layers) -# * heterogeneous_block_specs / heterogeneous_layers_config_path +# * linear_attention_freq (interleaved linear-attention) +# * heterogeneous_block_specs / heterogeneous_layers_config_* # (Nemotron-NAS per-layer specs) -# * multi_latent_attention (MLA: different QKV layout) +# * multi_latent_attention (MLA: different QKV key layout) # * mtp_num_layers (Multi-Token Prediction head) # +# Notably NOT rejected (they round-trip via mlp.* / self_attention.* keys): +# * num_moe_experts (MoE on every layer) +# * moe_shared_expert_intermediate_size (shared experts on every layer) +# # All rejected configurations raise ValueError early, before any tensors # are touched. @@ -235,26 +261,19 @@ def build_layer_index_mapping(layer_types, direction): # Predicates are applied with getattr(args, field, None); missing fields # are treated as "absent" and pass. _GPT_COMPAT_REJECT_FIELDS = ( - ( - 'num_moe_experts', - lambda v: v is not None and v > 0, - 'MoE routing (num_moe_experts)', - ), - ( - 'moe_shared_expert_intermediate_size', - lambda v: v is not None and v > 0, - 'MoE shared experts (moe_shared_expert_intermediate_size)', - ), ( 'moe_layer_freq', - # moe_layer_freq is None or 1 for non-MoE models; a list or a value - # > 1 means interleaved MoE layers. + # moe_layer_freq is None or 1 when every layer is the same kind (all + # dense or all MoE). A value > 1 or a list with mixed entries means + # GPT has interleaved dense/MoE layers — heterogeneous, can't pair + # one-to-one with a uniform hybrid pattern. lambda v: ( v is not None and not (isinstance(v, int) and v == 1) and not (isinstance(v, str) and v.strip() in ('', '1')) + and not (isinstance(v, (list, tuple)) and all(x == 1 for x in v)) ), - 'interleaved MoE layers (moe_layer_freq)', + 'interleaved dense/MoE layers (moe_layer_freq)', ), ( 'experimental_attention_variant', @@ -302,29 +321,40 @@ def validate_pattern_gpt_compatible(layer_types, direction): direction: 'gpt-to-mamba' or 'mamba-to-gpt' (for error messages). Rules: - * Allowed symbols are M / * / - only. G, D, E are rejected because - they denote layer kinds (GDN, DS-attention, MoE) that GPTModel - cannot emit or absorb. - * The number of '*' and '-' layers must match: every GPT layer pairs - one attention with one MLP. + * Allowed symbols: 'M', '*', '-', 'E'. 'G' (GDN) and 'D' (DS-attention) + are not currently key-translated. + * MLP-bearing symbols must be uniform: '-' (dense) and 'E' (MoE) cannot + both appear, because that would imply GPT has both dense and MoE + layers — the GPT side must be uniform. + * The number of attention positions must equal the number of + MLP-bearing positions: every GPT layer pairs one attention with one + MLP/MoE. """ bad = sorted({c for c in layer_types if c not in GPT_COMPATIBLE_PATTERN_SYMBOLS}) if bad: raise ValueError( f"Hybrid layer pattern contains symbols {bad} that are not " f"GPT-compatible (allowed: {sorted(GPT_COMPATIBLE_PATTERN_SYMBOLS)}). " - f"GPTModel only supports standard attention ('*') and MLP ('-') " - f"layers; 'G' (GDN), 'D' (DS-attention), and 'E' (MoE) have no " - f"GPT equivalent and cannot be {direction}-converted." + f"'G' (GDN) and 'D' (DS-attention) are not currently key-translated " + f"and cannot be {direction}-converted." + ) + + mlp_kinds_present = {t for t in layer_types if t in _MLP_BEARING_SYMBOLS} + if len(mlp_kinds_present) > 1: + raise ValueError( + f"Hybrid layer pattern mixes '-' (dense MLP) and 'E' (MoE) " + f"positions. GPTModel layers must be uniform — either all GPT " + f"layers are dense MLP, or all are MoE. Use only one of '-' or " + f"'E' in the pattern." ) n_attn = sum(1 for t in layer_types if t == '*') - n_mlp = sum(1 for t in layer_types if t == '-') + n_mlp = sum(1 for t in layer_types if t in _MLP_BEARING_SYMBOLS) if n_attn != n_mlp: raise ValueError( f"GPT-compatible hybrid patterns must pair every attention layer " - f"('*') with one MLP layer ('-'). Got {n_attn} '*' and {n_mlp} '-' " - f"in the pattern." + f"('*') with one MLP/MoE layer ('-' or 'E'). Got {n_attn} '*' " + f"and {n_mlp} MLP-bearing layers in the pattern." ) From c0a9682fe41e3472f912012e4dde2ca4f7af2a6e Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Tue, 28 Apr 2026 18:30:15 -0700 Subject: [PATCH 05/10] Auto format the code Signed-off-by: guihong-nv --- .../checkpoint/test_distributed_round_trip.py | 84 +++++++++---------- .../checkpoint/test_gpt_mamba_conversion.py | 11 +-- .../test_gpt_mamba_conversion_parallelism.py | 3 +- 3 files changed, 45 insertions(+), 53 deletions(-) diff --git a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py index 1b1f2c3080d..c1a6146bc8f 100644 --- a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py +++ b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py @@ -76,16 +76,10 @@ def _build_state_dict(num_layers, hidden_size, num_moe_experts, vocab_size, dtyp sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype) if num_moe_experts is None: - sd[p + 'mlp.linear_fc1.weight'] = torch.randn( - 4 * hidden_size, hidden_size, dtype=dtype - ) - sd[p + 'mlp.linear_fc2.weight'] = torch.randn( - hidden_size, 4 * hidden_size, dtype=dtype - ) + sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype) + sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype) else: - sd[p + 'mlp.router.weight'] = torch.randn( - num_moe_experts, hidden_size, dtype=dtype - ) + sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype) for j in range(num_moe_experts): ep = p + f'mlp.experts.local_experts.{j}.' sd[ep + 'linear_fc1.weight'] = torch.randn( @@ -184,24 +178,34 @@ def main(): parser.add_argument('--fsdp', type=int, default=1) parser.add_argument('--ep', type=int, default=1) parser.add_argument('--label', type=str, required=True) - parser.add_argument('--output-root', type=str, required=True, - help='Shared-filesystem path where this scenario writes its ' - 'checkpoints (must be visible from every rank).') + parser.add_argument( + '--output-root', + type=str, + required=True, + help='Shared-filesystem path where this scenario writes its ' + 'checkpoints (must be visible from every rank).', + ) parser.add_argument('--num-layers', type=int, default=3) parser.add_argument('--hidden-size', type=int, default=64) parser.add_argument('--vocab-size', type=int, default=512) - parser.add_argument('--num-moe-experts', type=int, default=None, - help='If set, use the MoE GPT state-dict layout ' - '(mlp.router + mlp.experts.local_experts.*).') - parser.add_argument('--pattern', type=str, default=None, - help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for ' - 'dense or "M*EM*EM*E" when --num-moe-experts is set.') + parser.add_argument( + '--num-moe-experts', + type=int, + default=None, + help='If set, use the MoE GPT state-dict layout ' + '(mlp.router + mlp.experts.local_experts.*).', + ) + parser.add_argument( + '--pattern', + type=str, + default=None, + help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for ' + 'dense or "M*EM*EM*E" when --num-moe-experts is set.', + ) args = parser.parse_args() expected_world = args.tp * args.pp * args.fsdp * args.ep - pattern = args.pattern or ( - 'M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-' - ) + pattern = args.pattern or ('M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-') # Shared init file lives on the same shared FS we use for checkpoints, so # all ranks on all nodes see the same path. @@ -217,12 +221,12 @@ def main(): sys.exit(2) # Lazy imports after sys.path is set. - from gpt_mamba_conversion import main as conversion_main from dist_checkpoint_io import ( load_dist_checkpoint_full, save_dist_checkpoint_full, write_latest_iteration_marker, ) + from gpt_mamba_conversion import main as conversion_main if rank == 0: _log( @@ -236,12 +240,9 @@ def main(): # Each rank builds the same full state dict — DCP de-duplicates writes # across ranks via its writer planner. state_dict = _build_state_dict( - args.num_layers, args.hidden_size, args.num_moe_experts, - args.vocab_size, torch.float32, - ) - ckpt_args = _build_ckpt_args( - args.num_layers, args.hidden_size, args.num_moe_experts, + args.num_layers, args.hidden_size, args.num_moe_experts, args.vocab_size, torch.float32 ) + ckpt_args = _build_ckpt_args(args.num_layers, args.hidden_size, args.num_moe_experts) scratch = os.path.join(args.output_root, args.label) src_dir = os.path.join(scratch, 'gpt_src') @@ -262,14 +263,9 @@ def main(): # load_dist_checkpoint_full + save_dist_checkpoint_full once each. # If a rank exits early its gloo socket closes and rank 0's reduce_scatter # dies with "Connection closed by peer". - common_state = { - 'args': copy.deepcopy(ckpt_args), - 'checkpoint_version': 3.0, - 'iteration': 100, - } + common_state = {'args': copy.deepcopy(ckpt_args), 'checkpoint_version': 3.0, 'iteration': 100} save_dist_checkpoint_full( - state_dict, common_state, iter_subdir, - model_prefix='model.', backend='torch_dist', + state_dict, common_state, iter_subdir, model_prefix='model.', backend='torch_dist' ) if rank == 0: write_latest_iteration_marker(iter_subdir, 100) @@ -296,17 +292,17 @@ def main(): sys.stdout = open(os.devnull, 'w') t0 = time.time() - conversion_main(argparse.Namespace( - direction='gpt-to-mamba', - load_dir=src_dir, save_dir=mid_dir, - **common_kwargs, - )) + conversion_main( + argparse.Namespace( + direction='gpt-to-mamba', load_dir=src_dir, save_dir=mid_dir, **common_kwargs + ) + ) dist.barrier() - conversion_main(argparse.Namespace( - direction='mamba-to-gpt', - load_dir=mid_dir, save_dir=dst_dir, - **common_kwargs, - )) + conversion_main( + argparse.Namespace( + direction='mamba-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs + ) + ) dist.barrier() dt = time.time() - t0 diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py index ebecb197da6..3d10483d383 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py @@ -739,16 +739,12 @@ def test_accepts_missing_optional_fields(self): def test_accepts_moe_args(self): # MoE keys live under decoder.layers..mlp.* and round-trip as-is. - validate_source_args_gpt_compatible( - self._ok_args(num_moe_experts=8), 'gpt-to-mamba' - ) + validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-mamba') def test_accepts_shared_expert_args(self): # Shared experts also live under mlp.shared_experts.* and round-trip. validate_source_args_gpt_compatible( - self._ok_args( - num_moe_experts=8, moe_shared_expert_intermediate_size=4096 - ), + self._ok_args(num_moe_experts=8, moe_shared_expert_intermediate_size=4096), 'gpt-to-mamba', ) @@ -806,8 +802,7 @@ def test_reports_multiple_reasons(self): # Both heterogeneous moe_layer_freq and MLA set — both should be reported. with pytest.raises(ValueError) as exc: validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), - 'gpt-to-mamba', + self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), 'gpt-to-mamba' ) msg = str(exc.value) assert 'interleaved' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index 1fb576a7bd4..3b67a9582bd 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -216,7 +216,8 @@ def _run_scenario( moe_shared_expert_intermediate_size=shared_expert_size, ) gpt_sd = make_gpt_state_dict( - num_layers, hidden_size, + num_layers, + hidden_size, num_moe_experts=num_moe_experts, shared_expert_size=shared_expert_size, ) From 62786624ff554004176d8ea7d83a7d4558da1cb8 Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Tue, 28 Apr 2026 18:49:00 -0700 Subject: [PATCH 06/10] Add missing copyright header Signed-off-by: guihong-nv --- hybrid_conversion.py | 398 ------------------ tests/unit_tests/tools/__init__.py | 1 + tests/unit_tests/tools/checkpoint/__init__.py | 1 + tools/checkpoint/gpt_mamba_conversion.py | 2 +- 4 files changed, 3 insertions(+), 399 deletions(-) delete mode 100644 hybrid_conversion.py diff --git a/hybrid_conversion.py b/hybrid_conversion.py deleted file mode 100644 index da384e31ced..00000000000 --- a/hybrid_conversion.py +++ /dev/null @@ -1,398 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - -# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion. -# This functionality should be integrated with the megatron core checkpoint loader/saver. - - -import copy -import os -import re -import shutil -from collections import OrderedDict - -import torch -import argparse - - -tp_split_dim = { - 'word_embeddings.weight': 0, - 'norm.weight': -1, - 'final_norm.weight': -1, - 'output_layer.weight': 0, - # mamba1/2 - 'A_log': 0, - 'D': 0, - 'dt_bias': 0, - 'in_proj.weight': 0, - 'conv1d.weight': 0, - 'conv1d.bias': 0, - 'x_proj.weight': 1, - 'dt_proj.weight': 0, - 'dt_proj.bias': 0, - 'out_proj.weight': 1, - 'mixer.norm.weight': 0, - # mlp - 'linear_fc1.layer_norm_weight': -1, - 'linear_fc1.weight': 0, - 'linear_fc2.weight': 1, - # attention - 'self_attention.linear_proj.weight': 1, - 'self_attention.linear_qkv.layer_norm_weight': -1, - 'self_attention.linear_qkv.weight': 0, -} - - -def get_split_dim(tensor_name): - # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish - if 'norm.weight' in tensor_name: - if 'mixer.norm.weight' in tensor_name: - return tp_split_dim['mixer.norm.weight'] - else: - return tp_split_dim['norm.weight'] - - for key in tp_split_dim.keys(): - if key in tensor_name: - return tp_split_dim[key] - raise Exception("Unknown tensor name {}".format(tensor_name)) - - -def combine_tp_tensors(params, key, dim, tensors): - tp_size = len(tensors) - - if 'mixer.in_proj.weight' in key and params.mamba_version == 1: - xs = []; zs = [] - for tensor in tensors: - x, z = torch.split(tensor, [params.mamba_d_inner//tp_size, - params.mamba_d_inner//tp_size], dim=dim) - xs.append(x); zs.append(z) - return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) - - elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: - xs = []; zs = []; Bs = []; Cs = []; dts = [] - for tensor in tensors: - x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size, - params.mamba_d_inner // tp_size, - (params.mamba2_n_groups // tp_size) * args.mamba_d_state, - (params.mamba2_n_groups // tp_size) * args.mamba_d_state, - params.mamba2_n_heads // tp_size], dim=dim) - xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt) - - for ii in range(len(Bs)): - Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1])) - Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1])) - B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) - x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim) - - return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) - - elif 'mixer.conv1d' in key and params.mamba_version == 2: - xs = []; Bs = []; Cs = [] - for tensor in tensors: - x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state, - (params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim) - xs.append(x); Bs.append(B); Cs.append(C) - - for ii in range(len(Bs)): - if 'weight' in key: - Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1])) - Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1])) - elif 'bias' in key: - Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state)) - Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state)) - else: - raise Exception("Unknown key") - B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim) - x = torch.cat(xs, dim=dim) - - return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) - - else: - return torch.cat(tensors, dim=dim) - - -def split_tensor_for_tp(params, key, dim, tensor): - tp_size = params.target_tp_size - tensor_sliced = [] - - if 'mixer.in_proj.weight' in key and params.mamba_version == 1: - x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim) - x_sliced = torch.chunk(x, tp_size, dim=dim) - z_sliced = torch.chunk(z, tp_size, dim=dim) - for (x, z) in zip(x_sliced, z_sliced): - tensor_sliced.append(torch.cat((x, z), dim=dim)) - - elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: - x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_heads], dim=dim) - B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1])) - C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1])) - - B_sliced = torch.chunk(B, tp_size, dim=dim) - C_sliced = torch.chunk(C, tp_size, dim=dim) - x_sliced = torch.chunk(x, tp_size, dim=dim) - z_sliced = torch.chunk(z, tp_size, dim=dim) - dt_sliced = torch.chunk(dt, tp_size, dim=dim) - - tensor_sliced = [] - for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced): - tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim)) - - elif 'mixer.conv1d' in key and params.mamba_version == 2: - x, B, C = torch.split(tensor, [params.mamba_d_inner, - params.mamba2_n_groups * params.mamba_d_state, - params.mamba2_n_groups * params.mamba_d_state], dim=dim) - if 'weight' in key: - B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1])) - C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1])) - elif 'bias' in key: - B = torch.reshape(B, (-1, params.mamba_d_state)) - C = torch.reshape(C, (-1, params.mamba_d_state)) - else: - raise Exception("Unknown key") - - B_sliced = torch.chunk(B, tp_size, dim=dim) - C_sliced = torch.chunk(C, tp_size, dim=dim) - x_sliced = torch.chunk(x, tp_size, dim=dim) - - tensor_sliced = [] - for (x, B, C) in zip(x_sliced, B_sliced, C_sliced): - tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim)) - - else: - tensor_sliced = torch.chunk(tensor, tp_size, dim=dim) - - return tensor_sliced - - -def finalize_checkpoint(sample_model, model, params, verbose=False): - # make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model') - reset_iterations = params.reset_iterations - - # checkpoint 'args' - model['args'] = copy.deepcopy(sample_model['args']) - model['args'].tensor_model_parallel_size = params.target_tp_size - model['args'].pipeline_model_parallel_size = params.target_pp_size - if reset_iterations: - model['args'].iteration = 0 - model['args'].consumed_valid_samples = 0 - model['args'].consumed_train_samples = 0 - model['args'].train_iters = 0 - model['args'].train_samples = 0 - - # checkpoint 'checkpoint_version' - model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version']) - - # checkpoint 'iteration' - model['iteration'] = copy.deepcopy(sample_model['iteration']) - if reset_iterations: - model['iteration'] = 0 - - # checkpoint 'optimizer' - # ignore - - # checkpoint 'opt_param_scheduler' - if 'opt_param_scheduler' in sample_model.keys(): - model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler']) - - # checkpoint 'rng_state' - model['rng_state'] = copy.deepcopy(sample_model['rng_state']) - - # report on argument difference - if verbose: - original_args = sample_model['args'].__dict__ - final_args = model['args'].__dict__ - for key in original_args: - if key in final_args: - if final_args[key] != original_args[key]: - print("KEY MISMATCH: {}".format(key)) - print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key])) - else: - print("KEY MISSING from final: {}, value {}".format(key, original_args[key])) - print("") - for key in final_args: - if key not in original_args: - print("KEY ADDED to final: {}, value {}".format(key, final_args[key])) - - return model - - -def main(args): - print("\n====RUNNING CHECKPOINT CONVERSION====\n") - - args.mamba_d_inner = args.d_model * 2 - args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim - - # get the latest iteration - tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt') - with open(tracker_filename, 'r') as f: - metastring = f.read().strip() - try: - iteration = int(metastring) - except ValueError: - raise Exception("Invalid iteration found in latest_checkpointed_iteration.txt!") - out_iteration = iteration if not args.reset_iterations else 0 - - # get model directory and model parallel ranks - input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration)) - input_sub_models = os.listdir(input_model_dir) - # input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group())) - - # load one of the model parallel ranks to get arguments - sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt") - sample_model = torch.load(sample_model_file) - print(f"Sample model {sample_model_file} is loaded.\n") - - # input tensor and pipeline parallel size - input_tp_rank = sample_model['args'].tensor_model_parallel_size - input_pp_rank = sample_model['args'].pipeline_model_parallel_size - num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank - - # construct full model - full_model = OrderedDict() - for pp in range(input_pp_rank): - print("[INFO] Processing input pipeline rank {}".format(pp)) - tp_models = [] - for tp in range(input_tp_rank): - dir_name = "mp_rank_{:02d}".format(tp) - if input_pp_rank > 1: - dir_name += "_{:03d}".format(pp) - model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt") - - tp_models.append(torch.load(model_file)) - print(f"Model {model_file} is loaded.") - - if input_tp_rank > 1: - combined_tp_model = OrderedDict() - for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()): - if "_extra_state" in key: - combined_tp_model[key] = original_tensor - continue - - split_dim = get_split_dim(key) - original_shape = list(original_tensor.shape) - combined_shape = copy.deepcopy(original_shape) - combined_shape[split_dim] *= input_tp_rank - # print("{}, {}, {}".format(ii, key, split_dim)) - - if split_dim != -1: - # slice together model - # print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape)) - combined_tensor = combine_tp_tensors(args, key, split_dim, - [tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)]) - combined_tp_model[key] = combined_tensor - else: - # copy model - combined_tp_model[key] = original_tensor - else: - combined_tp_model = tp_models[0]['model'] - # print("Combined tp model: {}".format(combined_tp_model.keys())) - - for ii, (key, original_tensor) in enumerate(combined_tp_model.items()): - try: - layer_num = int(re.findall(r'\d+', key)[0]) - new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1) - except: - new_key = key - full_model[new_key] = original_tensor - # print("Combined model: {}".format(full_model.keys())) - print("\n[INFO] Loaded combined model\n") - - # sort by layer - # full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1])) - - # create new split model - pp_offset = 0 - num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size - - for pp in range(args.target_pp_size): - print("[INFO] Processing output pipeline rank {}".format(pp)) - tp_models = [] - for ii in range(args.target_tp_size): - tp_models.append({'model': OrderedDict()}) - - for ii, (key, original_tensor) in enumerate(full_model.items()): - try: - layer_num = int(re.findall(r'\d+', key)[0]) - if layer_num >= num_layers_per_pipeline_rank * (pp+1): - break - new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1) - except Exception: - new_key = key - - if ii < pp_offset: - continue - else: - pp_offset += 1 - - if "_extra_state" in new_key: - # copy - for jj in range(args.target_tp_size): - tp_models[jj]['model'][new_key] = original_tensor - continue - - split_dim = get_split_dim(new_key) - original_shape = list(original_tensor.shape) - v0 = original_shape[split_dim] - split_size = v0 // args.target_tp_size - split_shape = copy.deepcopy(original_shape) - split_shape[split_dim] = split_size - # print("{}, {}, {}".format(ii, new_key, split_dim)) - - if split_dim != -1: - # split model - # print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape)) - tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor) - for jj in range(args.target_tp_size): - tp_models[jj]['model'][new_key] = tensor_sliced[jj] - else: - # copy model - for jj in range(args.target_tp_size): - tp_models[jj]['model'][new_key] = original_tensor - # print(tp_models[0]['model'].keys()) - - for tp in range(args.target_tp_size): - dir_name = "mp_rank_{:02d}".format(tp) - if args.target_pp_size > 1: - dir_name += "_{:03d}".format(pp) - - model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False) - - save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name) - os.makedirs(save_dir, exist_ok=True) - model_file = os.path.join(save_dir, "model_optim_rng.pt") - torch.save(model, model_file) - print(f"Model {model_file} is saved.") - - # shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')) - tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt') - with open(tracker_filename, 'w') as f: - f.write(str(out_iteration)) - - -if __name__ == "__main__": - # example run command: - # python hybrid_conversion.py - # --load-dir mamba2-840m-test/checkpoints/ - # --save-dir mamba2-840m-test-conversion/checkpoints/ - # --target-pp-size 1 - # --target-tp-size 1 - - parser = argparse.ArgumentParser() - parser.add_argument('--load-dir', type=str) - parser.add_argument('--save-dir', type=str) - parser.add_argument('--target-tp-size', type=int, default=1) - parser.add_argument('--target-pp-size', type=int, default=1) - parser.add_argument('--reset-iterations', action='store_true') - - parser.add_argument('--d-model', type=int, default=4096) - parser.add_argument('--mamba-version', type=int, default=2) - parser.add_argument('--mamba-d-state', type=int, default=128) - parser.add_argument('--mamba2-n-groups', type=int, default=8) - parser.add_argument('--mamba2-head-dim', type=int, default=64) - - args = parser.parse_args() - - main(args) diff --git a/tests/unit_tests/tools/__init__.py b/tests/unit_tests/tools/__init__.py index e69de29bb2d..b5dff7b5663 100644 --- a/tests/unit_tests/tools/__init__.py +++ b/tests/unit_tests/tools/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tests/unit_tests/tools/checkpoint/__init__.py b/tests/unit_tests/tools/checkpoint/__init__.py index e69de29bb2d..b5dff7b5663 100644 --- a/tests/unit_tests/tools/checkpoint/__init__.py +++ b/tests/unit_tests/tools/checkpoint/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/tools/checkpoint/gpt_mamba_conversion.py b/tools/checkpoint/gpt_mamba_conversion.py index a65a14a76e7..951ce1abec6 100644 --- a/tools/checkpoint/gpt_mamba_conversion.py +++ b/tools/checkpoint/gpt_mamba_conversion.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """ GPT <-> Mamba Checkpoint Conversion Tool From 763bdb6659bc3c38246d7983fe891d8ea6ddf370 Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Wed, 29 Apr 2026 17:32:36 -0700 Subject: [PATCH 07/10] Add the multi-node unit test bypass logic Signed-off-by: guihong-nv --- .../test_gpt_mamba_conversion_parallelism.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index 3b67a9582bd..fd201a0d215 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -24,7 +24,9 @@ from collections import OrderedDict from types import SimpleNamespace +import pytest import torch +import torch.distributed as dist # Make the conversion tool importable under both `python ` and `pytest`. _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -34,6 +36,34 @@ from gpt_mamba_conversion import main as conversion_main + +# These scenarios are SYNTHETIC and single-rank by design: each one writes a +# tiny synthetic DCP checkpoint and round-trips it through the converter on +# rank 0. They share the default torch.distributed process group with whatever +# harness launched pytest. When that default PG is multi-rank (e.g. Megatron's +# CI/CD initialises NCCL with world_size>1 before pytest collection), the +# dcp.save/dcp.load collectives stall: each rank has its own +# tempfile.mkdtemp() path and its own torch.randn() tensors, so the metadata +# coordination across ranks never converges and the NCCL watchdog kills the +# job after 10 minutes (see ProcessGroupNCCL ALLGATHER timeout). +# +# Multi-rank coverage lives in test_distributed_round_trip.py, which uses a +# fresh single-rank gloo subgroup per scenario via SLURM/srun in +# run_slurm_ckpt_convert_tests.sh. Skip these synthetic tests whenever the +# default PG is already multi-rank. +@pytest.fixture(autouse=True) +def _skip_when_multi_rank_pg(): + if ( + dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ): + pytest.skip( + "Synthetic single-rank tests skipped under a multi-rank default " + "process group; multi-rank coverage is in " + "test_distributed_round_trip.py." + ) + # --------------------------------------------------------------------------- # Synthetic-checkpoint helpers # --------------------------------------------------------------------------- From 2f9c46d38d6b090785c2af278b8532f4e816ab89 Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Wed, 29 Apr 2026 18:04:51 -0700 Subject: [PATCH 08/10] Format the code to meet the linting requirement Signed-off-by: guihong-nv --- .../checkpoint/test_gpt_mamba_conversion_parallelism.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py index fd201a0d215..4c908330610 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py @@ -53,17 +53,14 @@ # default PG is already multi-rank. @pytest.fixture(autouse=True) def _skip_when_multi_rank_pg(): - if ( - dist.is_available() - and dist.is_initialized() - and dist.get_world_size() > 1 - ): + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: pytest.skip( "Synthetic single-rank tests skipped under a multi-rank default " "process group; multi-rank coverage is in " "test_distributed_round_trip.py." ) + # --------------------------------------------------------------------------- # Synthetic-checkpoint helpers # --------------------------------------------------------------------------- From e2e635701ac9b3b1c3e113cb0fe3d1585190ab2e Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Thu, 30 Apr 2026 14:04:34 -0700 Subject: [PATCH 09/10] Rename mamba_model into hybrid_model --- .../checkpoint/test_distributed_round_trip.py | 12 +- ...rsion.py => test_gpt_hybrid_conversion.py} | 118 +++++++++--------- ...test_gpt_hybrid_conversion_parallelism.py} | 24 ++-- ...conversion.py => gpt_hybrid_conversion.py} | 98 +++++++-------- 4 files changed, 126 insertions(+), 126 deletions(-) rename tests/unit_tests/tools/checkpoint/{test_gpt_mamba_conversion.py => test_gpt_hybrid_conversion.py} (91%) rename tests/unit_tests/tools/checkpoint/{test_gpt_mamba_conversion_parallelism.py => test_gpt_hybrid_conversion_parallelism.py} (94%) rename tools/checkpoint/{gpt_mamba_conversion.py => gpt_hybrid_conversion.py} (92%) diff --git a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py index c1a6146bc8f..4e080f8f0b3 100644 --- a/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py +++ b/tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py @@ -1,10 +1,10 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """ -Multi-rank distributed round-trip test for gpt_mamba_conversion. +Multi-rank distributed round-trip test for gpt_hybrid_conversion. Each rank participates in a multi-rank DCP save of a synthetic GPT (or MoE -GPT) state dict; rank 0 then runs the converter and verifies the GPT->Mamba-> +GPT) state dict; rank 0 then runs the converter and verifies the GPT->Hybrid-> GPT round-trip exactly. This test is meant to be launched under SLURM/srun (or torchrun) with @@ -226,7 +226,7 @@ def main(): save_dist_checkpoint_full, write_latest_iteration_marker, ) - from gpt_mamba_conversion import main as conversion_main + from gpt_hybrid_conversion import main as conversion_main if rank == 0: _log( @@ -246,7 +246,7 @@ def main(): scratch = os.path.join(args.output_root, args.label) src_dir = os.path.join(scratch, 'gpt_src') - mid_dir = os.path.join(scratch, 'mamba_mid') + mid_dir = os.path.join(scratch, 'hybrid_mid') dst_dir = os.path.join(scratch, 'gpt_dst') iter_subdir = os.path.join(src_dir, 'iter_0000100') @@ -294,13 +294,13 @@ def main(): t0 = time.time() conversion_main( argparse.Namespace( - direction='gpt-to-mamba', load_dir=src_dir, save_dir=mid_dir, **common_kwargs + direction='gpt-to-hybrid', load_dir=src_dir, save_dir=mid_dir, **common_kwargs ) ) dist.barrier() conversion_main( argparse.Namespace( - direction='mamba-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs + direction='hybrid-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs ) ) dist.barrier() diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py similarity index 91% rename from tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py rename to tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py index 3d10483d383..79ec3e5ab36 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion.py @@ -1,15 +1,15 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """ -Unit tests for the GPT <-> Mamba checkpoint conversion tool. +Unit tests for the GPT <-> Hybrid checkpoint conversion tool. These tests validate: - Hybrid layer pattern parsing -- Layer index mapping (GPT <-> Mamba) +- Layer index mapping (GPT <-> Hybrid) - State dict key renaming (final_layernorm <-> final_norm) - Shared parameter copying (embeddings, output_layer) - SSM parameter initialization shapes and dtypes -- Round-trip conversion: GPT -> Mamba -> GPT preserves attention and MLP weights +- Round-trip conversion: GPT -> Hybrid -> GPT preserves attention and MLP weights - TP split dimension lookup """ @@ -28,10 +28,10 @@ 0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', 'tools', 'checkpoint') ) -from gpt_mamba_conversion import ( +from gpt_hybrid_conversion import ( build_layer_index_mapping, - convert_gpt_to_mamba, - convert_mamba_to_gpt, + convert_gpt_to_hybrid, + convert_hybrid_to_gpt, get_layer_num_from_key, initialize_ssm_layer_params, is_attention_param, @@ -90,18 +90,18 @@ def test_invalid_symbol(self): class TestLayerIndexMapping: - def test_gpt_to_mamba_basic(self): + def test_gpt_to_hybrid_basic(self): # Pattern: M*-M*- (2 attn at pos 1,4; 2 MLP at pos 2,5) layer_types = ['M', '*', '-', 'M', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-mamba') + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') # 2 GPT layers -> attn at [1,4], MLP at [2,5] assert attn_map == {0: 1, 1: 4} assert mlp_map == {0: 2, 1: 5} assert ssm_indices == [0, 3] - def test_mamba_to_gpt_basic(self): + def test_hybrid_to_gpt_basic(self): layer_types = ['M', '*', '-', 'M', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'mamba-to-gpt') + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'hybrid-to-gpt') # attn at mamba layer 1 -> GPT layer 0, attn at 4 -> GPT layer 1 assert attn_map == {1: 0, 4: 1} assert mlp_map == {2: 0, 5: 1} @@ -109,7 +109,7 @@ def test_mamba_to_gpt_basic(self): def test_alternating_pattern(self): layer_types = ['*', '-', '*', '-', '*', '-'] - attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-mamba') + attn_map, mlp_map, ssm_indices = build_layer_index_mapping(layer_types, 'gpt-to-hybrid') assert attn_map == {0: 0, 1: 2, 2: 4} assert mlp_map == {0: 1, 1: 3, 2: 5} assert ssm_indices == [] @@ -118,7 +118,7 @@ def test_mismatched_attn_mlp_count(self): # 2 attn but 1 MLP -> should raise layer_types = ['*', '*', '-', 'M'] with pytest.raises(ValueError, match="must equal"): - build_layer_index_mapping(layer_types, 'gpt-to-mamba') + build_layer_index_mapping(layer_types, 'gpt-to-hybrid') def test_unknown_direction(self): with pytest.raises(ValueError, match="Unknown direction"): @@ -349,7 +349,7 @@ def make_synthetic_gpt_checkpoint(num_layers, d_model, dtype=torch.float32): # --------------------------------------------------------------------------- -class TestGPTToMambaConversion: +class TestGPTToHybridConversion: def setup_method(self): self.d_model = 64 self.num_gpt_layers = 2 @@ -369,7 +369,7 @@ def setup_method(self): def test_shared_params_preserved(self): layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) # Embeddings should be identical assert torch.equal( @@ -381,7 +381,7 @@ def test_shared_params_preserved(self): def test_final_norm_renamed(self): layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) assert 'decoder.final_norm.weight' in result assert 'decoder.final_layernorm.weight' not in result @@ -391,7 +391,7 @@ def test_final_norm_renamed(self): def test_attention_params_mapped(self): layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) # GPT layer 0 attn -> Mamba layer 1 (first '*' in M*-M*-) assert torch.equal( @@ -406,7 +406,7 @@ def test_attention_params_mapped(self): def test_mlp_params_mapped(self): layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) # GPT layer 0 MLP -> Mamba layer 2 (first '-') assert torch.equal( @@ -421,7 +421,7 @@ def test_mlp_params_mapped(self): def test_ssm_layers_initialized(self): layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + result = convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) # SSM layers at index 0 and 3 for idx in [0, 3]: @@ -439,10 +439,10 @@ def test_layer_count_mismatch_raises(self): # Pattern with 3 attn but only 2 GPT layers layer_types = parse_hybrid_layer_pattern("M*-*-*-") with pytest.raises(ValueError, match="layers"): - convert_gpt_to_mamba(self.gpt_state, layer_types, self.args) + convert_gpt_to_hybrid(self.gpt_state, layer_types, self.args) -class TestMambaToGPTConversion: +class TestHybridToGPTConversion: def setup_method(self): self.d_model = 64 self.pattern = "M*-M*-" @@ -501,7 +501,7 @@ def _make_mamba_state(self): def test_final_norm_renamed_back(self): mamba_state = self._make_mamba_state() layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) assert 'decoder.final_layernorm.weight' in result assert 'decoder.final_norm.weight' not in result @@ -509,7 +509,7 @@ def test_final_norm_renamed_back(self): def test_ssm_params_discarded(self): mamba_state = self._make_mamba_state() layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) # No SSM keys should remain for key in result: @@ -518,7 +518,7 @@ def test_ssm_params_discarded(self): def test_attention_params_mapped(self): mamba_state = self._make_mamba_state() layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) # Mamba layer 1 (first *) -> GPT layer 0 assert torch.equal( @@ -534,7 +534,7 @@ def test_attention_params_mapped(self): def test_mlp_params_mapped(self): mamba_state = self._make_mamba_state() layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) # Mamba layer 2 (first -) -> GPT layer 0 assert torch.equal( @@ -545,7 +545,7 @@ def test_mlp_params_mapped(self): def test_gpt_layer_count(self): mamba_state = self._make_mamba_state() layer_types = parse_hybrid_layer_pattern(self.pattern) - result = convert_mamba_to_gpt(mamba_state, layer_types, self.args) + result = convert_hybrid_to_gpt(mamba_state, layer_types, self.args) # Should have 2 GPT layers (layers 0 and 1) layer_nums = set() @@ -557,13 +557,13 @@ def test_gpt_layer_count(self): # --------------------------------------------------------------------------- -# Round-trip test: GPT -> Mamba -> GPT +# Round-trip test: GPT -> Hybrid -> GPT; using Mamba as the example below # --------------------------------------------------------------------------- class TestRoundTrip: - def test_gpt_mamba_gpt_preserves_weights(self): - """Converting GPT -> Mamba -> GPT should preserve all attention & MLP weights.""" + def test_gpt_hybrid_gpt_preserves_weights(self): + """Converting GPT -> Hybrid -> GPT should preserve all attention & MLP weights.""" d_model = 64 num_layers = 2 pattern = "M*-M*-" @@ -583,11 +583,11 @@ def test_gpt_mamba_gpt_preserves_weights(self): original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) layer_types = parse_hybrid_layer_pattern(pattern) - # GPT -> Mamba - mamba_state = convert_gpt_to_mamba(original_gpt, layer_types, args) + # GPT -> Hybrid + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) - # Mamba -> GPT - recovered_gpt = convert_mamba_to_gpt(mamba_state, layer_types, args) + # Hybrid -> GPT + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) # Check all original GPT keys are preserved for key in original_gpt: @@ -626,8 +626,8 @@ def test_round_trip_different_pattern(self): original_gpt = make_synthetic_gpt_checkpoint(num_layers, d_model) layer_types = parse_hybrid_layer_pattern(pattern) - mamba_state = convert_gpt_to_mamba(original_gpt, layer_types, args) - recovered_gpt = convert_mamba_to_gpt(mamba_state, layer_types, args) + mamba_state = convert_gpt_to_hybrid(original_gpt, layer_types, args) + recovered_gpt = convert_hybrid_to_gpt(mamba_state, layer_types, args) for key in original_gpt: if 'final_layernorm' in key: @@ -647,57 +647,57 @@ class TestPatternWhitelist: def test_accepts_mamba_attn_mlp(self): # Standard hybrid with equal attn/MLP counts. layer_types = parse_hybrid_layer_pattern("M*-M*-M*-") - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_accepts_pure_transformer_pattern(self): layer_types = parse_hybrid_layer_pattern("*-*-*-") - validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') def test_accepts_pure_ssm_pattern(self): # Pure-SSM models have no attention/MLP, so trivially GPT-compatible # in the pattern sense (the GPT side would be empty). layer_types = parse_hybrid_layer_pattern("MMMM") - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_accepts_moe_pattern(self): # MoE layers ('E') round-trip through the converter as long as every # MLP-bearing position is the same kind. layer_types = parse_hybrid_layer_pattern("M*EM*EM*E") - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_accepts_pure_attn_moe_pattern(self): # No SSM, alternating attn/MoE — i.e. a Mixtral-like GPT. layer_types = parse_hybrid_layer_pattern("*E*E*E") - validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') def test_rejects_mixed_dense_and_moe(self): # GPT layers must be uniform: '-' (dense) and 'E' (MoE) cannot both # appear in the same pattern. layer_types = parse_hybrid_layer_pattern("M*-M*E") with pytest.raises(ValueError, match="uniform"): - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_rejects_gdn_symbol(self): layer_types = parse_hybrid_layer_pattern("G*-*-") with pytest.raises(ValueError, match="not GPT-compatible"): - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_rejects_unequal_attn_mlp(self): layer_types = parse_hybrid_layer_pattern("M**-") # 2 attn, 1 MLP with pytest.raises(ValueError, match="pair every attention"): - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_unequal_attn_moe_also_rejected(self): # Same uniformity check, but with MoE — 2 attn, 1 MoE. layer_types = parse_hybrid_layer_pattern("M**E") with pytest.raises(ValueError, match="pair every attention"): - validate_pattern_gpt_compatible(layer_types, 'gpt-to-mamba') + validate_pattern_gpt_compatible(layer_types, 'gpt-to-hybrid') def test_error_lists_offending_symbols(self): # 'G' is still rejected; the error message should mention it. layer_types = parse_hybrid_layer_pattern("M*-G") with pytest.raises(ValueError) as exc: - validate_pattern_gpt_compatible(layer_types, 'mamba-to-gpt') + validate_pattern_gpt_compatible(layer_types, 'hybrid-to-gpt') assert 'G' in str(exc.value) @@ -725,84 +725,84 @@ def _ok_args(self, **overrides): return argparse.Namespace(**base) def test_accepts_plain_gpt_args(self): - validate_source_args_gpt_compatible(self._ok_args(), 'gpt-to-mamba') + validate_source_args_gpt_compatible(self._ok_args(), 'gpt-to-hybrid') def test_none_args_is_noop(self): # Dist checkpoints sometimes have no cached args blob. - validate_source_args_gpt_compatible(None, 'gpt-to-mamba') + validate_source_args_gpt_compatible(None, 'gpt-to-hybrid') def test_accepts_missing_optional_fields(self): # Older checkpoints may not have every field; the validator should # silently skip fields it doesn't find. minimal = argparse.Namespace(num_moe_experts=None) - validate_source_args_gpt_compatible(minimal, 'mamba-to-gpt') + validate_source_args_gpt_compatible(minimal, 'hybrid-to-gpt') def test_accepts_moe_args(self): # MoE keys live under decoder.layers..mlp.* and round-trip as-is. - validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-mamba') + validate_source_args_gpt_compatible(self._ok_args(num_moe_experts=8), 'gpt-to-hybrid') def test_accepts_shared_expert_args(self): # Shared experts also live under mlp.shared_experts.* and round-trip. validate_source_args_gpt_compatible( self._ok_args(num_moe_experts=8, moe_shared_expert_intermediate_size=4096), - 'gpt-to-mamba', + 'gpt-to-hybrid', ) def test_rejects_moe_layer_freq_list(self): # Heterogeneous interleaving (some dense, some MoE) breaks GPT uniformity. with pytest.raises(ValueError, match="interleaved"): validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-mamba' + self._ok_args(moe_layer_freq=[1, 0, 1, 0]), 'gpt-to-hybrid' ) def test_accepts_moe_layer_freq_1(self): - validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-mamba') + validate_source_args_gpt_compatible(self._ok_args(moe_layer_freq=1), 'gpt-to-hybrid') def test_accepts_moe_layer_freq_all_ones_list(self): # An all-1s list is uniform (every layer is the same kind) and accepted. validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=[1, 1, 1, 1]), 'gpt-to-mamba' + self._ok_args(moe_layer_freq=[1, 1, 1, 1]), 'gpt-to-hybrid' ) def test_rejects_experimental_attention(self): with pytest.raises(ValueError, match="experimental attention"): validate_source_args_gpt_compatible( - self._ok_args(experimental_attention_variant='gated_delta_net'), 'gpt-to-mamba' + self._ok_args(experimental_attention_variant='gated_delta_net'), 'gpt-to-hybrid' ) def test_rejects_linear_attention(self): with pytest.raises(ValueError, match="linear attention"): validate_source_args_gpt_compatible( - self._ok_args(linear_attention_freq=4), 'gpt-to-mamba' + self._ok_args(linear_attention_freq=4), 'gpt-to-hybrid' ) def test_rejects_heterogeneous_block_specs(self): with pytest.raises(ValueError, match="heterogeneous"): validate_source_args_gpt_compatible( - self._ok_args(heterogeneous_block_specs=True), 'mamba-to-gpt' + self._ok_args(heterogeneous_block_specs=True), 'hybrid-to-gpt' ) def test_rejects_heterogeneous_config_path(self): with pytest.raises(ValueError, match="heterogeneous"): validate_source_args_gpt_compatible( - self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), 'gpt-to-mamba' + self._ok_args(heterogeneous_layers_config_path='/tmp/x.json'), 'gpt-to-hybrid' ) def test_rejects_mla(self): with pytest.raises(ValueError, match="Multi-Latent"): validate_source_args_gpt_compatible( - self._ok_args(multi_latent_attention=True), 'gpt-to-mamba' + self._ok_args(multi_latent_attention=True), 'gpt-to-hybrid' ) def test_rejects_mtp(self): with pytest.raises(ValueError, match="Multi-Token Prediction"): - validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-mamba') + validate_source_args_gpt_compatible(self._ok_args(mtp_num_layers=2), 'gpt-to-hybrid') def test_reports_multiple_reasons(self): # Both heterogeneous moe_layer_freq and MLA set — both should be reported. with pytest.raises(ValueError) as exc: validate_source_args_gpt_compatible( - self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), 'gpt-to-mamba' + self._ok_args(moe_layer_freq=[1, 0], multi_latent_attention=True), 'gpt-to-hybrid' ) msg = str(exc.value) assert 'interleaved' in msg diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py similarity index 94% rename from tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py rename to tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py index 4c908330610..4e9056008af 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_mamba_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py @@ -1,13 +1,13 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """ -Parallelism-matrix integration tests for gpt_mamba_conversion.py. +Parallelism-matrix integration tests for gpt_hybrid_conversion.py. The converter operates on dist (``torch_dist`` / ``fsdp_dtensor``) checkpoints only — DCP's metadata stores each tensor's ``global_shape``, so the on-disk TP / PP / FSDP layout is abstracted away from the conversion logic. We synthesize a DCP checkpoint via a single-rank ``dcp.save`` and round-trip -GPT -> Mamba -> GPT through the conversion CLI, asserting attention and MLP +GPT -> Hybrid -> GPT through the conversion CLI, asserting attention and MLP weights match exactly. Each scenario is run as a distinct test to document the supported matrix and @@ -34,7 +34,7 @@ sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint')) sys.path.insert(0, _THIS_DIR) -from gpt_mamba_conversion import main as conversion_main +from gpt_hybrid_conversion import main as conversion_main # These scenarios are SYNTHETIC and single-rank by design: each one writes a @@ -223,17 +223,17 @@ def _run_scenario( num_moe_experts=None, shared_expert_size=None, ): - """Build a GPT source ckpt, convert GPT->Mamba->GPT, verify round-trip.""" + """Build a GPT source ckpt, convert GPT->Hybrid->GPT, verify round-trip.""" print(f"\n=== {label} ===") print(f" source={source_format} (prefix='{source_prefix}')") print(f" target={target_format}") if num_moe_experts is not None: print(f" MoE: num_experts={num_moe_experts} shared={shared_expert_size}") - tmpdir = tempfile.mkdtemp(prefix=f'gpt_mamba_{label.replace(" ", "_")}_') + tmpdir = tempfile.mkdtemp(prefix=f'gpt_hybrid_{label.replace(" ", "_")}_') try: src_gpt_dir = os.path.join(tmpdir, 'gpt_src') - mamba_dir = os.path.join(tmpdir, 'mamba_mid') + hybrid_dir = os.path.join(tmpdir, 'hybrid_mid') dst_gpt_dir = os.path.join(tmpdir, 'gpt_dst') ckpt_args = make_checkpoint_args( @@ -267,23 +267,23 @@ def _run_scenario( output_format=target_format, ) - # --- GPT -> Mamba --- + # --- GPT -> Hybrid --- conversion_main( argparse.Namespace( - direction='gpt-to-mamba', load_dir=src_gpt_dir, save_dir=mamba_dir, **common_kwargs + direction='gpt-to-hybrid', load_dir=src_gpt_dir, save_dir=hybrid_dir, **common_kwargs ) ) - # --- Mamba -> GPT --- + # --- Hybrid -> GPT --- conversion_main( argparse.Namespace( - direction='mamba-to-gpt', load_dir=mamba_dir, save_dir=dst_gpt_dir, **common_kwargs + direction='hybrid-to-gpt', load_dir=hybrid_dir, save_dir=dst_gpt_dir, **common_kwargs ) ) # --- Verify --- recovered_sd, _ = _load_converted_dist(dst_gpt_dir) - # The mamba->gpt step renames decoder.final_norm -> decoder.final_layernorm, + # The hybrid->gpt step renames decoder.final_norm -> decoder.final_layernorm, # mirroring the original GPT key. So recovered_sd should have the same # keys and tensor values as gpt_sd. @@ -395,7 +395,7 @@ def test_fsdp_dtensor_moe_roundtrip(): if __name__ == '__main__': print("=" * 60) - print("GPT <-> Mamba Conversion Parallelism Matrix Tests") + print("GPT <-> Hybrid Conversion Parallelism Matrix Tests") print("=" * 60) test_torch_dist_roundtrip() diff --git a/tools/checkpoint/gpt_mamba_conversion.py b/tools/checkpoint/gpt_hybrid_conversion.py similarity index 92% rename from tools/checkpoint/gpt_mamba_conversion.py rename to tools/checkpoint/gpt_hybrid_conversion.py index 951ce1abec6..94236daeed9 100644 --- a/tools/checkpoint/gpt_mamba_conversion.py +++ b/tools/checkpoint/gpt_hybrid_conversion.py @@ -1,18 +1,18 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """ -GPT <-> Mamba Checkpoint Conversion Tool +GPT <-> Hybrid Checkpoint Conversion Tool ========================================= Directly converts checkpoints between GPTModel (homogeneous Transformer) and -MambaModel (hybrid Mamba+Transformer) without going through HuggingFace as an +HybridModel (hybrid Mamba+Transformer) without going through HuggingFace as an intermediary. Supported directions: - gpt-to-mamba : Convert a GPT checkpoint to Mamba hybrid format. - mamba-to-gpt : Convert a Mamba hybrid checkpoint to GPT format. + gpt-to-hybrid : Convert a GPT checkpoint to Hybrid format. + hybrid-to-gpt : Convert a Hybrid checkpoint to GPT format. -How the hybrid layer pattern maps GPT layers (gpt-to-mamba): +How the hybrid layer pattern maps GPT layers (gpt-to-hybrid): - Each GPT layer contains both attention and MLP sub-layers. - The target hybrid model's hybrid_layer_pattern specifies per-layer types: M = Mamba SSM layer @@ -41,7 +41,7 @@ - Use a pattern like 'M*EM*EM*E' to pair Mamba/Attn/MoE-MLP per stage. What happens to SSM parameters: - gpt-to-mamba: SSM layers (M) are initialized from scratch: + gpt-to-hybrid: SSM layers (M) are initialized from scratch: - A_log: log(uniform(1, 16)) - dt_bias: inverse_softplus(log_uniform(dt_min, dt_max)) - D: ones @@ -51,7 +51,7 @@ - in_proj.layer_norm_weight: ones - out_proj.weight: kaiming_uniform(a=sqrt(5)) - norm.weight: ones - mamba-to-gpt: SSM layers are discarded with a warning. + hybrid-to-gpt: SSM layers are discarded with a warning. Supported checkpoint formats: - torch_dist : Megatron distributed checkpoint (TP + PP + FSDP). @@ -87,21 +87,21 @@ `validate_source_args_gpt_compatible` for the exact rules. Example commands: - # GPT -> Mamba (TP+PP+FSDP dist checkpoint) - python tools/checkpoint/gpt_mamba_conversion.py \\ - --direction gpt-to-mamba \\ + # GPT -> Hybrid (TP+PP+FSDP dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction gpt-to-hybrid \\ --load-dir /path/to/gpt-dist-checkpoint \\ - --save-dir /path/to/mamba-dist-checkpoint \\ + --save-dir /path/to/hybrid-dist-checkpoint \\ --hybrid-layer-pattern "M*-M*-M*-M*-" \\ --d-model 4096 \\ --mamba-d-state 128 \\ --mamba2-n-groups 8 \\ --mamba2-head-dim 64 - # Mamba -> GPT (dist checkpoint) - python tools/checkpoint/gpt_mamba_conversion.py \\ - --direction mamba-to-gpt \\ - --load-dir /path/to/mamba-dist-checkpoint \\ + # Hybrid -> GPT (dist checkpoint) + python tools/checkpoint/gpt_hybrid_conversion.py \\ + --direction hybrid-to-gpt \\ + --load-dir /path/to/hybrid-dist-checkpoint \\ --save-dir /path/to/gpt-dist-checkpoint \\ --hybrid-layer-pattern "M*-M*-M*-M*-" \\ --d-model 4096 \\ @@ -140,10 +140,10 @@ # '-' : standard (optionally gated) dense MLP layer # 'E' : MoE MLP layer. Both sides keep the keys under # decoder.layers..mlp.{router,experts,shared_experts}.* so MoE -# tensors round-trip verbatim (see convert_gpt_to_mamba and -# convert_mamba_to_gpt — `is_mlp_param` already matches `mlp.*`). +# tensors round-trip verbatim (see convert_gpt_to_hybrid and +# convert_hybrid_to_gpt — `is_mlp_param` already matches `mlp.*`). # SSM ('M') has no GPT equivalent and is initialized from scratch / -# discarded (see convert_gpt_to_mamba / convert_mamba_to_gpt). +# discarded (see convert_gpt_to_hybrid / convert_hybrid_to_gpt). # 'G' (GDN) and 'D' (DS-attention) are not currently mapped — they would # need separate key-naming work. Reject for now. GPT_COMPATIBLE_PATTERN_SYMBOLS = {'M', '*', '-', 'E'} @@ -183,13 +183,13 @@ def parse_hybrid_layer_pattern(pattern): def build_layer_index_mapping(layer_types, direction): """Build mapping between GPT layer indices and hybrid-model layer indices. - For gpt-to-mamba: + For gpt-to-hybrid: Returns (attn_map, mlp_map, ssm_indices) where: - attn_map[gpt_layer_i] = hybrid_layer_j (j is the index of the i-th '*') - mlp_map[gpt_layer_i] = hybrid_layer_k (k is the index of the i-th MLP-bearing position; either '-' or 'E') - For mamba-to-gpt: + For hybrid-to-gpt: Returns (attn_map, mlp_map, ssm_indices) where: - attn_map[hybrid_attn_idx] = gpt_layer_i - mlp_map[hybrid_mlp_idx] = gpt_layer_i @@ -198,20 +198,20 @@ def build_layer_index_mapping(layer_types, direction): mlp_indices = [i for i, t in enumerate(layer_types) if t in _MLP_BEARING_SYMBOLS] ssm_indices = [i for i, t in enumerate(layer_types) if t == 'M'] - if direction == 'gpt-to-mamba': + if direction == 'gpt-to-hybrid': if len(attn_indices) != len(mlp_indices): raise ValueError( - f"For gpt-to-mamba, the number of attention layers ({len(attn_indices)}) " + f"For gpt-to-hybrid, the number of attention layers ({len(attn_indices)}) " f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." ) attn_map = {i: attn_indices[i] for i in range(len(attn_indices))} mlp_map = {i: mlp_indices[i] for i in range(len(mlp_indices))} return attn_map, mlp_map, ssm_indices - elif direction == 'mamba-to-gpt': + elif direction == 'hybrid-to-gpt': if len(attn_indices) != len(mlp_indices): raise ValueError( - f"For mamba-to-gpt, the number of attention layers ({len(attn_indices)}) " + f"For hybrid-to-gpt, the number of attention layers ({len(attn_indices)}) " f"must equal the number of MLP/MoE layers ({len(mlp_indices)}) in the pattern." ) attn_map = {attn_indices[i]: i for i in range(len(attn_indices))} @@ -318,7 +318,7 @@ def validate_pattern_gpt_compatible(layer_types, direction): Args: layer_types: list of layer-type chars from parse_hybrid_layer_pattern(). - direction: 'gpt-to-mamba' or 'mamba-to-gpt' (for error messages). + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt' (for error messages). Rules: * Allowed symbols: 'M', '*', '-', 'E'. 'G' (GDN) and 'D' (DS-attention) @@ -365,7 +365,7 @@ def validate_source_args_gpt_compatible(source_args, direction): source_args: argparse.Namespace (or any attribute-bag) loaded from the source checkpoint; may be None, in which case this check is a no-op (dist checkpoints without a cached args blob). - direction: 'gpt-to-mamba' or 'mamba-to-gpt'. + direction: 'gpt-to-hybrid' or 'hybrid-to-gpt'. Rejects MoE, MLA, MTP, linear / experimental attention, and heterogeneous per-layer specs. See the module header for the full list. @@ -393,12 +393,12 @@ def validate_source_args_gpt_compatible(source_args, direction): f"conversion. The following features have no GPTModel equivalent " f"and would produce a corrupt target checkpoint:\n{joined}\n" f"Remove these features from the model (or use a different " - f"conversion tool) before running gpt_mamba_conversion." + f"conversion tool) before running gpt_hybrid_conversion." ) # --------------------------------------------------------------------------- -# SSM parameter initialization (for gpt-to-mamba) +# SSM parameter initialization (for gpt-to-hybrid) # --------------------------------------------------------------------------- def initialize_ssm_layer_params( @@ -528,11 +528,11 @@ def is_layer_norm_for_ssm(key): # --------------------------------------------------------------------------- -# Core conversion: GPT -> Mamba +# Core conversion: GPT -> Hybrid # --------------------------------------------------------------------------- -def convert_gpt_to_mamba(full_model, layer_types, args): - """Convert a GPT state dict to a Mamba hybrid state dict. +def convert_gpt_to_hybrid(full_model, layer_types, args): + """Convert a GPT state dict to a Hybrid state dict. Args: full_model: OrderedDict with globally-indexed GPT state dict keys. @@ -540,10 +540,10 @@ def convert_gpt_to_mamba(full_model, layer_types, args): args: Parsed CLI arguments. Returns: - OrderedDict: Mamba state dict with globally-indexed keys. + OrderedDict: Hybrid state dict with globally-indexed keys. """ attn_map, mlp_map, ssm_indices = build_layer_index_mapping( - layer_types, 'gpt-to-mamba' + layer_types, 'gpt-to-hybrid' ) num_gpt_layers = len(attn_map) @@ -625,14 +625,14 @@ def convert_gpt_to_mamba(full_model, layer_types, args): # --------------------------------------------------------------------------- -# Core conversion: Mamba -> GPT +# Core conversion: Hybrid -> GPT # --------------------------------------------------------------------------- -def convert_mamba_to_gpt(full_model, layer_types, args): - """Convert a Mamba hybrid state dict to a GPT state dict. +def convert_hybrid_to_gpt(full_model, layer_types, args): + """Convert a Hybrid state dict to a GPT state dict. Args: - full_model: OrderedDict with globally-indexed Mamba state dict keys. + full_model: OrderedDict with globally-indexed Hybrid state dict keys. layer_types: list of layer type chars from hybrid_layer_pattern. args: Parsed CLI arguments. @@ -640,7 +640,7 @@ def convert_mamba_to_gpt(full_model, layer_types, args): OrderedDict: GPT state dict with globally-indexed keys. """ attn_map, mlp_map, ssm_indices = build_layer_index_mapping( - layer_types, 'mamba-to-gpt' + layer_types, 'hybrid-to-gpt' ) num_gpt_layers = len(attn_map) @@ -687,7 +687,7 @@ def convert_mamba_to_gpt(full_model, layer_types, args): if discarded_ssm_keys: print(f"\n WARNING: Discarded {len(discarded_ssm_keys)} SSM parameter tensors " - f"from {len(ssm_indices)} Mamba layers (no GPT equivalent).") + f"from {len(ssm_indices)} SSM layers (no GPT equivalent).") print(f" First few discarded keys: {discarded_ssm_keys[:5]}") target = _sort_state_dict(target) @@ -743,7 +743,7 @@ def _save_dist_full(target_state_dict, common_state, model_prefix, backend, ckpt_args = common_state['args'] ckpt_args.num_layers = args.target_num_layers if hasattr(ckpt_args, 'hybrid_layer_pattern'): - if args.direction == 'gpt-to-mamba': + if args.direction == 'gpt-to-hybrid': ckpt_args.hybrid_layer_pattern = args.hybrid_layer_pattern else: ckpt_args.hybrid_layer_pattern = None @@ -765,7 +765,7 @@ def _save_dist_full(target_state_dict, common_state, model_prefix, backend, def main(args): - print("\n====RUNNING GPT <-> MAMBA CHECKPOINT CONVERSION====\n") + print("\n====RUNNING GPT <-> Hybrid CHECKPOINT CONVERSION====\n") print(f" Direction: {args.direction}") print(f" Source: {args.load_dir}") print(f" Target: {args.save_dir}") @@ -777,7 +777,7 @@ def main(args): # Parse hybrid layer pattern layer_types = parse_hybrid_layer_pattern(args.hybrid_layer_pattern) - total_mamba_layers = len(layer_types) + total_hybrid_layers = len(layer_types) attn_count = sum(1 for t in layer_types if t == '*') mlp_count = sum(1 for t in layer_types if t == '-') ssm_count = sum(1 for t in layer_types if t == 'M') @@ -824,11 +824,11 @@ def main(args): # 3. Convert print(f"\n[Step 2] Converting ({args.direction})...") - if args.direction == 'gpt-to-mamba': - target_state_dict = convert_gpt_to_mamba(full_model, layer_types, args) - args.target_num_layers = total_mamba_layers - elif args.direction == 'mamba-to-gpt': - target_state_dict = convert_mamba_to_gpt(full_model, layer_types, args) + if args.direction == 'gpt-to-hybrid': + target_state_dict = convert_gpt_to_hybrid(full_model, layer_types, args) + args.target_num_layers = total_hybrid_layers + elif args.direction == 'hybrid-to-gpt': + target_state_dict = convert_hybrid_to_gpt(full_model, layer_types, args) args.target_num_layers = attn_count else: raise ValueError(f"Unknown direction: {args.direction}") @@ -846,13 +846,13 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Convert checkpoints between GPTModel and MambaModel formats.", + description="Convert checkpoints between GPTModel and HybridModel formats.", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( '--direction', type=str, required=True, - choices=['gpt-to-mamba', 'mamba-to-gpt'], + choices=['gpt-to-hybrid', 'hybrid-to-gpt'], help='Conversion direction.', ) parser.add_argument('--load-dir', type=str, required=True, From e532385f1e49414e3b04f9437849aa4d12190bfe Mon Sep 17 00:00:00 2001 From: guihong-nv Date: Thu, 30 Apr 2026 14:07:45 -0700 Subject: [PATCH 10/10] Fix linting issues Signed-off-by: guihong-nv --- .../test_gpt_hybrid_conversion_parallelism.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py index 4e9056008af..8102b7018ae 100644 --- a/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py +++ b/tests/unit_tests/tools/checkpoint/test_gpt_hybrid_conversion_parallelism.py @@ -270,14 +270,20 @@ def _run_scenario( # --- GPT -> Hybrid --- conversion_main( argparse.Namespace( - direction='gpt-to-hybrid', load_dir=src_gpt_dir, save_dir=hybrid_dir, **common_kwargs + direction='gpt-to-hybrid', + load_dir=src_gpt_dir, + save_dir=hybrid_dir, + **common_kwargs, ) ) # --- Hybrid -> GPT --- conversion_main( argparse.Namespace( - direction='hybrid-to-gpt', load_dir=hybrid_dir, save_dir=dst_gpt_dir, **common_kwargs + direction='hybrid-to-gpt', + load_dir=hybrid_dir, + save_dir=dst_gpt_dir, + **common_kwargs, ) )