Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_tpu_quant_method)
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
VllmCompressedTensorsW8A8Fp8MoEMethod
VllmCompressedTensorsMoEMethod
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
VllmCompressedTensorsW8A8Fp8
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
Expand Down Expand Up @@ -113,8 +113,9 @@ def get_quant_method(
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, FusedMoE):
return VllmCompressedTensorsW8A8Fp8MoEMethod(
self, layer.quant_config, self.mesh)
layer.moe_config = self.get_moe_config(layer)
return VllmCompressedTensorsMoEMethod.get_moe_method(
self, layer, layer_name=prefix)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import QuantizationArgs
from jax.experimental.layout import Format, Layout
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
Expand All @@ -12,52 +13,89 @@
from torchax.ops.mappings import t2j
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
CompressedTensorsConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
CompressedTensorsW8A8Fp8MoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)

from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
from tpu_inference.layers.vllm.quantization.unquantized import \
VllmUnquantizedFusedMoEMethod

logger = init_logger(__name__)


class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):

@staticmethod
def get_moe_method(
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
layer_name: str,
) -> CompressedTensorsMoEMethod:

assert isinstance(layer, FusedMoE)

# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [
layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
all_scheme_dicts = [
quant_config.get_scheme_dict(layer, name) for name in unfused_names
]
scheme_dict = all_scheme_dicts.pop()

# multiple schemes found
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
raise ValueError("All MoE projections need to have same "
"quantization scheme but found multiple")

if scheme_dict is None:
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
quant_config.mesh)

weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")

if quant_config._is_fp8_w8a8(weight_quant, input_quant):
return VllmCompressedTensorsW8A8Fp8MoEMethod(
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")


class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
JaxCommonConfig):

def __init__(self, quant_config: "CompressedTensorsConfig",
moe: FusedMoEConfig, mesh: Mesh):
super().__init__(quant_config, moe)
def __init__(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs, moe: FusedMoEConfig,
mesh: Mesh):
super().__init__(weight_quant, input_quant, moe)
self.mesh = mesh
self.quant_config = quant_config

# disable GPU paths
self.use_marlin = False
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.is_fp8_w8a8_sm100 = False
self.use_cutlass = False
self.disable_expert_map = False

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert isinstance(layer, FusedMoE)

intermediate_size = layer.w13_weight.shape[1] // 2
w1_weight = layer.w13_weight[:, :intermediate_size]
w3_weight = layer.w13_weight[:, intermediate_size:]
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]

w13_weight = t2j(layer.w13_weight, use_dlpack=False)
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
use_dlpack=False)
w1_weight = t2j(w1_weight, use_dlpack=False)
w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
use_dlpack=False)
w3_weight = t2j(w3_weight, use_dlpack=False)
w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
use_dlpack=False)
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)

w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)

num_experts, hidden_size, intermediate_size = w2_weight.shape
assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
assert w13_weight.shape == (num_experts, 2 * intermediate_size,
hidden_size)
assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
1)

w1_weight, w3_weight = jnp.split(w13_weight, 2, 1)
w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1)

if layer.use_ep:
format = Format(Layout((0, 1, 2)),
Expand All @@ -69,16 +107,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w2_weight = jax.device_put(w2_weight, format)
w2_weight_scale = jax.device_put(w2_weight_scale, format)
else:
assert intermediate_size == w2_weight.shape[-1]
n_shards = self.mesh.shape["model"]
assert intermediate_size % n_shards == 0

# TODO: enable this if using fused weights
# output_sizes = [intermediate_size, intermediate_size]
# w13_weight = reorder_concatenated_tensor_for_sharding(
# w13_weight, output_sizes, n_shards, dim=1
# )

w13_format = Format(
Layout((0, 1, 2)),
NamedSharding(self.mesh, P(None, "model", None)))
Expand Down Expand Up @@ -128,12 +159,7 @@ def apply(
raise NotImplementedError(
"Only softmax is supported for scoring_func")

# import sys
# sys.stdin = open(0)
# breakpoint()

# TODO: Use MoE kernel when it supports fp8

seqlen = x.shape[0]

expert_weights = F.softmax(router_logits, dim=-1)
Expand Down
2 changes: 0 additions & 2 deletions tpu_inference/layers/vllm/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module):
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)

pass

def apply(
self,
layer: torch.nn.Module,
Expand Down