Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 221 additions & 8 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@
can_auto_enable_marlin_fp8,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
dispatch_w8a8_mxfp8_linear,
flashinfer_mxfp8_blockscaled_linear,
flashinfer_mxfp8_quantize,
input_to_float8,
mxfp8_group_quantize,
normalize_e4m3fn_to_e4m3fnuz,
requant_weight_ue8m0_inplace,
triton_mxfp8_blockscaled_linear,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
Expand All @@ -73,11 +75,13 @@
get_bool_env_var,
is_cpu,
is_cuda,
is_flashinfer_available,
is_hip,
is_npu,
is_sm90_supported,
is_sm100_supported,
log_info_on_rank0,
next_power_of_2,
print_warning_once,
set_weight_attrs,
use_intel_amx_backend,
Expand All @@ -102,6 +106,11 @@
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight

if is_flashinfer_available():
from flashinfer import block_scale_interleave as flashinfer_block_scale_interleave
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from flashinfer.fused_moe.core import ActivationType as FlashInferActivationType


ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -247,7 +256,16 @@ def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
self.block_quant = (
self.use_mxfp8 or self.quant_config.weight_block_size is not None
)
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
self.w8a8_block_fp8_linear = None
self.w8a8_mxfp8_linear = None
if self.use_mxfp8:
self.w8a8_mxfp8_linear = dispatch_w8a8_mxfp8_linear()
else:
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
self.use_flashinfer_mxfp8_linear = (
self.use_mxfp8
and self.w8a8_mxfp8_linear is flashinfer_mxfp8_blockscaled_linear
)
self.is_checkpoint_fp8_serialized = (
self.quant_config.is_checkpoint_fp8_serialized
)
Expand Down Expand Up @@ -420,6 +438,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
# Keep parameter object to preserve weight_loader attrs for hot reload.
layer.weight_scale_inv.requires_grad_(False)
layer.weight_scale_inv.format_ue8m0 = True
self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer)
return
else:
# For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0
Expand Down Expand Up @@ -453,6 +472,58 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
layer.weight.data = weight.data
layer.weight_scale_inv.data = weight_scale.data

def _maybe_prepare_flashinfer_mxfp8_weight_scale(self, layer: Module) -> None:
if not self.use_flashinfer_mxfp8_linear:
return

if not is_sm100_supported() or not is_flashinfer_available():
raise RuntimeError(
"FlashInfer MXFP8 weight-scale swizzle requires SM100 and FlashInfer."
)

scale_u8 = layer.weight_scale_inv.data
if scale_u8.dtype != torch.uint8:
raise TypeError(f"Expected uint8 scale tensor, got {scale_u8.dtype}.")
if scale_u8.ndim != 2:
raise ValueError(
f"Expected 2D scale tensor, got shape {tuple(scale_u8.shape)}."
)
new_swizzled = flashinfer_block_scale_interleave(
scale_u8.contiguous()
).contiguous()

cached_swizzled = getattr(layer, "weight_scale_inv_swizzled", None)
# Keep storage address stable when possible so CUDA graph captures remain valid.
if (
cached_swizzled is not None
and cached_swizzled.shape == new_swizzled.shape
and cached_swizzled.device == new_swizzled.device
and cached_swizzled.dtype == new_swizzled.dtype
):
cached_swizzled.copy_(new_swizzled)
else:
layer.weight_scale_inv_swizzled = new_swizzled
layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version
layer._weight_scale_inv_swizzled_src_data_ptr = (
layer.weight_scale_inv.data_ptr()
)

def _get_mxfp8_weight_scale(self, layer: Module) -> torch.Tensor:
if self.use_flashinfer_mxfp8_linear:
swizzled = getattr(layer, "weight_scale_inv_swizzled", None)
src_version = getattr(layer, "_weight_scale_inv_swizzled_src_version", -1)
src_data_ptr = getattr(layer, "_weight_scale_inv_swizzled_src_data_ptr", -1)
if (
swizzled is None
or swizzled.device != layer.weight_scale_inv.device
or swizzled.dtype != torch.uint8
or src_version != layer.weight_scale_inv._version
or src_data_ptr != layer.weight_scale_inv.data_ptr()
):
self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer)
return layer.weight_scale_inv_swizzled
return layer.weight_scale_inv

