44import jax .numpy as jnp
55import torch
66import torch .nn .functional as F
7+ from compressed_tensors .quantization import QuantizationArgs
78from jax .experimental .layout import Format , Layout
89from jax .sharding import Mesh , NamedSharding
910from jax .sharding import PartitionSpec as P
1213from torchax .ops .mappings import t2j
1314from vllm .logger import init_logger
1415from vllm .model_executor .layers .fused_moe import FusedMoE , FusedMoEConfig
15- from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import \
16- CompressedTensorsConfig
17- from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors_moe import \
18- CompressedTensorsW8A8Fp8MoEMethod
19- from vllm .model_executor .layers .quantization .compressed_tensors .schemes .compressed_tensors_wNa16 import ( # noqa
20- WNA16_SUPPORTED_BITS , WNA16_SUPPORTED_TYPES_MAP )
16+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors_moe import ( # noqa: E501
17+ CompressedTensorsMoEMethod , CompressedTensorsW8A8Fp8MoEMethod )
2118
2219from tpu_inference .layers .vllm .quantization .common import JaxCommonConfig
20+ from tpu_inference .layers .vllm .quantization .unquantized import \
21+ VllmUnquantizedFusedMoEMethod
2322
2423logger = init_logger (__name__ )
2524
2625
26+ class VllmCompressedTensorsMoEMethod (CompressedTensorsMoEMethod ):
27+
28+ @staticmethod
29+ def get_moe_method (
30+ quant_config : "VllmCompressedTensorsConfig" , # type: ignore # noqa E501
31+ layer : torch .nn .Module ,
32+ layer_name : str ,
33+ ) -> CompressedTensorsMoEMethod :
34+
35+ assert isinstance (layer , FusedMoE )
36+
37+ # FusedMoE was made by combining multiple Linears so need to
38+ # make sure quantization config for Linear can target it
39+ quant_config ._add_fused_moe_to_target_scheme_map ()
40+ unfused_names = [
41+ layer_name + proj_name
42+ for proj_name in [".0.gate_proj" , ".0.up_proj" , ".0.down_proj" ]
43+ ]
44+ # TODO: refactor this to use expert_mapping and check all layer numbers
45+ all_scheme_dicts = [
46+ quant_config .get_scheme_dict (layer , name ) for name in unfused_names
47+ ]
48+ scheme_dict = all_scheme_dicts .pop ()
49+
50+ # multiple schemes found
51+ if not all ([cur_dict == scheme_dict for cur_dict in all_scheme_dicts ]):
52+ raise ValueError ("All MoE projections need to have same "
53+ "quantization scheme but found multiple" )
54+
55+ if scheme_dict is None :
56+ return VllmUnquantizedFusedMoEMethod (layer .moe_config ,
57+ quant_config .mesh )
58+
59+ weight_quant = scheme_dict .get ("weights" )
60+ input_quant = scheme_dict .get ("input_activations" )
61+
62+ if quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
63+ return VllmCompressedTensorsW8A8Fp8MoEMethod (
64+ weight_quant , input_quant , layer .moe_config , quant_config .mesh )
65+ else :
66+ raise RuntimeError (
67+ f"Unsupported FusedMoe scheme: { weight_quant } , { input_quant } " )
68+
69+
2770class VllmCompressedTensorsW8A8Fp8MoEMethod (CompressedTensorsW8A8Fp8MoEMethod ,
2871 JaxCommonConfig ):
2972
30- def __init__ (self , quant_config : "CompressedTensorsConfig" ,
31- moe : FusedMoEConfig , mesh : Mesh ):
32- super ().__init__ (quant_config , moe )
73+ def __init__ (self , weight_quant : QuantizationArgs ,
74+ input_quant : QuantizationArgs , moe : FusedMoEConfig ,
75+ mesh : Mesh ):
76+ super ().__init__ (weight_quant , input_quant , moe )
3377 self .mesh = mesh
34- self .quant_config = quant_config
35-
36- # disable GPU paths
37- self .use_marlin = False
38- self .rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39- self .is_fp8_w8a8_sm100 = False
40- self .use_cutlass = False
41- self .disable_expert_map = False
4278
4379 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
4480 assert isinstance (layer , FusedMoE )
4581
46- intermediate_size = layer .w13_weight .shape [1 ] // 2
47- w1_weight = layer .w13_weight [:, :intermediate_size ]
48- w3_weight = layer .w13_weight [:, intermediate_size :]
49- w1_weight_scale = layer .w13_weight_scale [:, :intermediate_size ]
50- w3_weight_scale = layer .w13_weight_scale [:, intermediate_size :]
51-
82+ w13_weight = t2j (layer .w13_weight , use_dlpack = False )
83+ w13_weight_scale = t2j (layer .w13_weight_scale , use_dlpack = False )
5284 w2_weight = t2j (layer .w2_weight , use_dlpack = False )
53- w2_weight_scale = t2j (layer .w2_weight_scale .to (torch .bfloat16 ),
54- use_dlpack = False )
55- w1_weight = t2j (w1_weight , use_dlpack = False )
56- w1_weight_scale = t2j (w1_weight_scale .to (torch .bfloat16 ),
57- use_dlpack = False )
58- w3_weight = t2j (w3_weight , use_dlpack = False )
59- w3_weight_scale = t2j (w3_weight_scale .to (torch .bfloat16 ),
60- use_dlpack = False )
85+ w2_weight_scale = t2j (layer .w2_weight_scale , use_dlpack = False )
86+
87+ w13_weight_scale = w13_weight_scale .astype (jnp .bfloat16 )
88+ w2_weight_scale = w2_weight_scale .astype (jnp .bfloat16 )
89+
90+ num_experts , hidden_size , intermediate_size = w2_weight .shape
91+ assert w2_weight_scale .shape == (num_experts , hidden_size , 1 )
92+ assert w13_weight .shape == (num_experts , 2 * intermediate_size ,
93+ hidden_size )
94+ assert w13_weight_scale .shape == (num_experts , 2 * intermediate_size ,
95+ 1 )
96+
97+ w1_weight , w3_weight = jnp .split (w13_weight , 2 , 1 )
98+ w1_weight_scale , w3_weight_scale = jnp .split (w13_weight_scale , 2 , 1 )
6199
62100 if layer .use_ep :
63101 format = Format (Layout ((0 , 1 , 2 )),
@@ -69,16 +107,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
69107 w2_weight = jax .device_put (w2_weight , format )
70108 w2_weight_scale = jax .device_put (w2_weight_scale , format )
71109 else :
72- assert intermediate_size == w2_weight .shape [- 1 ]
73110 n_shards = self .mesh .shape ["model" ]
74111 assert intermediate_size % n_shards == 0
75112
76- # TODO: enable this if using fused weights
77- # output_sizes = [intermediate_size, intermediate_size]
78- # w13_weight = reorder_concatenated_tensor_for_sharding(
79- # w13_weight, output_sizes, n_shards, dim=1
80- # )
81-
82113 w13_format = Format (
83114 Layout ((0 , 1 , 2 )),
84115 NamedSharding (self .mesh , P (None , "model" , None )))
@@ -128,12 +159,7 @@ def apply(
128159 raise NotImplementedError (
129160 "Only softmax is supported for scoring_func" )
130161
131- # import sys
132- # sys.stdin = open(0)
133- # breakpoint()
134-
135162 # TODO: Use MoE kernel when it supports fp8
136-
137163 seqlen = x .shape [0 ]
138164
139165 expert_weights = F .softmax (router_logits , dim = - 1 )
0 commit comments