Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_ops.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mm op, why put under attention layer?

Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

from sglang.srt.utils import is_flashinfer_available

if is_flashinfer_available():

@torch.library.custom_op(
"sglang::flashinfer_mm_fp4",
mutates_args=[],
device_types="cuda",
)
def flashinfer_mm_fp4(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_

return flashinfer_mm_fp4_(
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
)

@torch.library.register_fake(
"sglang::flashinfer_mm_fp4",
)
def flashinfer_mm_fp4_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)


def flashinfer_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1]

if backend == "cutlass":
block_scale_a = block_scale_a.view(torch.uint8)
block_scale_b = block_scale_b.view(torch.uint8)

return flashinfer_mm_fp4(
a,
b.t(),
block_scale_a,
block_scale_b.t(),
alpha,
out_dtype,
backend=backend,
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
CompressedTensorsScheme,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
Expand All @@ -42,6 +44,7 @@
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import cutlass_fp4_supported

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -376,6 +379,35 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool
# All conditions satisfied.
return True

def _is_fp4a4_nvfp4(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
):
if weight_quant is None or input_quant is None:
return False

is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric and input_quant.symmetric

is_group_size_16 = (
weight_quant.group_size == 16 and input_quant.group_size == 16
)
is_float_type = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4

return (
is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)

def _is_wNa16_group_channel(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> bool:
Expand All @@ -389,10 +421,35 @@ def _is_wNa16_group_channel(

return is_channel_group and input_quant_none and is_symmetric and is_static

def _is_fp4a16_nvfp4(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric

is_group_size_16 = weight_quant.group_size == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4

return (
is_weight_only
and is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)

def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> CompressedTensorsScheme:

if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()

# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (
Expand All @@ -411,6 +468,16 @@ def _get_scheme_from_parts(
)

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w4a4 supports both flashinfer and cutlass, right? I think we should do something similar to the below method, check the capability.

return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4."
)
return CompressedTensorsW4A16Fp4(has_input_global_scale=True)

if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
Expand Down
Loading
Loading