diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 38bfdc00f81..2ee0c179c1a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -49,6 +49,30 @@ class ModelImpl(str, Enum): TRANSFORMERS = "transformers" +def is_deepseek_nsa(config: PretrainedConfig) -> bool: + return ( + config.architectures is not None + and config.architectures[0] + in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"] + and getattr(config, "index_topk", None) is not None + ) + + +def get_nsa_index_head_dim(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_head_dim + + +def get_nsa_index_topk(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_topk + + +def get_nsa_index_n_heads(config: PretrainedConfig) -> int: + assert is_deepseek_nsa(config) + return config.index_n_heads + + class ModelConfig: def __init__( self, @@ -271,6 +295,7 @@ def _derive_model_shapes(self): # FIXME: temporary special judge for MLA architecture if ( "DeepseekV2ForCausalLM" in self.hf_config.architectures + or "DeepseekV32ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures or "LongcatFlashForCausalLM" in self.hf_config.architectures @@ -283,6 +308,11 @@ def _derive_model_shapes(self): self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim self.v_head_dim = self.hf_config.v_head_dim + self.index_head_dim = ( + get_nsa_index_head_dim(self.hf_config) + if is_deepseek_nsa(self.hf_config) + else None + ) # Handle rope scaling with yarn self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) diff --git a/python/sglang/srt/disaggregation/ascend/transfer_engine.py b/python/sglang/srt/disaggregation/ascend/transfer_engine.py index 0ccffffd631..a1fe58ce605 100644 --- a/python/sglang/srt/disaggregation/ascend/transfer_engine.py +++ b/python/sglang/srt/disaggregation/ascend/transfer_engine.py @@ -2,9 +2,19 @@ import os from typing import List, Optional +import torch + from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode +try: + from mf_adapter import TransferEngine + + import_error = None +except ImportError as e: + import_error = e + pass + logger = logging.getLogger(__name__) @@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine): def __init__( self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode ): - try: - from mf_adapter import TransferEngine - except ImportError as e: - raise ImportError( + if import_error is not None: + logger.warning( "Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md" - ) from e + ) + raise import_error self.engine = TransferEngine() self.hostname = hostname @@ -37,12 +46,29 @@ def __init__( self.initialize() def initialize(self) -> None: + from sglang.srt.layers.dp_attention import ( + get_tensor_model_parallel_world_size, + get_tp_group, + ) + + transfer_protocol = self._get_transfer_protocol() + if transfer_protocol is None or transfer_protocol == "sdma": + trans_op_type = TransferEngine.TransDataOpType.SDMA + else: + trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA + """with device RDMA for PD transfer""" + tmp_tensor = torch.zeros(1, device="npu") + output_tensor_list = [ + torch.empty_like(tmp_tensor) + for _ in range(get_tensor_model_parallel_world_size()) + ] + # Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization. + torch.distributed.all_gather( + output_tensor_list, tmp_tensor, group=get_tp_group().device_group + ) """Initialize the ascend transfer instance.""" ret_value = self.engine.initialize( - self.store_url, - self.session_id, - self.role, - self.npu_id, + self.store_url, self.session_id, self.role, self.npu_id, trans_op_type ) if ret_value != 0: logger.error("Ascend Transfer Engine initialization failed.") @@ -56,3 +82,15 @@ def batch_register(self, ptrs: List[int], lengths: List[int]): ret_value = -1 if ret_value != 0: logger.debug(f"Ascend memory registration for ptr {ptrs} failed.") + + @staticmethod + def _get_transfer_protocol(): + protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL") + allowed_protocols = {"device_rdma", "sdma"} + if protocol and protocol.lower() in allowed_protocols: + return protocol.lower() + else: + logger.warning( + "Invalid or no transfer protocol specified, using default protocol." + ) + return None diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 2391e866400..65490b017f7 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -36,6 +36,8 @@ class ForwardMetadata: seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_list: Optional[List[int]] = None seq_lens_list_cumsum: Optional[List[int]] = None + seq_lens: Optional[torch.Tensor] = None + actual_seq_lengths_q: Optional[torch.Tensor] = None class AscendAttnBackend(AttentionBackend): @@ -67,6 +69,9 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.q_head_dim = ( + self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim + ) self.native_attn = TorchNativeAttnBackend(model_runner) self.graph_metadata = {} self.max_context_len = model_runner.model_config.context_len @@ -102,10 +107,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) - if forward_batch.is_extend_in_batch: - seq_lens_list_cumsum[-1] = ( - (seq_lens_list_cumsum[-1] - 1) // tp_size + 1 - ) * tp_size self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum self.graph_mode = False @@ -133,6 +134,10 @@ def init_forward_metadata_capture_cuda_graph( metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() + metadata.seq_lens = seq_lens + metadata.actual_seq_lengths_q = torch.tensor( + [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device + ) self.graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -161,6 +166,8 @@ def init_forward_metadata_replay_cuda_graph( metadata.block_tables[:bs, max_seq_pages:].fill_(0) metadata.block_tables[bs:, :].fill_(0) + metadata.seq_lens[:bs].copy_(seq_lens[:bs]) + self.forward_metadata = metadata self.graph_mode = True @@ -168,6 +175,64 @@ def init_forward_metadata_replay_cuda_graph( def get_cuda_graph_seq_len_fill_value(self): return 0 + def forward_sparse( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + # For multi_head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + topk_indices: torch.Tensor = None, + ): + + is_prefill = forward_batch.forward_mode.is_extend() + + if save_kv_cache: + k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank) + k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, k_rope + ) + q_nope, q_pe = q, q_rope + k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + block_table = self.forward_metadata.block_tables + if is_prefill: + actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0) + else: + if self.forward_metadata.actual_seq_lengths_q is None: + actual_seq_qlen = ( + torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32) + ) + else: + actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q + if self.forward_metadata.seq_lens_cpu_int is None: + actual_seq_lengths_kv = self.forward_metadata.seq_lens + else: + actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int + + attn_out = torch.ops.custom.npu_sparse_flash_attention( + query=q_nope, + key=k_nope, + value=k_nope, + query_rope=q_pe, + key_rope=k_pe, + sparse_indices=topk_indices, + scale_value=layer.scaling, + actual_seq_lengths_query=actual_seq_qlen.to(torch.int32), + actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device), + block_table=block_table, + sparse_block_size=1, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + + return attn_out + def forward_extend( self, q, @@ -176,7 +241,23 @@ def forward_extend( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + # For multi_head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, ): + if topk_indices is not None: + return self.forward_sparse( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + q_rope, + k_rope, + topk_indices, + ) if not self.use_mla: if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -437,10 +518,23 @@ def forward_decode( # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, ): if is_mla_preprocess_enabled(): # MLAPO does saving kv_cache save_kv_cache = False + if topk_indices is not None: + return self.forward_sparse( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + q_rope, + k_rope, + topk_indices, + ) if self.graph_mode: return self.forward_decode_graph( diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index aa843685a9d..0ec435d6fb6 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -66,6 +66,13 @@ def create_ascend_backend(runner): return AscendAttnBackend(runner) +@register_attention_backend("nsa") +def create_nsa_backend(runner): + from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend + + return NativeSparseAttnBackend(runner) + + @register_attention_backend("triton") def create_triton_backend(runner): assert not runner.model_config.is_encoder_decoder, ( diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 5b0377fcda4..fabdcb9e460 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -6,6 +6,7 @@ import torch if TYPE_CHECKING: + from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput @@ -115,3 +116,11 @@ def forward_extend( def support_triton(self): """Check if the current backend supports triton.""" return True + + def get_indexer_metadata( + self, + layer_id: int, + forward_batch: ForwardBatch, + ) -> Optional[BaseIndexerMetadata]: + """Get the indexer metadata. None means don't support indexer.""" + return None diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index f7f2c2193e5..7a78fd4d1c6 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -3,6 +3,7 @@ import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner @@ -138,3 +139,9 @@ def forward_extend( return backend.forward_extend( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) + + def get_indexer_metadata( + self, layer_id: int, forward_batch: ForwardBatch + ) -> Optional[BaseIndexerMetadata]: + backend = self._select_backend(forward_batch.forward_mode) + return backend.get_indexer_metadata(layer_id, forward_batch) diff --git a/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py index 84efe2ce443..06a55254529 100644 --- a/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py +++ b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py @@ -76,12 +76,14 @@ def __init__( self.rotary_emb = rotary_emb self.layer_id = layer_id self.has_preprocess_weights = False + self.dtype = None self.q_lora_rank = self.q_b_proj.input_size # 1536 self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512 self.num_local_heads = num_local_heads # tp self.qk_nope_head_dim = qk_nope_head_dim # 128 self.qk_rope_head_dim = qk_rope_head_dim # 64 + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim def preprocess_weights(self, hidden_states): self.dummy = torch.empty( @@ -236,7 +238,83 @@ def get_kv_cache_and_cache_idx(self, forward_batch): slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32) return k_cache, v_cache, slot_mapping - def forward(self, positions, hidden_states, forward_batch, zero_allocator): + def forward_absorb_prepare_npu_rms_norm_cache( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch, + zero_allocator, + ): + bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape + self.dtype = hidden_states.dtype + self.cos, self.sin = self.get_sin_cos(positions) + self.kvCache, self.kvCacheRope, self.slotmapping = ( + self.get_kv_cache_and_cache_idx(forward_batch) + ) + + if not self.has_preprocess_weights: + self.has_preprocess_weights = True + + cos, sin = self.cos, self.sin + + if self.q_lora_rank is not None: + fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0] + q_lowrank, latent_cache = fused_qkv_a_proj_out.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = self.q_a_layernorm(q_lowrank) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) # b*s,n,d + + q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim) + q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1) + + q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim) + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D) + q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim) + + latent_cache = latent_cache.view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim + ) # (B*S,N,1,D) + + cache_mode = "PA_BNSD" + self.kvCache = self.kvCache.view( + -1, + forward_batch.attn_backend.page_size, + 1, + forward_batch.attn_backend.kv_lora_rank, + ) + self.kvCacheRope = self.kvCacheRope.view( + -1, + forward_batch.attn_backend.page_size, + 1, + forward_batch.attn_backend.qk_rope_head_dim, + ) + k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + latent_cache, + self.kv_a_layernorm.weight, + cos, + sin, + self.slotmapping.to(torch.int64), + self.kvCacheRope, + self.kvCache, + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + + return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions) + + def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator): input_dtype = hidden_states.dtype if not self.has_preprocess_weights: self.preprocess_weights(hidden_states) @@ -298,3 +376,18 @@ def forward(self, positions, hidden_states, forward_batch, zero_allocator): zero_allocator, positions, ) + + def forward(self, positions, hidden_states, forward_batch, zero_allocator): + _is_w8a8 = ( + hasattr(self.qkv_a_proj.quant_method, "quantization_config") + and self.qkv_a_proj.quant_method.quantization_config.get_name() + == "w8a8_int8" + ) + if _is_w8a8: + return self.forward_mlapo( + positions, hidden_states, forward_batch, zero_allocator + ) + else: + return self.forward_absorb_prepare_npu_rms_norm_cache( + positions, hidden_states, forward_batch, zero_allocator + ) diff --git a/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py new file mode 100644 index 00000000000..b6c2269f5b2 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST + + +def dequantize_k_cache(quant_k_cache): + if NSA_DEQUANT_K_CACHE_FAST: + return _dequantize_k_cache_fast_wrapped(quant_k_cache) + else: + return _dequantize_k_cache_slow(quant_k_cache) + + +def _dequantize_k_cache_slow( + quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) + dv: int = 512, + tile_size: int = 128, + d: int = 576, +) -> torch.Tensor: + """ + De-quantize the k-cache + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, _ = quant_k_cache.shape + assert h_k == 1 + result = torch.empty( + (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device + ) + + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + + input_nope = quant_k_cache[..., :dv] + input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16) + result[..., dv:] = input_rope + + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ) + + result = result.view(num_blocks, block_size, 1, d) + return result + + +def _dequantize_k_cache_fast_wrapped( + quant_k_cache: torch.Tensor, + dv: int = 512, + tile_size: int = 128, +) -> torch.Tensor: + # TODO the final API may be 2D instead of 4D, thus we convert them here + num_blocks, block_size, _, dim_quant = quant_k_cache.shape + assert dv == 512 + assert dim_quant == 656 + assert tile_size == 128 + quant_k_cache = quant_k_cache.view((-1, dim_quant)) + + output = _dequantize_k_cache_fast(quant_k_cache) + + return output.view(num_blocks, block_size, 1, -1) + + +def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128): + num_tokens, dim_quant = quant_k_cache.shape + + assert quant_k_cache.dtype == torch.float8_e4m3fn + dim_nope = 512 + dim_rope = 64 + num_tiles = dim_nope // group_size + assert dim_quant == 656 + + output = torch.empty( + (num_tokens, dim_nope + dim_rope), + dtype=torch.bfloat16, + device=quant_k_cache.device, + ) + + num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size) + assert num_blocks_per_token == 5 + + assert dim_nope % group_size == 0 + NUM_NOPE_BLOCKS = dim_nope // group_size + + input_nope_q = quant_k_cache[:, :dim_nope] + input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view( + torch.float32 + ) + input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16) + + _dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)]( + output, + input_nope_q, + input_nope_s, + input_rope, + output.stride(0), + input_nope_q.stride(0), + input_nope_s.stride(0), + input_rope.stride(0), + NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS, + GROUP_SIZE=group_size, + DIM_NOPE=dim_nope, + DIM_ROPE=dim_rope, + ) + + return output + + +@triton.jit +def _dequantize_k_cache_fast_kernel( + output_ptr, + input_nope_q_ptr, + input_nope_s_ptr, + input_rope_ptr, + output_stride_0: int, + input_nope_q_stride_0: int, + input_nope_s_stride_0: int, + input_rope_stride_0: int, + NUM_NOPE_BLOCKS: tl.constexpr, + GROUP_SIZE: tl.constexpr, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, +): + token_id = tl.program_id(0) + raw_block_id = tl.program_id(1) + + if raw_block_id < NUM_NOPE_BLOCKS: + # a. dequant nope + effective_block_id = raw_block_id + + offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + mask = offs_q < DIM_NOPE + ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q + ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id + + y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32) + y_s = tl.load(ptr_s) + + y = (y_q * y_s).to(output_ptr.dtype.element_ty) + + dst_ptr = output_ptr + token_id * output_stride_0 + offs_q + tl.store(dst_ptr, y, mask=mask) + else: + # b. copy rope + effective_block_id = raw_block_id - NUM_NOPE_BLOCKS + + offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + mask = offs < DIM_ROPE + + src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs + dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs + + data = tl.load(src_ptr, mask=mask).to(tl.bfloat16) + tl.store(dst_ptr, data, mask=mask) + + +if __name__ == "__main__": + raise Exception("UT is in quant_k_cache.py") diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py new file mode 100644 index 00000000000..d887cfddd49 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -0,0 +1,354 @@ +from typing import TYPE_CHECKING + +import torch +import triton +import triton.language as tl + +if TYPE_CHECKING: + from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool + +""" +k: data, 128 item per token, fp8 +s: scale, 1 item per token, fp32 +""" + + +class GetK: + @classmethod + def execute(cls, *args, **kwargs): + return cls.torch_fast(*args, **kwargs) + + @classmethod + def slow( + cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor + ): + num_pages = (seq_len + pool.page_size - 1) // pool.page_size + seq_len_ = num_pages * pool.page_size + index_k_fp8 = torch.empty( + (seq_len_, pool.index_head_dim), + dtype=torch.uint8, + device=pool.device, + ) + for i in range(num_pages): + page_index = page_indices[i] + index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[ + page_index + ][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim) + + return index_k_fp8[:seq_len] + + @classmethod + def torch_fast( + cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor + ): + """ + :param page_indices: (num_pages,), int32 + :return: (seq_len, index_head_dim), uint8 + """ + + # can handle per 128B instead of per element + + # page_indices: (num_pages,), element := a page index + buf_numel_per_page = buf.shape[1] + + num_k_bytes_per_page = pool.page_size * pool.index_head_dim + num_k_bytes_per_token = pool.index_head_dim + + # buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8 + # flat_buf: (whatever,), uint8 + flat_buf = buf.flatten() + + # flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access + flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange( + num_k_bytes_per_page, dtype=torch.int32, device="cuda" + )[None, :] + flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token] + + out = flat_buf[flat_indices] + return out.view(-1, 128) + + +class GetS: + @classmethod + def execute(cls, *args, **kwargs): + return cls.torch_fast(*args, **kwargs) + + @classmethod + def slow( + cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor + ): + num_pages = (seq_len + pool.page_size - 1) // pool.page_size + seq_len_ = num_pages * pool.page_size + assert pool.index_head_dim // pool.quant_block_size == 1 + index_k_scale_fp8 = torch.empty( + (seq_len_, 4), + dtype=torch.uint8, + device=pool.device, + ) + for i in range(num_pages): + page_index = page_indices[i] + index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[ + page_index + ][pool.page_size * pool.index_head_dim :].view(-1, 4) + return index_k_scale_fp8[:seq_len] + + @classmethod + def torch_fast( + cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor + ): + """ + :param page_indices: (num_pages,), int32 + :return: (seq_len, index_head_dim // quant_block_size), uint8 + """ + buf_numel_per_page = buf.shape[1] + + num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim + num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4 + s_offset_in_page = pool.page_size * pool.index_head_dim + + flat_buf = buf.flatten() + flat_indices = ( + (page_indices * buf_numel_per_page)[:, None] + + torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[ + None, : + ] + + s_offset_in_page + ) + flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token] + + out = flat_buf[flat_indices] + return out.view(-1, 4) + + +class SetK: + @classmethod + def execute(cls, *args, buf, **kwargs): + return cls.torch_fast(*args, **kwargs, buf=buf) + + @classmethod + def slow( + cls, + pool: "NSATokenToKVPool", + buf: torch.Tensor, + loc: torch.Tensor, + index_k: torch.Tensor, + ): + for i in range(len(loc)): + page_index = loc[i] // pool.page_size + offset = loc[i] % pool.page_size + buf[ + page_index, + offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim, + ] = index_k[i].view(torch.uint8) + + @classmethod + def torch_fast( + cls, + pool: "NSATokenToKVPool", + buf: torch.Tensor, + loc: torch.Tensor, + index_k: torch.Tensor, + ): + (num_tokens_to_write,) = loc.shape + buf_numel_per_page = buf.shape[1] + num_k_bytes_per_token = pool.index_head_dim + + # loc: (num_tokens_to_write,), int32, element := the token index to write to + loc_page_index = loc // pool.page_size + loc_token_offset_in_page = loc % pool.page_size + + flat_buf = buf.flatten() + flat_indices = ( + (loc_page_index * buf_numel_per_page)[:, None] + + (loc_token_offset_in_page * num_k_bytes_per_token)[:, None] + + torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[ + None, : + ] + ) + num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token + flat_indices = flat_indices.flatten()[:num_k_bytes_total] + flat_buf[flat_indices] = index_k.view(torch.uint8).flatten() + + +class SetS: + @classmethod + def execute(cls, *args, buf, **kwargs): + return cls.torch_fast(*args, **kwargs, buf=buf) + + @classmethod + def slow( + cls, + pool: "NSATokenToKVPool", + buf: torch.Tensor, + loc: torch.Tensor, + index_k_scale: torch.Tensor, + ): + for i in range(len(loc)): + page_index = loc[i] // pool.page_size + offset = loc[i] % pool.page_size + start = pool.page_size * pool.index_head_dim + buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = ( + index_k_scale[i].view(torch.uint8) + ) + + @classmethod + def torch_fast( + cls, + pool: "NSATokenToKVPool", + buf: torch.Tensor, + loc: torch.Tensor, + index_k_scale: torch.Tensor, + ): + (num_tokens_to_write,) = loc.shape + buf_numel_per_page = buf.shape[1] + num_s_bytes_per_token = 4 + s_offset_in_page = pool.page_size * pool.index_head_dim + + # loc: (num_tokens_to_write,), int32, element := the token index to write to + loc_page_index = loc // pool.page_size + loc_token_offset_in_page = loc % pool.page_size + + flat_buf = buf.flatten() + flat_indices = ( + (loc_page_index * buf_numel_per_page)[:, None] + + s_offset_in_page + + (loc_token_offset_in_page * num_s_bytes_per_token)[:, None] + + torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[ + None, : + ] + ) + number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token + flat_indices = flat_indices.flatten()[:number_s_bytes_total] + flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten() + + +class SetKAndS: + @classmethod + def execute(cls, *args, buf, **kwargs): + if 0: + # print("SetK, SetS comparison test") + buf_cloned = buf.clone() + cls.vanilla(*args, **kwargs, buf=buf) + cls.triton(*args, **kwargs, buf=buf_cloned) + + def _clear_token_0(target): + target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0 + + _clear_token_0(buf) + _clear_token_0(buf_cloned) + + assert torch.all( + buf == buf_cloned + ), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}" + return + + cls.triton(*args, **kwargs, buf=buf) + + @classmethod + def vanilla(cls, pool, buf, loc, index_k, index_k_scale): + SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k) + SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale) + + @classmethod + def triton(cls, pool, buf, loc, index_k, index_k_scale): + _set_k_and_s_triton( + buf=buf, + loc=loc, + index_k=index_k, + index_k_scale=index_k_scale, + page_size=pool.page_size, + ) + + +def _set_k_and_s_triton( + buf: torch.Tensor, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + page_size: int, +): + """ + :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8 + :param loc: (num_tokens_to_write,), int, element := the token index to write to + :param index_k: (num_tokens_to_write, 128 elem), fp8 + :param index_k_scale: (num_tokens_to_write, 1 elem), fp32 + :return: + """ + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + num_tokens_to_write_, index_head_dim = index_k.shape + num_tokens_to_write__, scale_dim = index_k_scale.shape + assert buf_numel_per_page == 64 * (128 + 4) + assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__ + assert index_head_dim == 128 + assert scale_dim == 1 + assert page_size == 64 + + assert buf.dtype == torch.uint8 + assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32 + assert index_k.dtype == torch.float8_e4m3fn + assert index_k_scale.dtype == torch.float32 + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert index_k.is_contiguous() + assert index_k_scale.is_contiguous() + + buf_fp8 = buf.view(torch.float8_e4m3fn) + buf_fp32 = buf.view(torch.float32) + + _set_k_and_s_triton_kernel[(num_tokens_to_write,)]( + buf_fp8, + buf_fp32, + loc, + index_k, + index_k_scale, + index_k.stride(0), + PAGE_SIZE=page_size, + BUF_NUMEL_PER_PAGE=buf_numel_per_page, + NUM_K_ELEMS_PER_TOKEN=index_head_dim, + S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim, + ) + + +@triton.jit +def _set_k_and_s_triton_kernel( + buf_fp8_ptr, + buf_fp32_ptr, + loc_ptr, + index_k_ptr, + index_k_scale_ptr, + index_k_ptr_stride_0, + PAGE_SIZE: tl.constexpr, + BUF_NUMEL_PER_PAGE: tl.constexpr, + NUM_K_ELEMS_PER_TOKEN: tl.constexpr, + S_OFFSET_NBYTES_IN_PAGE: tl.constexpr, +): + token_id = tl.program_id(0) + + loc = tl.load(loc_ptr + token_id) + + in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN) + + # no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2 + k = tl.load(index_k_ptr + in_k_offsets) + k_scale = tl.load(index_k_scale_ptr + token_id) + + loc_page_index = loc // PAGE_SIZE + loc_token_offset_in_page = loc % PAGE_SIZE + + out_k_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN + + tl.arange(0, NUM_K_ELEMS_PER_TOKEN) + ) + + # "//4" b/c it is fp32 instead of uint8 + out_s_offset = ( + loc_page_index * BUF_NUMEL_PER_PAGE // 4 + + S_OFFSET_NBYTES_IN_PAGE // 4 + + loc_token_offset_in_page + ) + + tl.store(buf_fp8_ptr + out_k_offsets, k) + tl.store(buf_fp32_ptr + out_s_offset, k_scale) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py new file mode 100644 index 00000000000..2bc6771ab8f --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -0,0 +1,761 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from sglang.srt.custom_op import CustomOp +from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu + +if is_cuda(): + import deep_gemm + +from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER +from sglang.srt.layers.dp_attention import get_attention_tp_group +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +if TYPE_CHECKING: + from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool + +DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0 + + +class BaseIndexerMetadata(ABC): + @abstractmethod + def get_seqlens_int32(self) -> torch.Tensor: + """ + Return: (batch_size,) int32 tensor + """ + + @abstractmethod + def get_page_table_64(self) -> torch.Tensor: + """ + Return: (batch_size, num_blocks) int32, page table. + The page size of the table is 64. + """ + + @abstractmethod + def get_seqlens_expanded(self) -> torch.Tensor: + """ + Return: (sum_extend_seq_len,) int32 tensor + """ + + @abstractmethod + def topk_transform( + self, + logits: torch.Tensor, + topk: int, + ) -> torch.Tensor: + """ + Perform topk selection on the logits and possibly transform the result. + + NOTE that attention backend may override this function to do some + transformation, which means the result of this topk_transform may not + be the topk indices of the input logits. + + Return: Anything, since it will be passed to the attention backend + for further processing on sparse attention computation. + Don't assume it is the topk indices of the input logits. + """ + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + + hidden_size = x.size(-1) + assert ( + hidden_size & (hidden_size - 1) + ) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size**-0.5) + + +class V32LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm( + x.float(), (self.dim,), self.weight, self.bias, self.eps + ).type_as(x) + + +class Indexer(CustomOp): + def __init__( + self, + hidden_size: int, + index_n_heads: int, + index_head_dim: int, + rope_head_dim: int, + index_topk: int, + q_lora_rank: int, + max_position_embeddings: int, + rope_theta: float, + layer_id: int, + scale_fmt: Optional[str], + block_size: int = 128, + rope_scaling: Optional[Dict[str, Any]] = None, + prefix: str = "", + quant_config: Optional[QuantizationConfig] = None, + alt_stream: Optional[torch.cuda.Stream] = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.n_heads = index_n_heads + self.head_dim = index_head_dim + self.rope_head_dim = rope_head_dim + self.index_topk = index_topk + self.q_lora_rank = q_lora_rank + self.layer_id = layer_id + self.alt_stream = alt_stream + if is_cuda(): + self.sm_count = deep_gemm.get_num_sms() + self.half_device_sm_count = align(self.sm_count // 2, 8) + + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_b", prefix), + ) + self.wk = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wk", prefix), + ) + self.k_norm = V32LayerNorm(self.head_dim) + # NOTE: weight_proj is not quantized + self.weights_proj = ReplicatedLinear( + self.hidden_size, + self.n_heads, + bias=False, + prefix=add_prefix("weights_proj", prefix), + ) + self.rotary_emb = get_rope_wrapper( + rope_head_dim, + rotary_dim=rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, # type: ignore + rope_scaling=rope_scaling, + is_neox_style=False, + device=global_server_args_dict["device"], + ) + self.block_size = block_size + self.scale_fmt = scale_fmt + self.softmax_scale = self.head_dim**-0.5 + + def _forward_fake( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + ): + bs = x.shape[0] + assert self.index_topk == 2048 + ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[ + None, ... + ].repeat(bs, 1) + if forward_batch.forward_mode.is_extend(): + assert ( + forward_batch.extend_seq_lens_cpu is not None + and forward_batch.seq_lens_cpu is not None + ) + which = 0 + for i, (kv_len, qo_len) in enumerate( + zip( + forward_batch.seq_lens_cpu.tolist(), + forward_batch.extend_seq_lens_cpu, + strict=True, + ) + ): + for j in range(kv_len - qo_len, kv_len): + ans[which, j + 1 :] = -1 + which += 1 + assert which == ans.shape[0] + else: + assert forward_batch.seq_lens_cpu is not None + for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()): + ans[i, seq_len:] = -1 + + return ans + + def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): + weights, _ = self.weights_proj(x) + weights = weights * self.n_heads**-0.5 + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + return weights + + def _get_q_k_bf16( + self, + q_lora: torch.Tensor, + x: torch.Tensor, + positions: torch.Tensor, + enable_dual_stream: bool, + ): + + if enable_dual_stream: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + + with deep_gemm_wrapper.configure_deep_gemm_num_sms( + self.half_device_sm_count + ): + query, _ = self.wq_b(q_lora) + query = rearrange(query, "l (h d) -> l h d", d=self.head_dim) + q_rope, _ = torch.split( + query, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) + with torch.cuda.stream(self.alt_stream): + # TODO we should also put DeepGEMM half SM here? + key, _ = self.wk(x) + key = self.k_norm(key) + + k_rope, _ = torch.split( + key, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) + + current_stream.wait_stream(self.alt_stream) + else: + query, _ = self.wq_b(q_lora) + query = rearrange(query, "l (h d) -> l h d", d=self.head_dim) + + q_rope, _ = torch.split( + query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 + ) + + key, _ = self.wk(x) + key = self.k_norm(key) + k_rope, _ = torch.split( + key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 + ) + + q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) + + query[..., : self.rope_head_dim] = q_rope + key[..., : self.rope_head_dim] = k_rope + + if enable_dual_stream: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + query = rotate_activation(query) + + with torch.cuda.stream(self.alt_stream): + key = rotate_activation(key) + current_stream.wait_stream(self.alt_stream) + else: + query = rotate_activation(query) + key = rotate_activation(key) + + return query, key + + def _get_topk_paged( + self, + forward_batch: ForwardBatch, + layer_id: int, + q_fp8: torch.Tensor, + weights: torch.Tensor, + metadata: BaseIndexerMetadata, + ) -> torch.Tensor: + if TYPE_CHECKING: + assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) + + page_size = forward_batch.token_to_kv_pool.page_size + # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm + assert page_size == 64, "only support page size 64" + + # NOTE(dark): this support extend/decode/decode+graph + block_tables = metadata.get_page_table_64() + + max_seq_len = block_tables.shape[1] * page_size + kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=layer_id + ) + + blocksize = page_size + seqlens_32 = metadata.get_seqlens_int32() + # NOTE(dark): 132 is SM count on H200/B200, not magic number + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( + seqlens_32, blocksize, self.sm_count + ) + + assert len(q_fp8.shape) == 3 + q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now + assert len(kv_cache_fp8.shape) == 2 + block_kv = 64 + num_heads_kv = 1 + head_dim_with_sf = 132 + kv_cache_fp8 = kv_cache_fp8.view( + kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf + ) + assert len(weights.shape) == 3 + weights = weights.squeeze(2) + + logits = deep_gemm.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + seqlens_32, + block_tables, + schedule_metadata, + max_seq_len, + clean_logits=False, + ) + + # NOTE(dark): logits should be cleaned in topk_transform + topk_result = metadata.topk_transform(logits, self.index_topk) + return topk_result + + def _get_topk_ragged( + self, + forward_batch: ForwardBatch, + layer_id: int, + q_fp8: torch.Tensor, + weights: torch.Tensor, + metadata: BaseIndexerMetadata, + ) -> torch.Tensor: + if TYPE_CHECKING: + assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) + + page_size = forward_batch.token_to_kv_pool.page_size + assert page_size == 64, "only support page size 64" + assert len(weights.shape) == 3 + weights = weights.squeeze(-1) + k_fp8_list = [] + k_scale_list = [] + ks_list = [] + offset = 0 + + block_tables = metadata.get_page_table_64() + + assert ( + forward_batch.seq_lens_cpu is not None + and forward_batch.extend_seq_lens_cpu is not None + ) + + for i in range(forward_batch.batch_size): + seq_len = forward_batch.seq_lens_cpu[i].item() + assert isinstance(seq_len, int) + k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous( + layer_id, + seq_len, + block_tables[i], + ) + k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous( + layer_id, + seq_len, + block_tables[i], + ) + extend_seq_len = forward_batch.extend_seq_lens_cpu[i] + ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda") + k_fp8_list.append(k_fp8) + k_scale_list.append(k_scale) + ks_list.append(ks) + offset += extend_seq_len + + k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) + kv_fp8 = (k_fp8, k_scale) + ks = torch.cat(ks_list, dim=0) + seq_lens_expanded = metadata.get_seqlens_expanded() + ke = ks + seq_lens_expanded + + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights, + ks, + ke, + clean_logits=False, + ) + + assert logits.shape[0] == len(seq_lens_expanded) + topk_result = metadata.topk_transform(logits, self.index_topk) + + return topk_result + + def forward_indexer_bs_1( + self, + q_fp8: torch.Tensor, + weights: torch.Tensor, + forward_batch: ForwardBatch, + topk: int, + layer_id: int, + ) -> Optional[torch.Tensor]: + if not is_npu(): + from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index + + page_size = forward_batch.token_to_kv_pool.page_size + assert page_size == 64, "only support page size 64" + + assert len(weights.shape) == 3 + weights = weights.squeeze(-1) + + # logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + k_fp8_list = [] + k_scale_list = [] + + topk_indices_list = [] + + block_tables = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + strided_indices = torch.arange( + 0, block_tables.shape[-1], page_size, device="cuda" + ) + block_tables = block_tables[:, strided_indices] // page_size + + q_len_start = 0 + + for i in range(forward_batch.batch_size): + seq_len = forward_batch.seq_lens[i].item() + q_len = ( + forward_batch.extend_seq_lens_cpu[i] + if forward_batch.forward_mode.is_extend() + else 1 + ) + q_len_end = q_len_start + q_len + + q_fp8_partial = q_fp8[q_len_start:q_len_end] + q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous() + + weights_partial = weights[q_len_start:q_len_end] + weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous() + + k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous( + layer_id, + seq_len, + block_tables[i], + ) + k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous( + layer_id, + seq_len, + block_tables[i], + ) + + k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous() + k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous() + + index_score = fp8_index( + q_fp8_partial, + weights_partial, + k_fp8, + k_scale, + ) + end_pos = seq_len + topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0) + + pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1] + topk_indices = torch.nn.functional.pad( + topk_indices, (0, pad_len), "constant", -1 + ) + + topk_indices_list.append(topk_indices) + + q_len_start = q_len_end + + topk_indices = torch.cat(topk_indices_list, dim=0) + + return topk_indices + + def forward_indexer( + self, + q_fp8: torch.Tensor, + weights: torch.Tensor, + forward_batch: ForwardBatch, + topk: int, + layer_id: int, + ) -> Optional[torch.Tensor]: + return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id) + + def _forward( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + ) -> Optional[torch.Tensor]: + if not is_npu(): + from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant + + if TYPE_CHECKING: + assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) + + metadata = forward_batch.attn_backend.get_indexer_metadata( + layer_id, forward_batch + ) + + enable_dual_stream = ( + NSA_DUAL_STREAM + and self.alt_stream is not None + and get_is_capture_mode() + and q_lora.shape[0] > 0 + and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + ) + + # skip NSA if attention backend choose to skip this batch + if metadata is None: + return None + + if not NSA_USE_REAL_INDEXER: # temporary + return self._forward_fake(x, q_lora, positions, forward_batch, layer_id) + + query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream) + + if enable_dual_stream: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + + q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt) + with torch.cuda.stream(self.alt_stream): + k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt) + current_stream.wait_stream(self.alt_stream) + else: + q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt) + + # k_fp8: (seq_len, head_dim) fp8_e4m3fn + # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn + # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn + # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn + forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer( + layer_id=layer_id, + loc=forward_batch.out_cache_loc, + index_k=k_fp8, + index_k_scale=k_scale, + ) + + weights = self._get_logits_head_gate(x, q_scale) + + if is_cuda(): + assert forward_batch.seq_lens_cpu is not None + if len(forward_batch.seq_lens_cpu) == 0: + # this seems b/c max-pad, no worries? + # if x.shape[0] != 0: + # print( + # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result" + # ) + return torch.full( + (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda" + ) + + if forward_batch.forward_mode.is_decode_or_idle(): + topk_result = self._get_topk_paged( + forward_batch, layer_id, q_fp8, weights, metadata + ) + else: + topk_result = self._get_topk_ragged( + forward_batch, layer_id, q_fp8, weights, metadata + ) + else: + topk_result = self.forward_indexer( + q_fp8.contiguous(), + weights, + forward_batch, + topk=self.index_topk, + layer_id=layer_id, + ) + + return topk_result + + def forward_cuda( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + ) -> Optional[torch.Tensor]: + return self._forward(x, q_lora, positions, forward_batch, layer_id) + + def forward_npu( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + ) -> torch.Tensor: + import custom_ops + import torch_npu + + from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + ) + from sglang.srt.utils import get_bool_env_var + + if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None: + actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens + else: + actual_seq_lengths_kv = ( + forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int + ) + enable_index_cp = ( + get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4 + ) + is_prefill = forward_batch.forward_mode.is_extend() + + attention_tp_rank = get_attention_tp_rank() + attention_tp_size = get_attention_tp_size() + + cos_sin = self.rotary_emb.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + if is_prefill and enable_index_cp: + slice_length = cos.shape[0] // attention_tp_size + cos = cos[ + slice_length + * attention_tp_rank : slice_length + * (attention_tp_rank + 1) + ] + sin = sin[ + slice_length + * attention_tp_rank : slice_length + * (attention_tp_rank + 1) + ] + + slot_mapping = forward_batch.out_cache_loc + block_table = forward_batch.attn_backend.forward_metadata.block_tables + + bs = x.shape[0] + + q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] + q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] + q_pe, q_nope = torch.split( + q, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64, 64 + 64] + + q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view( + bs, self.n_heads, self.rope_head_dim + ) # [bs, n, d] + q = torch.cat([q_pe, q_nope], dim=-1) + + k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] + k = self.k_norm(k_proj) + k_pe, k_nope = torch.split( + k, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64 + 64] + + k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view( + bs, 1, self.rope_head_dim + ) # [bs, 1, d] + k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] + + if is_prefill and enable_index_cp: + k, local_k = ( + torch.empty( + (k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]), + dtype=k.dtype, + device=k.device, + ), + k, + ) + get_attention_tp_group().all_gather_into_tensor(k, local_k) + + forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k) + + indexer_input = {} + if is_prefill: + actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device) + actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to( + device=q.device + ) + if enable_index_cp: + actual_seq_lengths_q -= bs * attention_tp_rank + actual_seq_lengths_q = torch.max( + actual_seq_lengths_q, + torch.zeros_like(actual_seq_lengths_q).to( + device=actual_seq_lengths_q.device + ), + ) + actual_seq_lengths_q = torch.min( + actual_seq_lengths_q, + torch.full(actual_seq_lengths_q.shape, bs).to( + device=actual_seq_lengths_q.device + ), + ) + + else: + if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None: + actual_seq_lengths_q = torch.tensor( + [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device + ) + else: + actual_seq_lengths_q = ( + forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q + ) + + past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id) + + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x)[0] + block_table = ( + block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table + ) + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q.view(-1, self.n_heads, self.head_dim), + key=past_key_states, + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32), + actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32), + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=self.index_topk, + sparse_mode=3, + ) + + if is_prefill and enable_index_cp: + topk_indices, local_topk_indices = ( + torch.empty( + ( + topk_indices.shape[0] * attention_tp_size, + topk_indices.shape[1], + topk_indices.shape[2], + ), + dtype=topk_indices.dtype, + device=topk_indices.device, + ), + topk_indices, + ) + get_attention_tp_group().all_gather_into_tensor( + topk_indices, local_topk_indices + ) + + return topk_indices diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py new file mode 100644 index 00000000000..1c7ae38b564 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py @@ -0,0 +1,255 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST + + +def quantize_k_cache(cache_k): + # TODO upstream can skip concat([k_nope, k_pe]) since we split them here + if NSA_QUANT_K_CACHE_FAST: + return _quantize_k_cache_fast_wrapped(cache_k) + else: + return _quantize_k_cache_slow(cache_k) + + +# Copied from original +def _quantize_k_cache_slow( + input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) + dv: int = 512, + tile_size: int = 128, +) -> torch.Tensor: + """ + Quantize the k-cache + Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, d = input_k_cache.shape + assert h_k == 1 + input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] + input_elem_size = input_k_cache.element_size() + + result = torch.empty( + (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), + dtype=torch.float8_e4m3fn, + device=input_k_cache.device, + ) + result_k_nope_part = result[..., :dv] + result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype) + result_k_rope_part[:] = input_k_cache[..., dv:] + + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = ( + torch.abs( + input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] + ) + .max(dim=-1) + .values + / 448.0 + ) # [num_blocks, block_size] + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = ( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].float() + / cur_scale_factors_inv.float() + ).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_quantized_nope + ) + + result = result.view(num_blocks, block_size, 1, -1) + return result + + +def _quantize_k_cache_fast_wrapped( + input_k_cache: torch.Tensor, + dv: int = 512, + tile_size: int = 128, +) -> torch.Tensor: + # TODO the final API may be 2D instead of 4D, thus we convert them here + num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape + assert dv == 512 + assert dim_nope_and_rope == 512 + 64 + assert tile_size == 128 + input_k_cache = input_k_cache.view((-1, dim_nope_and_rope)) + + # TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one + k_nope = input_k_cache[:, :dv] + k_rope = input_k_cache[:, dv:] + + output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope) + + return output.view(num_blocks, block_size, 1, -1) + + +def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128): + """ + :param k_nope: (num_tokens, dim_nope 512) + :param k_rope: (num_tokens, dim_rope 64) + """ + + assert k_nope.dtype == torch.bfloat16 + assert k_rope.dtype == torch.bfloat16 + + num_tokens, dim_nope = k_nope.shape + num_tokens_, dim_rope = k_rope.shape + assert num_tokens == num_tokens_ + assert dim_nope == 512 + assert dim_rope == 64 + assert k_nope.dtype == k_rope.dtype + num_tiles = dim_nope // group_size + + assert k_nope.stride(1) == 1 + assert k_rope.stride(1) == 1 + + output = torch.empty( + (num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope), + dtype=torch.float8_e4m3fn, + device=k_nope.device, + ) + output_nope_q = output[..., :dim_nope] + output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32) + output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16) + + num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size) + assert num_blocks_per_token == 5 + + assert dim_nope % group_size == 0 + NUM_NOPE_BLOCKS = dim_nope // group_size + + _quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)]( + output_nope_q, + output_nope_s, + output_rope, + k_nope, + k_rope, + output_nope_q.stride(0), + output_nope_s.stride(0), + output_rope.stride(0), + k_nope.stride(0), + k_rope.stride(0), + NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS, + GROUP_SIZE=group_size, + DIM_NOPE=dim_nope, + DIM_ROPE=dim_rope, + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + ) + + return output + + +@triton.jit +def _quantize_k_cache_fast_kernel( + output_nope_q_ptr, + output_nope_s_ptr, + output_rope_ptr, + k_nope_ptr, + k_rope_ptr, + output_nope_q_stride_0: int, + output_nope_s_stride_0: int, + output_rope_stride_0: int, + k_nope_stride_0: int, + k_rope_stride_0: int, + NUM_NOPE_BLOCKS: tl.constexpr, + GROUP_SIZE: tl.constexpr, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, +): + token_id = tl.program_id(0) + raw_block_id = tl.program_id(1) + + if raw_block_id < NUM_NOPE_BLOCKS: + # a. quant nope + effective_block_id = raw_block_id + + offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + mask = offs < DIM_NOPE + ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs + + y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32) + + # the ref impl do not have a `tl.maximum(... eps)`, so we remove it here + y_s = tl.max(tl.abs(y)) / FP8_MAX + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to( + output_nope_q_ptr.dtype.element_ty + ) + + dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs + dst_s_ptr = ( + output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id + ) + + tl.store(dst_q_ptr, y_q, mask=mask) + tl.store(dst_s_ptr, y_s) + else: + # b. copy rope + effective_block_id = raw_block_id - NUM_NOPE_BLOCKS + + offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + mask = offs < DIM_ROPE + + src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs + dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs + + data = tl.load(src_ptr, mask=mask) + tl.store(dst_ptr, data, mask=mask) + + +if __name__ == "__main__": + for num_blocks, block_size in [ + (1, 1), + (10, 64), + ]: + dim_nope_and_rope = 512 + 64 + + input_k_cache = torch.randn( + (num_blocks, block_size, 1, dim_nope_and_rope), + dtype=torch.bfloat16, + device="cuda", + ) + # temp debug + # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope) + + ref_quant = _quantize_k_cache_slow(input_k_cache) + actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache) + # print(f"{input_k_cache=}") + # print(f"{ref_quant=}") + # print(f"{actual_quant=}") + # print(f"{ref_quant == actual_quant=}") + # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}") + # print(f"{ref_quant.view(torch.bfloat16)=}") + # print(f"{actual_quant.view(torch.bfloat16)=}") + # assert torch.all(ref_quant == actual_quant) + + import dequant_k_cache + + ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant) + ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant) + actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped( + actual_quant + ) + + print(f"{ref_ref_dequant=}") + print(f"{actual_actual_dequant=}") + print(f"{actual_actual_dequant - ref_ref_dequant=}") + print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}") + + # TODO too different? + torch.testing.assert_close( + ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2 + ) + torch.testing.assert_close( + ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2 + ) + + print("Passed") diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py new file mode 100644 index 00000000000..05266ee72af --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -0,0 +1,785 @@ +from typing import Optional, Tuple + +import tilelang +import tilelang.language as T +import torch + +from sglang.srt.utils import is_hip + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +} + +BF16 = "bfloat16" +FP8 = "float8_e4m3" +FP32 = "float32" + +_is_hip = is_hip() + + +def fast_log2_ceil(x): + bits_x = T.reinterpret("uint32", x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret("float32", bits_x) + + +def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel( + N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False +): + M = T.symbolic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m,), scale_dtype) + s_local = T.alloc_fragment((blk_m,), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp( + x_local[i, j] / s_local[i], fp8_min, fp8_max + ) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int, clear_accum=True): + b = T.symbolic("b") + m = T.symbolic("m") + n = T.symbolic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=clear_accum, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + if _is_hip: + return fp8_index_kernel(q.shape[2], q.shape[3], False)(q, q_s, k, k_s) + else: + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) + + +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_attention_fwd_kernel_v1( + num_heads, + dim, + tail_dim, + topk, + *, + kv_group=1, + sm_scale=None, + is_causal=True, + block_I=64, + num_stages=2, + threads=256, +): + assert dim == tilelang.math.next_power_of_2( + dim + ), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim + ), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert ( + topk % block_I == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.symbolic("batch") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = num_heads // kv_group + q_shape = [batch, seq_len, num_heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, num_heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1 + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + + for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0 + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i + ] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i + ] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + mask[bi_i], 0, -T.infinity(acc_s.dtype) + ) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, O_shared) + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + + return main + + +@tilelang.jit( + out_idx=[-1], + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + ], +) # type: ignore +def sparse_attention_fwd_kernel_v2( + num_heads: int, + dim: int, + tail_dim: int, + topk: int, + *, + kv_group: int = 1, + sm_scale: Optional[float] = None, + block_I: int = 64, +): + assert dim == tilelang.math.next_power_of_2( + dim + ), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim + ), f"haven't check padding correctness yet, dim={tail_dim}" + assert ( + topk % block_I == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + threads = 384 + + batch = T.symbolic("batch") + qo_len = T.symbolic("seq_len") + num_pages = T.symbolic("num_pages") + + q_shape = [batch, qo_len, num_heads, dim + tail_dim] + kv_shape = [batch, num_pages, kv_group, dim + tail_dim] + o_shape = [batch, qo_len, num_heads, dim] + indices_shape = [batch, qo_len, kv_group, topk] + + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + H = num_heads + padded_H = max(tilelang.math.next_power_of_2(num_heads), 16) + if padded_H != H: + assert kv_group == 1 + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, "NI should be a multiple of 2" + D = dim + D_tail = tail_dim + if num_heads > 64: + assert num_heads % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = num_heads // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + ): + """ + Q: [b, qo_len, H, D + D_tail] (bfloat16) + KV: [b, num_pages, kv_group, D + D_tail] (bfloat16) + Indices: [b, qo_len, kv_group, topk] (int32) + """ + + with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared") + is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + indices_local = T.alloc_local([1], indices_dtype) + indices_tmp = T.alloc_local([1], indices_dtype) + + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + bar_0_128 = T.alloc_barrier(arrive_count=128) + bar_1_128 = T.alloc_barrier(arrive_count=128) + bar_2_128 = T.alloc_barrier(arrive_count=128) + bar_final = T.alloc_barrier(arrive_count=128) + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + # with sync_at(bar_0_128, 0): + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + T.barrier_arrive(bar_0_128) + T.barrier_wait(bar_0_128, 0) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype) + ) + T.gemm( + Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1 + ) + T.gemm( + Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1 + ) + T.gemm( + Q_tail_shared, + K_tail_shared_0, + acc_s, + transpose_B=True, + wg_wait=-1, + ) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum( + acc_s, sumexp_i, dim=1 + ) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + T.barrier_arrive(bar_0_128) + T.barrier_wait(bar_0_128, 1) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype) + ) + T.gemm( + Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1 + ) + T.gemm( + Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1 + ) + T.gemm( + Q_tail_shared, + K_tail_shared_1, + acc_s, + transpose_B=True, + wg_wait=-1, + ) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum( + acc_s, sumexp_i, dim=1 + ) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(H_per_block): + sum_exp_shared[h_i] = sumexp[h_i] + T.barrier_arrive(bar_final) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) + elif tx >= 128 and tx < 256: + # T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + T.barrier_arrive(bar_1_128) + T.barrier_wait(bar_1_128, 0) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + T.barrier_arrive(bar_1_128) + T.barrier_wait(bar_1_128, 1) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + T.barrier_wait(bar_final, 0) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + indices_local[0] = 0 + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + T.barrier_arrive(bar_2_128) + T.barrier_wait(bar_2_128, 0) + + for r in T.serial(4): + indices_tmp[0] = Indices[ + b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8 + ] + is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0 + if is_kv_valid_0[r * 16 + (tx - 256) // 8]: + indices_local[0] = indices_tmp[0] + + with T.attr("default", "async_scope", 1): # type: ignore + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[ + r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + v, + ] = KV[ + b_i, + indices_local[0], + g_i, + 64 * u + (tx - 256) % 8 * 8 + v, + ] + KV_shared_0_r[ + r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + v, + ] = KV[ + b_i, + indices_local[0], + g_i, + D // 2 + 64 * u + (tx - 256) % 8 * 8 + v, + ] + with T.attr("default", "async_scope", 1): # type: ignore + for v in T.vectorized(8): + K_tail_shared_0[ + r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v + ] = KV[ + b_i, + indices_local[0], + g_i, + D + (tx - 256) % 8 * 8 + v, + ] + + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + T.barrier_arrive(bar_2_128) + T.barrier_wait(bar_2_128, 1) + + for r in T.serial(4): + indices_tmp[0] = Indices[ + b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8 + ] + is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0 + if is_kv_valid_1[r * 16 + (tx - 256) // 8]: + indices_local[0] = indices_tmp[0] + + with T.attr("default", "async_scope", 1): # type: ignore + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[ + r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + v, + ] = KV[ + b_i, + indices_local[0], + g_i, + 64 * u + (tx - 256) % 8 * 8 + v, + ] + KV_shared_1_r[ + r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + v, + ] = KV[ + b_i, + indices_local[0], + g_i, + D // 2 + 64 * u + (tx - 256) % 8 * 8 + v, + ] + with T.attr("default", "async_scope", 1): # type: ignore + for v in T.vectorized(8): + K_tail_shared_1[ + r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v + ] = KV[ + b_i, + indices_local[0], + g_i, + D + (tx - 256) % 8 * 8 + v, + ] + + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + return main + + +def tilelang_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> torch.Tensor: + assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3 + num_heads = q.shape[1] + dim = q.shape[2] + tail_dim = dim - d_v + topk = indices.shape[-1] + assert topk == 2048 + if _is_hip: + kernel = sparse_attention_fwd_kernel_v1( + num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1 + ) + else: + kernel = sparse_attention_fwd_kernel_v2( + num_heads, d_v, tail_dim, topk, sm_scale=sm_scale + ) + return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore diff --git a/python/sglang/srt/layers/attention/nsa/transform_index.py b/python/sglang/srt/layers/attention/nsa/transform_index.py new file mode 100644 index 00000000000..442dd113d20 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/transform_index.py @@ -0,0 +1,144 @@ +from typing import List, Optional + +import torch +import triton +import triton.language as tl + + +def transform_index_page_table_prefill(**kwargs): + return transform_index_page_table_prefill_ref(**kwargs) + + +def transform_index_page_table_decode(**kwargs): + return transform_index_page_table_decode_ref(**kwargs) + + +@triton.jit +def transform_index_page_table_decode_kernel( + page_table_ptr: torch.Tensor, + topk_indices_ptr: torch.Tensor, + result_ptr: torch.Tensor, + page_size: tl.constexpr, + max_seqlen_k: tl.constexpr, +): + TOPK: tl.constexpr = 2048 + req_id = tl.program_id(0) + page_table_ptr = page_table_ptr + req_id * max_seqlen_k + topk_indices_ptr = topk_indices_ptr + req_id * TOPK + result_ptr = result_ptr + req_id * TOPK + + offset = tl.arange(0, TOPK) # topk should be 2048 + loaded_topk_indices = tl.load(topk_indices_ptr + offset) + mask = loaded_topk_indices >= 0 + loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask) + tl.store(result_ptr + offset, loaded_kv_indices, mask=mask) + tl.store(result_ptr + offset, -1, mask=~mask) + + +def transform_index_page_table_decode_fast( + page_table: torch.Tensor, + topk_indices: torch.Tensor, + result: Optional[torch.Tensor] = None, + page_size: int = 1, +) -> torch.Tensor: + """ + Transform the page table according to topk indices for sparse topk attention. + Args: + page_table: [qo_len, max_seqlen_k], the original page table + topk_indices: [qo_len, topk], the topk indices for each query position + Returns: + transformed_page_table: [qo_len, topk], the transformed page table + For out-of-bound indices in topk_indices, this should be filled with -1. + """ + assert page_size == 1 + assert page_table.shape[0] == topk_indices.shape[0] + assert topk_indices.shape[1] == 2048 + qo_len = topk_indices.shape[0] + max_seqlen_k = page_table.shape[1] + if result is None: + result = torch.empty_like(topk_indices, dtype=torch.int32) + # Launch triton kernel + grid = (qo_len,) + transform_index_page_table_decode_kernel[grid]( + page_table, + topk_indices, + result, + page_size, + max_seqlen_k=max_seqlen_k, + ) + return result + + +def transform_index_page_table_prefill_fast( + page_table: torch.Tensor, + topk_indices: torch.Tensor, + extend_lens_cpu: List[int], + page_size: int = 1, +) -> torch.Tensor: + # TODO(baizhou): can be implemented with another triton kernel + assert page_size == 1 + result = torch.empty_like(topk_indices, dtype=torch.int32) + assert len(extend_lens_cpu) == page_table.shape[0] + offset = 0 + for i, l in enumerate(extend_lens_cpu): + transform_index_page_table_decode_fast( + page_table[i].unsqueeze(0).expand(l, -1), + topk_indices[offset : offset + l], + result=result[offset : offset + l], + ) + offset += l + assert offset == topk_indices.shape[0] + return result + + +def transform_index_page_table_decode_ref( + page_table: torch.Tensor, + topk_indices: torch.Tensor, + result: Optional[torch.Tensor] = None, + page_size: int = 1, +) -> torch.Tensor: + assert page_size == 1 + assert page_table.shape[0] == topk_indices.shape[0] + if result is None: + result = torch.empty_like(topk_indices, dtype=torch.int32) + assert result.shape == topk_indices.shape + torch.gather( + page_table, + dim=1, + index=topk_indices.clamp(min=0), + out=result, + ) + result[topk_indices < 0] = -1 + return result + + +def transform_index_page_table_prefill_ref( + page_table: torch.Tensor, + topk_indices: torch.Tensor, + extend_lens_cpu: List[int], + page_size: int = 1, +) -> torch.Tensor: + assert page_size == 1 + result = torch.empty_like(topk_indices, dtype=torch.int32) + assert len(extend_lens_cpu) == page_table.shape[0] + offset = 0 + for i, l in enumerate(extend_lens_cpu): + transform_index_page_table_decode_ref( + page_table[i].unsqueeze(0).expand(l, -1), + topk_indices[offset : offset + l], + result=result[offset : offset + l], + ) + offset += l + assert offset == topk_indices.shape[0] + return result + + +if __name__ == "__main__": + bs, topk, max_seqlen = 10, 2048, 3000 + page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda") + topk_indices = torch.full((bs, topk), -1, device="cuda") + topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1) + ref_result = transform_index_page_table_decode_ref(page_table, topk_indices) + result = transform_index_page_table_decode_fast(page_table, topk_indices) + assert torch.all(result == ref_result) + print("Passed") diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py new file mode 100644 index 00000000000..348f1b73645 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -0,0 +1,24 @@ +# temp NSA debugging environ +from sglang.srt.utils import get_bool_env_var + +NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true") +NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true") +NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true") + +NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var( + "SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true" +) +NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "true") +NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "true") + + +def print_nsa_bool_env_vars(): + msg = "" + for k, v in globals().items(): + if k.startswith("NSA_") and isinstance(v, bool): + msg += f"{k}={v} " + print(msg, flush=True) + + +def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int): + return original_seq_lens.clamp(max=nsa_index_topk) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py new file mode 100644 index 00000000000..74d293fd310 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -0,0 +1,887 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias + +import torch + +from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata +from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache +from sglang.srt.layers.attention.nsa.transform_index import ( + transform_index_page_table_decode, + transform_index_page_table_prefill, +) +from sglang.srt.layers.attention.nsa.utils import ( + NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8, + NSA_FUSE_TOPK, + compute_nsa_seqlens, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_hip + +# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +_is_hip = is_hip() + +if _is_hip: + try: + from aiter import ( + flash_attn_varlen_func, + mha_batch_prefill_func, + paged_attention_ragged, + ) + from aiter.mla import mla_decode_fwd, mla_prefill_fwd + except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) +else: + from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + + +@dataclass(frozen=True) +class NSAFlashMLAMetadata: + """Metadata only needed by FlashMLA""" + + flashmla_metadata: torch.Tensor + num_splits: torch.Tensor + + def slice(self, sli): + return NSAFlashMLAMetadata( + flashmla_metadata=self.flashmla_metadata, + num_splits=self.num_splits[sli], + ) + + def copy_(self, other: "NSAFlashMLAMetadata"): + self.flashmla_metadata.copy_(other.flashmla_metadata) + self.num_splits.copy_(other.num_splits) + + +@dataclass(frozen=True) +class NSAMetadata: + page_size: int + + # Sequence lengths for the forward batch + cache_seqlens_int32: torch.Tensor + # Maximum sequence length for query + max_seq_len_q: int + # Maximum sequence length for key + max_seq_len_k: int + # Cumulative sequence lengths for query + cu_seqlens_q: torch.Tensor + # Cumulative sequence lengths for key + cu_seqlens_k: torch.Tensor + # Page table, the index of KV Cache Tables/Blocks + # this table is always with page_size = 1 + page_table_1: torch.Tensor + + # NOTE(dark): This will property be used in: + # 1. dense decode/prefill, we use paged flash attention, need real_page_table + # 2. sparse decode/prefill, indexer need real_page_table to compute the score + real_page_table: torch.Tensor + + # NSA metadata (nsa prefill are expanded) + nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk` + nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k)) + nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32` + nsa_extend_seq_lens_list: List[int] + nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens` + nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend + + flashmla_metadata: Optional[NSAFlashMLAMetadata] = None + + +@dataclass(frozen=True) +class NSAIndexerMetadata(BaseIndexerMetadata): + attn_metadata: NSAMetadata + + def get_seqlens_int32(self) -> torch.Tensor: + return self.attn_metadata.cache_seqlens_int32 + + def get_page_table_64(self) -> torch.Tensor: + return self.attn_metadata.real_page_table + + def get_seqlens_expanded(self) -> torch.Tensor: + return self.attn_metadata.nsa_seqlens_expanded + + def topk_transform( + self, + logits: torch.Tensor, + topk: int, + ) -> torch.Tensor: + from sgl_kernel import fast_topk_transform_fused, fast_topk_v2 + + if not NSA_FUSE_TOPK: + return fast_topk_v2(logits, self.get_seqlens_expanded(), topk) + + # NOTE(dark): if fused, we return a transformed page table directly + return fast_topk_transform_fused( + score=logits, + lengths=self.get_seqlens_expanded(), + page_table_size_1=self.attn_metadata.page_table_1, + cu_seqlens_q=self.attn_metadata.cu_seqlens_q, + topk=topk, + ) + + +def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: + assert seqlens.dtype == torch.int32 and seqlens.is_cuda + return torch.nn.functional.pad( + torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0) + ) + + +_NSA_IMPL_T: TypeAlias = Literal[ + "flashmla_prefill", "flashmla_decode", "fa3", "tilelang" +] + +NSA_PREFILL_IMPL: _NSA_IMPL_T +NSA_DECODE_IMPL: _NSA_IMPL_T + + +class NativeSparseAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.forward_metadata: NSAMetadata + self.device = model_runner.device + assert isinstance(model_runner.page_size, int) + self.real_page_size = model_runner.page_size + self.num_splits = ( + 1 if model_runner.server_args.enable_deterministic_inference else 0 + ) + self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config) + assert self.use_nsa, "NSA backend only supports DeepSeek NSA" + self.nsa_kv_cache_store_fp8 = ( + model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8 + ) + self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config) + self.max_context_len = model_runner.model_config.context_len + self.num_q_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim + + assert model_runner.req_to_token_pool is not None + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + global NSA_PREFILL_IMPL, NSA_DECODE_IMPL + NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill + NSA_DECODE_IMPL = model_runner.server_args.nsa_decode + + self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32) + + if _is_hip: + max_bs = model_runner.req_to_token_pool.size + + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + def get_device_int32_arange(self, l: int) -> torch.Tensor: + if l > len(self._arange_buf): + next_pow_of_2 = 1 << (l - 1).bit_length() + self._arange_buf = torch.arange( + next_pow_of_2, device=self.device, dtype=torch.int32 + ) + return self._arange_buf[:l] + + def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor: + page_size = self.real_page_size + if page_size == 1: + return page_table + max_seqlen_k = page_table.shape[1] + strided_indices = torch.arange( + 0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32 + ) + return page_table[:, strided_indices] // page_size + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + batch_size = forward_batch.batch_size + device = forward_batch.seq_lens.device + + assert ( + forward_batch.spec_info is None + ), "Spec decoding is not supported for NSA backend now" + cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32) + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + assert forward_batch.seq_lens_cpu is not None + max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item()) + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, :max_seqlen_k + ] + + if forward_batch.forward_mode.is_decode_or_idle(): + extend_seq_lens_cpu = [1] * batch_size + max_seqlen_q = 1 + cu_seqlens_q = self.get_device_int32_arange(batch_size + 1) + seqlens_expanded = cache_seqlens_int32 + elif forward_batch.forward_mode.is_extend(): + assert ( + forward_batch.extend_seq_lens_cpu is not None + and forward_batch.extend_seq_lens is not None + and forward_batch.extend_prefix_lens_cpu is not None + ), "All of them must not be None" + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + assert forward_batch.extend_seq_lens is not None + if any(forward_batch.extend_prefix_lens_cpu): + max_seqlen_q = max(extend_seq_lens_cpu) + cu_seqlens_q = compute_cu_seqlens( + forward_batch.extend_seq_lens.to(torch.int32) + ) + else: + max_seqlen_q = max_seqlen_k + cu_seqlens_q = cu_seqlens_k + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=device, + ) + for qo_len, kv_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.seq_lens_cpu.tolist(), + strict=True, + ) + ] + ) + else: + assert False, f"Unsupported {forward_batch.forward_mode = }" + + # 1D, expanded seqlens (1D means cheap to compute, so always compute it) + nsa_cache_seqlens_int32 = compute_nsa_seqlens( + original_seq_lens=seqlens_expanded, + nsa_index_topk=self.nsa_index_topk, + ) + nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32) + nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k)) + + metadata = NSAMetadata( + page_size=self.real_page_size, + cache_seqlens_int32=cache_seqlens_int32, + max_seq_len_q=max_seqlen_q, + max_seq_len_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + page_table_1=page_table, + flashmla_metadata=( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens_int32, + seq_len_q=1, # TODO handle MTP which is not 1 + ) + if NSA_DECODE_IMPL == "flashmla_decode" + else None + ), + nsa_cache_seqlens_int32=nsa_cache_seqlens_int32, + nsa_cu_seqlens_q=nsa_cu_seqlens_q, + nsa_cu_seqlens_k=nsa_cu_seqlens_k, + nsa_seqlens_expanded=seqlens_expanded, + nsa_extend_seq_lens_list=extend_seq_lens_cpu, + real_page_table=self._transform_table_1_to_real(page_table), + ) + + self.forward_metadata = metadata + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + """Initialize CUDA graph state for the attention backend. + + Args: + max_bs (int): Maximum batch size to support in CUDA graphs + + This creates fixed-size tensors that will be reused during CUDA graph replay + to avoid memory allocations. + """ + self.decode_cuda_graph_metadata: Dict = { + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "cu_seqlens_q": torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=self.device + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + # fake page_table for sparse_prefill + "page_table": torch.zeros( + max_bs, + self.max_context_len, + dtype=torch.int32, + device=self.device, + ), + "flashmla_metadata": ( + self._compute_flashmla_metadata( + cache_seqlens=torch.ones( + max_bs, dtype=torch.int32, device=self.device + ), + seq_len_q=1, # TODO handle MTP which is not 1 + ) + if NSA_DECODE_IMPL == "flashmla_decode" + else None + ), + } + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + """Initialize forward metadata for capturing CUDA graph.""" + assert forward_mode.is_decode_or_idle(), "Only support decode for now" + assert ( + spec_info is None + ), "Speculative decoding is not supported for NSA backend now" + + # Normal Decode + # Get sequence information + cache_seqlens_int32 = seq_lens.to(torch.int32) + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + + # Use max context length for seq_len_k + page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :] + max_seq_len_k = page_table_1.shape[1] + + # Precompute page table + # Precompute cumulative sequence lengths + + # NOTE(dark): this is always arange, since we are decoding + cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1] + nsa_cache_seqlens_int32 = compute_nsa_seqlens( + cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk + ) + nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32) + nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k)) + real_page_table = self._transform_table_1_to_real(page_table_1) + + if NSA_DECODE_IMPL == "flashmla_decode": + flashmla_metadata = self.decode_cuda_graph_metadata[ + "flashmla_metadata" + ].slice(slice(0, bs + 1)) + flashmla_metadata.copy_( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens_int32, + seq_len_q=1, # TODO handle MTP which is not 1 + ) + ) + else: + flashmla_metadata = None + + metadata = NSAMetadata( + page_size=self.real_page_size, + cache_seqlens_int32=cache_seqlens_int32, + max_seq_len_q=1, + max_seq_len_k=max_seq_len_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + page_table_1=page_table_1, + flashmla_metadata=flashmla_metadata, + nsa_cache_seqlens_int32=nsa_cache_seqlens_int32, + nsa_cu_seqlens_q=nsa_cu_seqlens_q, + nsa_cu_seqlens_k=nsa_cu_seqlens_k, + nsa_seqlens_expanded=cache_seqlens_int32, + real_page_table=real_page_table, + nsa_extend_seq_lens_list=[1] * bs, + ) + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + """Initialize forward metadata for replaying CUDA graph.""" + assert seq_lens_cpu is not None + assert forward_mode.is_decode_or_idle(), "Only support decode for now" + assert ( + spec_info is None + ), "Speculative decoding is not supported for NSA backend now" + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + + # Normal Decode + metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs] + max_len = int(seq_lens_cpu.max().item()) + + cache_seqlens = seq_lens.to(torch.int32) + metadata.cache_seqlens_int32.copy_(cache_seqlens) + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) + ) + page_indices = self.req_to_token[req_pool_indices, :max_len] + metadata.page_table_1[:, :max_len].copy_(page_indices) + assert ( + metadata.nsa_cache_seqlens_int32 is not None + and metadata.nsa_cu_seqlens_k is not None + and self.nsa_index_topk is not None + ) + nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk) + metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens) + metadata.nsa_cu_seqlens_k[1:].copy_( + torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32) + ) + # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy + + assert self.real_page_size == metadata.page_size + if self.real_page_size > 1: + real_table = self._transform_table_1_to_real(page_indices) + new_len = real_table.shape[1] + metadata.real_page_table[:, :new_len].copy_(real_table) + else: + assert metadata.real_page_table is metadata.page_table_1 + + if NSA_DECODE_IMPL == "flashmla_decode": + metadata.flashmla_metadata.copy_( + self._compute_flashmla_metadata( + cache_seqlens=nsa_cache_seqlens, + seq_len_q=1, # TODO handle MTP which is not 1 + ) + ) + + self.forward_metadata = metadata + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert ( + not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ), "NSA backend doesn't support speculative decoding" + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore + layer, + cache_loc, + k, + k_rope, + ) + + metadata = self.forward_metadata + causal = not layer.is_cross_attention + assert causal, "NSA is causal only" + + # For fa3 interface version compatibility, we put new fields into conditional keyword args + kwargs = {} + + # Do absorbed multi-latent attention + assert q_rope is not None + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + # when store in fp8 and compute in fp8, no need to convert dtype + if not ( + NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8 + ): + kv_cache = kv_cache.to(q.dtype) + + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + + # NOTE(dark): here, we use page size = 1 + + if NSA_FUSE_TOPK: + page_table_1 = topk_indices + else: + assert metadata.nsa_extend_seq_lens_list is not None + page_table_1 = transform_index_page_table_prefill( + page_table=metadata.page_table_1, + topk_indices=topk_indices, + extend_lens_cpu=metadata.nsa_extend_seq_lens_list, + page_size=1, + ) + if NSA_PREFILL_IMPL == "tilelang": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_tilelang( + q_all=q_all, + kv_cache=kv_cache, + page_table_1=page_table_1, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + ) + elif NSA_PREFILL_IMPL == "flashmla_prefill": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_flashmla_prefill( + q_all=q_all, + kv_cache=kv_cache, + page_table_1=page_table_1, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + ) + elif NSA_PREFILL_IMPL == "flashmla_decode": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_flashmla_decode( + q_all=q_all, + kv_cache=kv_cache, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + # TODO optimize args + layer=layer, + metadata=metadata, + page_table_1=page_table_1, + ) + elif NSA_PREFILL_IMPL == "fa3": + return self._forward_fa3( + q_rope=q_rope, + kv_cache=kv_cache, + v_head_dim=layer.v_head_dim, + q_nope=q_nope, + page_table=page_table_1, + cache_seqlens=metadata.nsa_cache_seqlens_int32, + cu_seqlens_q=metadata.nsa_cu_seqlens_q, + cu_seqlens_k=metadata.nsa_cu_seqlens_k, + max_seqlen_q=metadata.nsa_max_seqlen_q, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + page_size=1, + ) + else: + raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }") + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if k is not None: + assert v is not None + if save_kv_cache: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore + layer, + cache_loc, + k, + k_rope, + ) + + metadata = self.forward_metadata + causal = not layer.is_cross_attention + assert causal, "NSA is causal only" + + # Do absorbed multi-latent attention + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] + + if NSA_FUSE_TOPK: + page_table_1 = topk_indices + else: + page_table_1 = transform_index_page_table_decode( + page_table=metadata.page_table_1, + topk_indices=topk_indices, + page_size=1, + ) + + if NSA_DECODE_IMPL == "flashmla_prefill": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_flashmla_prefill( + q_all=q_all, + kv_cache=kv_cache, + page_table_1=page_table_1, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + ) + elif NSA_DECODE_IMPL == "flashmla_decode": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_flashmla_decode( + q_all=q_all, + kv_cache=kv_cache, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + # TODO optimize args + layer=layer, + metadata=metadata, + page_table_1=page_table_1, + ) + elif NSA_DECODE_IMPL == "tilelang": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_tilelang( + q_all=q_all, + kv_cache=kv_cache, + page_table_1=page_table_1, + sm_scale=layer.scaling, + v_head_dim=layer.v_head_dim, + ) + elif NSA_DECODE_IMPL == "fa3": + return self._forward_fa3( + q_rope=q_rope, + kv_cache=kv_cache, + v_head_dim=layer.v_head_dim, + q_nope=q_nope, + page_table=page_table_1, + cache_seqlens=metadata.nsa_cache_seqlens_int32, + cu_seqlens_q=metadata.nsa_cu_seqlens_q, + cu_seqlens_k=metadata.nsa_cu_seqlens_k, + max_seqlen_q=metadata.nsa_max_seqlen_q, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + page_size=1, + ) + elif NSA_DECODE_IMPL == "aiter": + if q_rope is not None: + q_all = torch.cat([q_nope, q_rope], dim=-1) + return self._forward_aiter( + q_all=q_all, + kv_cache=kv_cache, + page_table_1=page_table_1, + layer=layer, + metadata=metadata, + bs=forward_batch.batch_size, + ) + + else: + assert False, f"Unsupported {NSA_DECODE_IMPL = }" + + def _forward_fa3( + self, + q_rope: torch.Tensor, + kv_cache: torch.Tensor, + v_head_dim: int, + q_nope: torch.Tensor, + page_table: torch.Tensor, + cache_seqlens: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + sm_scale: float, + logit_cap: float, + page_size: int, + ) -> torch.Tensor: + k_rope_cache = kv_cache[:, :, v_head_dim:] + c_kv_cache = kv_cache[:, :, :v_head_dim] + qk_rope_dim = k_rope_cache.shape[-1] + k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim) + c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim) + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope_cache, + v_cache=c_kv_cache, + qv=q_nope, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + softmax_scale=sm_scale, + causal=True, + softcap=logit_cap, + return_softmax_lse=False, + num_splits=self.num_splits, + ) + return o # type: ignore + + def _forward_flashmla_prefill( + self, + q_all: torch.Tensor, + kv_cache: torch.Tensor, + v_head_dim: int, + page_table_1: torch.Tensor, + sm_scale: float, + ) -> torch.Tensor: + from flash_mla import flash_mla_sparse_fwd + + o, _, _ = flash_mla_sparse_fwd( + q=q_all, + kv=kv_cache, + indices=page_table_1.unsqueeze(1), + sm_scale=sm_scale, + d_v=v_head_dim, + ) + return o + + def _forward_flashmla_decode( + self, + q_all: torch.Tensor, + kv_cache: torch.Tensor, + v_head_dim: int, + sm_scale: float, + layer, + metadata: NSAMetadata, + page_table_1, + ) -> torch.Tensor: + from flash_mla import flash_mla_with_kvcache + + cache_seqlens = metadata.nsa_cache_seqlens_int32 + + # TODO the 2nd dim is seq_len_q, need to be >1 when MTP + q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim) + kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim) + assert self.real_page_size == 64, "only page size 64 is supported" + + if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8: + # inefficiently quantize the whole cache + kv_cache = quantize_k_cache(kv_cache) + + indices = page_table_1.unsqueeze(1) + assert ( + indices.shape[-1] == self.nsa_index_topk + ) # requirement of FlashMLA decode kernel + + o, _ = flash_mla_with_kvcache( + q=q_all, + k_cache=kv_cache, + cache_seqlens=cache_seqlens, + head_dim_v=v_head_dim, + tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata, + num_splits=metadata.flashmla_metadata.num_splits, + softmax_scale=sm_scale, + indices=indices, + # doc says it is not used, but if pass in None then error + block_table=torch.empty( + (q_all.shape[0], 0), dtype=torch.int32, device=q_all.device + ), + is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8, + ) + return o + + def _forward_tilelang( + self, + q_all: torch.Tensor, + kv_cache: torch.Tensor, + v_head_dim: int, + page_table_1: torch.Tensor, + sm_scale: float, + ) -> torch.Tensor: + from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd + + return tilelang_sparse_fwd( + q=q_all, + kv=kv_cache, + indices=page_table_1.unsqueeze(1), + sm_scale=sm_scale, + d_v=v_head_dim, + ) + + def _forward_aiter( + self, + q_all: torch.Tensor, + kv_cache: torch.Tensor, + page_table_1: torch.Tensor, + layer: RadixAttention, + metadata: NSAMetadata, + bs: int, + ) -> torch.Tensor: + q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim) + + if layer.head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + kv_indptr = self.kv_indptr + + non_minus1_mask = page_table_1 != -1 + non_minus1_counts = non_minus1_mask.sum(dim=1) + kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0) + + kv_indices = page_table_1[page_table_1 != -1] + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + kv_cache.view(-1, 1, 1, layer.head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + metadata.cu_seqlens_q, + kv_indptr, + kv_indices, + metadata.cu_seqlens_q, + metadata.max_seq_len_q, + layer.scaling, + layer.logit_cap, + ) + # kv_cache = kv_cache.view(-1, 1, layer.head_dim) + return o + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for sequence length in CUDA graph.""" + return 1 + + def get_indexer_metadata( + self, layer_id: int, forward_batch: ForwardBatch + ) -> NSAIndexerMetadata: + return NSAIndexerMetadata(attn_metadata=self.forward_metadata) + + def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int): + from flash_mla import get_mla_metadata + + flashmla_metadata, num_splits = get_mla_metadata( + cache_seqlens=cache_seqlens, + # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k` + # but the name looks like need seq_len_q? + num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1, + num_heads_k=1, + num_heads_q=self.num_q_heads, + is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8, + topk=self.nsa_index_topk, + ) + + return NSAFlashMLAMetadata( + flashmla_metadata=flashmla_metadata, + num_splits=num_splits, + ) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5f2813da842..76f48bc4b4e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -813,45 +813,69 @@ def _forward_normal(dispatch_output: DeepEPNormalOutput): if isinstance(hidden_states, tuple): per_token_scale = hidden_states[1] hidden_states = hidden_states[0] - else: - # dynamic quant - hidden_states, per_token_scale = torch_npu.npu_dynamic_quant( - hidden_states - ) group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to( hidden_states.device ) + if self.w13_weight.dtype != torch.int8: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w13_weight.permute(0, 2, 1)], + # per_token_scale=[per_token_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + hidden_states = torch_npu.npu_swiglu(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w2_weight.permute(0, 2, 1)], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + else: + if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"): + hidden_states, per_token_scale = torch_npu.npu_dynamic_quant( + hidden_states + ) + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w13_weight], + scale=[self.w13_weight_scale.to(output_dtype)], + per_token_scale=[per_token_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states + ) - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight], - scale=[self.w13_weight_scale.to(output_dtype)], - per_token_scale=[per_token_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale.to(output_dtype)], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w2_weight], + scale=[self.w2_weight_scale.to(output_dtype)], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] return hidden_states @@ -860,47 +884,72 @@ def _forward_ll(dispatch_output: DeepEPLLOutput): assert isinstance(dispatch_output, DeepEPLLOutput) hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output - per_token_scale = hidden_states[1] - hidden_states = hidden_states[0] + if isinstance(hidden_states, tuple): + per_token_scale = hidden_states[1] + hidden_states = hidden_states[0] group_list = group_list.to(torch.int64) - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w13_weight], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32, - )[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=self.w13_weight_scale.to(torch.float32), - activation_scale=per_token_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) + if self.w13_weight.dtype != torch.int8: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w13_weight.permute(0, 2, 1)], + # per_token_scale=[per_token_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + hidden_states = torch_npu.npu_swiglu(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w2_weight.permute(0, 2, 1)], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] + else: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w13_weight], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32, + )[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=self.w13_weight_scale.to(torch.float32), + activation_scale=per_token_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale.to(output_dtype)], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=output_dtype, - )[0] + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[self.w2_weight], + scale=[self.w2_weight_scale.to(output_dtype)], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=output_dtype, + )[0] return hidden_states diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7ec246e4bb3..fba0664339f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -112,6 +112,8 @@ "enable_custom_logit_processor", "disaggregation_mode", "enable_deterministic_inference", + "nsa_prefill", + "nsa_decode", ] # Put some global args for easy access diff --git a/python/sglang/srt/mem_cache/allocator_ascend.py b/python/sglang/srt/mem_cache/allocator_ascend.py index 0bb1eaf0a5f..14fc1d1e362 100644 --- a/python/sglang/srt/mem_cache/allocator_ascend.py +++ b/python/sglang/srt/mem_cache/allocator_ascend.py @@ -76,35 +76,49 @@ def alloc_extend( (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) - num_new_pages = get_num_new_pages( - seq_lens=seq_lens_cpu, - page_size=self.page_size, - prefix_lens=prefix_lens_cpu, - ) - if self.need_sort and num_new_pages > len(self.free_pages): + num_new_pages = ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ).sum() + num_new_pages_item = num_new_pages.item() + if self.need_sort and num_new_pages_item > len(self.free_pages): self.merge_and_sort_free() - if num_new_pages > len(self.free_pages): + if num_new_pages_item > len(self.free_pages): return None out_indices = torch.empty( - (extend_num_tokens,), dtype=torch.int32, device=self.device + (extend_num_tokens,), dtype=torch.int64, device=self.device ) - alloc_extend_kernel_ascend( - prefix_lens, - seq_lens, - last_loc, - self.free_pages, - out_indices, - self.page_size, - self.device, - ) + if num_new_pages_item < 200: + import sgl_kernel_npu + + torch.ops.npu.alloc_extend( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + self.page_size, + out_indices, + num_new_pages, + ) + + else: + alloc_extend_kernel_ascend( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.page_size, + self.device, + ) if self.debug_mode: assert len(torch.unique(out_indices)) == len(out_indices) - self.free_pages = self.free_pages[num_new_pages:] + self.free_pages = self.free_pages[num_new_pages_item:] return out_indices def alloc_decode( diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5b0f8a7141c..11249ff7906 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -15,6 +15,8 @@ from __future__ import annotations +from sglang.srt.layers.attention.nsa import index_buf_accessor +from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter """ @@ -1030,6 +1032,8 @@ def __init__( enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, + use_nsa: bool = False, + override_kv_cache_dim: Optional[int] = None, ): super().__init__( size, @@ -1044,6 +1048,14 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim + self.use_nsa = use_nsa + self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn + # TODO do not hardcode + self.kv_cache_dim = ( + 656 + if self.use_nsa and self.nsa_kv_cache_store_fp8 + else (kv_lora_rank + qk_rope_head_dim) + ) # for disagg with nvlink self.enable_custom_mem_pool = get_bool_env_var( @@ -1067,7 +1079,7 @@ def __init__( # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ torch.zeros( - (size + page_size, 1, kv_lora_rank + qk_rope_head_dim), + (size + page_size, 1, self.kv_cache_dim), dtype=self.store_dtype, device=device, ) @@ -1130,6 +1142,7 @@ def set_kv_buffer( cache_v: torch.Tensor, ): layer_id = layer.layer_id + assert not (self.use_nsa and self.nsa_kv_cache_store_fp8) if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: @@ -1147,16 +1160,28 @@ def set_mla_kv_buffer( cache_k_rope: torch.Tensor, ): layer_id = layer.layer_id - if cache_k_nope.dtype != self.dtype: - cache_k_nope = cache_k_nope.to(self.dtype) - cache_k_rope = cache_k_rope.to(self.dtype) - if self.store_dtype != self.dtype: - cache_k_nope = cache_k_nope.view(self.store_dtype) - cache_k_rope = cache_k_rope.view(self.store_dtype) - set_mla_kv_buffer_triton( - self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope - ) + if self.use_nsa and self.nsa_kv_cache_store_fp8: + # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here + # TODO no need to cat + cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1) + cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1) + cache_k = cache_k.view(self.store_dtype) + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k + else: + if cache_k_nope.dtype != self.dtype: + cache_k_nope = cache_k_nope.to(self.dtype) + cache_k_rope = cache_k_rope.to(self.dtype) + if self.store_dtype != self.dtype: + cache_k_nope = cache_k_nope.view(self.store_dtype) + cache_k_rope = cache_k_rope.view(self.store_dtype) + + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope, + cache_k_rope, + ) def get_cpu_copy(self, indices): torch.cuda.synchronize() @@ -1186,6 +1211,103 @@ def load_cpu_copy(self, kv_cache_cpu, indices): torch.cuda.synchronize() +class NSATokenToKVPool(MLATokenToKVPool): + def __init__( + self, + size: int, + page_size: int, + kv_lora_rank: int, + dtype: torch.dtype, + qk_rope_head_dim: int, + layer_num: int, + device: str, + index_head_dim: int, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super().__init__( + size, + page_size, + dtype, + kv_lora_rank, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + use_nsa=True, + ) + # self.index_k_dtype = torch.float8_e4m3fn + # self.index_k_scale_dtype = torch.float32 + self.index_head_dim = index_head_dim + # num head == 1 and head dim == 128 for index_k in NSA + assert index_head_dim == 128 + + self.quant_block_size = 128 + + assert self.page_size == 64 + self.index_k_with_scale_buffer = [ + torch.zeros( + # Layout: + # ref: test_attention.py :: kv_cache_cast_to_fp8 + # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4) + # data: for page i, + # * buf[i, :page_size * head_dim] for fp8 data + # * buf[i, page_size * head_dim:].view(float32) for scale + ( + (size + page_size + 1) // self.page_size, + self.page_size + * (index_head_dim + index_head_dim // self.quant_block_size * 4), + ), + dtype=torch.uint8, + device=device, + ) + for _ in range(layer_num) + ] + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + return self.index_k_with_scale_buffer[layer_id - self.start_layer] + + def get_index_k_continuous( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ): + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + return index_buf_accessor.GetK.execute( + self, buf, seq_len=seq_len, page_indices=page_indices + ) + + def get_index_k_scale_continuous( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ): + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + return index_buf_accessor.GetS.execute( + self, buf, seq_len=seq_len, page_indices=page_indices + ) + + # TODO rename later (currently use diff name to avoid confusion) + def set_index_k_and_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + index_buf_accessor.SetKAndS.execute( + pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale + ) + + class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): def __init__( self, @@ -1194,6 +1316,7 @@ def __init__( dtype: torch.dtype, kv_lora_rank: int, qk_rope_head_dim: int, + index_head_dim: Optional[int], layer_num: int, device: str, enable_memory_saver: bool, @@ -1213,6 +1336,7 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim + self.index_head_dim = index_head_dim self.custom_mem_pool = None @@ -1240,6 +1364,18 @@ def __init__( dtype=self.store_dtype, device=self.device, ) + if self.index_head_dim is not None: + self.index_k_buffer = torch.zeros( + ( + layer_num, + self.size // self.page_size + 1, + self.page_size, + 1, + self.index_head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) self._finalize_allocation_log(size) @@ -1251,6 +1387,10 @@ def get_kv_size_bytes(self): kv_size_bytes += get_tensor_size_bytes(k_cache) for v_cache in self.v_buffer: kv_size_bytes += get_tensor_size_bytes(v_cache) + if self.index_head_dim is not None: + assert hasattr(self, "index_k_buffer") + for index_k_cache in self.index_k_buffer: + kv_size_bytes += get_tensor_size_bytes(index_k_cache) return kv_size_bytes def get_kv_buffer(self, layer_id: int): @@ -1277,6 +1417,14 @@ def get_value_buffer(self, layer_id: int): return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer] + def get_index_k_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + + if self.store_dtype != self.dtype: + return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype) + return self.index_k_buffer[layer_id - self.start_layer] + # for disagg def get_contiguous_buf_infos(self): # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. @@ -1289,6 +1437,16 @@ def get_contiguous_buf_infos(self): kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [ self.v_buffer[i][0].nbytes for i in range(self.layer_num) ] + if self.index_head_dim is not None: + kv_data_ptrs += [ + self.index_k_buffer[i].data_ptr() for i in range(self.layer_num) + ] + kv_data_lens += [ + self.index_k_buffer[i].nbytes for i in range(self.layer_num) + ] + kv_item_lens += [ + self.index_k_buffer[i][0].nbytes for i in range(self.layer_num) + ] return kv_data_ptrs, kv_data_lens, kv_item_lens def set_kv_buffer( @@ -1325,6 +1483,26 @@ def set_kv_buffer( cache_v.view(-1, 1, self.qk_rope_head_dim), ) + def set_index_k_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + ): + if index_k.dtype != self.dtype: + index_k = index_k.to(self.dtype) + + if self.store_dtype != self.dtype: + index_k = index_k.view(self.store_dtype) + + torch_npu.npu_scatter_nd_update_( + self.index_k_buffer[layer_id - self.start_layer].view( + -1, 1, self.index_head_dim + ), + loc.view(-1, 1), + index_k.view(-1, 1, self.index_head_dim), + ) + class DoubleSparseTokenToKVPool(KVCache): def __init__( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 3ad2f450c88..287efd1decd 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -522,6 +522,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] + seq_lens_cpu = self.seq_lens_cpu[:bs] out_cache_loc = self.out_cache_loc[:num_tokens] positions = self.positions[:num_tokens] if self.is_encoder_decoder: @@ -592,6 +593,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, next_token_logits_buffer=next_token_logits_buffer, orig_seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fce792a0432..4309f52118e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -293,6 +293,7 @@ class ForwardBatch: # For padding padded_static_len: int = -1 # -1 if not padded num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor + num_token_non_padded_cpu: int = None # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -354,6 +355,7 @@ def init_new( ret.num_token_non_padded = torch.tensor( len(batch.input_ids), dtype=torch.int32 ).to(device, non_blocking=True) + ret.num_token_non_padded_cpu = len(batch.input_ids) # For MLP sync if batch.global_num_tokens is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2d87ec6f6c9..92b5dfa0a52 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -31,7 +31,12 @@ from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat -from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.configs.model_config import ( + AttentionArch, + ModelConfig, + get_nsa_index_head_dim, + is_deepseek_nsa, +) from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( @@ -96,6 +101,7 @@ HybridReqToTokenPool, MHATokenToKVPool, MLATokenToKVPool, + NSATokenToKVPool, ReqToTokenPool, SWAKVPool, ) @@ -157,6 +163,7 @@ "cutlass_mla", "trtllm_mla", "ascend", + "nsa", ] @@ -1547,6 +1554,7 @@ def init_memory_pool( assert self.is_draft_worker # Initialize token_to_kv_pool + is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) if self.server_args.attention_backend == "ascend": if self.use_mla_backend: self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( @@ -1555,6 +1563,7 @@ def init_memory_pool( dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, + index_head_dim=self.model_config.index_head_dim, layer_num=self.num_effective_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, @@ -1574,7 +1583,22 @@ def init_memory_pool( device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) + elif self.use_mla_backend and is_nsa_model: + self.token_to_kv_pool = NSATokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), + ) elif self.use_mla_backend: + assert not is_nsa_model self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, page_size=self.page_size, diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index d7619b2d7bc..67a31c62f92 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -75,11 +75,16 @@ def replay( self.positions[: self.raw_num_token].copy_(forward_batch.positions) # Replay - seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs) - thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) - thread.start() - self.graphs[self.bs].replay() - thread.join() + if self.model_runner.model_config.index_head_dim is None: + seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( + self.bs - self.raw_bs + ) + thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) + thread.start() + self.graphs[self.bs].replay() + thread.join() + else: + self.graphs[self.bs].replay() output = self.output_buffers[self.bs] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 336a5b68cea..8877fe60222 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -15,6 +15,7 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py """Inference-only DeepseekV2 model.""" +from __future__ import annotations import concurrent.futures import logging @@ -25,10 +26,16 @@ import torch import torch.nn.functional as F from torch import nn -from tqdm import tqdm from transformers import PretrainedConfig from sglang.srt import single_batch_overlap +from sglang.srt.configs.model_config import ( + get_nsa_index_head_dim, + get_nsa_index_n_heads, + get_nsa_index_topk, + is_deepseek_nsa, +) +from sglang.srt.debug_utils.dumper import dumper from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_pp_group, @@ -48,6 +55,7 @@ NPUFusedMLAPreprocess, is_mla_preprocess_enabled, ) +from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer from sglang.srt.layers.communicator import ( LayerCommunicator, LayerScatterModes, @@ -172,10 +180,13 @@ from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_triton as awq_dequantize, ) +elif _is_npu: + import custom_ops + import sgl_kernel_npu + import torch_npu else: pass - _is_flashinfer_available = is_flashinfer_available() _is_sm100_supported = is_cuda() and is_sm100_supported() @@ -184,6 +195,7 @@ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [ "fa3", + "nsa", "flashinfer", "cutlass_mla", "trtllm_mla", @@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum): # Use absorbed multi-latent attention MLA = auto() + # Use Deepseek V3.2 sparse multi-latent attention + NPU_MLA_SPARSE = auto() + # Use multi-head attention, but with KV cache chunked. # This method can avoid OOM when prefix lengths are long. MHA_CHUNKED_KV = auto() @@ -246,9 +261,15 @@ def handle_attention_ascend(attn, forward_batch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend() ): - return AttnForwardMethod.MHA + if hasattr(attn, "indexer"): + return AttnForwardMethod.NPU_MLA_SPARSE + else: + return AttnForwardMethod.MHA else: - return AttnForwardMethod.MLA + if hasattr(attn, "indexer"): + return AttnForwardMethod.NPU_MLA_SPARSE + else: + return AttnForwardMethod.MLA def _get_sum_extend_prefix_lens(forward_batch): @@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch): ) -def _handle_attention_backend(attn, forward_batch, backend_name): +def _handle_attention_backend( + attn: DeepseekV2AttentionMLA, forward_batch, backend_name +): sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) disable_ragged = ( backend_name in ["flashinfer", "flashmla"] @@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch): return AttnForwardMethod.MLA +def handle_attention_nsa(attn, forward_batch): + return AttnForwardMethod.MLA + + def handle_attention_triton(attn, forward_batch): if ( _is_extend_without_speculative(forward_batch) @@ -1005,6 +1032,10 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + # For tensor parallel attention if self.q_lora_rank is not None: self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( @@ -1042,6 +1073,26 @@ def __init__( prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) + self.use_nsa = is_deepseek_nsa(config) + if self.use_nsa: + self.indexer = Indexer( + hidden_size=hidden_size, + index_n_heads=get_nsa_index_n_heads(config), + index_head_dim=get_nsa_index_head_dim(config), + rope_head_dim=qk_rope_head_dim, + index_topk=get_nsa_index_topk(config), + q_lora_rank=q_lora_rank, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + scale_fmt="ue8m0", + block_size=128, + rope_scaling=rope_scaling, + prefix=add_prefix("indexer", prefix), + quant_config=quant_config, + layer_id=layer_id, + alt_stream=alt_stream, + ) + self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1064,9 +1115,6 @@ def __init__( ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - if rope_scaling: - rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -1193,8 +1241,8 @@ def __init__( self.is_mla_preprocess_enabled = is_mla_preprocess_enabled() if self.is_mla_preprocess_enabled: assert ( - quant_config.get_name() == "w8a8_int8" - ), "MLA Preprocess only works with W8A8Int8" + quant_config is None or quant_config.get_name() == "w8a8_int8" + ), "MLA Preprocess only works with Unquant or W8A8Int8" self.mla_preprocess = None def dispatch_attn_forward_method( @@ -1272,7 +1320,6 @@ def forward_prepare( return hidden_states, None, forward_batch, None attn_forward_method = self.dispatch_attn_forward_method(forward_batch) - if attn_forward_method == AttnForwardMethod.MHA: inner_state = self.forward_normal_prepare( positions, hidden_states, forward_batch, zero_allocator @@ -1304,6 +1351,10 @@ def forward_prepare( inner_state = self.mla_preprocess.forward( positions, hidden_states, forward_batch, zero_allocator ) + elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE: + inner_state = self.forward_npu_sparse_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: inner_state = self.forward_absorb_fused_mla_rope_prepare( positions, hidden_states, forward_batch, zero_allocator @@ -1329,6 +1380,8 @@ def forward_core(self, intermediate_state): return self.forward_normal_chunked_kv_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MLA: return self.forward_absorb_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE: + return self.forward_npu_sparse_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: return self.forward_absorb_fused_mla_rope_core(*inner_state) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU: @@ -1424,6 +1477,7 @@ def forward_absorb_prepare( ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + q_lora = None if self.q_lora_rank is not None: if ( (not isinstance(hidden_states, tuple)) @@ -1462,6 +1516,10 @@ def forward_absorb_prepare( q = self.q_a_layernorm(q) k_nope = self.kv_a_layernorm(k_nope) + # q_lora needed by indexer + if self.use_nsa: + q_lora = q + k_nope = k_nope.unsqueeze(1) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: @@ -1527,14 +1585,41 @@ def forward_absorb_prepare( q_nope_out = q_nope_out.transpose(0, 1) if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( - not _use_aiter or not _is_gfx95_supported + not _use_aiter or not _is_gfx95_supported or self.use_nsa ): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions + topk_indices = None + if q_lora is not None: + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_id, + ) + + return ( + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, + ) def forward_absorb_core( - self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions + self, + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, ): if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} @@ -1543,6 +1628,7 @@ def forward_absorb_core( "cos_sin_cache": self.rotary_emb.cos_sin_cache, "is_neox": self.rotary_emb.is_neox_style, } + attn_output = self.attn_mqa( q_nope_out, k_nope, @@ -1551,6 +1637,7 @@ def forward_absorb_core( q_rope=q_pe, k_rope=k_pe, **extra_args, + **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) else: if _use_aiter_gfx95: @@ -1570,7 +1657,13 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) - attn_output = self.attn_mqa(q, k, k_nope, forward_batch) + attn_output = self.attn_mqa( + q, + k, + k_nope, + forward_batch, + **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), + ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.use_deep_gemm_bmm: @@ -1652,6 +1745,221 @@ def forward_absorb_core( return output + def forward_npu_sparse_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, + ): + """ + Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare + """ + if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode(): + if self.mla_preprocess is None: + self.mla_preprocess = NPUFusedMLAPreprocess( + self.fused_qkv_a_proj_with_mqa, + self.q_a_layernorm, + self.kv_a_layernorm, + self.q_b_proj, + self.w_kc, + self.rotary_emb, + self.layer_id, + self.num_local_heads, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + ) + ( + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + ) = self.mla_preprocess.forward( + positions, hidden_states, forward_batch, zero_allocator + ) + + fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0] + q, _ = fused_qkv_a_proj_out.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q_lora = self.q_a_layernorm(q) + else: + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if ( + (not isinstance(hidden_states, tuple)) + and hidden_states.shape[0] <= 16 + and self.use_min_latency_fused_a_gemm + ): + fused_qkv_a_proj_out = dsv3_fused_a_gemm( + hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T + ) + else: + fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0] + q, latent_cache = fused_qkv_a_proj_out.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + k_nope = latent_cache[..., : self.kv_lora_rank] + + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q = self.q_a_layernorm(q) + with torch.cuda.stream(self.alt_stream): + k_nope = self.kv_a_layernorm(k_nope) + current_stream.wait_stream(self.alt_stream) + else: + if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: + q, k_nope = fused_rms_mxfp4_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + ) + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) + + q_lora = q.clone() # required for topk_indices + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) + + if self.use_deep_gemm_bmm: + q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8( + q_nope.transpose(0, 1) + ) + ) + q_nope_out = q_nope.new_empty( + (self.num_local_heads, aligned_m, self.kv_lora_rank) + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (q_nope_val, q_nope_scale), + (self.w_kc, self.w_scale_k), + q_nope_out, + masked_m, + expected_m, + ) + q_nope_out = q_nope_out[:, :expected_m, :] + elif _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm + if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: + x = q_nope.transpose(0, 1) + q_nope_out = torch.empty( + x.shape[0], + x.shape[1], + self.w_kc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_kc.transpose(-2, -1), + self.w_scale_k.transpose(-2, -1), + torch.bfloat16, + q_nope_out, + ) + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_kc.dtype == torch.float8_e4m3fn: + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), + zero_allocator.allocate(1), + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + + q_nope_out = q_nope_out.transpose(0, 1) + + if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( + not _use_aiter or not _is_gfx95_supported + ): + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + # TODO: multi-stream indexer + topk_indices = self.indexer( + hidden_states, q_lora, positions, forward_batch, self.layer_id + ) + + return ( + q_pe, + k_pe, + q_nope_out, + k_nope, + topk_indices, + forward_batch, + zero_allocator, + positions, + ) + + def forward_npu_sparse_core( + self, + q_pe, + k_pe, + q_nope_out, + k_nope, + topk_indices, + forward_batch, + zero_allocator, + positions, + ): + attn_output = self.attn_mqa( + q_nope_out.contiguous(), + k_nope.contiguous(), + k_nope.contiguous(), + forward_batch, + save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True, + q_rope=q_pe.contiguous(), + k_rope=k_pe.contiguous(), + topk_indices=topk_indices, + ) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + + attn_bmm_output = torch.empty( + (attn_output.shape[0], self.num_local_heads, self.v_head_dim), + dtype=attn_output.dtype, + device=attn_output.device, + ) + + if not forward_batch.forward_mode.is_decode(): + attn_output = attn_output.transpose(0, 1) + torch.bmm( + attn_output, + self.w_vc, + out=attn_bmm_output.view( + -1, self.num_local_heads, self.v_head_dim + ).transpose(0, 1), + ) + else: + attn_output = attn_output.contiguous() + torch.ops.npu.batch_matmul_transpose( + attn_output, self.w_vc, attn_bmm_output + ) + + attn_bmm_output = attn_bmm_output.reshape( + -1, self.num_local_heads * self.v_head_dim + ) + + output, _ = self.o_proj(attn_bmm_output) + return output + def forward_absorb_fused_mla_rope_prepare( self, positions: torch.Tensor, @@ -2134,7 +2442,6 @@ def forward( zero_allocator: BumpAllocator, gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: - quant_format = ( "mxfp4" if _is_gfx95_supported @@ -3099,6 +3406,7 @@ def get_model_config_for_expert_location(cls, config): AttentionBackendRegistry.register("fa4", handle_attention_fa4) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) AttentionBackendRegistry.register("aiter", handle_attention_aiter) +AttentionBackendRegistry.register("nsa", handle_attention_nsa) AttentionBackendRegistry.register("triton", handle_attention_triton) @@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass -EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] +class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM): + pass + + +EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dfa4d8e8fb6..b27afcdaa4c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -91,6 +91,7 @@ "triton", "torch_native", "flex_attention", + "nsa", # NVIDIA specific "cutlass_mla", "fa3", @@ -116,6 +117,8 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] +NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"] + RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] @@ -284,6 +287,8 @@ class ServerArgs: sampling_backend: Optional[str] = None grammar_backend: Optional[str] = None mm_attention_backend: Optional[str] = None + nsa_prefill: str = "flashmla_prefill" + nsa_decode: str = "fa3" # Speculative decoding speculative_algorithm: Optional[str] = None @@ -719,6 +724,8 @@ def _handle_cpu_backends(self): self.sampling_backend = "pytorch" def _handle_model_specific_adjustments(self): + from sglang.srt.configs.model_config import is_deepseek_nsa + if parse_connector_type(self.model_path) == ConnectorType.INSTANCE: return @@ -796,6 +803,48 @@ def _handle_model_specific_adjustments(self): ) self.disable_hybrid_swa_memory = True + if is_deepseek_nsa(hf_config): + if ( + self.attention_backend is None + and self.prefill_attention_backend is None + and self.decode_attention_backend is None + ): + self.attention_backend = "nsa" + logger.warning("Set nsa attention backend for DeepSeek NSA.") + + if not is_npu(): + self.enable_dp_attention = True + self.dp_size = self.tp_size + logger.warning("DP attention is enabled for DeepSeek NSA.") + + self.page_size = 64 + logger.warning("Setting page size to 64 for DeepSeek NSA.") + + self.mem_fraction_static = 0.8 + logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.") + + # For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently + import torch + + major, _ = torch.cuda.get_device_capability() + if major >= 10: + self.kv_cache_dtype = "fp8_e4m3" + logger.warning("Setting KV cache dtype to fp8.") + + if self.kv_cache_dtype == "fp8_e4m3": + self.nsa_prefill = "flashmla_decode" + self.nsa_decode = "flashmla_decode" + logger.warning( + "Setting NSA backend to flashmla_decode for FP8 KV Cache." + ) + + # Logging env vars for NSA + from sglang.srt.layers.attention.nsa.utils import ( + print_nsa_bool_env_vars, + ) + + print_nsa_bool_env_vars() + def _handle_sampling_backend(self): if self.sampling_backend is None: self.sampling_backend = ( @@ -1023,6 +1072,7 @@ def _handle_speculative_decoding(self): model_arch = self.get_hf_config().architectures[0] if model_arch in [ + "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", "BailingMoeForCausalLM", @@ -1974,6 +2024,18 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.mm_attention_backend, help="Set multimodal attention backend.", ) + parser.add_argument( + "--nsa-prefill", + default=ServerArgs.nsa_prefill, + type=str, + choices=NSA_CHOICES, + ) + parser.add_argument( + "--nsa-decode", + default=ServerArgs.nsa_decode, + type=str, + choices=NSA_CHOICES, + ) # Speculative decoding parser.add_argument( @@ -3251,6 +3313,7 @@ def auto_choose_speculative_params(self: ServerArgs): # The default value for llama return (5, 4, 8) elif arch in [ + "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", "GptOssForCausalLM", diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index d67636aa433..49205d4c99d 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -705,6 +705,8 @@ def filter_batch( extend_num_tokens=extend_num_tokens, attn_backend=output_attn_backend, num_token_non_padded=out_num_token_non_padded, + # TODO: handle it when we need TBO + DeepSeek V3.2 + num_token_non_padded_cpu=None, tbo_split_seq_index=None, tbo_parent_token_range=(start_token_index, end_token_index), tbo_children=None, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 52f7a0fff32..0ab2783c359 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -471,7 +471,7 @@ def is_pin_memory_available() -> bool: class LayerFn(Protocol): - def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... + def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ... def make_layers( @@ -482,7 +482,7 @@ def make_layers( prefix: str = "", return_tuple: bool = False, offloader_kwargs: Dict[str, Any] = {}, -) -> Tuple[int, int, torch.nn.ModuleList]: +) -> Tuple[torch.nn.Module, int, int]: """Make a list of layers with the given layer function""" # circula imports from sglang.srt.distributed import get_pp_indices diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 68b0c4534ac..75e0a8b7559 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -123,6 +123,38 @@ def get_hf_text_config(config: PretrainedConfig): return config +# Temporary hack for DeepSeek-V3.2 model +def _load_deepseek_v32_model( + model_path: str, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +): + # first get the local path + local_path = download_from_hf(model_path) + # then load the config file in json + config_file = os.path.join(local_path, "config.json") + if not os.path.exists(config_file): + raise RuntimeError(f"Can't find config file in {local_path}.") + + with open(config_file, "r") as f: + config_json = json.load(f) + + config_json["architectures"] = ["DeepseekV3ForCausalLM"] + config_json["model_type"] = "deepseek_v3" + + tmp_path = os.path.join(local_path, "_tmp_config_folder") + os.makedirs(tmp_path, exist_ok=True) + + unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}") + with open(unique_path, "w") as f: + json.dump(config_json, f) + + return AutoConfig.from_pretrained( + unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + + @lru_cache_frozenset(maxsize=32) def get_config( model: str, @@ -144,9 +176,17 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + except ValueError as e: + if not "deepseek_v32" in str(e): + raise e + config = _load_deepseek_v32_model( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + if ( config.architectures is not None and config.architectures[0] == "Phi4MMForCausalLM" diff --git a/python/sglang/test/get_logits_ut.py b/python/sglang/test/get_logits_ut.py new file mode 100644 index 00000000000..17edf8a4f2a --- /dev/null +++ b/python/sglang/test/get_logits_ut.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn + + +class DummyModel(nn.Module): + def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5): + super().__init__() + self.weights_proj = nn.Linear(d_in, 1024) + self.n_heads = n_heads + self.softmax_scale = softmax_scale + + def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor): + weights = self.weights_proj(x) + weights = weights * self.n_heads**-0.5 + q_scale = q_scale.unsqueeze(1) # (B,1,1) + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + return weights + + def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor): + weights = self.weights_proj(x) + q_scale = q_scale.unsqueeze(1) # (B,1,1) + scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1) + weights = weights.unsqueeze(-1) * scale_const # (B,1024,1) + return weights + + +def main(): + torch.manual_seed(0) + model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5) + x = torch.randn(128, 2048) # batch=128, d_in=2048 + q_scale = torch.randn(128, 1) + + import time + + start = time.time() + for _ in range(1000): + out_orig = model._get_logits_head_gate_orig(x, q_scale) + print("Original version time:", time.time() - start) + + start = time.time() + for _ in range(1000): + out_opt = model._get_logits_head_gate_opt(x, q_scale) + print("Optimized version time:", time.time() - start) + + print("Difference:", (out_orig - out_opt).abs().max().item()) + assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized" + + +if __name__ == "__main__": + main() + + +""" +Original version time: 0.49235057830810547 +Optimized version time: 0.4087331295013428 +Difference: 1.4901161193847656e-08 +"""