Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/unit_tests/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1 change: 1 addition & 0 deletions tests/unit_tests/tools/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
328 changes: 328 additions & 0 deletions tests/unit_tests/tools/checkpoint/test_distributed_round_trip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""
Multi-rank distributed round-trip test for gpt_mamba_conversion.

Each rank participates in a multi-rank DCP save of a synthetic GPT (or MoE
GPT) state dict; rank 0 then runs the converter and verifies the GPT->Mamba->
GPT round-trip exactly.

This test is meant to be launched under SLURM/srun (or torchrun) with
WORLD_SIZE = TP * PP * FSDP * EP. The (tp, pp, fsdp, ep) values are passed
as flags purely as labels — the converter sees only the DCP-stored
``global_shape`` per tensor and is agnostic to *which* dimension(s) the
source was sharded along. The test value is in:

1. Coordinating a real multi-rank ``dcp.save`` (cross-node networking,
collective barriers, shared-filesystem write).
2. Verifying the converter loads a multi-rank-written checkpoint and
round-trips it through both directions.

Usage (under srun):
export RANK=$SLURM_PROCID
export LOCAL_RANK=$SLURM_LOCALID
export WORLD_SIZE=$SLURM_NTASKS
export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -1)
export MASTER_PORT=29500
python test_distributed_round_trip.py \\
--tp 2 --pp 2 --fsdp 2 --ep 2 --label TP2-PP2-FSDP2-EP2 \\
--output-root /lustre/.../scratch/dist_test
"""

import argparse
import copy
import os
import shutil
import sys
import time
from collections import OrderedDict
from types import SimpleNamespace

import torch
import torch.distributed as dist

# Make the conversion tool and helpers importable.
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_REPO_ROOT = os.path.join(_THIS_DIR, '..', '..', '..', '..')
sys.path.insert(0, os.path.join(_REPO_ROOT, 'tools', 'checkpoint'))
sys.path.insert(0, _THIS_DIR)


def _log(msg, rank, label=""):
prefix = f"[rank={rank}{(' ' + label) if label else ''}]"
print(f"{prefix} {msg}", flush=True)


def _build_state_dict(num_layers, hidden_size, num_moe_experts, vocab_size, dtype):
"""Build a deterministic GPT(MoE) state dict identical on every rank.

Determinism (via fixed seed) lets every rank produce the same tensors so
DCP's de-duplication writes a single coherent checkpoint. After load, we
re-derive the same tensors on rank 0 to verify round-trip.
"""
torch.manual_seed(0xC0FFEE)
sd = OrderedDict()
sd['embedding.word_embeddings.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype)

for i in range(num_layers):
p = f'decoder.layers.{i}.'
sd[p + 'input_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype)
sd[p + 'self_attention.linear_qkv.weight'] = torch.randn(
3 * hidden_size, hidden_size, dtype=dtype
)
sd[p + 'self_attention.linear_proj.weight'] = torch.randn(
hidden_size, hidden_size, dtype=dtype
)
sd[p + 'pre_mlp_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype)

if num_moe_experts is None:
sd[p + 'mlp.linear_fc1.weight'] = torch.randn(4 * hidden_size, hidden_size, dtype=dtype)
sd[p + 'mlp.linear_fc2.weight'] = torch.randn(hidden_size, 4 * hidden_size, dtype=dtype)
else:
sd[p + 'mlp.router.weight'] = torch.randn(num_moe_experts, hidden_size, dtype=dtype)
for j in range(num_moe_experts):
ep = p + f'mlp.experts.local_experts.{j}.'
sd[ep + 'linear_fc1.weight'] = torch.randn(
4 * hidden_size, hidden_size, dtype=dtype
)
sd[ep + 'linear_fc2.weight'] = torch.randn(
hidden_size, 4 * hidden_size, dtype=dtype
)

sd['decoder.final_layernorm.weight'] = torch.randn(hidden_size, dtype=dtype)
sd['output_layer.weight'] = torch.randn(vocab_size, hidden_size, dtype=dtype)
return sd


def _build_ckpt_args(num_layers, hidden_size, num_moe_experts):
return SimpleNamespace(
num_layers=num_layers,
hidden_size=hidden_size,
num_attention_heads=4,
ffn_hidden_size=hidden_size * 4,
seq_length=256,
max_position_embeddings=256,
iteration=100,
consumed_train_samples=0,
consumed_valid_samples=0,
train_iters=1000,
train_samples=0,
tokenizer_type='GPT2BPETokenizer',
position_embedding_type='rope',
params_dtype=torch.float32,
fp16=False,
bf16=False,
num_moe_experts=num_moe_experts,
moe_shared_expert_intermediate_size=None,
moe_layer_freq=1,
)


def _init_process_group(init_file):
"""Initialize via file:// rendezvous on a shared filesystem.

