diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py new file mode 100644 index 000000000..b65c8fb47 --- /dev/null +++ b/tests/distributed/test_context_parallel.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, EleutherAI +# Licensed under the Apache 2.0 license. + +""" +Unit‑tests for context‑parallelism + +We patch `megatron.mpu.get_context_parallel_*` so that we don't have to set up distributed on the gh CI runner +2‑way context‑parallel world is running, then verify that: + +1. `zigzag_data` returns the correct slice for each (fake) rank. +2. `RotaryEmbedding` builds `cos_cached` / `sin_cached` using the same + zig‑zag time‑indices. + +""" + +import torch +import pytest +import megatron.mpu as mpu +from megatron.mpu.data import zigzag_data +from megatron.model.positional_embeddings import RotaryEmbedding + + +@pytest.mark.parametrize("rank", [0, 1]) +def test_zigzag_and_rotary(monkeypatch, rank): + """ + Simulate a 2‑GPU context‑parallel group and check that both the low‑level + zig‑zag utility and the higher‑level rotary‑embedding cache behave as + expected on each rank. + """ + # Patch the MPU helpers to fake a 2‑way group + monkeypatch.setattr(mpu, "get_context_parallel_world_size", lambda: 2) + monkeypatch.setattr(mpu, "get_context_parallel_rank", lambda: rank) + + # zigzag_data + seq_dim = 1 + x = torch.arange(16).view(2, 8) # shape: (batch=2, seq=8) + + # Compute the expected zig‑zag slice manually + chunks = torch.chunk(x, 2 * 2, dim=seq_dim) # 4 chunks of length 2 + expected = ( + torch.cat((chunks[0], chunks[-1]), dim=seq_dim) + if rank == 0 + else torch.cat((chunks[1], chunks[-2]), dim=seq_dim) + ) + + out = zigzag_data(x, seq_dim=seq_dim) + assert torch.equal(out, expected), "zig‑zag sharding mismatch" + + # RotaryEmbedding cache + dim = 8 + rope = RotaryEmbedding( + dim=dim, + max_seq_len=8, + base=10_000, + precision=torch.float32, + zigzag=True, + ) + + # Re‑create the ‘t’ indices that _prepare_cache() should have used + full_t = torch.arange(8) + expected_t = ( + torch.cat((full_t[:2], full_t[-2:])) # rank 0 + if rank == 0 + else torch.cat((full_t[2:4], full_t[-4:-2])) # rank 1 + ) + + inv_freq = 1.0 / (10_000 ** (torch.arange(0, dim, 2).float() / dim)) + freqs = torch.einsum("i,j->ij", expected_t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_ref, sin_ref = emb.cos(), emb.sin() + + assert rope.cos_cached.shape == cos_ref.shape + assert rope.sin_cached.shape == sin_ref.shape + assert torch.allclose(rope.cos_cached, cos_ref, atol=1e-6) + assert torch.allclose(rope.sin_cached, sin_ref, atol=1e-6) + diff --git a/tests/model/test_dmoe.py b/tests/model/test_dmoe.py new file mode 100644 index 000000000..daeda4d58 --- /dev/null +++ b/tests/model/test_dmoe.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, EleutherAI +# Licensed under the Apache 2.0 licence. + +""" +▸ Part 1 – expert‑token helper utilities: + * `get_expert_tokens_for_rank` + * `get_expert_token_counts_for_rank` + +▸ Part 2 – lightweight router (`TopKTokenChoiceRouter`) + * shape & range of returned weights / indices + * determinism under identical input + +""" + +import types +import torch +import pytest +import importlib + + +@pytest.fixture(autouse=True) +def patch_mpu(monkeypatch): + """ + Pretend we have a 2‑way tensor‑parallel group; most MoE helpers only query + `get_model_parallel_world_size` and `get_model_parallel_rank`. + """ + import megatron.mpu as mpu + + monkeypatch.setattr(mpu, "get_model_parallel_world_size", lambda: 2, raising=False) + # `rank` will be injected per‑test case + yield + + +def _set_rank(monkeypatch, rank: int): + import megatron.mpu as mpu + monkeypatch.setattr(mpu, "get_model_parallel_rank", lambda: rank, raising=False) + + +# Part 1 – expert‑token split / gather helpers +@pytest.mark.parametrize("rank", [0, 1]) +def test_expert_token_helpers(monkeypatch, rank): + """ + A tiny batch of 6 routed tokens divided among 4 experts with the pattern + [2,1,0,3]. With world_size==2 each rank owns 2 experts ⇒ verify that + the expected slices/counts are returned. + """ + from megatron.mpu.initialize import ( + get_expert_tokens_for_rank, + get_expert_token_counts_for_rank, + ) + + _set_rank(monkeypatch, rank) + + tokens_per_expert = torch.tensor([2, 1, 0, 3]) # len == num_experts + routed = torch.arange(6*3).view(6, 3) # shape (6, 3) + + # ‑‑ expected slice for this fake rank + # cumulative sums → [2,3,3,6]; rank 0 gets experts 0&1, rank 1 gets 2&3 + start = 0 if rank == 0 else 3 + end = 3 if rank == 0 else 6 + want_slice = routed[start:end] + + out_tokens = get_expert_tokens_for_rank(routed, tokens_per_expert) + out_counts = get_expert_token_counts_for_rank(tokens_per_expert) + + assert torch.equal(out_tokens, want_slice) + assert out_counts.tolist() == ([2, 1] if rank == 0 else [0, 3]) + + +# Part 2 – Top‑K token‑choice router +def _dummy_args(num_experts=8, top_k=2, hidden_size=16): + """Return a minimal object that TopKTokenChoiceRouter expects.""" + return types.SimpleNamespace( + hidden_size = hidden_size, + moe_num_experts = num_experts, + moe_top_k = top_k, + moe_jitter_eps = None, + params_dtype = torch.float32, # keep everything on CPU + ) + + +@pytest.mark.parametrize("top_k", [1, 2]) +def test_router_shapes_and_range(top_k): + """Router must return (batch, top_k) tensors; indices < num_experts.""" + mod = importlib.import_module("megatron.model.router") + Router = mod.TopKTokenChoiceRouter + + args = _dummy_args(num_experts=5, top_k=top_k, hidden_size=32) + router = Router(args, init_method=torch.nn.init.uniform_) + + seq, bs = 4, 3 + x = torch.randn(seq, bs, args.hidden_size) + + w, idx = router(x) + + assert w.shape == (seq * bs, top_k) + assert idx.shape == (seq * bs, top_k) + assert torch.all(idx < args.moe_num_experts) + # Probabilities must be positive and ≤1 + assert torch.all(w >= 0) and torch.all(w <= 1) + + # Deterministic behaviour for identical input (no jitter, eval mode). + router.eval() + w2, idx2 = router(x) + assert torch.equal(w, w2) + assert torch.equal(idx, idx2) + diff --git a/tests/model/test_mup.py b/tests/model/test_mup.py new file mode 100644 index 000000000..ab1770384 --- /dev/null +++ b/tests/model/test_mup.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, EleutherAI +# Licensed under the Apache 2.0 license. + +import types +import torch +import pytest + +from megatron.model.utils import get_params_for_weight_decay_optimization +from megatron.learning_rates import AnnealingLR + + +class TinyNet(torch.nn.Module): + """Just enough structure to exercise the param‑group builder.""" + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(4, 4) # should get weight‑decay + self.norm = torch.nn.LayerNorm(4) # should be no‑decay + + +@pytest.fixture(scope="module") +def dummy_args(): + # Only the attributes that `get_params_for_weight_decay_optimization` + # actually accesses. + return types.SimpleNamespace(weight_decay=0.1) + + +def _new_scheduler(optimizer, use_mup, width_mult): + """ + Construct an AnnealingLR and monkey‑patch ``get_lr`` so the test is + independent of the exact schedule math. + """ + sched = AnnealingLR( + optimizer, + start_lr=0.0, + max_lr=0.02, + min_lr=0.0, + warmup_iter=0, + total_iters=1, + decay_style="constant", + use_checkpoint_lr_scheduler=False, + override_lr_scheduler=False, + use_mup=use_mup, + mup_width_multiplier=width_mult, + ) + # Force the scheduler to think LR should be 0.02 every step + AnnealingLR.get_lr = lambda self: 0.02 + return sched + + +def test_param_groups_have_lr_adjust(dummy_args): + """Builder should tag both WD and no‑WD groups with ``lr_adjust``.""" + net = TinyNet() + groups = get_params_for_weight_decay_optimization(net, dummy_args) + + assert len(groups) == 2 + assert all(g.get("lr_adjust", False) for g in groups), ( + "Every param‑group returned by the builder must carry lr_adjust=True " + "so muP knows to divide its LR." + ) + + +@pytest.mark.parametrize("use_mup,expected_factor", [(True, 4.0), (False, 1.0)]) +def test_scheduler_scales_learning_rate(monkeypatch, dummy_args, use_mup, expected_factor): + """ + When `use_mup` is True the LR of *lr_adjust* groups must be divided by + ``mup_width_multiplier``; otherwise, it must stay unchanged. + """ + net = TinyNet() + param_groups = get_params_for_weight_decay_optimization(net, dummy_args) + + optimizer = torch.optim.SGD(param_groups, lr=0.0) # fine for sanity checking + width_mult = 4.0 + sched = _new_scheduler(optimizer, use_mup=use_mup, width_mult=width_mult) + + sched.step() + + lrs = [g["lr"] for g in optimizer.param_groups] + assert pytest.approx(lrs[0], rel=1e-7) == 0.02 / expected_factor + assert pytest.approx(lrs[1], rel=1e-7) == 0.02 / expected_factor +