diff --git a/README.md b/README.md
index eb468f33b..9d1a9fbd8 100644
--- a/README.md
+++ b/README.md
@@ -2,24 +2,21 @@
-
-
-
-# Nano-vLLM
-
-A lightweight vLLM implementation built from scratch.
+# Nano-vLLM-kv-compression
+An improved implementation based on Nano-vLLM featuring int8 KV cache compression, head-major memory layout for coalesced access, and asynchronous stream pipelining that hides KV store latency behind attention computation.
## Key Features
-* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
-* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
-* âš¡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
+* ⚡ **Int8 KV Cache Compression** — 50% memory reduction via dynamic per-head quantization
+* 🔄 **Coalesced Layout** — Head-major reordering for warp-level memory coalescing
+* 🎯 **GQA-Optimized Flash Attention** — group Q-head CTA mapping eliminates redundant KV loads
+* 🔗 **Async KV Store Pipeline** — Multi-stream architecture overlaps KV quantization and cache writeback with attention computation
## Installation
```bash
-pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
+pip install git+https://github.com/naalo2/nano-vLLM-kv-compression.git
```
## Model Download
@@ -48,19 +45,16 @@ outputs[0]["text"]
See `bench.py` for benchmark.
**Test Configuration:**
-- Hardware: RTX 4070 Laptop (8GB)
+- Hardware: RTX 3090 (24GB)
- Model: Qwen3-0.6B
- Total Requests: 256 sequences
- Input Length: Randomly sampled between 100–1024 tokens
- Output Length: Randomly sampled between 100–1024 tokens
**Performance Results:**
-| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
-|----------------|-------------|----------|-----------------------|
-| vLLM | 133,966 | 98.37 | 1361.84 |
-| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
-
+| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
+|---------------------------|---------------|----------|-----------------------|
+| Nano-vLLM | 133,966 | 33.05 | 4052.56 |
+| Nano-vLLM-kv-compression | 133,966 | 27.00 | 4962.21 |
-## Star History
-[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
\ No newline at end of file
diff --git a/assets/logo.png b/assets/logo.png
index ac0b8fd6e..b2c41996d 100644
Binary files a/assets/logo.png and b/assets/logo.png differ
diff --git a/bench.py b/bench.py
index 8e61d6545..3468a7fdb 100644
--- a/bench.py
+++ b/bench.py
@@ -2,6 +2,8 @@
import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
+from nanovllm.config import Config
+
# from vllm import LLM, SamplingParams
@@ -10,18 +12,22 @@ def main():
num_seqs = 256
max_input_len = 1024
max_ouput_len = 1024
-
- path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
- llm = LLM(path, enforce_eager=False, max_model_len=4096)
+ enforce_eager = False
+ print(f"use cuda graph:{not enforce_eager}")
+ path = os.path.expanduser("/YOUR/MODEL/PATH")
+ llm = LLM(path, enforce_eager=enforce_eager, max_model_len=max(4096, max_ouput_len+max_input_len))
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
+ # prompt_token_ids = [[randint(0, 10000) for _ in range(max_input_len)] for _ in range(num_seqs)]
+ # sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_ouput_len) for _ in range(num_seqs)]
+
# uncomment the following line for vllm
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
llm.generate(["Benchmark: "], SamplingParams())
t = time.time()
- llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=True)
t = (time.time() - t)
total_tokens = sum(sp.max_tokens for sp in sampling_params)
throughput = total_tokens / t
diff --git a/nanovllm/config.py b/nanovllm/config.py
index 7066cbeb5..3e40e37f6 100644
--- a/nanovllm/config.py
+++ b/nanovllm/config.py
@@ -1,4 +1,5 @@
import os
+import torch
from dataclasses import dataclass
from transformers import AutoConfig
@@ -16,6 +17,10 @@ class Config:
eos: int = -1
kvcache_block_size: int = 256
num_kvcache_blocks: int = -1
+ kv_quant: bool = True
+ kvcache_quant_dtype = torch.int8
+ kvscale_dtype = torch.bfloat16
+ group_num: int = 2
def __post_init__(self):
assert os.path.isdir(self.model)
@@ -23,3 +28,4 @@ def __post_init__(self):
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
+ assert self.max_num_batched_tokens >= self.max_model_len
diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py
index 71d9883c9..d4eba1204 100644
--- a/nanovllm/engine/model_runner.py
+++ b/nanovllm/engine/model_runner.py
@@ -17,6 +17,7 @@ class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
+ Sequence.set_block_size(config.kvcache_block_size)
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
@@ -26,12 +27,18 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
- torch.set_default_dtype(hf_config.dtype)
+ torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
+ if self.config.kv_quant:
+ self.kv_store_stream = torch.cuda.Stream()
+ self.kv_store_event = torch.cuda.Event()
+ self.has_pedding = False
self.warmup_model()
+ if self.config.kv_quant:
+ self.has_pedding = False
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
@@ -53,6 +60,7 @@ def exit(self):
dist.barrier()
if self.rank == 0:
self.shm.unlink()
+
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
@@ -92,11 +100,8 @@ def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
- seq_len = min(max_num_batched_tokens, max_model_len)
- num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
- seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
- for seq in seqs:
- seq.num_scheduled_tokens = seq_len
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
@@ -109,16 +114,49 @@ def allocate_kv_cache(self):
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
- block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
- config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
- assert config.num_kvcache_blocks > 0
- self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
- layer_id = 0
- for module in self.model.modules():
- if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
- module.k_cache = self.kv_cache[0, layer_id]
- module.v_cache = self.kv_cache[1, layer_id]
- layer_id += 1
+ available_bytes = int(total * config.gpu_memory_utilization - used - peak + current)
+ if config.kv_quant:
+ kv_bytes = config.kvcache_quant_dtype.itemsize
+ scale_bytes = config.kvscale_dtype.itemsize
+ block_bytes = (
+ 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads *
+ (head_dim * kv_bytes + scale_bytes)
+ )
+ config.num_kvcache_blocks = available_bytes // block_bytes
+ assert config.num_kvcache_blocks > 0
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks,
+ self.block_size, head_dim, dtype=config.kvcache_quant_dtype)
+ self.kv_scale_cache = torch.empty(2,hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks,
+ self.block_size, dtype=config.kvscale_dtype,)
+ num_slots = config.num_kvcache_blocks * self.block_size
+ layer_id = 0
+ for module in self.model.modules():
+ if (hasattr(module, "k_cache")
+ and hasattr(module, "v_cache")):
+ module.k_cache = self.kv_cache[0, layer_id].reshape(num_kv_heads, num_slots, head_dim)
+ module.v_cache = self.kv_cache[1, layer_id].reshape(num_kv_heads, num_slots, head_dim)
+ module.k_scale_cache = self.kv_scale_cache[0, layer_id].reshape(num_kv_heads, num_slots)
+ module.v_scale_cache = self.kv_scale_cache[1, layer_id].reshape(num_kv_heads, num_slots)
+ module.block_size = self.block_size
+ module.kv_quant = True
+ module.GROUP_NUM = config.group_num
+ module.kv_store_stream = self.kv_store_stream
+ module.kv_store_event = self.kv_store_event
+ layer_id += 1
+ else:
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
+ config.num_kvcache_blocks = available_bytes // block_bytes
+ assert config.num_kvcache_blocks > 0
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks,
+ self.block_size, num_kv_heads, head_dim, )
+ layer_id = 0
+ for module in self.model.modules():
+ if (hasattr(module, "k_cache")
+ and hasattr(module, "v_cache")):
+ module.k_cache = self.kv_cache[0, layer_id]
+ module.v_cache = self.kv_cache[1, layer_id]
+ module.kv_quant = False
+ layer_id += 1
def prepare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs)
@@ -131,42 +169,41 @@ def prepare_prefill(self, seqs: list[Sequence]):
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
+ context_lens = []
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
- start = seq.num_cached_tokens
- seqlen_q = seq.num_scheduled_tokens
- end = start + seqlen_q
- seqlen_k = end
- input_ids.extend(seq[start:end])
- positions.extend(range(start, end))
+ seqlen = len(seq)
+ context_lens.append(seqlen)
+ input_ids.extend(seq[seq.num_cached_tokens:])
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
+ seqlen_q = seqlen - seq.num_cached_tokens
+ seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
- if not seq.block_table: # warmup
+ if not seq.block_table:
continue
- start_block = start // self.block_size
- end_block = (end + self.block_size - 1) // self.block_size
- for i in range(start_block, end_block):
- slot_start = seq.block_table[i] * self.block_size
- if i == start_block:
- slot_start += start % self.block_size
- if i != end_block - 1:
- slot_end = seq.block_table[i] * self.block_size + self.block_size
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
+ start = seq.block_table[i] * self.block_size
+ if i != seq.num_blocks - 1:
+ end = start + self.block_size
else:
- slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
- slot_mapping.extend(range(slot_start, slot_end))
- if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
+ end = start + seq.last_block_num_tokens
+ slot_mapping.extend(list(range(start, end)))
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
+ if self.config.kv_quant:
+ context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
- set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -179,21 +216,28 @@ def prepare_decode(self, seqs: list[Sequence]):
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
+ max_context_len = max(context_lens) if context_lens else 0
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
- set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens,
+ block_tables=block_tables, max_context_len=max_context_len)
return input_ids, positions
def prepare_sample(self, seqs: list[Sequence]):
- temperatures = [seq.temperature for seq in seqs]
+ temperatures = []
+ for seq in seqs:
+ temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
+ if self.config.kv_quant and getattr(self, "has_pedding", True):
+ torch.cuda.current_stream().wait_event(self.kv_store_event)
+ self.has_pedding = False
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else:
@@ -215,6 +259,8 @@ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
+ if self.config.kv_quant:
+ self.has_pedding = True
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
@@ -235,12 +281,16 @@ def capture_cudagraph(self):
self.graphs = {}
self.graph_pool = None
+ if self.config.kv_quant:
+ self.capture_quant(slot_mapping, context_lens, block_tables, outputs, input_ids, positions)
+ return
+
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
- outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
with torch.cuda.graph(graph, self.graph_pool):
- outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
@@ -255,3 +305,34 @@ def capture_cudagraph(self):
block_tables=block_tables,
outputs=outputs,
)
+
+ def capture_quant(self, slot_mapping, context_lens, block_tables, outputs, input_ids, positions):
+ main_stream = torch.cuda.current_stream()
+ for bs in reversed(self.graph_bs):
+ graph = torch.cuda.CUDAGraph()
+ self.kv_store_stream.synchronize()
+ set_context(
+ False,
+ slot_mapping=slot_mapping[:bs],
+ context_lens=context_lens[:bs],
+ block_tables=block_tables[:bs],
+ )
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
+ main_stream.wait_event(self.kv_store_event)
+ torch.cuda.synchronize()
+ with torch.cuda.graph(graph, self.graph_pool):
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
+ torch.cuda.current_stream().wait_event(self.kv_store_event)
+ if self.graph_pool is None:
+ self.graph_pool = graph.pool()
+ self.graphs[bs] = graph
+ torch.cuda.synchronize()
+ reset_context()
+ self.graph_vars = dict(
+ input_ids=input_ids,
+ positions=positions,
+ slot_mapping=slot_mapping,
+ context_lens=context_lens,
+ block_tables=block_tables,
+ outputs=outputs,
+ )
\ No newline at end of file
diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py
index 4decfce5f..316888828 100644
--- a/nanovllm/engine/sequence.py
+++ b/nanovllm/engine/sequence.py
@@ -12,9 +12,13 @@ class SequenceStatus(Enum):
class Sequence:
- block_size = 256
+ block_size: int = 0 # invalid value, will be set by set_block_size
counter = count()
+ @classmethod
+ def set_block_size(cls, block_size: int):
+ cls.block_size = block_size
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING
@@ -23,12 +27,11 @@ def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids)
self.num_cached_tokens = 0
- self.num_scheduled_tokens = 0
- self.is_prefill = True
self.block_table = []
self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
+ self.prefilled = False
def __len__(self):
return self.num_tokens
@@ -52,6 +55,10 @@ def prompt_token_ids(self):
def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:]
+ @property
+ def num_cached_blocks(self):
+ return self.num_cached_tokens // self.block_size
+
@property
def num_blocks(self):
return (self.num_tokens + self.block_size - 1) // self.block_size
@@ -68,16 +75,15 @@ def append_token(self, token_id: int):
self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
+ self.prefilled = True
def __getstate__(self):
- last_state = self.last_token if not self.is_prefill else self.token_ids
- return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state)
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled,
+ self.last_token if self.prefilled else self.token_ids)
def __setstate__(self, state):
- self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state
- if isinstance(last_state, list):
- self.token_ids = last_state
- self.last_token = self.token_ids[-1]
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1]
+ if self.prefilled:
+ self.last_token = state[-1]
else:
- self.token_ids = []
- self.last_token = last_state
+ self.token_ids = state[-1]
diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py
index e416139ea..cb07826bf 100644
--- a/nanovllm/layers/attention.py
+++ b/nanovllm/layers/attention.py
@@ -2,8 +2,9 @@
from torch import nn
import triton
import triton.language as tl
-
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
+from nanovllm.tools.quant_attn_kvhead_based import decode_attn_quantkv_direct, \
+ prefill_attn_quantkv_direct
from nanovllm.utils.context import get_context
@@ -40,6 +41,101 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor,
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
+@triton.jit
+def _store_quantkv_kernel(
+ key_ptr,
+ key_stride_b,
+ key_stride_h,
+ key_stride_d,
+ value_ptr,
+ value_stride_b,
+ value_stride_h,
+ value_stride_d,
+ k_cache_ptr,
+ v_cache_ptr,
+ k_scale_ptr,
+ v_scale_ptr,
+ slot_mapping_ptr,
+ stride_cache_h,
+ stride_cache_s,
+ stride_cache_d,
+ stride_scale_h,
+ stride_scale_s,
+ EPS: tl.constexpr,
+ BLOCK_H: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ TARGET_DTYPE: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ slot = tl.load(slot_mapping_ptr + token_idx)
+ if slot == -1:
+ return
+ offs_head = tl.arange(0, BLOCK_H)
+ offs_dim = tl.arange(0, BLOCK_D)
+ # k
+ kv_offsets = token_idx * key_stride_b + offs_head[:, None] * key_stride_h + offs_dim[None, :] * key_stride_d
+ kv = tl.load(key_ptr + kv_offsets)
+ kv_abs_max = tl.max(tl.abs(kv), axis=1)
+ kv_scale = tl.maximum(kv_abs_max / 127.0, EPS).to(TARGET_DTYPE)
+ kv = tl.extra.cuda.libdevice.llrint(kv / kv_scale[:, None])
+ kv = tl.maximum(tl.minimum(kv, 127), -127).to(tl.int8)
+ cache_offsets = offs_head[:, None] * stride_cache_h + slot * stride_cache_s + offs_dim[None, :] * stride_cache_d
+ tl.store(k_cache_ptr + cache_offsets, kv)
+ scale_offsets = offs_head * stride_scale_h + slot * stride_scale_s
+ tl.store(k_scale_ptr + scale_offsets, kv_scale)
+ # v
+ kv_offsets = token_idx * value_stride_b + offs_head[:, None] * value_stride_h + offs_dim[None, :] * value_stride_d
+ kv = tl.load(value_ptr + kv_offsets)
+ kv_abs_max = tl.max(tl.abs(kv), axis=1)
+ kv_scale = tl.maximum(kv_abs_max/127.0, EPS).to(TARGET_DTYPE)
+ kv = tl.extra.cuda.libdevice.llrint(kv / kv_scale[:, None])
+ kv = tl.maximum(tl.minimum(kv, 127), -127).to(tl.int8)
+ tl.store(v_cache_ptr + cache_offsets, kv)
+ tl.store(v_scale_ptr + scale_offsets, kv_scale)
+
+def store_quantkv(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ k_scale_cache: torch.Tensor,
+ v_scale_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+):
+ n, num_kv_heads, head_dim = key.shape
+ if key.dtype == torch.float16:
+ target_dtype = tl.float16
+ elif key.dtype == torch.bfloat16:
+ target_dtype = tl.bfloat16
+ else:
+ target_dtype = tl.float32
+ _store_quantkv_kernel[(n,)](
+ key,
+ key.stride(0),
+ key.stride(1),
+ key.stride(2),
+ value,
+ value.stride(0),
+ value.stride(1),
+ value.stride(2),
+ k_cache,
+ v_cache,
+ k_scale_cache,
+ v_scale_cache,
+ slot_mapping,
+ k_cache.stride(0),
+ k_cache.stride(1),
+ k_cache.stride(2),
+ k_scale_cache.stride(0),
+ k_scale_cache.stride(1),
+ EPS=1e-8,
+ BLOCK_H=num_kv_heads,
+ BLOCK_D=head_dim,
+ TARGET_DTYPE=target_dtype,
+ num_warps = 1,
+ num_stages = 1
+ )
+
class Attention(nn.Module):
def __init__(
@@ -55,21 +151,72 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
+ self.k_scale_cache = self.v_scale_cache = torch.tensor([])
+ self.block_size = 256
+ self.kv_quant = True
+ self.GROUP_NUM = 2
+ self.kv_store_stream = None
+ self.kv_store_event = None
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ if self.kv_quant:
+ o = self._forward_quantized(q,k,v)
+ else:
+ o = self._forward_unquantized(q,k,v)
+ return o
+
+ def _forward_unquantized(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
- if context.block_tables is not None: # prefix cache
+ if context.block_tables is not None:
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
- else: # decode
+ else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
- cache_seqlens=context.context_lens, block_table=context.block_tables,
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
+
+ def _forward_quantized(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ context = get_context()
+ k_cache, v_cache = self.k_cache, self.v_cache
+ k_scale_cache, v_scale_cache = self.k_scale_cache, self.v_scale_cache
+ if context.is_prefill:
+ if context.block_tables is not None:
+ with torch.cuda.nvtx.range("PreFlashQuantDir"):
+ o = prefill_attn_quantkv_direct(q, k, v, k_cache, v_cache, k_scale_cache, v_scale_cache,
+ max_seqlen_q=context.max_seqlen_q,
+ cu_seqlens_q=context.cu_seqlens_q,
+ context_lens=context.context_lens,
+ block_table=context.block_tables, block_size=self.block_size,
+ GROUP_NUM=self.GROUP_NUM,
+ softmax_scale=self.scale, causal=True)
+ else:
+ with torch.cuda.nvtx.range("PreFlashOri"):
+ o = flash_attn_varlen_func(q, k, v,
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
+ else: # decode
+ with torch.cuda.nvtx.range("DeFlashQuantDir"):
+ o = decode_attn_quantkv_direct(q, k, v, k_cache, v_cache, k_scale_cache, v_scale_cache,
+ context_lens=context.context_lens, block_table=context.block_tables,
+ block_size=self.block_size,
+ GROUP_NUM=self.GROUP_NUM,
+ softmax_scale=self.scale, )
+ if k_cache.numel() and v_cache.numel():
+ if torch.cuda.is_current_stream_capturing():
+ self.kv_store_stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.kv_store_stream):
+ with torch.cuda.nvtx.range("DeStore"):
+ store_quantkv(k, v, k_cache, v_cache, k_scale_cache, v_scale_cache, context.slot_mapping, )
+ self.kv_store_event.record(self.kv_store_stream)
+ return o
+
+
diff --git a/nanovllm/tools/quant_attn_kvhead_based.py b/nanovllm/tools/quant_attn_kvhead_based.py
new file mode 100644
index 000000000..5a3d6114b
--- /dev/null
+++ b/nanovllm/tools/quant_attn_kvhead_based.py
@@ -0,0 +1,415 @@
+import torch
+import triton
+import triton.language as tl
+
+decode_num_warps = 1
+decode_num_stages = 1
+prefill_num_warps = 4
+prefill_num_stages = 1
+prefill_num_tileQ = 32
+TILE_KV_PREFILL = 32
+TILE_KV_DECODE = 32
+assert prefill_num_tileQ == TILE_KV_PREFILL
+NUM_KV_TILES_PER_BLOCK_PREFILL = 256 // TILE_KV_PREFILL # 256/32 = 8
+NUM_KV_TILES_PER_BLOCK_DECODE = 256 // TILE_KV_DECODE # 256/32 = 8
+
+@triton.jit
+def _decode_quant_direct_kernel(
+ q_ptr, # [seqs_num, q_head_num, head_dim]
+ new_k_ptr, # [seqs_num, kv_head_num, head_dim]
+ new_v_ptr, # [seqs_num, kv_head_num, head_dim]
+ stride_new_b,
+ stride_new_h,
+ stride_new_d,
+ k_cache_ptr, # [kv_head_num, num_slots, head_dim]
+ v_cache_ptr, # [kv_head_num, num_slots, head_dim]
+ k_scale_ptr, # [kv_head_num, num_slots]
+ v_scale_ptr, # [kv_head_num, num_slots]
+ context_lens_ptr, # [seqs_num]
+ block_table_ptr,
+ out_ptr, # [seqs_num, q_head_num, head_dim]
+ stride_q_b, # q_head_num * head_dim
+ stride_q_h, # head_dim
+ stride_q_d, # 1
+ stride_cache_h, # num_slots * head_dim
+ stride_cache_s, # head_dim
+ stride_cache_d, # 1
+ stride_scale_h, # num_slots
+ stride_scale_s, # 1
+ stride_bt_b, # max_block_table_len
+ stride_bt_blk, # 1
+ stride_out_b, # = stride_q_b
+ stride_out_h, # head_dim
+ stride_out_d, # 1
+ softmax_scale, # 1/sqrt(dim)
+ block_size,
+ TILE_KV: tl.constexpr, # tile_K
+ BLOCK_DIM_MODEL: tl.constexpr, # head_dim
+ TARGET_DTYPE: tl.constexpr,
+ NUM_KV_TILES_PER_BLOCK: tl.constexpr,
+ GROUP_NUM: tl.constexpr, # GQA
+):
+ batch_idx = tl.program_id(0)
+ kv_head_idx = tl.program_id(1)
+ q_head_start = kv_head_idx * GROUP_NUM
+ q_head_offsets = q_head_start + tl.arange(0, GROUP_NUM)
+ ctx_len = tl.load(context_lens_ptr + batch_idx)
+ ctx_len -= 1
+ num_total_tiles = tl.cdiv(ctx_len, TILE_KV)
+ if num_total_tiles <= 0 or ctx_len <= 0:
+ return
+ offs_dim = tl.arange(0, BLOCK_DIM_MODEL)
+ offs_kv = tl.arange(0, TILE_KV)
+ q_ptrs = q_ptr + batch_idx * stride_q_b + q_head_offsets[:, None] * stride_q_h + offs_dim[None, :] * stride_q_d
+ # [GROUP_NUM, head_dim]
+ q = tl.load(q_ptrs)
+ m_i = tl.full((GROUP_NUM,), float("-inf"), tl.float32)
+ li = tl.full((GROUP_NUM,), 0.0, tl.float32)
+ oi = tl.zeros((GROUP_NUM, BLOCK_DIM_MODEL), dtype=tl.float32)
+ curr_kv_block_idx = 0
+ physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + curr_kv_block_idx * stride_bt_blk)
+ for tile_idx in tl.range(0, num_total_tiles - 1):
+ kv_block_idx = tile_idx // NUM_KV_TILES_PER_BLOCK
+ kv_tile_idx = tile_idx % NUM_KV_TILES_PER_BLOCK
+ if kv_block_idx != curr_kv_block_idx:
+ curr_kv_block_idx = kv_block_idx
+ physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + curr_kv_block_idx * stride_bt_blk)
+ logic_in_block_offs = kv_tile_idx * TILE_KV
+ token_in_block = logic_in_block_offs + offs_kv
+ slot = physical_block_idx * block_size + token_in_block
+ scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s
+ cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d
+ kvj = tl.load(k_cache_ptr + cache_offs).to(TARGET_DTYPE)
+ kv_scale = tl.load(k_scale_ptr + scale_offs).to(TARGET_DTYPE)
+ kvj = kvj * kv_scale[:, None]
+ scores = tl.sum(q[:, None, :] * kvj[None, :, :], axis=2) * softmax_scale
+ m_ij = tl.max(scores, axis=1) # [GROUP_NUM]
+ mj_new = tl.maximum(m_i, m_ij) # [GROUP_NUM]
+ alpha = tl.exp(m_i - mj_new) # [GROUP_NUM]
+ scores = tl.exp(scores - mj_new[:, None]) # [GROUP_NUM, tileK]
+ li = alpha * li + tl.sum(scores, axis=1) # [GROUP_NUM]
+ kvj = tl.load(v_cache_ptr + cache_offs).to(TARGET_DTYPE)
+ kv_scale = tl.load(v_scale_ptr + scale_offs).to(TARGET_DTYPE)
+ kvj = kvj * kv_scale[:, None]
+ # [GROUP_NUM, head_dim]
+ oij = tl.sum(scores[:, :, None] * kvj[None, :, :], axis=1)
+ oi = oi * alpha[:, None] + oij
+ m_i = mj_new
+ last_tile_idx = num_total_tiles - 1
+ last_kv_block_idx = last_tile_idx // NUM_KV_TILES_PER_BLOCK
+ last_kv_tile_idx = last_tile_idx % NUM_KV_TILES_PER_BLOCK
+ if last_kv_block_idx != curr_kv_block_idx:
+ physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + last_kv_block_idx * stride_bt_blk)
+ logic_block_offs = last_kv_block_idx * block_size
+ logic_in_block_offs = last_kv_tile_idx * TILE_KV
+ token_in_block = logic_in_block_offs + offs_kv
+ logical_token_idx = logic_block_offs + token_in_block
+ valid = logical_token_idx < ctx_len
+ slot = physical_block_idx * block_size + token_in_block
+ scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s
+ cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d
+ kvj = tl.load(k_cache_ptr + cache_offs, mask=valid[:, None], other=0).to(TARGET_DTYPE)
+ kv_scale = tl.load(k_scale_ptr + scale_offs, mask=valid, other=0.0).to(TARGET_DTYPE)
+ kvj = kvj * kv_scale[:, None]
+ scores = tl.sum(q[:, None, :] * kvj[None, :, :], axis=2) * softmax_scale
+ scores = tl.where(valid[None, :], scores, float("-inf"))
+ m_ij = tl.max(scores, axis=1)
+ mj_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - mj_new)
+ p = tl.exp(scores - mj_new[:, None])
+ li = alpha * li + tl.sum(p, axis=1)
+ kvj = tl.load(v_cache_ptr + cache_offs, mask=valid[:, None], other=0).to(TARGET_DTYPE)
+ kv_scale = tl.load(v_scale_ptr + scale_offs, mask=valid, other=0.0).to(TARGET_DTYPE)
+ kvj = kvj * kv_scale[:, None]
+ oij = tl.sum(p[:, :, None] * kvj[None, :, :], axis=1)
+ oi = oi * alpha[:, None] + oij
+ m_i = mj_new
+ cache_offs = batch_idx * stride_new_b + kv_head_idx * stride_new_h + offs_dim * stride_new_d
+ new_kv = tl.load(new_k_ptr + cache_offs)
+ scores = tl.sum(q * new_kv[None, :], axis=1)
+ scores = scores * softmax_scale
+ mj_new = tl.maximum(m_i, scores)
+ alpha = tl.exp(m_i - mj_new)
+ p = tl.exp(scores - mj_new)
+ li = alpha * li + p
+ new_kv = tl.load(new_v_ptr + cache_offs)
+ oi = oi * alpha[:, None] + p[:, None] * new_kv[None, :]
+ oi = oi / li[:, None]
+ # [GROUP_NUM, head_dim]
+ out_ptrs = out_ptr + batch_idx * stride_out_b + q_head_offsets[:, None] * stride_out_h + offs_dim[None, :] * stride_out_d
+ tl.store(out_ptrs, oi.to(TARGET_DTYPE))
+
+def decode_attn_quantkv_direct(
+ q: torch.Tensor,
+ new_k: torch.Tensor,
+ new_v: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+ GROUP_NUM: int,
+ softmax_scale: float | None = None,
+):
+ batch, num_q_heads, head_dim = q.shape
+ num_kv_heads, num_slots, _ = k_cache.shape
+ if softmax_scale is None:
+ softmax_scale = head_dim ** -0.5
+ out = torch.empty_like(q)
+ if q.dtype == torch.float16:
+ target_dtype = tl.float16
+ elif q.dtype == torch.bfloat16:
+ target_dtype = tl.bfloat16
+ else:
+ target_dtype = tl.float32
+
+ _decode_quant_direct_kernel[(batch, num_kv_heads, )](
+ q,
+ new_k,
+ new_v,
+ new_k.stride(0),
+ new_k.stride(1),
+ new_k.stride(2),
+ k_cache,
+ v_cache,
+ k_scale,
+ v_scale,
+ context_lens,
+ block_table,
+ out,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k_cache.stride(0),
+ k_cache.stride(1),
+ k_cache.stride(2),
+ k_scale.stride(0),
+ k_scale.stride(1),
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ softmax_scale,
+ block_size,
+ TILE_KV=TILE_KV_DECODE,
+ BLOCK_DIM_MODEL=head_dim,
+ TARGET_DTYPE=target_dtype,
+ NUM_KV_TILES_PER_BLOCK=NUM_KV_TILES_PER_BLOCK_DECODE,
+ GROUP_NUM=GROUP_NUM,
+ num_warps=decode_num_warps,
+ num_stages=decode_num_stages,
+ )
+ return out
+
+@triton.jit
+def _prefill_quant_direct_kernel(
+ q_ptr, # [total_q, q_head_num, head_dim] without prefix part
+ k_ptr, # [total_k, kv_head_num, head_dim] total_q == total_k == total_v
+ v_ptr, # [total_v, kv_head_num, head_dim]
+ stride_kv_b, # kv_head_num * head_dim
+ stride_kv_h, # head_dim
+ stride_kv_d, # 1
+ k_cache_ptr, # [kv_head_num, num_slots, head_dim]
+ v_cache_ptr, # [kv_head_num, num_slots, head_dim]
+ k_scale_ptr, # [kv_head_num, num_slots]
+ v_scale_ptr, # [kv_head_num, num_slots]
+ cu_seqlens_q_ptr, # [batch + 1]
+ context_lens_ptr, # [batch]
+ block_table_ptr, # [batch, max_num_blocks_per_seq]
+ out_ptr, # [total_q, q_head_num, head_dim]
+ stride_q_b, # q_head_num * head_dim
+ stride_q_h, # head_dim
+ stride_q_d, # 1
+ stride_cache_h, # num_slots * head_dim
+ stride_cache_s, # head_dim
+ stride_cache_d, # 1
+ stride_scale_h, # num_slots
+ stride_scale_s, # 1
+ stride_bt_b, # max_block_table_len
+ stride_bt_blk, # 1
+ stride_out_b, # = stride_q_b
+ stride_out_h, # head_dim
+ stride_out_d, # 1
+ softmax_scale, # = 1/sqrt(dim)
+ block_size,
+ TILE_Q: tl.constexpr,
+ TILE_KV: tl.constexpr,
+ TILE_DIM_MODEL: tl.constexpr, # head_dim
+ TARGET_DTYPE: tl.constexpr, # bfloat16
+ NUM_KV_TILES_PER_BLOCK: tl.constexpr,
+ GROUP_NUM: tl.constexpr, # GQA
+):
+ batch_idx = tl.program_id(2)
+ kv_head_idx = tl.program_id(1)
+ q_tile_idx = tl.program_id(0)
+ q_start = tl.load(cu_seqlens_q_ptr + batch_idx)
+ q_end = tl.load(cu_seqlens_q_ptr + batch_idx + 1)
+ q_len = q_end - q_start
+ k_len = tl.load(context_lens_ptr + batch_idx)
+ prefix_len = k_len - q_len
+ q_tile_start = q_tile_idx * TILE_Q
+ if q_tile_start >= q_len or q_len <= 0:
+ return
+ offs_dim = tl.arange(0, TILE_DIM_MODEL)
+ offs_kv = tl.arange(0, TILE_KV)
+ offs_rows = tl.arange(0, GROUP_NUM * TILE_Q)
+ offs_group = offs_rows // TILE_Q
+ offs_q_in_tile = offs_rows % TILE_Q
+ logic_offs_q = q_tile_start + offs_q_in_tile # [GROUP_NUM * TILE_Q]
+ q_valid = logic_offs_q < q_len #[tileQ] valid_q_num
+ logic_offs_q_in_total = q_start + logic_offs_q
+ q_head_offsets = kv_head_idx * GROUP_NUM + offs_group
+ q_ptrs = (q_ptr + logic_offs_q_in_total[:, None] * stride_q_b + q_head_offsets[:, None] * stride_q_h + offs_dim[None, :] * stride_q_d)
+ q = tl.load(q_ptrs, mask=q_valid[:, None], other=0.0) # [GROUP_NUM * tileQ, head_dim]
+ m_i = tl.where(q_valid, float("-inf"), 0.0).to(tl.float32)
+ l_i = tl.zeros((GROUP_NUM * TILE_Q,), dtype=tl.float32)
+ oi = tl.zeros((GROUP_NUM * TILE_Q, TILE_DIM_MODEL), dtype=tl.float32)
+ # prefix KV, from head-major kv cache
+ num_prefix_tiles = tl.cdiv(prefix_len, TILE_KV)
+ for tile_idx in tl.range(0, num_prefix_tiles):
+ kv_block_idx = tile_idx // NUM_KV_TILES_PER_BLOCK
+ kv_tile_idx = tile_idx % NUM_KV_TILES_PER_BLOCK
+ physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + kv_block_idx * stride_bt_blk)
+ logic_block_offs = kv_block_idx * block_size
+ logic_in_block_offs = kv_tile_idx * TILE_KV
+ token_in_block = logic_in_block_offs + offs_kv
+ logical_k = logic_block_offs + token_in_block
+ kv_valid = logical_k < prefix_len
+ slot = physical_block_idx * block_size + token_in_block
+ attn_mask = (q_valid[:, None] & kv_valid[None, :]) # no need casual mask
+ scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s
+ cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d
+ kj = tl.load(k_cache_ptr + cache_offs, mask=kv_valid[:, None], other=0).to(TARGET_DTYPE)
+ k_scale = tl.load(k_scale_ptr + scale_offs, mask=kv_valid, other=0.0).to(TARGET_DTYPE)
+ kj = kj * k_scale[None, :]
+ # [GROUP_NUM * tileQ, tileK]
+ scores = tl.dot(q, tl.trans(kj)) * softmax_scale
+ scores = tl.where(attn_mask, scores, float("-inf"))
+ # [GROUP_NUM * tileQ]
+ m_ij = tl.max(scores, axis=1)
+ m_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_new)
+ p = tl.where(attn_mask, tl.exp(scores - m_new[:, None]), 0.0).to(TARGET_DTYPE)
+ # [GROUP_NUM * tileQ]
+ l_i = l_i * alpha + tl.sum(p, axis=1)
+ vj = tl.load(v_cache_ptr + cache_offs, mask=kv_valid[:, None], other=0).to(TARGET_DTYPE)
+ v_scale = tl.load(v_scale_ptr + scale_offs, mask=kv_valid, other=0.0).to(TARGET_DTYPE)
+ vj = vj * v_scale[:, None]
+ # [GROUP_NUM * tileQ, head_dim]
+ oi = oi * alpha[:, None] + tl.dot(p, vj)
+ m_i = m_new
+ # post KV: from new kv [total_k, kv_head_num, head_dim]
+ num_post_tiles = tl.cdiv(q_len, TILE_KV)
+ for tile_idx in tl.range(0, num_post_tiles):
+ post_logic_k = tile_idx * TILE_KV + offs_kv
+ post_valid = post_logic_k < q_len
+ attn_mask = q_valid[:, None] & post_valid[None, :] & (post_logic_k[None, :] <= logic_offs_q[:, None])
+ postfix_global_k = q_start + post_logic_k
+ cache_offs = postfix_global_k[:, None] * stride_kv_b + kv_head_idx * stride_kv_h + offs_dim[None, :] * stride_kv_d
+ kj = tl.load(k_ptr + cache_offs, mask=post_valid[:, None], other=0.0)
+ # [GROUP_NUM * tileQ, tileK]
+ scores = tl.dot(q, tl.trans(kj))
+ scores = scores * softmax_scale
+ scores = tl.where(attn_mask, scores, float("-inf"))
+ # [GROUP_NUM * tileQ]
+ m_ij = tl.max(scores, axis=1)
+ m_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_new)
+ p = tl.where(attn_mask, tl.exp(scores - m_new[:, None]), 0.0).to(TARGET_DTYPE)
+ # [GROUP_NUM * tileQ]
+ l_i = l_i * alpha + tl.sum(p, axis=1)
+ vj = tl.load(v_ptr + cache_offs, mask=post_valid[:, None], other=0.0)
+ oi = oi * alpha[:, None] + tl.dot(p, vj)
+ m_i = m_new
+ oi = oi / l_i[:, None]
+ # [GROUP_NUM * tileQ, head_dim]
+ out_ptrs = out_ptr + logic_offs_q_in_total[:, None] * stride_out_b + q_head_offsets[:, None] * stride_out_h + offs_dim[None, :] * stride_out_d
+ tl.store(out_ptrs, oi.to(TARGET_DTYPE), mask=q_valid[:, None])
+
+def prefill_attn_quantkv_direct(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+ max_seqlen_q: int,
+ cu_seqlens_q: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+ GROUP_NUM: int,
+ softmax_scale: float | None = None,
+ causal: bool = True,
+):
+ assert causal == True
+ total_q, num_q_heads, head_dim = q.shape
+ num_kv_heads, num_slots, _ = k_cache.shape
+ batch = cu_seqlens_q.numel() - 1
+ if softmax_scale is None:
+ softmax_scale = head_dim ** -0.5
+ out = torch.empty_like(q)
+ if q.dtype == torch.float16:
+ target_dtype = tl.float16
+ elif q.dtype == torch.bfloat16:
+ target_dtype = tl.bfloat16
+ else:
+ target_dtype = tl.float32
+
+ if k_cache.numel():
+ stride_cache_h = k_cache.stride(0)
+ stride_cache_s = k_cache.stride(1)
+ stride_cache_d = k_cache.stride(2)
+ stride_scale_h = k_scale.stride(0)
+ stride_scale_s = k_scale.stride(1)
+ else:
+ # warmup
+ stride_cache_h = 0
+ stride_cache_s = 0
+ stride_cache_d = 0
+ stride_scale_h = 0
+ stride_scale_s = 0
+ grid = (triton.cdiv(max_seqlen_q, prefill_num_tileQ), num_kv_heads, batch )
+ _prefill_quant_direct_kernel[grid](
+ q,
+ k,
+ v,
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ k_cache,
+ v_cache,
+ k_scale,
+ v_scale,
+ cu_seqlens_q,
+ context_lens,
+ block_table,
+ out,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ stride_cache_h,
+ stride_cache_s,
+ stride_cache_d,
+ stride_scale_h,
+ stride_scale_s,
+ block_table.stride(0),
+ block_table.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ softmax_scale,
+ block_size,
+ TILE_Q=prefill_num_tileQ,
+ TILE_KV=TILE_KV_PREFILL,
+ TILE_DIM_MODEL=head_dim,
+ TARGET_DTYPE=target_dtype,
+ NUM_KV_TILES_PER_BLOCK=NUM_KV_TILES_PER_BLOCK_PREFILL,
+ GROUP_NUM=GROUP_NUM,
+ num_warps=prefill_num_warps,
+ num_stages=prefill_num_stages,
+ )
diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py
index 3b02a1d5d..e5881f536 100644
--- a/nanovllm/utils/context.py
+++ b/nanovllm/utils/context.py
@@ -12,15 +12,18 @@ class Context:
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
+ max_context_len: int = 0
_CONTEXT = Context()
def get_context():
return _CONTEXT
-def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
+def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None,
+ context_lens=None, block_tables=None, max_context_len=0):
global _CONTEXT
- _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
+ _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens,
+ block_tables, max_context_len)
def reset_context():
global _CONTEXT
diff --git a/pyproject.toml b/pyproject.toml
index dc1399a10..d912c47e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,13 +3,13 @@ requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"
[project]
-name = "nano-vllm"
+name = "nano-vllm-kv-compression"
version = "0.2.0"
-authors = [{ name = "Xingkai Yu" }]
+authors = [{ name = "Mingsong Jiang" }]
license = "MIT"
license-files = ["LICENSE"]
readme = "README.md"
-description = "a lightweight vLLM implementation built from scratch"
+description = "An improved implementation based on Nano-vLLM featuring int8 KV cache compression, head-major memory layout for coalesced access, and asynchronous stream pipelining that hides KV store latency behind attention computation."
requires-python = ">=3.10,<3.13"
dependencies = [
"torch>=2.4.0",
@@ -20,7 +20,7 @@ dependencies = [
]
[project.urls]
-Homepage="https://github.com/GeeeekExplorer/nano-vllm"
+Homepage="https://github.com/naalo2/nano-vLLM-kv-compression"
[tool.setuptools.packages.find]
where = ["."]