diff --git a/README.md b/README.md index eb468f33b..9d1a9fbd8 100644 --- a/README.md +++ b/README.md @@ -2,24 +2,21 @@

-

-GeeeekExplorer%2Fnano-vllm | Trendshift -

-# Nano-vLLM - -A lightweight vLLM implementation built from scratch. +# Nano-vLLM-kv-compression +An improved implementation based on Nano-vLLM featuring int8 KV cache compression, head-major memory layout for coalesced access, and asynchronous stream pipelining that hides KV store latency behind attention computation. ## Key Features -* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM -* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code -* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc. +* ⚡ **Int8 KV Cache Compression** — 50% memory reduction via dynamic per-head quantization +* 🔄 **Coalesced Layout** — Head-major reordering for warp-level memory coalescing +* 🎯 **GQA-Optimized Flash Attention** — group Q-head CTA mapping eliminates redundant KV loads +* 🔗 **Async KV Store Pipeline** — Multi-stream architecture overlaps KV quantization and cache writeback with attention computation ## Installation ```bash -pip install git+https://github.com/GeeeekExplorer/nano-vllm.git +pip install git+https://github.com/naalo2/nano-vLLM-kv-compression.git ``` ## Model Download @@ -48,19 +45,16 @@ outputs[0]["text"] See `bench.py` for benchmark. **Test Configuration:** -- Hardware: RTX 4070 Laptop (8GB) +- Hardware: RTX 3090 (24GB) - Model: Qwen3-0.6B - Total Requests: 256 sequences - Input Length: Randomly sampled between 100–1024 tokens - Output Length: Randomly sampled between 100–1024 tokens **Performance Results:** -| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) | -|----------------|-------------|----------|-----------------------| -| vLLM | 133,966 | 98.37 | 1361.84 | -| Nano-vLLM | 133,966 | 93.41 | 1434.13 | - +| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) | +|---------------------------|---------------|----------|-----------------------| +| Nano-vLLM | 133,966 | 33.05 | 4052.56 | +| Nano-vLLM-kv-compression | 133,966 | 27.00 | 4962.21 | -## Star History -[![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date) \ No newline at end of file diff --git a/assets/logo.png b/assets/logo.png index ac0b8fd6e..b2c41996d 100644 Binary files a/assets/logo.png and b/assets/logo.png differ diff --git a/bench.py b/bench.py index 8e61d6545..3468a7fdb 100644 --- a/bench.py +++ b/bench.py @@ -2,6 +2,8 @@ import time from random import randint, seed from nanovllm import LLM, SamplingParams +from nanovllm.config import Config + # from vllm import LLM, SamplingParams @@ -10,18 +12,22 @@ def main(): num_seqs = 256 max_input_len = 1024 max_ouput_len = 1024 - - path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") - llm = LLM(path, enforce_eager=False, max_model_len=4096) + enforce_eager = False + print(f"use cuda graph:{not enforce_eager}") + path = os.path.expanduser("/YOUR/MODEL/PATH") + llm = LLM(path, enforce_eager=enforce_eager, max_model_len=max(4096, max_ouput_len+max_input_len)) prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # prompt_token_ids = [[randint(0, 10000) for _ in range(max_input_len)] for _ in range(num_seqs)] + # sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_ouput_len) for _ in range(num_seqs)] + # uncomment the following line for vllm # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] llm.generate(["Benchmark: "], SamplingParams()) t = time.time() - llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + llm.generate(prompt_token_ids, sampling_params, use_tqdm=True) t = (time.time() - t) total_tokens = sum(sp.max_tokens for sp in sampling_params) throughput = total_tokens / t diff --git a/nanovllm/config.py b/nanovllm/config.py index 7066cbeb5..3e40e37f6 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -1,4 +1,5 @@ import os +import torch from dataclasses import dataclass from transformers import AutoConfig @@ -16,6 +17,10 @@ class Config: eos: int = -1 kvcache_block_size: int = 256 num_kvcache_blocks: int = -1 + kv_quant: bool = True + kvcache_quant_dtype = torch.int8 + kvscale_dtype = torch.bfloat16 + group_num: int = 2 def __post_init__(self): assert os.path.isdir(self.model) @@ -23,3 +28,4 @@ def __post_init__(self): assert 1 <= self.tensor_parallel_size <= 8 self.hf_config = AutoConfig.from_pretrained(self.model) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + assert self.max_num_batched_tokens >= self.max_model_len diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 71d9883c9..d4eba1204 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -17,6 +17,7 @@ class ModelRunner: def __init__(self, config: Config, rank: int, event: Event | list[Event]): self.config = config hf_config = config.hf_config + Sequence.set_block_size(config.kvcache_block_size) self.block_size = config.kvcache_block_size self.enforce_eager = config.enforce_eager self.world_size = config.tensor_parallel_size @@ -26,12 +27,18 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() - torch.set_default_dtype(hf_config.dtype) + torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() + if self.config.kv_quant: + self.kv_store_stream = torch.cuda.Stream() + self.kv_store_event = torch.cuda.Event() + self.has_pedding = False self.warmup_model() + if self.config.kv_quant: + self.has_pedding = False self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() @@ -53,6 +60,7 @@ def exit(self): dist.barrier() if self.rank == 0: self.shm.unlink() + if not self.enforce_eager: del self.graphs, self.graph_pool torch.cuda.synchronize() @@ -92,11 +100,8 @@ def warmup_model(self): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - seq_len = min(max_num_batched_tokens, max_model_len) - num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs) - seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)] - for seq in seqs: - seq.num_scheduled_tokens = seq_len + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] self.run(seqs, True) torch.cuda.empty_cache() @@ -109,16 +114,49 @@ def allocate_kv_cache(self): current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) - block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize - config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes - assert config.num_kvcache_blocks > 0 - self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim) - layer_id = 0 - for module in self.model.modules(): - if hasattr(module, "k_cache") and hasattr(module, "v_cache"): - module.k_cache = self.kv_cache[0, layer_id] - module.v_cache = self.kv_cache[1, layer_id] - layer_id += 1 + available_bytes = int(total * config.gpu_memory_utilization - used - peak + current) + if config.kv_quant: + kv_bytes = config.kvcache_quant_dtype.itemsize + scale_bytes = config.kvscale_dtype.itemsize + block_bytes = ( + 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * + (head_dim * kv_bytes + scale_bytes) + ) + config.num_kvcache_blocks = available_bytes // block_bytes + assert config.num_kvcache_blocks > 0 + self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks, + self.block_size, head_dim, dtype=config.kvcache_quant_dtype) + self.kv_scale_cache = torch.empty(2,hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks, + self.block_size, dtype=config.kvscale_dtype,) + num_slots = config.num_kvcache_blocks * self.block_size + layer_id = 0 + for module in self.model.modules(): + if (hasattr(module, "k_cache") + and hasattr(module, "v_cache")): + module.k_cache = self.kv_cache[0, layer_id].reshape(num_kv_heads, num_slots, head_dim) + module.v_cache = self.kv_cache[1, layer_id].reshape(num_kv_heads, num_slots, head_dim) + module.k_scale_cache = self.kv_scale_cache[0, layer_id].reshape(num_kv_heads, num_slots) + module.v_scale_cache = self.kv_scale_cache[1, layer_id].reshape(num_kv_heads, num_slots) + module.block_size = self.block_size + module.kv_quant = True + module.GROUP_NUM = config.group_num + module.kv_store_stream = self.kv_store_stream + module.kv_store_event = self.kv_store_event + layer_id += 1 + else: + block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize + config.num_kvcache_blocks = available_bytes // block_bytes + assert config.num_kvcache_blocks > 0 + self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, + self.block_size, num_kv_heads, head_dim, ) + layer_id = 0 + for module in self.model.modules(): + if (hasattr(module, "k_cache") + and hasattr(module, "v_cache")): + module.k_cache = self.kv_cache[0, layer_id] + module.v_cache = self.kv_cache[1, layer_id] + module.kv_quant = False + layer_id += 1 def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) @@ -131,42 +169,41 @@ def prepare_prefill(self, seqs: list[Sequence]): positions = [] cu_seqlens_q = [0] cu_seqlens_k = [0] + context_lens = [] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] block_tables = None for seq in seqs: - start = seq.num_cached_tokens - seqlen_q = seq.num_scheduled_tokens - end = start + seqlen_q - seqlen_k = end - input_ids.extend(seq[start:end]) - positions.extend(range(start, end)) + seqlen = len(seq) + context_lens.append(seqlen) + input_ids.extend(seq[seq.num_cached_tokens:]) + positions.extend(list(range(seq.num_cached_tokens, seqlen))) + seqlen_q = seqlen - seq.num_cached_tokens + seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) - if not seq.block_table: # warmup + if not seq.block_table: continue - start_block = start // self.block_size - end_block = (end + self.block_size - 1) // self.block_size - for i in range(start_block, end_block): - slot_start = seq.block_table[i] * self.block_size - if i == start_block: - slot_start += start % self.block_size - if i != end_block - 1: - slot_end = seq.block_table[i] * self.block_size + self.block_size + for i in range(seq.num_cached_blocks, seq.num_blocks): + start = seq.block_table[i] * self.block_size + if i != seq.num_blocks - 1: + end = start + self.block_size else: - slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size - slot_mapping.extend(range(slot_start, slot_end)) - if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + if cu_seqlens_k[-1] > cu_seqlens_q[-1]: block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + if self.config.kv_quant: + context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) + set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): @@ -179,21 +216,28 @@ def prepare_decode(self, seqs: list[Sequence]): positions.append(len(seq) - 1) context_lens.append(len(seq)) slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) + max_context_len = max(context_lens) if context_lens else 0 input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.prepare_block_tables(seqs) - set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) + set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, + block_tables=block_tables, max_context_len=max_context_len) return input_ids, positions def prepare_sample(self, seqs: list[Sequence]): - temperatures = [seq.temperature for seq in seqs] + temperatures = [] + for seq in seqs: + temperatures.append(seq.temperature) temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) return temperatures @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + if self.config.kv_quant and getattr(self, "has_pedding", True): + torch.cuda.current_stream().wait_event(self.kv_store_event) + self.has_pedding = False if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) else: @@ -215,6 +259,8 @@ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) + if self.config.kv_quant: + self.has_pedding = True token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None reset_context() return token_ids @@ -235,12 +281,16 @@ def capture_cudagraph(self): self.graphs = {} self.graph_pool = None + if self.config.kv_quant: + self.capture_quant(slot_mapping, context_lens, block_tables, outputs, input_ids, positions) + return + for bs in reversed(self.graph_bs): graph = torch.cuda.CUDAGraph() set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) - outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) with torch.cuda.graph(graph, self.graph_pool): - outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[bs] = graph @@ -255,3 +305,34 @@ def capture_cudagraph(self): block_tables=block_tables, outputs=outputs, ) + + def capture_quant(self, slot_mapping, context_lens, block_tables, outputs, input_ids, positions): + main_stream = torch.cuda.current_stream() + for bs in reversed(self.graph_bs): + graph = torch.cuda.CUDAGraph() + self.kv_store_stream.synchronize() + set_context( + False, + slot_mapping=slot_mapping[:bs], + context_lens=context_lens[:bs], + block_tables=block_tables[:bs], + ) + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) + main_stream.wait_event(self.kv_store_event) + torch.cuda.synchronize() + with torch.cuda.graph(graph, self.graph_pool): + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) + torch.cuda.current_stream().wait_event(self.kv_store_event) + if self.graph_pool is None: + self.graph_pool = graph.pool() + self.graphs[bs] = graph + torch.cuda.synchronize() + reset_context() + self.graph_vars = dict( + input_ids=input_ids, + positions=positions, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + outputs=outputs, + ) \ No newline at end of file diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 4decfce5f..316888828 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -12,9 +12,13 @@ class SequenceStatus(Enum): class Sequence: - block_size = 256 + block_size: int = 0 # invalid value, will be set by set_block_size counter = count() + @classmethod + def set_block_size(cls, block_size: int): + cls.block_size = block_size + def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING @@ -23,12 +27,11 @@ def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.num_tokens = len(self.token_ids) self.num_prompt_tokens = len(token_ids) self.num_cached_tokens = 0 - self.num_scheduled_tokens = 0 - self.is_prefill = True self.block_table = [] self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos + self.prefilled = False def __len__(self): return self.num_tokens @@ -52,6 +55,10 @@ def prompt_token_ids(self): def completion_token_ids(self): return self.token_ids[self.num_prompt_tokens:] + @property + def num_cached_blocks(self): + return self.num_cached_tokens // self.block_size + @property def num_blocks(self): return (self.num_tokens + self.block_size - 1) // self.block_size @@ -68,16 +75,15 @@ def append_token(self, token_id: int): self.token_ids.append(token_id) self.last_token = token_id self.num_tokens += 1 + self.prefilled = True def __getstate__(self): - last_state = self.last_token if not self.is_prefill else self.token_ids - return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state) + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled, + self.last_token if self.prefilled else self.token_ids) def __setstate__(self, state): - self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.num_scheduled_tokens, self.block_table, last_state = state - if isinstance(last_state, list): - self.token_ids = last_state - self.last_token = self.token_ids[-1] + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, self.prefilled = state[:-1] + if self.prefilled: + self.last_token = state[-1] else: - self.token_ids = [] - self.last_token = last_state + self.token_ids = state[-1] diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index e416139ea..cb07826bf 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -2,8 +2,9 @@ from torch import nn import triton import triton.language as tl - from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from nanovllm.tools.quant_attn_kvhead_based import decode_attn_quantkv_direct, \ + prefill_attn_quantkv_direct from nanovllm.utils.context import get_context @@ -40,6 +41,101 @@ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) +@triton.jit +def _store_quantkv_kernel( + key_ptr, + key_stride_b, + key_stride_h, + key_stride_d, + value_ptr, + value_stride_b, + value_stride_h, + value_stride_d, + k_cache_ptr, + v_cache_ptr, + k_scale_ptr, + v_scale_ptr, + slot_mapping_ptr, + stride_cache_h, + stride_cache_s, + stride_cache_d, + stride_scale_h, + stride_scale_s, + EPS: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + TARGET_DTYPE: tl.constexpr, +): + token_idx = tl.program_id(0) + slot = tl.load(slot_mapping_ptr + token_idx) + if slot == -1: + return + offs_head = tl.arange(0, BLOCK_H) + offs_dim = tl.arange(0, BLOCK_D) + # k + kv_offsets = token_idx * key_stride_b + offs_head[:, None] * key_stride_h + offs_dim[None, :] * key_stride_d + kv = tl.load(key_ptr + kv_offsets) + kv_abs_max = tl.max(tl.abs(kv), axis=1) + kv_scale = tl.maximum(kv_abs_max / 127.0, EPS).to(TARGET_DTYPE) + kv = tl.extra.cuda.libdevice.llrint(kv / kv_scale[:, None]) + kv = tl.maximum(tl.minimum(kv, 127), -127).to(tl.int8) + cache_offsets = offs_head[:, None] * stride_cache_h + slot * stride_cache_s + offs_dim[None, :] * stride_cache_d + tl.store(k_cache_ptr + cache_offsets, kv) + scale_offsets = offs_head * stride_scale_h + slot * stride_scale_s + tl.store(k_scale_ptr + scale_offsets, kv_scale) + # v + kv_offsets = token_idx * value_stride_b + offs_head[:, None] * value_stride_h + offs_dim[None, :] * value_stride_d + kv = tl.load(value_ptr + kv_offsets) + kv_abs_max = tl.max(tl.abs(kv), axis=1) + kv_scale = tl.maximum(kv_abs_max/127.0, EPS).to(TARGET_DTYPE) + kv = tl.extra.cuda.libdevice.llrint(kv / kv_scale[:, None]) + kv = tl.maximum(tl.minimum(kv, 127), -127).to(tl.int8) + tl.store(v_cache_ptr + cache_offsets, kv) + tl.store(v_scale_ptr + scale_offsets, kv_scale) + +def store_quantkv( + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + slot_mapping: torch.Tensor, +): + n, num_kv_heads, head_dim = key.shape + if key.dtype == torch.float16: + target_dtype = tl.float16 + elif key.dtype == torch.bfloat16: + target_dtype = tl.bfloat16 + else: + target_dtype = tl.float32 + _store_quantkv_kernel[(n,)]( + key, + key.stride(0), + key.stride(1), + key.stride(2), + value, + value.stride(0), + value.stride(1), + value.stride(2), + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + slot_mapping, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_scale_cache.stride(0), + k_scale_cache.stride(1), + EPS=1e-8, + BLOCK_H=num_kv_heads, + BLOCK_D=head_dim, + TARGET_DTYPE=target_dtype, + num_warps = 1, + num_stages = 1 + ) + class Attention(nn.Module): def __init__( @@ -55,21 +151,72 @@ def __init__( self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) + self.k_scale_cache = self.v_scale_cache = torch.tensor([]) + self.block_size = 256 + self.kv_quant = True + self.GROUP_NUM = 2 + self.kv_store_stream = None + self.kv_store_event = None def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + if self.kv_quant: + o = self._forward_quantized(q,k,v) + else: + o = self._forward_unquantized(q,k,v) + return o + + def _forward_unquantized(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: - if context.block_tables is not None: # prefix cache + if context.block_tables is not None: k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) - else: # decode + else: o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, - cache_seqlens=context.context_lens, block_table=context.block_tables, + cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) return o + + def _forward_quantized(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + context = get_context() + k_cache, v_cache = self.k_cache, self.v_cache + k_scale_cache, v_scale_cache = self.k_scale_cache, self.v_scale_cache + if context.is_prefill: + if context.block_tables is not None: + with torch.cuda.nvtx.range("PreFlashQuantDir"): + o = prefill_attn_quantkv_direct(q, k, v, k_cache, v_cache, k_scale_cache, v_scale_cache, + max_seqlen_q=context.max_seqlen_q, + cu_seqlens_q=context.cu_seqlens_q, + context_lens=context.context_lens, + block_table=context.block_tables, block_size=self.block_size, + GROUP_NUM=self.GROUP_NUM, + softmax_scale=self.scale, causal=True) + else: + with torch.cuda.nvtx.range("PreFlashOri"): + o = flash_attn_varlen_func(q, k, v, + max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, + max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, + softmax_scale=self.scale, causal=True, block_table=context.block_tables) + else: # decode + with torch.cuda.nvtx.range("DeFlashQuantDir"): + o = decode_attn_quantkv_direct(q, k, v, k_cache, v_cache, k_scale_cache, v_scale_cache, + context_lens=context.context_lens, block_table=context.block_tables, + block_size=self.block_size, + GROUP_NUM=self.GROUP_NUM, + softmax_scale=self.scale, ) + if k_cache.numel() and v_cache.numel(): + if torch.cuda.is_current_stream_capturing(): + self.kv_store_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.kv_store_stream): + with torch.cuda.nvtx.range("DeStore"): + store_quantkv(k, v, k_cache, v_cache, k_scale_cache, v_scale_cache, context.slot_mapping, ) + self.kv_store_event.record(self.kv_store_stream) + return o + + diff --git a/nanovllm/tools/quant_attn_kvhead_based.py b/nanovllm/tools/quant_attn_kvhead_based.py new file mode 100644 index 000000000..5a3d6114b --- /dev/null +++ b/nanovllm/tools/quant_attn_kvhead_based.py @@ -0,0 +1,415 @@ +import torch +import triton +import triton.language as tl + +decode_num_warps = 1 +decode_num_stages = 1 +prefill_num_warps = 4 +prefill_num_stages = 1 +prefill_num_tileQ = 32 +TILE_KV_PREFILL = 32 +TILE_KV_DECODE = 32 +assert prefill_num_tileQ == TILE_KV_PREFILL +NUM_KV_TILES_PER_BLOCK_PREFILL = 256 // TILE_KV_PREFILL # 256/32 = 8 +NUM_KV_TILES_PER_BLOCK_DECODE = 256 // TILE_KV_DECODE # 256/32 = 8 + +@triton.jit +def _decode_quant_direct_kernel( + q_ptr, # [seqs_num, q_head_num, head_dim] + new_k_ptr, # [seqs_num, kv_head_num, head_dim] + new_v_ptr, # [seqs_num, kv_head_num, head_dim] + stride_new_b, + stride_new_h, + stride_new_d, + k_cache_ptr, # [kv_head_num, num_slots, head_dim] + v_cache_ptr, # [kv_head_num, num_slots, head_dim] + k_scale_ptr, # [kv_head_num, num_slots] + v_scale_ptr, # [kv_head_num, num_slots] + context_lens_ptr, # [seqs_num] + block_table_ptr, + out_ptr, # [seqs_num, q_head_num, head_dim] + stride_q_b, # q_head_num * head_dim + stride_q_h, # head_dim + stride_q_d, # 1 + stride_cache_h, # num_slots * head_dim + stride_cache_s, # head_dim + stride_cache_d, # 1 + stride_scale_h, # num_slots + stride_scale_s, # 1 + stride_bt_b, # max_block_table_len + stride_bt_blk, # 1 + stride_out_b, # = stride_q_b + stride_out_h, # head_dim + stride_out_d, # 1 + softmax_scale, # 1/sqrt(dim) + block_size, + TILE_KV: tl.constexpr, # tile_K + BLOCK_DIM_MODEL: tl.constexpr, # head_dim + TARGET_DTYPE: tl.constexpr, + NUM_KV_TILES_PER_BLOCK: tl.constexpr, + GROUP_NUM: tl.constexpr, # GQA +): + batch_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + q_head_start = kv_head_idx * GROUP_NUM + q_head_offsets = q_head_start + tl.arange(0, GROUP_NUM) + ctx_len = tl.load(context_lens_ptr + batch_idx) + ctx_len -= 1 + num_total_tiles = tl.cdiv(ctx_len, TILE_KV) + if num_total_tiles <= 0 or ctx_len <= 0: + return + offs_dim = tl.arange(0, BLOCK_DIM_MODEL) + offs_kv = tl.arange(0, TILE_KV) + q_ptrs = q_ptr + batch_idx * stride_q_b + q_head_offsets[:, None] * stride_q_h + offs_dim[None, :] * stride_q_d + # [GROUP_NUM, head_dim] + q = tl.load(q_ptrs) + m_i = tl.full((GROUP_NUM,), float("-inf"), tl.float32) + li = tl.full((GROUP_NUM,), 0.0, tl.float32) + oi = tl.zeros((GROUP_NUM, BLOCK_DIM_MODEL), dtype=tl.float32) + curr_kv_block_idx = 0 + physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + curr_kv_block_idx * stride_bt_blk) + for tile_idx in tl.range(0, num_total_tiles - 1): + kv_block_idx = tile_idx // NUM_KV_TILES_PER_BLOCK + kv_tile_idx = tile_idx % NUM_KV_TILES_PER_BLOCK + if kv_block_idx != curr_kv_block_idx: + curr_kv_block_idx = kv_block_idx + physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + curr_kv_block_idx * stride_bt_blk) + logic_in_block_offs = kv_tile_idx * TILE_KV + token_in_block = logic_in_block_offs + offs_kv + slot = physical_block_idx * block_size + token_in_block + scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s + cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d + kvj = tl.load(k_cache_ptr + cache_offs).to(TARGET_DTYPE) + kv_scale = tl.load(k_scale_ptr + scale_offs).to(TARGET_DTYPE) + kvj = kvj * kv_scale[:, None] + scores = tl.sum(q[:, None, :] * kvj[None, :, :], axis=2) * softmax_scale + m_ij = tl.max(scores, axis=1) # [GROUP_NUM] + mj_new = tl.maximum(m_i, m_ij) # [GROUP_NUM] + alpha = tl.exp(m_i - mj_new) # [GROUP_NUM] + scores = tl.exp(scores - mj_new[:, None]) # [GROUP_NUM, tileK] + li = alpha * li + tl.sum(scores, axis=1) # [GROUP_NUM] + kvj = tl.load(v_cache_ptr + cache_offs).to(TARGET_DTYPE) + kv_scale = tl.load(v_scale_ptr + scale_offs).to(TARGET_DTYPE) + kvj = kvj * kv_scale[:, None] + # [GROUP_NUM, head_dim] + oij = tl.sum(scores[:, :, None] * kvj[None, :, :], axis=1) + oi = oi * alpha[:, None] + oij + m_i = mj_new + last_tile_idx = num_total_tiles - 1 + last_kv_block_idx = last_tile_idx // NUM_KV_TILES_PER_BLOCK + last_kv_tile_idx = last_tile_idx % NUM_KV_TILES_PER_BLOCK + if last_kv_block_idx != curr_kv_block_idx: + physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + last_kv_block_idx * stride_bt_blk) + logic_block_offs = last_kv_block_idx * block_size + logic_in_block_offs = last_kv_tile_idx * TILE_KV + token_in_block = logic_in_block_offs + offs_kv + logical_token_idx = logic_block_offs + token_in_block + valid = logical_token_idx < ctx_len + slot = physical_block_idx * block_size + token_in_block + scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s + cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d + kvj = tl.load(k_cache_ptr + cache_offs, mask=valid[:, None], other=0).to(TARGET_DTYPE) + kv_scale = tl.load(k_scale_ptr + scale_offs, mask=valid, other=0.0).to(TARGET_DTYPE) + kvj = kvj * kv_scale[:, None] + scores = tl.sum(q[:, None, :] * kvj[None, :, :], axis=2) * softmax_scale + scores = tl.where(valid[None, :], scores, float("-inf")) + m_ij = tl.max(scores, axis=1) + mj_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - mj_new) + p = tl.exp(scores - mj_new[:, None]) + li = alpha * li + tl.sum(p, axis=1) + kvj = tl.load(v_cache_ptr + cache_offs, mask=valid[:, None], other=0).to(TARGET_DTYPE) + kv_scale = tl.load(v_scale_ptr + scale_offs, mask=valid, other=0.0).to(TARGET_DTYPE) + kvj = kvj * kv_scale[:, None] + oij = tl.sum(p[:, :, None] * kvj[None, :, :], axis=1) + oi = oi * alpha[:, None] + oij + m_i = mj_new + cache_offs = batch_idx * stride_new_b + kv_head_idx * stride_new_h + offs_dim * stride_new_d + new_kv = tl.load(new_k_ptr + cache_offs) + scores = tl.sum(q * new_kv[None, :], axis=1) + scores = scores * softmax_scale + mj_new = tl.maximum(m_i, scores) + alpha = tl.exp(m_i - mj_new) + p = tl.exp(scores - mj_new) + li = alpha * li + p + new_kv = tl.load(new_v_ptr + cache_offs) + oi = oi * alpha[:, None] + p[:, None] * new_kv[None, :] + oi = oi / li[:, None] + # [GROUP_NUM, head_dim] + out_ptrs = out_ptr + batch_idx * stride_out_b + q_head_offsets[:, None] * stride_out_h + offs_dim[None, :] * stride_out_d + tl.store(out_ptrs, oi.to(TARGET_DTYPE)) + +def decode_attn_quantkv_direct( + q: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + context_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + GROUP_NUM: int, + softmax_scale: float | None = None, +): + batch, num_q_heads, head_dim = q.shape + num_kv_heads, num_slots, _ = k_cache.shape + if softmax_scale is None: + softmax_scale = head_dim ** -0.5 + out = torch.empty_like(q) + if q.dtype == torch.float16: + target_dtype = tl.float16 + elif q.dtype == torch.bfloat16: + target_dtype = tl.bfloat16 + else: + target_dtype = tl.float32 + + _decode_quant_direct_kernel[(batch, num_kv_heads, )]( + q, + new_k, + new_v, + new_k.stride(0), + new_k.stride(1), + new_k.stride(2), + k_cache, + v_cache, + k_scale, + v_scale, + context_lens, + block_table, + out, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_scale.stride(0), + k_scale.stride(1), + block_table.stride(0), + block_table.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + softmax_scale, + block_size, + TILE_KV=TILE_KV_DECODE, + BLOCK_DIM_MODEL=head_dim, + TARGET_DTYPE=target_dtype, + NUM_KV_TILES_PER_BLOCK=NUM_KV_TILES_PER_BLOCK_DECODE, + GROUP_NUM=GROUP_NUM, + num_warps=decode_num_warps, + num_stages=decode_num_stages, + ) + return out + +@triton.jit +def _prefill_quant_direct_kernel( + q_ptr, # [total_q, q_head_num, head_dim] without prefix part + k_ptr, # [total_k, kv_head_num, head_dim] total_q == total_k == total_v + v_ptr, # [total_v, kv_head_num, head_dim] + stride_kv_b, # kv_head_num * head_dim + stride_kv_h, # head_dim + stride_kv_d, # 1 + k_cache_ptr, # [kv_head_num, num_slots, head_dim] + v_cache_ptr, # [kv_head_num, num_slots, head_dim] + k_scale_ptr, # [kv_head_num, num_slots] + v_scale_ptr, # [kv_head_num, num_slots] + cu_seqlens_q_ptr, # [batch + 1] + context_lens_ptr, # [batch] + block_table_ptr, # [batch, max_num_blocks_per_seq] + out_ptr, # [total_q, q_head_num, head_dim] + stride_q_b, # q_head_num * head_dim + stride_q_h, # head_dim + stride_q_d, # 1 + stride_cache_h, # num_slots * head_dim + stride_cache_s, # head_dim + stride_cache_d, # 1 + stride_scale_h, # num_slots + stride_scale_s, # 1 + stride_bt_b, # max_block_table_len + stride_bt_blk, # 1 + stride_out_b, # = stride_q_b + stride_out_h, # head_dim + stride_out_d, # 1 + softmax_scale, # = 1/sqrt(dim) + block_size, + TILE_Q: tl.constexpr, + TILE_KV: tl.constexpr, + TILE_DIM_MODEL: tl.constexpr, # head_dim + TARGET_DTYPE: tl.constexpr, # bfloat16 + NUM_KV_TILES_PER_BLOCK: tl.constexpr, + GROUP_NUM: tl.constexpr, # GQA +): + batch_idx = tl.program_id(2) + kv_head_idx = tl.program_id(1) + q_tile_idx = tl.program_id(0) + q_start = tl.load(cu_seqlens_q_ptr + batch_idx) + q_end = tl.load(cu_seqlens_q_ptr + batch_idx + 1) + q_len = q_end - q_start + k_len = tl.load(context_lens_ptr + batch_idx) + prefix_len = k_len - q_len + q_tile_start = q_tile_idx * TILE_Q + if q_tile_start >= q_len or q_len <= 0: + return + offs_dim = tl.arange(0, TILE_DIM_MODEL) + offs_kv = tl.arange(0, TILE_KV) + offs_rows = tl.arange(0, GROUP_NUM * TILE_Q) + offs_group = offs_rows // TILE_Q + offs_q_in_tile = offs_rows % TILE_Q + logic_offs_q = q_tile_start + offs_q_in_tile # [GROUP_NUM * TILE_Q] + q_valid = logic_offs_q < q_len #[tileQ] valid_q_num + logic_offs_q_in_total = q_start + logic_offs_q + q_head_offsets = kv_head_idx * GROUP_NUM + offs_group + q_ptrs = (q_ptr + logic_offs_q_in_total[:, None] * stride_q_b + q_head_offsets[:, None] * stride_q_h + offs_dim[None, :] * stride_q_d) + q = tl.load(q_ptrs, mask=q_valid[:, None], other=0.0) # [GROUP_NUM * tileQ, head_dim] + m_i = tl.where(q_valid, float("-inf"), 0.0).to(tl.float32) + l_i = tl.zeros((GROUP_NUM * TILE_Q,), dtype=tl.float32) + oi = tl.zeros((GROUP_NUM * TILE_Q, TILE_DIM_MODEL), dtype=tl.float32) + # prefix KV, from head-major kv cache + num_prefix_tiles = tl.cdiv(prefix_len, TILE_KV) + for tile_idx in tl.range(0, num_prefix_tiles): + kv_block_idx = tile_idx // NUM_KV_TILES_PER_BLOCK + kv_tile_idx = tile_idx % NUM_KV_TILES_PER_BLOCK + physical_block_idx = tl.load(block_table_ptr + batch_idx * stride_bt_b + kv_block_idx * stride_bt_blk) + logic_block_offs = kv_block_idx * block_size + logic_in_block_offs = kv_tile_idx * TILE_KV + token_in_block = logic_in_block_offs + offs_kv + logical_k = logic_block_offs + token_in_block + kv_valid = logical_k < prefix_len + slot = physical_block_idx * block_size + token_in_block + attn_mask = (q_valid[:, None] & kv_valid[None, :]) # no need casual mask + scale_offs = kv_head_idx * stride_scale_h + slot * stride_scale_s + cache_offs = kv_head_idx * stride_cache_h + slot[:, None] * stride_cache_s + offs_dim[None, :] * stride_cache_d + kj = tl.load(k_cache_ptr + cache_offs, mask=kv_valid[:, None], other=0).to(TARGET_DTYPE) + k_scale = tl.load(k_scale_ptr + scale_offs, mask=kv_valid, other=0.0).to(TARGET_DTYPE) + kj = kj * k_scale[None, :] + # [GROUP_NUM * tileQ, tileK] + scores = tl.dot(q, tl.trans(kj)) * softmax_scale + scores = tl.where(attn_mask, scores, float("-inf")) + # [GROUP_NUM * tileQ] + m_ij = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_new) + p = tl.where(attn_mask, tl.exp(scores - m_new[:, None]), 0.0).to(TARGET_DTYPE) + # [GROUP_NUM * tileQ] + l_i = l_i * alpha + tl.sum(p, axis=1) + vj = tl.load(v_cache_ptr + cache_offs, mask=kv_valid[:, None], other=0).to(TARGET_DTYPE) + v_scale = tl.load(v_scale_ptr + scale_offs, mask=kv_valid, other=0.0).to(TARGET_DTYPE) + vj = vj * v_scale[:, None] + # [GROUP_NUM * tileQ, head_dim] + oi = oi * alpha[:, None] + tl.dot(p, vj) + m_i = m_new + # post KV: from new kv [total_k, kv_head_num, head_dim] + num_post_tiles = tl.cdiv(q_len, TILE_KV) + for tile_idx in tl.range(0, num_post_tiles): + post_logic_k = tile_idx * TILE_KV + offs_kv + post_valid = post_logic_k < q_len + attn_mask = q_valid[:, None] & post_valid[None, :] & (post_logic_k[None, :] <= logic_offs_q[:, None]) + postfix_global_k = q_start + post_logic_k + cache_offs = postfix_global_k[:, None] * stride_kv_b + kv_head_idx * stride_kv_h + offs_dim[None, :] * stride_kv_d + kj = tl.load(k_ptr + cache_offs, mask=post_valid[:, None], other=0.0) + # [GROUP_NUM * tileQ, tileK] + scores = tl.dot(q, tl.trans(kj)) + scores = scores * softmax_scale + scores = tl.where(attn_mask, scores, float("-inf")) + # [GROUP_NUM * tileQ] + m_ij = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_new) + p = tl.where(attn_mask, tl.exp(scores - m_new[:, None]), 0.0).to(TARGET_DTYPE) + # [GROUP_NUM * tileQ] + l_i = l_i * alpha + tl.sum(p, axis=1) + vj = tl.load(v_ptr + cache_offs, mask=post_valid[:, None], other=0.0) + oi = oi * alpha[:, None] + tl.dot(p, vj) + m_i = m_new + oi = oi / l_i[:, None] + # [GROUP_NUM * tileQ, head_dim] + out_ptrs = out_ptr + logic_offs_q_in_total[:, None] * stride_out_b + q_head_offsets[:, None] * stride_out_h + offs_dim[None, :] * stride_out_d + tl.store(out_ptrs, oi.to(TARGET_DTYPE), mask=q_valid[:, None]) + +def prefill_attn_quantkv_direct( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + max_seqlen_q: int, + cu_seqlens_q: torch.Tensor, + context_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + GROUP_NUM: int, + softmax_scale: float | None = None, + causal: bool = True, +): + assert causal == True + total_q, num_q_heads, head_dim = q.shape + num_kv_heads, num_slots, _ = k_cache.shape + batch = cu_seqlens_q.numel() - 1 + if softmax_scale is None: + softmax_scale = head_dim ** -0.5 + out = torch.empty_like(q) + if q.dtype == torch.float16: + target_dtype = tl.float16 + elif q.dtype == torch.bfloat16: + target_dtype = tl.bfloat16 + else: + target_dtype = tl.float32 + + if k_cache.numel(): + stride_cache_h = k_cache.stride(0) + stride_cache_s = k_cache.stride(1) + stride_cache_d = k_cache.stride(2) + stride_scale_h = k_scale.stride(0) + stride_scale_s = k_scale.stride(1) + else: + # warmup + stride_cache_h = 0 + stride_cache_s = 0 + stride_cache_d = 0 + stride_scale_h = 0 + stride_scale_s = 0 + grid = (triton.cdiv(max_seqlen_q, prefill_num_tileQ), num_kv_heads, batch ) + _prefill_quant_direct_kernel[grid]( + q, + k, + v, + k.stride(0), + k.stride(1), + k.stride(2), + k_cache, + v_cache, + k_scale, + v_scale, + cu_seqlens_q, + context_lens, + block_table, + out, + q.stride(0), + q.stride(1), + q.stride(2), + stride_cache_h, + stride_cache_s, + stride_cache_d, + stride_scale_h, + stride_scale_s, + block_table.stride(0), + block_table.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + softmax_scale, + block_size, + TILE_Q=prefill_num_tileQ, + TILE_KV=TILE_KV_PREFILL, + TILE_DIM_MODEL=head_dim, + TARGET_DTYPE=target_dtype, + NUM_KV_TILES_PER_BLOCK=NUM_KV_TILES_PER_BLOCK_PREFILL, + GROUP_NUM=GROUP_NUM, + num_warps=prefill_num_warps, + num_stages=prefill_num_stages, + ) diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 3b02a1d5d..e5881f536 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -12,15 +12,18 @@ class Context: slot_mapping: torch.Tensor | None = None context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None + max_context_len: int = 0 _CONTEXT = Context() def get_context(): return _CONTEXT -def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, + context_lens=None, block_tables=None, max_context_len=0): global _CONTEXT - _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, + block_tables, max_context_len) def reset_context(): global _CONTEXT diff --git a/pyproject.toml b/pyproject.toml index dc1399a10..d912c47e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,13 +3,13 @@ requires = ["setuptools>=61"] build-backend = "setuptools.build_meta" [project] -name = "nano-vllm" +name = "nano-vllm-kv-compression" version = "0.2.0" -authors = [{ name = "Xingkai Yu" }] +authors = [{ name = "Mingsong Jiang" }] license = "MIT" license-files = ["LICENSE"] readme = "README.md" -description = "a lightweight vLLM implementation built from scratch" +description = "An improved implementation based on Nano-vLLM featuring int8 KV cache compression, head-major memory layout for coalesced access, and asynchronous stream pipelining that hides KV store latency behind attention computation." requires-python = ">=3.10,<3.13" dependencies = [ "torch>=2.4.0", @@ -20,7 +20,7 @@ dependencies = [ ] [project.urls] -Homepage="https://github.com/GeeeekExplorer/nano-vllm" +Homepage="https://github.com/naalo2/nano-vLLM-kv-compression" [tool.setuptools.packages.find] where = ["."]