diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4cbc3a7e3fc3..f2e9bf4c1cc1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -49,11 +49,13 @@ can_auto_enable_marlin_fp8, cutlass_fp8_supported, dispatch_w8a8_block_fp8_linear, + dispatch_w8a8_mxfp8_linear, + flashinfer_mxfp8_blockscaled_linear, + flashinfer_mxfp8_quantize, input_to_float8, mxfp8_group_quantize, normalize_e4m3fn_to_e4m3fnuz, requant_weight_ue8m0_inplace, - triton_mxfp8_blockscaled_linear, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.marlin_utils_fp8 import ( @@ -73,11 +75,13 @@ get_bool_env_var, is_cpu, is_cuda, + is_flashinfer_available, is_hip, is_npu, is_sm90_supported, is_sm100_supported, log_info_on_rank0, + next_power_of_2, print_warning_once, set_weight_attrs, use_intel_amx_backend, @@ -102,6 +106,11 @@ from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight +if is_flashinfer_available(): + from flashinfer import block_scale_interleave as flashinfer_block_scale_interleave + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + from flashinfer.fused_moe.core import ActivationType as FlashInferActivationType + ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -247,7 +256,16 @@ def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): self.block_quant = ( self.use_mxfp8 or self.quant_config.weight_block_size is not None ) - self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() + self.w8a8_block_fp8_linear = None + self.w8a8_mxfp8_linear = None + if self.use_mxfp8: + self.w8a8_mxfp8_linear = dispatch_w8a8_mxfp8_linear() + else: + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() + self.use_flashinfer_mxfp8_linear = ( + self.use_mxfp8 + and self.w8a8_mxfp8_linear is flashinfer_mxfp8_blockscaled_linear + ) self.is_checkpoint_fp8_serialized = ( self.quant_config.is_checkpoint_fp8_serialized ) @@ -420,6 +438,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: # Keep parameter object to preserve weight_loader attrs for hot reload. layer.weight_scale_inv.requires_grad_(False) layer.weight_scale_inv.format_ue8m0 = True + self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer) return else: # For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0 @@ -453,6 +472,58 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: layer.weight.data = weight.data layer.weight_scale_inv.data = weight_scale.data + def _maybe_prepare_flashinfer_mxfp8_weight_scale(self, layer: Module) -> None: + if not self.use_flashinfer_mxfp8_linear: + return + + if not is_sm100_supported() or not is_flashinfer_available(): + raise RuntimeError( + "FlashInfer MXFP8 weight-scale swizzle requires SM100 and FlashInfer." + ) + + scale_u8 = layer.weight_scale_inv.data + if scale_u8.dtype != torch.uint8: + raise TypeError(f"Expected uint8 scale tensor, got {scale_u8.dtype}.") + if scale_u8.ndim != 2: + raise ValueError( + f"Expected 2D scale tensor, got shape {tuple(scale_u8.shape)}." + ) + new_swizzled = flashinfer_block_scale_interleave( + scale_u8.contiguous() + ).contiguous() + + cached_swizzled = getattr(layer, "weight_scale_inv_swizzled", None) + # Keep storage address stable when possible so CUDA graph captures remain valid. + if ( + cached_swizzled is not None + and cached_swizzled.shape == new_swizzled.shape + and cached_swizzled.device == new_swizzled.device + and cached_swizzled.dtype == new_swizzled.dtype + ): + cached_swizzled.copy_(new_swizzled) + else: + layer.weight_scale_inv_swizzled = new_swizzled + layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version + layer._weight_scale_inv_swizzled_src_data_ptr = ( + layer.weight_scale_inv.data_ptr() + ) + + def _get_mxfp8_weight_scale(self, layer: Module) -> torch.Tensor: + if self.use_flashinfer_mxfp8_linear: + swizzled = getattr(layer, "weight_scale_inv_swizzled", None) + src_version = getattr(layer, "_weight_scale_inv_swizzled_src_version", -1) + src_data_ptr = getattr(layer, "_weight_scale_inv_swizzled_src_data_ptr", -1) + if ( + swizzled is None + or swizzled.device != layer.weight_scale_inv.device + or swizzled.dtype != torch.uint8 + or src_version != layer.weight_scale_inv._version + or src_data_ptr != layer.weight_scale_inv.data_ptr() + ): + self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer) + return layer.weight_scale_inv_swizzled + return layer.weight_scale_inv + def _quantize_mxfp8_weights(self, layer: Module) -> None: weight = layer.weight.data qweight, weight_scale = mxfp8_group_quantize(weight) @@ -468,6 +539,7 @@ def _quantize_mxfp8_weights(self, layer: Module) -> None: "weight_scale_inv", Parameter(weight_scale, requires_grad=False) ) layer.weight_scale_inv.format_ue8m0 = True + self._maybe_prepare_flashinfer_mxfp8_weight_scale(layer) layer.input_scale = None def process_weights_after_loading(self, layer: Module) -> None: @@ -600,18 +672,20 @@ def apply( ) if self.use_mxfp8: + assert self.w8a8_mxfp8_linear is not None + weight_scale = self._get_mxfp8_weight_scale(layer) if isinstance(x, tuple): - return triton_mxfp8_blockscaled_linear( + return self.w8a8_mxfp8_linear( input=x[0], weight=layer.weight, - weight_scale=layer.weight_scale_inv, + weight_scale=weight_scale, input_scale=x[1], bias=bias, ) - return triton_mxfp8_blockscaled_linear( + return self.w8a8_mxfp8_linear( input=x, weight=layer.weight, - weight_scale=layer.weight_scale_inv, + weight_scale=weight_scale, input_scale=None, bias=bias, ) @@ -677,12 +751,38 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = ( self.use_mxfp8 or self.quant_config.weight_block_size is not None ) - if get_moe_runner_backend().is_cutlass(): + moe_runner_backend = get_moe_runner_backend() + if moe_runner_backend.is_cutlass(): assert ( cutlass_fp8_supported() ), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89" assert self.block_quant, "cutlass_fp8 MoE requires block quantization" assert is_sm100_supported() or is_sm90_supported() + if moe_runner_backend.is_flashinfer_cutlass(): + assert self.use_mxfp8, ( + "flashinfer_cutlass FP8 MoE currently requires MXFP8 " + "(quantization=mxfp8)." + ) + assert is_sm100_supported(), "flashinfer_cutlass MXFP8 MoE requires SM100." + assert ( + is_flashinfer_available() + ), "flashinfer_cutlass backend requires FlashInfer." + assert ( + flashinfer_mxfp8_quantize is not None + ), "flashinfer_cutlass MXFP8 MoE requires flashinfer.mxfp8_quantize." + + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] order for gated W13. + moe_runner_config = getattr(self, "moe_runner_config", None) + is_gated = moe_runner_config is None or moe_runner_config.is_gated + moe_runner_backend = get_moe_runner_backend() + return ( + self.use_mxfp8 + and moe_runner_backend.is_flashinfer_cutlass() + and is_flashinfer_available() + and is_gated + ) @staticmethod def is_deepgemm_moe_runner_backend_enabled() -> bool: @@ -1059,7 +1159,11 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor): return qweight, scale if quantize: - if get_moe_runner_backend().is_cutlass(): + moe_runner_backend = get_moe_runner_backend() + if ( + moe_runner_backend.is_cutlass() + or moe_runner_backend.is_flashinfer_cutlass() + ): w13_q, w13_s = _quantize_and_swizzle_with_cutlass_es_kernel( layer.w13_weight.data ) @@ -1335,6 +1439,94 @@ def create_moe_runner( # TODO(cwan): refactor other backends pass + def _apply_flashinfer_cutlass_mxfp8( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + moe_runner_config: MoeRunnerConfig, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if not is_flashinfer_available() or flashinfer_mxfp8_quantize is None: + raise RuntimeError( + "FlashInfer MXFP8 MoE requires flashinfer.fused_moe.cutlass_fused_moe " + "and flashinfer.mxfp8_quantize." + ) + if layer.w13_weight_scale_inv.dtype != torch.uint8: + raise TypeError("w13_weight_scale_inv must be uint8 UE8M0 for MXFP8.") + if layer.w2_weight_scale_inv.dtype != torch.uint8: + raise TypeError("w2_weight_scale_inv must be uint8 UE8M0 for MXFP8.") + if layer.w13_weight_scale_inv.shape[-1] % 4 != 0: + raise ValueError( + "MXFP8 FlashInfer MoE requires w13 scale last dim divisible by 4." + ) + if layer.w2_weight_scale_inv.shape[-1] % 4 != 0: + raise ValueError( + "MXFP8 FlashInfer MoE requires w2 scale last dim divisible by 4." + ) + + activation = moe_runner_config.activation + is_gated = moe_runner_config.is_gated + if activation == "silu" and is_gated: + activation_type = FlashInferActivationType.Swiglu + elif activation == "gelu" and not is_gated: + activation_type = FlashInferActivationType.Relu2 + else: + raise ValueError( + "FlashInfer MXFP8 MoE supports only silu-gated or gelu-nongated " + f"activation, but got activation={activation!r}, is_gated={is_gated}." + ) + if moe_runner_config.apply_router_weight_on_input: + raise ValueError( + "apply_router_weight_on_input is not supported for FlashInfer MXFP8 MoE." + ) + + x_2d = x.contiguous() + x_mxfp8, x_sf = flashinfer_mxfp8_quantize( + x_2d, is_sf_swizzled_layout=True, alignment=32 + ) + x_mxfp8 = x_mxfp8[:, : x_2d.shape[1]].contiguous() + x_sf = x_sf.contiguous() + + w13_scale_block = layer.w13_weight_scale_inv.contiguous().view(torch.int32) + w2_scale_block = layer.w2_weight_scale_inv.contiguous().view(torch.int32) + num_experts = layer.w13_weight.shape[0] + global_scale = getattr(layer, "_flashinfer_mxfp8_global_scale", None) + if ( + global_scale is None + or global_scale.shape != (num_experts,) + or global_scale.device != x.device + or global_scale.dtype != torch.float32 + ): + global_scale = torch.ones(num_experts, dtype=torch.float32, device=x.device) + layer._flashinfer_mxfp8_global_scale = global_scale + + if output is None: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + output = torch.empty_like(x) + + return flashinfer_cutlass_fused_moe( + output=output, + input=x_mxfp8, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + output_dtype=x.dtype, + quant_scales=[w13_scale_block, global_scale, w2_scale_block, global_scale], + input_sf=x_sf, + use_mxfp8_act_scaling=True, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + activation_type=activation_type, + )[0] + def apply( self, layer: torch.nn.Module, @@ -1345,6 +1537,7 @@ def apply( x = dispatch_output.hidden_states moe_runner_config = self.moe_runner_config + moe_runner_backend = get_moe_runner_backend() if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu @@ -1420,6 +1613,26 @@ def apply( ) return StandardCombineInput(hidden_states=output) + if self.use_mxfp8 and get_moe_runner_backend().is_flashinfer_cutlass(): + from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker + + topk_weights, topk_ids, _ = dispatch_output.topk_output + preallocated_output = ( + dispatch_output.moe_output + if DispatchOutputChecker.format_is_flashinfer(dispatch_output) + else None + ) + output = self._apply_flashinfer_cutlass_mxfp8( + layer=layer, + x=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + moe_runner_config=moe_runner_config, + output=preallocated_output, + ) + return StandardCombineInput(hidden_states=output) + + if self.runner.runner_backend.is_deep_gemm(): w13_weight = layer.w13_weight diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3b9616e2798f..b964ab2a75a4 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -172,6 +172,8 @@ def _check_cutlass_block_fp8_hardware_support() -> bool: if is_blackwell_supported() and is_flashinfer_available(): + from flashinfer import mm_mxfp8 as flashinfer_mm_mxfp8 + from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize from flashinfer.gemm import gemm_fp8_nt_groupwise if is_sm90_supported() and is_flashinfer_available(): @@ -197,6 +199,23 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: return _dispatch_auto_backend() +def dispatch_w8a8_mxfp8_linear() -> Callable: + """Dispatch MXFP8 linear kernel by --fp8-gemm-backend. + + For MXFP8, Triton remains the default path. We only route to FlashInfer + when backend is explicitly set to flashinfer_trtllm. + """ + backend = get_fp8_gemm_runner_backend() + if backend.is_flashinfer_trtllm(): + if not (is_blackwell_supported() and is_flashinfer_available()): + raise RuntimeError( + "MXFP8 FlashInfer GEMM requested via --fp8-gemm-backend=flashinfer_trtllm, " + "but FlashInfer is unavailable or unsupported on this hardware." + ) + return flashinfer_mxfp8_blockscaled_linear + return triton_mxfp8_blockscaled_linear + + def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable: """Dispatch based on explicitly selected backend.""" if backend.is_flashinfer_trtllm(): @@ -730,6 +749,109 @@ def triton_mxfp8_blockscaled_linear( return output.to(dtype=output_dtype).view(*output_shape) +def _validate_mxfp8_scale_tensor( + scale: torch.Tensor, *, rows: int, k_scales: int, name: str +) -> None: + if scale.dtype != torch.uint8: + raise TypeError(f"{name} must be UE8M0 uint8.") + + if scale.ndim == 2: + expected_shape = (rows, k_scales) + if scale.shape != expected_shape: + raise ValueError( + f"Expected {name} shape {expected_shape}, got {scale.shape}." + ) + return + + if scale.ndim == 1: + expected_len = ceil_div(rows, 128) * 128 * ceil_div(k_scales, 4) * 4 + if scale.numel() != expected_len: + raise ValueError( + f"Expected swizzled {name} length {expected_len}, got {scale.numel()}." + ) + return + + raise ValueError(f"Expected {name} to be 1D or 2D, got {scale.ndim}D.") + + +def flashinfer_mxfp8_blockscaled_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """MXFP8 dense linear via FlashInfer mm_mxfp8.""" + if not (is_blackwell_supported() and is_flashinfer_available()): + raise RuntimeError( + "MXFP8 FlashInfer GEMM requested, but flashinfer.mm_mxfp8 is unavailable." + ) + + if not (_is_cuda and is_sm100_supported()): + raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100+).") + + input_2d = input.view(-1, input.shape[-1]).contiguous() + output_shape = [*input.shape[:-1], weight.shape[0]] + + m, k = input_2d.shape + n, k_w = weight.shape + if k != k_w: + raise ValueError(f"Input K={k} does not match weight K={k_w}.") + if k % 32 != 0: + raise ValueError(f"K={k} must be divisible by 32 for MXFP8.") + if weight.dtype != torch.float8_e4m3fn: + raise TypeError("MXFP8 weight must be FP8 E4M3.") + k_scales = k // 32 + _validate_mxfp8_scale_tensor( + weight_scale, rows=n, k_scales=k_scales, name="weight_scale" + ) + + if input_scale is None: + q_input, x_scale_u8 = flashinfer_mxfp8_quantize( + input_2d, is_sf_swizzled_layout=True, alignment=32 + ) + q_input = q_input[:, :k].contiguous() + x_scale_u8 = x_scale_u8.contiguous() + else: + q_input = input_2d + x_scale_u8 = input_scale.contiguous() + if q_input.dtype != torch.float8_e4m3fn: + raise TypeError( + "When input_scale is provided, input must be MXFP8 tensor " + "(torch.float8_e4m3fn)." + ) + _validate_mxfp8_scale_tensor( + x_scale_u8, rows=m, k_scales=k_scales, name="input_scale" + ) + + if output_dtype is None: + if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): + output_dtype = input_2d.dtype + else: + output_dtype = torch.bfloat16 + + # Ensure transposed tensors are contiguous for FlashInfer's internal runner. + weight_t = weight.contiguous().t() + weight_scale_t = ( + weight_scale.contiguous().t() + if weight_scale.ndim == 2 + else weight_scale.contiguous() + ) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + backend="auto", + ) + + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + def dequant_mxfp4( w_block: torch.Tensor, w_scale: torch.Tensor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 89688d34da54..dddebd2ac7af 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2105,19 +2105,28 @@ def _handle_data_parallelism(self): def _handle_moe_kernel_config(self): if self.quantization == "mxfp8": - if self.moe_runner_backend not in ["auto", "cutlass"]: - logger.warning( - "mxfp8 quantization forces --moe-runner-backend=cutlass. " - f"Overriding {self.moe_runner_backend!r}." + allowed_mxfp8_backends = ("auto", "cutlass", "flashinfer_cutlass") + if self.moe_runner_backend not in allowed_mxfp8_backends: + raise ValueError( + "mxfp8 quantization supports --moe-runner-backend in " + f"{allowed_mxfp8_backends}, but got " + f"{self.moe_runner_backend!r}." ) - self.moe_runner_backend = "cutlass" + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "cutlass" if self.moe_runner_backend == "flashinfer_cutlass": - assert self.quantization in [ + flashinfer_cutlass_quantizations = ( "modelopt_fp4", "modelopt_fp8", + "mxfp8", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer Cutlass MOE supports only: 'modelopt_fp4', 'modelopt_fp8', or bfloat16 (None)." + ) + assert self.quantization in flashinfer_cutlass_quantizations, ( + f"Invalid quantization '{self.quantization}'.\n" + "FlashInfer Cutlass MOE supports only " + f"{flashinfer_cutlass_quantizations}." + ) assert self.ep_size in [ 1, self.tp_size,