diff --git a/docs/advanced_features/attention_backend.md b/docs/advanced_features/attention_backend.md index cbc34043e46..705ace1c135 100644 --- a/docs/advanced_features/attention_backend.md +++ b/docs/advanced_features/attention_backend.md @@ -13,33 +13,33 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu ### MHA Backends -| **Backend** | **Page Size > 1 (native)** | **FP8 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** | -|---------------------------------|-----------------------------|------------------|-----------------|-----------------|--------------------|----------------| -| **FlashInfer** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| **FA3 (FlashAttention 3)** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| **FA4 (FlashAttention 4)** | 128 | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | -| **Torch Native (SDPA)** | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | -| **FlexAttention (PyTorch)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **TRTLLM MHA** | 16, 32 or 64 | ✅ | ✅ | ❌ | ✅ | ❌ | -| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ | -| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | -| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| **Backend** | **Page Size > 1 (native)** | **FP8 KV Cache** | **FP4 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** | +|---------------------------------|-----------------------------|------------------|-----------------|-----------------|-----------------|--------------------|----------------| +| **FlashInfer** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | +| **FA3 (FlashAttention 3)** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | +| **FA4 (FlashAttention 4)** | 128 | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | +| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Torch Native (SDPA)** | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | +| **FlexAttention (PyTorch)** | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | +| **TRTLLM MHA** | 16, 32 or 64 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | +| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **AITER (ROCm)** | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | +| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | +| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ### MLA Backends -| **Backend** | **Native Page Sizes** | **FP8 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** | -|----------------------------|---------------------------|------------------|--------------------------|-----------------|-----------------| -| **FlashInfer MLA** | 1 | ❌ | ✅ | ✅ | ❌ | -| **FlashMLA** | 64 | ✅ | ✅ | ✅ | ❌ | -| **Cutlass MLA** | 128 | ✅ | ✅ | ✅ | ❌ | -| **TRTLLM MLA (Blackwell)** | 32 or 64 | ✅ | ✅ | ✅ | ❌ | -| **FA3 (FlashAttention 3)** | n/a | ❌ | ✅ | ✅ | ⚠️ (page_size=1 only) | -| **Triton** | n/a | ❌ | ❌ | ✅ | ⚠️ (page_size=1 only) | -| **FA4** | 1 | ❌ | ❌ | ❌ | ❌ | -| **Ascend MLA (NPU)** | 128 | ❌ | ❌ | ❌ | ❌ | +| **Backend** | **Native Page Sizes** | **FP8 KV Cache** | **FP4 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** | +|----------------------------|---------------------------|------------------|------------------|--------------------------|-----------------|-----------------| +| **FlashInfer MLA** | 1 | ❌ | ✅ | ✅ | ✅ | ❌ | +| **FlashMLA** | 64 | ✅ | ❌ | ✅ | ✅ | ❌ | +| **Cutlass MLA** | 128 | ✅ | ✅ | ✅ | ✅ | ❌ | +| **TRTLLM MLA (Blackwell)** | 32 or 64 | ✅ | ✅ | ✅ | ✅ | ❌ | +| **FA3 (FlashAttention 3)** | n/a | ❌ | ❌ | ✅ | ✅ | ⚠️ (page_size=1 only) | +| **Triton** | n/a | ❌ | ❌ | ❌ | ✅ | ⚠️ (page_size=1 only) | +| **FA4** | 1 | ❌ | ✅ | ❌ | ❌ | ❌ | +| **Ascend MLA (NPU)** | 128 | ❌ | ❌ | ❌ | ❌ | ❌ | ```{note} Multimodal attention is selected by `--mm-attention-backend`. The "MultiModal" column indicates whether a corresponding multimodal implementation exists for that backend family. @@ -54,6 +54,10 @@ Multimodal attention is selected by `--mm-attention-backend`. The "MultiModal" c Speculative decoding topk: `topk` is the number of draft tokens sampled per step from the draft model. `topk = 1` follows classic EAGLE; `topk > 1` explores multiple branches and requires backend support in both draft and verification paths. ``` +```{note} +For the KV4 FA4 scenario, FA4 requires using a different --decode-attention-backend to run. Except for trtllm_mha being incompatible with FA4, all other decode backends behave as shown in the table. +``` + Note: Many backends that do not natively operate on pages can emulate `page_size > 1` at the wrapper layer by expanding page tables to per-token indices. The "Page Size > 1 (native)" column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128). MLA page-size constraints: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 568e9b8488c..9db90eb9cbc 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -353,6 +353,9 @@ def __init__( if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) + # Check FP4 KV cache compatibility with the attention backend + self._handle_kv4_compatibility() + # Initialize the model runner self.initialize(min_per_gpu_memory) @@ -384,6 +387,80 @@ def init_mindspore_runner(self): port=self.dist_port, ) + def _handle_kv4_compatibility(self): + self.prefill_attention_backend_str, self.decode_attention_backend_str = ( + self.server_args.get_attention_backends() + ) + + if is_cuda(): + if ( + self.prefill_attention_backend_str != self.decode_attention_backend_str + and self.prefill_attention_backend_str != "fa4" + ): # Take care of prefill=fa4 later + logger.warning( + f"Attention: Using KV4 with PREFILL = {self.prefill_attention_backend_str} " + f"and DECODE = {self.decode_attention_backend_str}. " + f"Compatibility issues are unlikely, but may occur in rare edge cases." + ) + else: + if self.prefill_attention_backend_str == "fa4": + if self.use_mla_backend: # FA4 + MLA + KV4_FA4_MLA_BACKEND_CHOICES = [ + "cutlass_mla", + "flashinfer", + "trtllm_mla", + ] + assert ( + self.decode_attention_backend_str + in KV4_FA4_MLA_BACKEND_CHOICES + ), ( + f"KV4 FA4 MLA expects decode_attention_backend to be one of " + f"{KV4_FA4_MLA_BACKEND_CHOICES}, but got {self.decode_attention_backend_str}" + ) + else: # FA4 + MHA + KV4_FA4_MHA_BACKEND_CHOICES = [ + "triton", + "torch_native", + "flex_attention", + ] + assert ( + self.decode_attention_backend_str + in KV4_FA4_MHA_BACKEND_CHOICES + ), ( + f"KV4 FA4 MHA expects decode_attention_backend to be one of " + f"{KV4_FA4_MHA_BACKEND_CHOICES}, but got {self.decode_attention_backend_str}" + ) + else: + if self.use_mla_backend: # !FA4 + MLA + KV4_ATTENTION_MLA_BACKEND_CHOICES = [ + "cutlass_mla", + "flashinfer", + "trtllm_mla", + ] + assert ( + self.server_args.attention_backend + in KV4_ATTENTION_MLA_BACKEND_CHOICES + ), ( + f"KV4 MLA expects attention_backend to be one of " + f"{KV4_ATTENTION_MLA_BACKEND_CHOICES}, but got {self.server_args.attention_backend}" + ) + else: # !FA4 + MHA + KV4_ATTENTION_MHA_BACKEND_CHOICES = [ + "triton", + "torch_native", + "flex_attention", + "trtllm_mha", + ] + assert ( + self.server_args.attention_backend + in KV4_ATTENTION_MHA_BACKEND_CHOICES + ), ( + f"KV4 MHA expects attention_backend to be one of " + f"{KV4_ATTENTION_MHA_BACKEND_CHOICES}, but got {self.server_args.attention_backend}" + ) + else: + raise RuntimeError("KV4 is not tested on non-CUDA platforms.") + def initialize(self, min_per_gpu_memory: float): server_args = self.server_args