def _quantize_mxfp8_weights(self, layer: Module) -> None:
weight = layer.weight.data
qweight, weight_scale = mxfp8_group_quantize(weight)
Expand All @@ -468,6 +539,7 @@ def _quantize_mxfp8_weights(self, layer: Module) -> None:
"weight_scale_inv", Parameter(weight_scale, requires_grad=False)
)
layer.weight_scale_inv.format_ue8m0 = True
self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer)
layer.input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
Expand Down Expand Up @@ -600,18 +672,20 @@ def apply(
)

if self.use_mxfp8:
assert self.w8a8_mxfp8_linear is not None
weight_scale = self._get_mxfp8_weight_scale(layer)
if isinstance(x, tuple):
return triton_mxfp8_blockscaled_linear(
return self.w8a8_mxfp8_linear(
input=x[0],
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
weight_scale=weight_scale,
input_scale=x[1],
bias=bias,
)
return triton_mxfp8_blockscaled_linear(
return self.w8a8_mxfp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
weight_scale=weight_scale,
input_scale=None,
bias=bias,
)
Expand Down Expand Up @@ -677,12 +751,38 @@ def __init__(self, quant_config: Fp8Config):
self.block_quant = (
self.use_mxfp8 or self.quant_config.weight_block_size is not None
)
if get_moe_runner_backend().is_cutlass():
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_cutlass():
assert (
cutlass_fp8_supported()
), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89"
assert self.block_quant, "cutlass_fp8 MoE requires block quantization"
assert is_sm100_supported() or is_sm90_supported()
if moe_runner_backend.is_flashinfer_cutlass():
assert self.use_mxfp8, (
"flashinfer_cutlass FP8 MoE currently requires MXFP8 "
"(quantization=mxfp8)."
)
assert is_sm100_supported(), "flashinfer_cutlass MXFP8 MoE requires SM100."
assert (
is_flashinfer_available()
), "flashinfer_cutlass backend requires FlashInfer."
assert (
flashinfer_mxfp8_quantize is not None
), "flashinfer_cutlass MXFP8 MoE requires flashinfer.mxfp8_quantize."

@property
def load_up_proj_weight_first(self) -> bool:
# FlashInfer CUTLASS kernel assumes [Up, Gate] order for gated W13.
moe_runner_config = getattr(self, "moe_runner_config", None)
is_gated = moe_runner_config is None or moe_runner_config.is_gated
moe_runner_backend = get_moe_runner_backend()
return (
self.use_mxfp8
and moe_runner_backend.is_flashinfer_cutlass()
and is_flashinfer_available()
and is_gated
)

@staticmethod
def is_deepgemm_moe_runner_backend_enabled() -> bool:
Expand Down Expand Up @@ -1059,7 +1159,11 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor):
return qweight, scale

