Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
<img width="300" src="assets/logo.png">
</p>

<p align="center">
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</p>

# Nano-vLLM

A lightweight vLLM implementation built from scratch.
# Nano-vLLM-kv-compression

An improved implementation based on Nano-vLLM featuring int8 KV cache compression, head-major memory layout for coalesced access, and asynchronous stream pipelining that hides KV store latency behind attention computation.
## Key Features

* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
* ⚡ **Int8 KV Cache Compression** — 50% memory reduction via dynamic per-head quantization
* 🔄 **Coalesced Layout** — Head-major reordering for warp-level memory coalescing
* 🎯 **GQA-Optimized Flash Attention** — group Q-head CTA mapping eliminates redundant KV loads
* 🔗 **Async KV Store Pipeline** — Multi-stream architecture overlaps KV quantization and cache writeback with attention computation

## Installation

```bash
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
pip install git+https://github.com/naalo2/nano-vLLM-kv-compression.git
```

## Model Download
Expand Down Expand Up @@ -48,19 +45,16 @@ outputs[0]["text"]
See `bench.py` for benchmark.

**Test Configuration:**
- Hardware: RTX 4070 Laptop (8GB)
- Hardware: RTX 3090 (24GB)
- Model: Qwen3-0.6B
- Total Requests: 256 sequences
- Input Length: Randomly sampled between 100–1024 tokens
- Output Length: Randomly sampled between 100–1024 tokens

**Performance Results:**
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|----------------|-------------|----------|-----------------------|
| vLLM | 133,966 | 98.37 | 1361.84 |
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |

| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|---------------------------|---------------|----------|-----------------------|
| Nano-vLLM | 133,966 | 33.05 | 4052.56 |
| Nano-vLLM-kv-compression | 133,966 | 27.00 | 4962.21 |

## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
Binary file modified assets/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 10 additions & 4 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.config import Config

# from vllm import LLM, SamplingParams


Expand All @@ -10,18 +12,22 @@ def main():
num_seqs = 256
max_input_len = 1024
max_ouput_len = 1024

path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
llm = LLM(path, enforce_eager=False, max_model_len=4096)
enforce_eager = False
print(f"use cuda graph:{not enforce_eager}")
path = os.path.expanduser("/YOUR/MODEL/PATH")
llm = LLM(path, enforce_eager=enforce_eager, max_model_len=max(4096, max_ouput_len+max_input_len))

prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
# prompt_token_ids = [[randint(0, 10000) for _ in range(max_input_len)] for _ in range(num_seqs)]
# sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_ouput_len) for _ in range(num_seqs)]

# uncomment the following line for vllm
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]

llm.generate(["Benchmark: "], SamplingParams())
t = time.time()
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
llm.generate(prompt_token_ids, sampling_params, use_tqdm=True)
t = (time.time() - t)
total_tokens = sum(sp.max_tokens for sp in sampling_params)
throughput = total_tokens / t
Expand Down
6 changes: 6 additions & 0 deletions nanovllm/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
from dataclasses import dataclass
from transformers import AutoConfig

Expand All @@ -16,10 +17,15 @@ class Config:
eos: int = -1
kvcache_block_size: int = 256
num_kvcache_blocks: int = -1
kv_quant: bool = True
kvcache_quant_dtype = torch.int8
kvscale_dtype = torch.bfloat16
group_num: int = 2

def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
assert self.max_num_batched_tokens >= self.max_model_len
159 changes: 120 additions & 39 deletions nanovllm/engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
Sequence.set_block_size(config.kvcache_block_size)
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
Expand All @@ -26,12 +27,18 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.dtype)
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
if self.config.kv_quant:
self.kv_store_stream = torch.cuda.Stream()
self.kv_store_event = torch.cuda.Event()
self.has_pedding = False
self.warmup_model()
if self.config.kv_quant:
self.has_pedding = False
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
Expand All @@ -53,6 +60,7 @@ def exit(self):
dist.barrier()
if self.rank == 0:
self.shm.unlink()

