diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py index 985f2c6c4..89c89e97a 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -20,7 +20,7 @@ get_tpu_quant_method) from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \ - VllmCompressedTensorsW8A8Fp8MoEMethod + VllmCompressedTensorsMoEMethod from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \ VllmCompressedTensorsW8A8Fp8 from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \ @@ -113,8 +113,9 @@ def get_quant_method( layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, FusedMoE): - return VllmCompressedTensorsW8A8Fp8MoEMethod( - self, layer.quant_config, self.mesh) + layer.moe_config = self.get_moe_config(layer) + return VllmCompressedTensorsMoEMethod.get_moe_method( + self, layer, layer_name=prefix) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) return None diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py index bc9721d8c..bfe4d4ef3 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import torch import torch.nn.functional as F +from compressed_tensors.quantization import QuantizationArgs from jax.experimental.layout import Format, Layout from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -12,52 +13,89 @@ from torchax.ops.mappings import t2j from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \ - CompressedTensorsConfig -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \ - CompressedTensorsW8A8Fp8MoEMethod -from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa - WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod) from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig +from tpu_inference.layers.vllm.quantization.unquantized import \ + VllmUnquantizedFusedMoEMethod logger = init_logger(__name__) +class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod): + + @staticmethod + def get_moe_method( + quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501 + layer: torch.nn.Module, + layer_name: str, + ) -> CompressedTensorsMoEMethod: + + assert isinstance(layer, FusedMoE) + + # FusedMoE was made by combining multiple Linears so need to + # make sure quantization config for Linear can target it + quant_config._add_fused_moe_to_target_scheme_map() + unfused_names = [ + layer_name + proj_name + for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] + ] + # TODO: refactor this to use expert_mapping and check all layer numbers + all_scheme_dicts = [ + quant_config.get_scheme_dict(layer, name) for name in unfused_names + ] + scheme_dict = all_scheme_dicts.pop() + + # multiple schemes found + if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]): + raise ValueError("All MoE projections need to have same " + "quantization scheme but found multiple") + + if scheme_dict is None: + return VllmUnquantizedFusedMoEMethod(layer.moe_config, + quant_config.mesh) + + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + if quant_config._is_fp8_w8a8(weight_quant, input_quant): + return VllmCompressedTensorsW8A8Fp8MoEMethod( + weight_quant, input_quant, layer.moe_config, quant_config.mesh) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod, JaxCommonConfig): - def __init__(self, quant_config: "CompressedTensorsConfig", - moe: FusedMoEConfig, mesh: Mesh): - super().__init__(quant_config, moe) + def __init__(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, + mesh: Mesh): + super().__init__(weight_quant, input_quant, moe) self.mesh = mesh - self.quant_config = quant_config - - # disable GPU paths - self.use_marlin = False - self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() - self.is_fp8_w8a8_sm100 = False - self.use_cutlass = False - self.disable_expert_map = False def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert isinstance(layer, FusedMoE) - intermediate_size = layer.w13_weight.shape[1] // 2 - w1_weight = layer.w13_weight[:, :intermediate_size] - w3_weight = layer.w13_weight[:, intermediate_size:] - w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size] - w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:] - + w13_weight = t2j(layer.w13_weight, use_dlpack=False) + w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False) w2_weight = t2j(layer.w2_weight, use_dlpack=False) - w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16), - use_dlpack=False) - w1_weight = t2j(w1_weight, use_dlpack=False) - w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16), - use_dlpack=False) - w3_weight = t2j(w3_weight, use_dlpack=False) - w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16), - use_dlpack=False) + w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False) + + w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16) + w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16) + + num_experts, hidden_size, intermediate_size = w2_weight.shape + assert w2_weight_scale.shape == (num_experts, hidden_size, 1) + assert w13_weight.shape == (num_experts, 2 * intermediate_size, + hidden_size) + assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size, + 1) + + w1_weight, w3_weight = jnp.split(w13_weight, 2, 1) + w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1) if layer.use_ep: format = Format(Layout((0, 1, 2)), @@ -69,16 +107,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_weight = jax.device_put(w2_weight, format) w2_weight_scale = jax.device_put(w2_weight_scale, format) else: - assert intermediate_size == w2_weight.shape[-1] n_shards = self.mesh.shape["model"] assert intermediate_size % n_shards == 0 - # TODO: enable this if using fused weights - # output_sizes = [intermediate_size, intermediate_size] - # w13_weight = reorder_concatenated_tensor_for_sharding( - # w13_weight, output_sizes, n_shards, dim=1 - # ) - w13_format = Format( Layout((0, 1, 2)), NamedSharding(self.mesh, P(None, "model", None))) @@ -128,12 +159,7 @@ def apply( raise NotImplementedError( "Only softmax is supported for scoring_func") - # import sys - # sys.stdin = open(0) - # breakpoint() - # TODO: Use MoE kernel when it supports fp8 - seqlen = x.shape[0] expert_weights = F.softmax(router_logits, dim=-1) diff --git a/tpu_inference/layers/vllm/quantization/mxfp4.py b/tpu_inference/layers/vllm/quantization/mxfp4.py index 3ff9726d7..5fdf12c84 100644 --- a/tpu_inference/layers/vllm/quantization/mxfp4.py +++ b/tpu_inference/layers/vllm/quantization/mxfp4.py @@ -261,8 +261,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False) layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False) - pass - def apply( self, layer: torch.nn.Module,