1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD 3-Clause license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
17import torch
28from torch import nn
39from torch .nn import functional as F
410
511# this feature requires CUDA and SM89+
612assert torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
713
8- from torchao .prototype .moe_training .conversion_utils import MoETrainingConfig
14+ from torchao .prototype .moe_training .conversion_utils import (
15+ MoEScalingType ,
16+ MoETrainingConfig ,
17+ )
918from torchao .quantization .quant_api import quantize_
1019
1120# this example uses torchtitan llama4 MoE, see
1221try :
13- from torchtitan .experiments . llama4 . model . args import TransformerModelArgs
14- from torchtitan .experiments . llama4 . model . moe import MoE
22+ from torchtitan .models . moe import MoE , MoEArgs
23+ from torchtitan .models . moe . utils import set_token_group_alignment_size_m
1524except ImportError as e :
1625 raise ImportError (
1726 "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
1827 ) from e
1928
2029
30+ from argparse import ArgumentParser
31+
32+ parser = ArgumentParser ()
33+ parser .add_argument (
34+ "--scaling_type" ,
35+ type = str ,
36+ default = "fp8_rowwise" ,
37+ choices = ["fp8_rowwise" , "mxfp8" ],
38+ )
39+ args = parser .parse_args ()
40+
41+
2142# initialize model
2243device = torch .device ("cuda" )
23- model_args = TransformerModelArgs (
24- moe_enabled = True ,
25- num_experts = 8 ,
26- dim = 256 ,
27- )
28- model = MoE (model_args ).to (torch .bfloat16 ).to (device )
44+ torch .manual_seed (42 )
45+ model_args = MoEArgs (num_experts = 8 , top_k = 2 , use_grouped_mm = True )
46+ dim = 1024
47+ hidden_dim = dim * 4
48+ model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).to (device )
2949init_std = 0.02
3050model .init_weights (init_std , device )
3151
@@ -40,14 +60,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4060 return False
4161
4262
43- # quantize the model
44- config = MoETrainingConfig ()
63+ if args .scaling_type == "fp8_rowwise" :
64+ config = MoETrainingConfig ()
65+ alignment_size = 16
66+
67+ elif args .scaling_type == "mxfp8" :
68+ config = MoETrainingConfig (scaling_type = MoEScalingType .MXFP8 )
69+ alignment_size = 32
70+
4571quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
72+ set_token_group_alignment_size_m (alignment_size )
4673
4774# training loop
4875optimizer = torch .optim .AdamW (model .parameters (), lr = 1e-3 )
4976for step in range (10 ):
50- batch , seq , dim = 8 , 2048 , 256
77+ batch , seq = 8 , 2048
5178 x = torch .randn (
5279 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
5380 )
0 commit comments