Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions docs/advanced_features/attention_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,33 @@ The support matrix is split into two parts: MHA (standard attention) and MLA (mu

### MHA Backends

| **Backend** | **Page Size > 1 (native)** | **FP8 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** |
|---------------------------------|-----------------------------|------------------|-----------------|-----------------|--------------------|----------------|
| **FlashInfer** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **FA4 (FlashAttention 4)** | 128 | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |
| **Torch Native (SDPA)** | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ |
| **FlexAttention (PyTorch)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **TRTLLM MHA** | 16, 32 or 64 | ✅ | ✅ | ❌ | ✅ | ❌ |
| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ |
| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ |
| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| **Backend** | **Page Size > 1 (native)** | **FP8 KV Cache** | **FP4 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** |
|---------------------------------|-----------------------------|------------------|-----------------|-----------------|-----------------|--------------------|----------------|
| **FlashInfer** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ |
| **FA4 (FlashAttention 4)** | 128 | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Torch Native (SDPA)** | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ |
| **FlexAttention (PyTorch)** | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **TRTLLM MHA** | 16, 32 or 64 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **AITER (ROCm)** | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ |
| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ |
| **Intel XPU** | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |

### MLA Backends

| **Backend** | **Native Page Sizes** | **FP8 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** |
|----------------------------|---------------------------|------------------|--------------------------|-----------------|-----------------|
| **FlashInfer MLA** | 1 | ❌ | ✅ | ✅ | ❌ |
| **FlashMLA** | 64 | ✅ | ✅ | ✅ | ❌ |
| **Cutlass MLA** | 128 | ✅ | ✅ | ✅ | ❌ |
| **TRTLLM MLA (Blackwell)** | 32 or 64 | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | n/a | ❌ | ✅ | ✅ | ⚠️ (page_size=1 only) |
| **Triton** | n/a | ❌ | ❌ | ✅ | ⚠️ (page_size=1 only) |
| **FA4** | 1 | ❌ | ❌ | ❌ | ❌ |
| **Ascend MLA (NPU)** | 128 | ❌ | ❌ | ❌ | ❌ |
| **Backend** | **Native Page Sizes** | **FP8 KV Cache** | **FP4 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** |
|----------------------------|---------------------------|------------------|------------------|--------------------------|-----------------|-----------------|
| **FlashInfer MLA** | 1 | ❌ | ✅ | ✅ | ✅ | ❌ |
| **FlashMLA** | 64 | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Cutlass MLA** | 128 | ✅ | ✅ | ✅ | ✅ | ❌ |
| **TRTLLM MLA (Blackwell)** | 32 or 64 | ✅ | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | n/a | ❌ | ❌ | ✅ | ✅ | ⚠️ (page_size=1 only) |
| **Triton** | n/a | ❌ | ❌ | ❌ | ✅ | ⚠️ (page_size=1 only) |
| **FA4** | 1 | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Ascend MLA (NPU)** | 128 | ❌ | ❌ | ❌ | ❌ | ❌ |

```{note}
Multimodal attention is selected by `--mm-attention-backend`. The "MultiModal" column indicates whether a corresponding multimodal implementation exists for that backend family.
Expand All @@ -54,6 +54,10 @@ Multimodal attention is selected by `--mm-attention-backend`. The "MultiModal" c
Speculative decoding topk: `topk` is the number of draft tokens sampled per step from the draft model. `topk = 1` follows classic EAGLE; `topk > 1` explores multiple branches and requires backend support in both draft and verification paths.
```

```{note}
For the KV4 FA4 scenario, FA4 requires using a different --decode-attention-backend to run. Except for trtllm_mha being incompatible with FA4, all other decode backends behave as shown in the table.
```

Note: Many backends that do not natively operate on pages can emulate `page_size > 1` at the wrapper layer by expanding page tables to per-token indices. The "Page Size > 1 (native)" column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128).

MLA page-size constraints:
Expand Down
77 changes: 77 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ def __init__(
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)

# Check FP4 KV cache compatibility with the attention backend
self._handle_kv4_compatibility()

# Initialize the model runner
self.initialize(min_per_gpu_memory)

Expand Down Expand Up @@ -384,6 +387,80 @@ def init_mindspore_runner(self):
port=self.dist_port,
)

def _handle_kv4_compatibility(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this function to server_args.py, somewhere after _handle_attention_backend_compatibility

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. I will fix this as soon as I can. Thank you for your feedback!

self.prefill_attention_backend_str, self.decode_attention_backend_str = (
self.server_args.get_attention_backends()
)

if is_cuda():
if (
self.prefill_attention_backend_str != self.decode_attention_backend_str
and self.prefill_attention_backend_str != "fa4"
): # Take care of prefill=fa4 later
logger.warning(
f"Attention: Using KV4 with PREFILL = {self.prefill_attention_backend_str} "
f"and DECODE = {self.decode_attention_backend_str}. "
f"Compatibility issues are unlikely, but may occur in rare edge cases."
)
else:
if self.prefill_attention_backend_str == "fa4":
if self.use_mla_backend: # FA4 + MLA
KV4_FA4_MLA_BACKEND_CHOICES = [
"cutlass_mla",
"flashinfer",
"trtllm_mla",
]
assert (
self.decode_attention_backend_str
in KV4_FA4_MLA_BACKEND_CHOICES
), (
f"KV4 FA4 MLA expects decode_attention_backend to be one of "
f"{KV4_FA4_MLA_BACKEND_CHOICES}, but got {self.decode_attention_backend_str}"
)
else: # FA4 + MHA
KV4_FA4_MHA_BACKEND_CHOICES = [
"triton",
"torch_native",
"flex_attention",
]
assert (
self.decode_attention_backend_str
in KV4_FA4_MHA_BACKEND_CHOICES
), (
f"KV4 FA4 MHA expects decode_attention_backend to be one of "
f"{KV4_FA4_MHA_BACKEND_CHOICES}, but got {self.decode_attention_backend_str}"
)
else:
if self.use_mla_backend: # !FA4 + MLA
KV4_ATTENTION_MLA_BACKEND_CHOICES = [
"cutlass_mla",
"flashinfer",
"trtllm_mla",
]
assert (
self.server_args.attention_backend
in KV4_ATTENTION_MLA_BACKEND_CHOICES
), (
f"KV4 MLA expects attention_backend to be one of "
f"{KV4_ATTENTION_MLA_BACKEND_CHOICES}, but got {self.server_args.attention_backend}"
)
else: # !FA4 + MHA
KV4_ATTENTION_MHA_BACKEND_CHOICES = [
"triton",
"torch_native",
"flex_attention",
"trtllm_mha",
]
assert (
self.server_args.attention_backend
in KV4_ATTENTION_MHA_BACKEND_CHOICES
), (
f"KV4 MHA expects attention_backend to be one of "
f"{KV4_ATTENTION_MHA_BACKEND_CHOICES}, but got {self.server_args.attention_backend}"
)
else:
raise RuntimeError("KV4 is not tested on non-CUDA platforms.")

def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args

Expand Down
Loading