if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
Expand Down Expand Up @@ -92,11 +100,8 @@ def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
seq_len = min(max_num_batched_tokens, max_model_len)
num_seqs = min(max_num_batched_tokens // seq_len, self.config.max_num_seqs)
seqs = [Sequence([0] * seq_len) for _ in range(num_seqs)]
for seq in seqs:
seq.num_scheduled_tokens = seq_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()

Expand All @@ -109,16 +114,49 @@ def allocate_kv_cache(self):
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
available_bytes = int(total * config.gpu_memory_utilization - used - peak + current)
if config.kv_quant:
kv_bytes = config.kvcache_quant_dtype.itemsize
scale_bytes = config.kvscale_dtype.itemsize
block_bytes = (
2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads *
(head_dim * kv_bytes + scale_bytes)
)
config.num_kvcache_blocks = available_bytes // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks,
self.block_size, head_dim, dtype=config.kvcache_quant_dtype)
self.kv_scale_cache = torch.empty(2,hf_config.num_hidden_layers, num_kv_heads, config.num_kvcache_blocks,
self.block_size, dtype=config.kvscale_dtype,)
num_slots = config.num_kvcache_blocks * self.block_size
layer_id = 0
for module in self.model.modules():
if (hasattr(module, "k_cache")
and hasattr(module, "v_cache")):
module.k_cache = self.kv_cache[0, layer_id].reshape(num_kv_heads, num_slots, head_dim)
module.v_cache = self.kv_cache[1, layer_id].reshape(num_kv_heads, num_slots, head_dim)
module.k_scale_cache = self.kv_scale_cache[0, layer_id].reshape(num_kv_heads, num_slots)
module.v_scale_cache = self.kv_scale_cache[1, layer_id].reshape(num_kv_heads, num_slots)
module.block_size = self.block_size
module.kv_quant = True
module.GROUP_NUM = config.group_num
module.kv_store_stream = self.kv_store_stream
module.kv_store_event = self.kv_store_event
layer_id += 1
else:
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = available_bytes // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks,
self.block_size, num_kv_heads, head_dim, )
layer_id = 0
for module in self.model.modules():
if (hasattr(module, "k_cache")
and hasattr(module, "v_cache")):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
module.kv_quant = False
layer_id += 1

def prepare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs)
Expand All @@ -131,42 +169,41 @@ def prepare_prefill(self, seqs: list[Sequence]):
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
context_lens = []
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
start = seq.num_cached_tokens
seqlen_q = seq.num_scheduled_tokens
end = start + seqlen_q
seqlen_k = end
input_ids.extend(seq[start:end])
positions.extend(range(start, end))
seqlen = len(seq)
context_lens.append(seqlen)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
if not seq.block_table:
continue
start_block = start // self.block_size
end_block = (end + self.block_size - 1) // self.block_size
for i in range(start_block, end_block):
slot_start = seq.block_table[i] * self.block_size
if i == start_block:
slot_start += start % self.block_size
if i != end_block - 1:
slot_end = seq.block_table[i] * self.block_size + self.block_size
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
slot_end = seq.block_table[i] * self.block_size + end - i * self.block_size
slot_mapping.extend(range(slot_start, slot_end))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
if self.config.kv_quant:
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
return input_ids, positions

def prepare_decode(self, seqs: list[Sequence]):
Expand All @@ -179,21 +216,28 @@ def prepare_decode(self, seqs: list[Sequence]):
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
max_context_len = max(context_lens) if context_lens else 0
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens,
block_tables=block_tables, max_context_len=max_context_len)
return input_ids, positions

def prepare_sample(self, seqs: list[Sequence]):
temperatures = [seq.temperature for seq in seqs]
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures

@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
if self.config.kv_quant and getattr(self, "has_pedding", True):
torch.cuda.current_stream().wait_event(self.kv_store_event)
self.has_pedding = False
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else:
Expand All @@ -215,6 +259,8 @@ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
if self.config.kv_quant:
self.has_pedding = True
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
Expand All @@ -235,12 +281,16 @@ def capture_cudagraph(self):
self.graphs = {}
self.graph_pool = None

if self.config.kv_quant:
self.capture_quant(slot_mapping, context_lens, block_tables, outputs, input_ids, positions)
return

for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
Expand All @@ -255,3 +305,34 @@ def capture_cudagraph(self):
block_tables=block_tables,
outputs=outputs,
)

def capture_quant(self, slot_mapping, context_lens, block_tables, outputs, input_ids, positions):
main_stream = torch.cuda.current_stream()
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
self.kv_store_stream.synchronize()
set_context(
False,
slot_mapping=slot_mapping[:bs],
context_lens=context_lens[:bs],
block_tables=block_tables[:bs],
)
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
main_stream.wait_event(self.kv_store_event)
torch.cuda.synchronize()
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
torch.cuda.current_stream().wait_event(self.kv_store_event)
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)
Loading