Skip to content

Commit bf18374

Browse files
committed
fix the outdated end2end training examples of moe+torchtitan
make quant receipt flexible
1 parent 03c2d28 commit bf18374

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

torchao/prototype/moe_training/examples/simple_moe_layer.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import torch
28
from torch import nn
39
from torch.nn import functional as F
410

511
# this feature requires CUDA and SM89+
612
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
713

8-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
14+
from torchao.prototype.moe_training.conversion_utils import (
15+
MoEScalingType,
16+
MoETrainingConfig,
17+
)
918
from torchao.quantization.quant_api import quantize_
1019

1120
# this example uses torchtitan llama4 MoE, see
1221
try:
13-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14-
from torchtitan.experiments.llama4.model.moe import MoE
22+
from torchtitan.models.moe import MoE, MoEArgs
23+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
1524
except ImportError as e:
1625
raise ImportError(
1726
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
1827
) from e
1928

2029

30+
from argparse import ArgumentParser
31+
32+
parser = ArgumentParser()
33+
parser.add_argument(
34+
"--scaling_type",
35+
type=str,
36+
default="fp8_rowwise",
37+
choices=["fp8_rowwise", "mxfp8"],
38+
)
39+
args = parser.parse_args()
40+
41+
2142
# initialize model
2243
device = torch.device("cuda")
23-
model_args = TransformerModelArgs(
24-
moe_enabled=True,
25-
num_experts=8,
26-
dim=256,
27-
)
28-
model = MoE(model_args).to(torch.bfloat16).to(device)
44+
torch.manual_seed(42)
45+
model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True)
46+
dim = 1024
47+
hidden_dim = dim * 4
48+
model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device)
2949
init_std = 0.02
3050
model.init_weights(init_std, device)
3151

@@ -40,14 +60,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4060
return False
4161

4262

43-
# quantize the model
44-
config = MoETrainingConfig()
63+
if args.scaling_type == "fp8_rowwise":
64+
config = MoETrainingConfig()
65+
alignment_size = 16
66+
67+
elif args.scaling_type == "mxfp8":
68+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
69+
alignment_size = 32
70+
4571
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
72+
set_token_group_alignment_size_m(alignment_size)
4673

4774
# training loop
4875
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
4976
for step in range(10):
50-
batch, seq, dim = 8, 2048, 256
77+
batch, seq = 8, 2048
5178
x = torch.randn(
5279
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
5380
)

0 commit comments

Comments
 (0)