RANK / WORLD_SIZE come from env (set by srun). MASTER_ADDR / MASTER_PORT
are not needed — every rank just opens the same shared file. This avoids
the SLURM CLI tools (e.g. scontrol) which are not always present inside
container images.

The init file MUST NOT pre-exist; rank 0 cleans up any stale leftover.
"""
if dist.is_initialized():
return
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
if rank == 0 and os.path.exists(init_file):
os.remove(init_file)
# Brief settle so other ranks don't race ahead of the cleanup.
time.sleep(1)
dist.init_process_group(
backend='gloo', # CPU-only synthetic save; no NCCL needed
init_method=f'file://{init_file}',
rank=rank,
world_size=world_size,
)


def _verify_round_trip(original, recovered, label):
missing, mismatch = [], []
for k, v in original.items():
if k not in recovered:
missing.append(k)
continue
if not torch.equal(v, recovered[k].to(v.dtype)):
mismatch.append(k)

leaked_ssm = [k for k in recovered if 'mixer.' in k]

if missing or mismatch or leaked_ssm:
print(f"FAIL [{label}]:")
for k in missing[:5]:
print(f" MISSING: {k}")
for k in mismatch[:5]:
print(f" MISMATCH: {k}")
for k in leaked_ssm[:5]:
print(f" SSM LEAKED: {k}")
raise AssertionError(
f"{label}: {len(missing)} missing, {len(mismatch)} mismatched, "
f"{len(leaked_ssm)} SSM keys leaked"
)

print(f"PASS [{label}]: {len(original)} keys round-tripped exactly")


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--pp', type=int, default=1)
parser.add_argument('--fsdp', type=int, default=1)
parser.add_argument('--ep', type=int, default=1)
parser.add_argument('--label', type=str, required=True)
parser.add_argument(
'--output-root',
type=str,
required=True,
help='Shared-filesystem path where this scenario writes its '
'checkpoints (must be visible from every rank).',
)
parser.add_argument('--num-layers', type=int, default=3)
parser.add_argument('--hidden-size', type=int, default=64)
parser.add_argument('--vocab-size', type=int, default=512)
parser.add_argument(
'--num-moe-experts',
type=int,
default=None,
help='If set, use the MoE GPT state-dict layout '
'(mlp.router + mlp.experts.local_experts.*).',
)
parser.add_argument(
'--pattern',
type=str,
default=None,
help='Hybrid layer pattern. Defaults to "M*-M*-M*-" for '
'dense or "M*EM*EM*E" when --num-moe-experts is set.',
)
args = parser.parse_args()

expected_world = args.tp * args.pp * args.fsdp * args.ep
pattern = args.pattern or ('M*EM*EM*E' if args.num_moe_experts is not None else 'M*-M*-M*-')

# Shared init file lives on the same shared FS we use for checkpoints, so
# all ranks on all nodes see the same path.
init_file = os.path.join(args.output_root, f'_pg_init_{args.label}')
if int(os.environ.get('RANK', 0)) == 0:
os.makedirs(args.output_root, exist_ok=True)
_init_process_group(init_file)
rank = dist.get_rank()
world = dist.get_world_size()
if world != expected_world:
if rank == 0:
print(f"FAIL [{args.label}]: world={world} but tp*pp*fsdp*ep={expected_world}")
sys.exit(2)

# Lazy imports after sys.path is set.
from dist_checkpoint_io import (
load_dist_checkpoint_full,
save_dist_checkpoint_full,
write_latest_iteration_marker,
)
from gpt_mamba_conversion import main as conversion_main

if rank == 0:
_log(
f"label={args.label} tp={args.tp} pp={args.pp} fsdp={args.fsdp} "
f"ep={args.ep} world={world} pattern={pattern} "
f"num_moe_experts={args.num_moe_experts}",
rank,
args.label,
)

# Each rank builds the same full state dict — DCP de-duplicates writes
# across ranks via its writer planner.
state_dict = _build_state_dict(
args.num_layers, args.hidden_size, args.num_moe_experts, args.vocab_size, torch.float32
)
ckpt_args = _build_ckpt_args(args.num_layers, args.hidden_size, args.num_moe_experts)

scratch = os.path.join(args.output_root, args.label)
src_dir = os.path.join(scratch, 'gpt_src')
mid_dir = os.path.join(scratch, 'mamba_mid')
dst_dir = os.path.join(scratch, 'gpt_dst')
iter_subdir = os.path.join(src_dir, 'iter_0000100')

if rank == 0:
if os.path.isdir(scratch):
shutil.rmtree(scratch, ignore_errors=True)
os.makedirs(iter_subdir, exist_ok=True)
dist.barrier()

# --- Multi-rank DCP write of the source GPT checkpoint ---
# dcp.save / dcp.load are both COLLECTIVE in the active process group, so
# every rank in this PG must participate in every save and every load.
# That includes the two conversion_main calls below, which internally call
# load_dist_checkpoint_full + save_dist_checkpoint_full once each.
# If a rank exits early its gloo socket closes and rank 0's reduce_scatter
# dies with "Connection closed by peer".
common_state = {'args': copy.deepcopy(ckpt_args), 'checkpoint_version': 3.0, 'iteration': 100}
save_dist_checkpoint_full(
state_dict, common_state, iter_subdir, model_prefix='model.', backend='torch_dist'
)
if rank == 0:
write_latest_iteration_marker(iter_subdir, 100)
dist.barrier()

# --- Convert GPT -> hybrid -> GPT (every rank participates collectively).
common_kwargs = dict(
hybrid_layer_pattern=pattern,
d_model=args.hidden_size,
mamba_version=2,
mamba_d_state=16,
mamba2_n_groups=2,
mamba2_head_dim=32,
d_conv=4,
init_method_std=0.02,
reset_iterations=False,
input_format='auto',
output_format='torch_dist',
)

# Silence non-rank-0 stdout to keep logs readable; collective behavior
# is unaffected.
if rank != 0:
sys.stdout = open(os.devnull, 'w')

t0 = time.time()
conversion_main(
argparse.Namespace(
direction='gpt-to-mamba', load_dir=src_dir, save_dir=mid_dir, **common_kwargs
)
)
dist.barrier()
conversion_main(
argparse.Namespace(
direction='mamba-to-gpt', load_dir=mid_dir, save_dir=dst_dir, **common_kwargs
)
)
dist.barrier()
dt = time.time() - t0

# Restore stdout for rank 0's verify message.
if rank != 0:
sys.stdout = sys.__stdout__

# --- Load final + (rank 0 only) verify ---
recovered, _, _, _, _ = load_dist_checkpoint_full(dst_dir)
dist.barrier()

if rank == 0:
_log(f"conversion time: {dt:.2f}s", rank, args.label)
_verify_round_trip(state_dict, recovered, args.label)
shutil.rmtree(scratch, ignore_errors=True)
if os.path.exists(init_file):
os.remove(init_file)
dist.barrier()
dist.destroy_process_group()


if __name__ == '__main__':
main()
Loading
Loading