if quantize:
if get_moe_runner_backend().is_cutlass():
moe_runner_backend = get_moe_runner_backend()
if (
moe_runner_backend.is_cutlass()
or moe_runner_backend.is_flashinfer_cutlass()
):
w13_q, w13_s = _quantize_and_swizzle_with_cutlass_es_kernel(
layer.w13_weight.data
)
Expand Down Expand Up @@ -1335,6 +1439,94 @@ def create_moe_runner(
# TODO(cwan): refactor other backends
pass

def _apply_flashinfer_cutlass_mxfp8(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
moe_runner_config: MoeRunnerConfig,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not is_flashinfer_available() or flashinfer_mxfp8_quantize is None:
raise RuntimeError(
"FlashInfer MXFP8 MoE requires flashinfer.fused_moe.cutlass_fused_moe "
"and flashinfer.mxfp8_quantize."
)
if layer.w13_weight_scale_inv.dtype != torch.uint8:
raise TypeError("w13_weight_scale_inv must be uint8 UE8M0 for MXFP8.")
if layer.w2_weight_scale_inv.dtype != torch.uint8:
raise TypeError("w2_weight_scale_inv must be uint8 UE8M0 for MXFP8.")
if layer.w13_weight_scale_inv.shape[-1] % 4 != 0:
raise ValueError(
"MXFP8 FlashInfer MoE requires w13 scale last dim divisible by 4."
)
if layer.w2_weight_scale_inv.shape[-1] % 4 != 0:
raise ValueError(
"MXFP8 FlashInfer MoE requires w2 scale last dim divisible by 4."
)

activation = moe_runner_config.activation
is_gated = moe_runner_config.is_gated
if activation == "silu" and is_gated:
activation_type = FlashInferActivationType.Swiglu
elif activation == "gelu" and not is_gated:
activation_type = FlashInferActivationType.Relu2
else:
raise ValueError(
"FlashInfer MXFP8 MoE supports only silu-gated or gelu-nongated "
f"activation, but got activation={activation!r}, is_gated={is_gated}."
)
if moe_runner_config.apply_router_weight_on_input:
raise ValueError(
"apply_router_weight_on_input is not supported for FlashInfer MXFP8 MoE."
)

x_2d = x.contiguous()
x_mxfp8, x_sf = flashinfer_mxfp8_quantize(
x_2d, is_sf_swizzled_layout=True, alignment=32
)
x_mxfp8 = x_mxfp8[:, : x_2d.shape[1]].contiguous()
x_sf = x_sf.contiguous()

w13_scale_block = layer.w13_weight_scale_inv.contiguous().view(torch.int32)
w2_scale_block = layer.w2_weight_scale_inv.contiguous().view(torch.int32)
Comment on lines +1492 to +1493
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conversion from uint8 (the dtype of weight_scale_inv for mxfp8) to int32 using .view() is implicit and relies on the memory layout. For better code clarity and maintainability, it would be beneficial to add a comment explaining this conversion, especially since it's part of a low-level kernel integration.

        # The flashinfer kernel expects scales to be packed as int32, where each int32 contains four uint8 scales.
        # The shape check `shape[-1] % 4 == 0` above ensures this view is safe.
        w13_scale_block = layer.w13_weight_scale_inv.contiguous().view(torch.int32)
        w2_scale_block = layer.w2_weight_scale_inv.contiguous().view(torch.int32)

num_experts = layer.w13_weight.shape[0]
global_scale = getattr(layer, "_flashinfer_mxfp8_global_scale", None)
if (
global_scale is None
or global_scale.shape != (num_experts,)
or global_scale.device != x.device
or global_scale.dtype != torch.float32
):
global_scale = torch.ones(num_experts, dtype=torch.float32, device=x.device)
layer._flashinfer_mxfp8_global_scale = global_scale

if output is None:
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
output = torch.empty_like(x)

return flashinfer_cutlass_fused_moe(
output=output,
input=x_mxfp8,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
output_dtype=x.dtype,
quant_scales=[w13_scale_block, global_scale, w2_scale_block, global_scale],
input_sf=x_sf,
use_mxfp8_act_scaling=True,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
activation_type=activation_type,
)[0]

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -1345,6 +1537,7 @@ def apply(

x = dispatch_output.hidden_states
moe_runner_config = self.moe_runner_config
moe_runner_backend = get_moe_runner_backend()

if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
Expand Down Expand Up @@ -1420,6 +1613,26 @@ def apply(
)
return StandardCombineInput(hidden_states=output)

if self.use_mxfp8 and get_moe_runner_backend().is_flashinfer_cutlass():
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker

topk_weights, topk_ids, _ = dispatch_output.topk_output
preallocated_output = (
dispatch_output.moe_output
if DispatchOutputChecker.format_is_flashinfer(dispatch_output)
else None
)
output = self._apply_flashinfer_cutlass_mxfp8(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_runner_config=moe_runner_config,
output=preallocated_output,
)
return StandardCombineInput(hidden_states=output)


if self.runner.runner_backend.is_deep_gemm():

w13_weight = layer.w13_weight
Expand Down
Loading