Skip to content

Commit ae5dd2e

Browse files
authored
[Bugfix][Refactor] Fix compressed tensor moe init (#1283)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent cceae1c commit ae5dd2e

File tree

3 files changed

+73
-48
lines changed

3 files changed

+73
-48
lines changed

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_tpu_quant_method)
2121
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
2222
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23-
VllmCompressedTensorsW8A8Fp8MoEMethod
23+
VllmCompressedTensorsMoEMethod
2424
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
2525
VllmCompressedTensorsW8A8Fp8
2626
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -113,8 +113,9 @@ def get_quant_method(
113113
layer.scheme = scheme
114114
return CompressedTensorsLinearMethod(self)
115115
if isinstance(layer, FusedMoE):
116-
return VllmCompressedTensorsW8A8Fp8MoEMethod(
117-
self, layer.quant_config, self.mesh)
116+
layer.moe_config = self.get_moe_config(layer)
117+
return VllmCompressedTensorsMoEMethod.get_moe_method(
118+
self, layer, layer_name=prefix)
118119
if isinstance(layer, Attention):
119120
return CompressedTensorsKVCacheMethod(self)
120121
return None

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax.numpy as jnp
55
import torch
66
import torch.nn.functional as F
7+
from compressed_tensors.quantization import QuantizationArgs
78
from jax.experimental.layout import Format, Layout
89
from jax.sharding import Mesh, NamedSharding
910
from jax.sharding import PartitionSpec as P
@@ -12,52 +13,89 @@
1213
from torchax.ops.mappings import t2j
1314
from vllm.logger import init_logger
1415
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
15-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
16-
CompressedTensorsConfig
17-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
18-
CompressedTensorsW8A8Fp8MoEMethod
19-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
20-
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
16+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
17+
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
2118

2219
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20+
from tpu_inference.layers.vllm.quantization.unquantized import \
21+
VllmUnquantizedFusedMoEMethod
2322

2423
logger = init_logger(__name__)
2524

2625

26+
class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
27+
28+
@staticmethod
29+
def get_moe_method(
30+
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
31+
layer: torch.nn.Module,
32+
layer_name: str,
33+
) -> CompressedTensorsMoEMethod:
34+
35+
assert isinstance(layer, FusedMoE)
36+
37+
# FusedMoE was made by combining multiple Linears so need to
38+
# make sure quantization config for Linear can target it
39+
quant_config._add_fused_moe_to_target_scheme_map()
40+
unfused_names = [
41+
layer_name + proj_name
42+
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
43+
]
44+
# TODO: refactor this to use expert_mapping and check all layer numbers
45+
all_scheme_dicts = [
46+
quant_config.get_scheme_dict(layer, name) for name in unfused_names
47+
]
48+
scheme_dict = all_scheme_dicts.pop()
49+
50+
# multiple schemes found
51+
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
52+
raise ValueError("All MoE projections need to have same "
53+
"quantization scheme but found multiple")
54+
55+
if scheme_dict is None:
56+
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
57+
quant_config.mesh)
58+
59+
weight_quant = scheme_dict.get("weights")
60+
input_quant = scheme_dict.get("input_activations")
61+
62+
if quant_config._is_fp8_w8a8(weight_quant, input_quant):
63+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
64+
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
65+
else:
66+
raise RuntimeError(
67+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
68+
69+
2770
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
2871
JaxCommonConfig):
2972

30-
def __init__(self, quant_config: "CompressedTensorsConfig",
31-
moe: FusedMoEConfig, mesh: Mesh):
32-
super().__init__(quant_config, moe)
73+
def __init__(self, weight_quant: QuantizationArgs,
74+
input_quant: QuantizationArgs, moe: FusedMoEConfig,
75+
mesh: Mesh):
76+
super().__init__(weight_quant, input_quant, moe)
3377
self.mesh = mesh
34-
self.quant_config = quant_config
35-
36-
# disable GPU paths
37-
self.use_marlin = False
38-
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39-
self.is_fp8_w8a8_sm100 = False
40-
self.use_cutlass = False
41-
self.disable_expert_map = False
4278

4379
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4480
assert isinstance(layer, FusedMoE)
4581

46-
intermediate_size = layer.w13_weight.shape[1] // 2
47-
w1_weight = layer.w13_weight[:, :intermediate_size]
48-
w3_weight = layer.w13_weight[:, intermediate_size:]
49-
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
50-
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
51-
82+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
83+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
5284
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
53-
w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
54-
use_dlpack=False)
55-
w1_weight = t2j(w1_weight, use_dlpack=False)
56-
w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
57-
use_dlpack=False)
58-
w3_weight = t2j(w3_weight, use_dlpack=False)
59-
w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
60-
use_dlpack=False)
85+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
86+
87+
w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
88+
w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
89+
90+
num_experts, hidden_size, intermediate_size = w2_weight.shape
91+
assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
92+
assert w13_weight.shape == (num_experts, 2 * intermediate_size,
93+
hidden_size)
94+
assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
95+
1)
96+
97+
w1_weight, w3_weight = jnp.split(w13_weight, 2, 1)
98+
w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1)
6199

62100
if layer.use_ep:
63101
format = Format(Layout((0, 1, 2)),
@@ -69,16 +107,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
69107
w2_weight = jax.device_put(w2_weight, format)
70108
w2_weight_scale = jax.device_put(w2_weight_scale, format)
71109
else:
72-
assert intermediate_size == w2_weight.shape[-1]
73110
n_shards = self.mesh.shape["model"]
74111
assert intermediate_size % n_shards == 0
75112

76-
# TODO: enable this if using fused weights
77-
# output_sizes = [intermediate_size, intermediate_size]
78-
# w13_weight = reorder_concatenated_tensor_for_sharding(
79-
# w13_weight, output_sizes, n_shards, dim=1
80-
# )
81-
82113
w13_format = Format(
83114
Layout((0, 1, 2)),
84115
NamedSharding(self.mesh, P(None, "model", None)))
@@ -128,12 +159,7 @@ def apply(
128159
raise NotImplementedError(
129160
"Only softmax is supported for scoring_func")
130161

131-
# import sys
132-
# sys.stdin = open(0)
133-
# breakpoint()
134-
135162
# TODO: Use MoE kernel when it supports fp8
136-
137163
seqlen = x.shape[0]
138164

139165
expert_weights = F.softmax(router_logits, dim=-1)

tpu_inference/layers/vllm/quantization/mxfp4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
261261
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
262262
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
263263

264-
pass
265-
266264
def apply(
267265
self,
268266
layer: torch.nn.Module,

0 commit comments

Comments
 (0)