Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,8 @@ tools/addlicense
tools/.addlicense.lock

.vscode/*
.claude
.claude

# macOS metadata files
.DS_Store
._*
93 changes: 93 additions & 0 deletions benchmarks/bench_alloc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# `KVCacheManager` alloc/free microbenchmark

Times the `alloc(k) + free(handles)` hot path under three allocator implementations:

- **Python allocator** — baseline before PR #319.
- **C++ allocator** — PR #319 (`lianghao_c++`): allocator migrated to C++, `cudaMemGetInfo` dropped from `available_size()`, page-grouping moved into a single C++ call, `KVCacheBlock` object pool added.
- **C++ + restored resize** — PR #319 + `fix/pr319-restore-resize`: re-adds the elastic-resize poll and shm-name pin that PR #319 dropped.

NVIDIA GB10 (aarch64). All latency numbers are per call in μs (lower is better) unless noted.

For e2e vLLM serving numbers, see [`../bench_layout/README.md`](../bench_layout/README.md).

## Run

```bash
python bench_alloc.py
```

## Results

### 1. `available_size()` — the most frequently called allocator function

The Python path called `cudaMemGetInfo` on every invocation (~6 μs each). The C++ path skips it.

| | μs/call |
|--:|--:|
| Python | 6.52 |
| C++ | 0.52 |
| C++ + resize | 0.52 |

**12.5×.** Called once per scheduler step.

### 2. `group_indices_by_page` — called inside `free()`

Maps a list of N block indices to their owning pages. Python used a per-element loop + `defaultdict`; C++ replaces it with one call.

| N | Python | C++ | speedup |
|--:|--:|--:|--:|
| 64 | 3.4 | 1.3 | 2.6× |
| 1024 | 52.6 | 16.8 | 3.1× |
| 16384 | 834 | 292 | 2.9× |

**~3× across the range.** Restored-resize matches C++ within noise.

### 3. Slow-path alloc — `cuMemMap` per call

`KVCACHED_MIN/MAX_RESERVED_PAGES=0` forces every alloc to map a fresh 2 MB VMM page. `k` is blocks per alloc; k=128 ≈ one page.

| k | Python | C++ | C++ + resize |
|--:|--:|--:|--:|
| 128 | 4196 | 4023 | 4354 |
| 1024 | 33028 | 32488 | 34662 |
| 4096 | 134479 | 134430 | 137295 |

**All within 5%.** The CUDA driver syscall dominates; switching the surrounding code to C++ doesn't help.

### 4. Multi-thread throughput — Python contention dissolves

N Python threads, each in a tight `alloc(k=16) + free(h)` loop, `async_sched=True`. Aggregate ops/s (**higher is better**).

| threads | Python Kops/s | C++ Kops/s | C++ + resize Kops/s |
|--:|--:|--:|--:|
| 1 | 15.1 | 41.2 | 32.5 |
| 4 | 12.0 | 48.6 | 31.6 |
| 8 | 9.1 | 51.5 | 29.1 |

Python **degrades** under thread count (Python-level contention dominated the old hot path). C++ holds or improves. Restored-resize is flat ~30K — each alloc polls a resize shm descriptor that bare C++ skips.

(GIL is still held during C++ work, so gains come from shorter critical sections, not real parallelism. Real vLLM uses `async_sched=False` and doesn't exercise this path.)

### 5. `KVCacheBlock` object pool — C++ only

Pre-allocated pool of `KVCacheBlock` objects vs `new` per call. The Python baseline has no equivalent pool.

| N | no-pool | pool | speedup |
|--:|--:|--:|--:|
| 8 | 1.06 | 0.19 | 5.6× |
| 1024 | 147 | 17.4 | 8.5× |
| 4096 | 651 | 67.7 | 9.6× |

**5-10×**, speedup grows with N.

## Summary — what the C++ allocator delivers

- **12.5× on `available_size()`** — eliminates the per-scheduler-step `cudaMemGetInfo` cost.
- **~3× on `group_indices_by_page`** — flat across N from 64 to 16,384.
- **Multi-thread throughput scales** instead of degrading: 8 threads go from 9 Kops/s (Python) to 51 Kops/s (C++).
- **5-10× on `KVCacheBlock` allocation** via the new object pool (no Python equivalent).
- Slow-path `cuMemMap` is driver-bound and unaffected by the migration.

The restored-resize variant retains every gain except multi-thread (~70% of bare C++) because each alloc polls the resize shm descriptor.

These wins amortise to ~5% on e2e vLLM serving (per-token model forward dominates). The much larger e2e lever is unrelated — see [`../bench_layout/README.md`](../bench_layout/README.md).
63 changes: 63 additions & 0 deletions benchmarks/bench_alloc/bench_alloc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: Copyright contributors to the kvcached project
# SPDX-License-Identifier: Apache-2.0
"""Microbench for KVCacheManager.alloc() hot path.

Run on each branch and compare:
python bench_alloc.py
"""
import time

import torch

from kvcached.integration.vllm.interfaces import alloc_kv_cache, init_kvcached, shutdown_kvcached
from kvcached.kv_cache_manager import KVCacheManager
from kvcached.vmm_ops import kv_tensors_created

TP_RANK, TP_SIZE = 0, 1
NUM_LAYERS = 16
BLOCK_SIZE = 16
NUM_BLOCKS = 65536
DTYPE = torch.float16
DEVICE = f"cuda:{TP_RANK}"
KV_SHAPE = (2, NUM_BLOCKS, BLOCK_SIZE, 8, 64)


def setup():
torch.cuda.set_device(TP_RANK)
init_kvcached(tp_rank=TP_RANK, world_size=TP_SIZE, is_worker=True,
async_sched=False)
alloc_kv_cache(kvcache_shape=KV_SHAPE, block_size=BLOCK_SIZE, dtype=DTYPE,
device=DEVICE, num_layers=NUM_LAYERS)
t0 = time.time()
while not kv_tensors_created():
if time.time() - t0 > 10.0:
raise RuntimeError("KV tensors not created within 10s")
time.sleep(0.05)
return KVCacheManager(num_blocks=NUM_BLOCKS, block_size=BLOCK_SIZE,
cell_size=1024, num_layers=NUM_LAYERS,
world_size=TP_SIZE)


def bench_alloc_free(manager, k, iters):
# warm up
for _ in range(100):
h = manager.alloc(k)
manager.free(h)

t0 = time.perf_counter()
for _ in range(iters):
h = manager.alloc(k)
manager.free(h)
elapsed = time.perf_counter() - t0
per_op_us = elapsed / iters * 1e6
return per_op_us


if __name__ == "__main__":
manager = setup()
print(f"{'k':>6} {'iters':>8} {'us/alloc+free':>16}")
for k, iters in [(1, 50000), (4, 50000), (16, 50000), (64, 20000),
(256, 10000)]:
per_op = bench_alloc_free(manager, k, iters)
print(f"{k:>6} {iters:>8} {per_op:>16.2f}")
shutdown_kvcached()
96 changes: 96 additions & 0 deletions benchmarks/bench_layout/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# vLLM e2e + `KVCACHED_CONTIGUOUS_LAYOUT` overhead

Why kvcached is 30-50% slower than vanilla vLLM by default, and why `KVCACHED_CONTIGUOUS_LAYOUT=false` fixes it.

For the alloc/free microbench, see [`../bench_alloc/README.md`](../bench_alloc/README.md).

## Setup

GB10 (aarch64). `Qwen/Qwen3-0.6B` (28 layers, 8 KV heads, head_dim 128, bf16). `vllm serve --gpu-memory-utilization 0.5 --max-model-len 2048`. Bench: `vllm bench serve` random 512in/128out, 500 prompts, 3 seeds.

## Run

```bash
# E2E sweep (vanilla vs kvcached × LAYOUT × reserved pool)
bash run_sweep.sh && python parse_results.py sweep_results/

# Kernel-level profile under both layouts
bash run_nsys_layout.sh
python diff_nsys_kernels.py nsys_runs/layout_false.nsys-rep nsys_runs/layout_true.nsys-rep
```

Intermediate outputs aren't tracked in git — reproducible from the scripts.

## 1. The gap

500 prompts at `rate=inf`:

| | tput (req/s) | TTFT mean (ms) | TPOT mean (ms) |
|---|--:|--:|--:|
| vanilla | 14.21 | 11575 | 119.3 |
| kvcached (`LAYOUT=true`, default) | 9.87 (-31%) | 16555 | 177.5 |
| kvcached + `LAYOUT=false` | 14.17 (-1%) | 11642 | 119.0 |

`LAYOUT=false` matches vanilla on every metric, also at `rate=16` (sustained load). The C++ allocator from PR #319 only buys back ~5%; reserve-pool size doesn't help either. **It's all the layout.**

## 2. Where the gap actually is

### Stride math

`CONTIGUOUS_LAYOUT=true` lays out KV as `[num_blocks, num_layers, k/v, token, head, dim]` (`interfaces.py:282-289`). When you slice down to one layer, block n→n+1 stride is `num_layers × per_block_bytes`. For Qwen3-0.6B:

- per-block K+V, one layer = 16·8·128·2 = **64 KB**
- stride under `LAYOUT=true` = 28 × 64 KB = **1.75 MB** (≈ VMM page = 2 MB)
- stride under `LAYOUT=false` = **64 KB** (~32 blocks share a page)

So under contiguous, every FlashAttention block read lands on its own fresh 2 MB page. Non-contiguous packs 32 blocks per page. The attention kernel can't hide that.

### nsys per-kernel breakdown

Same workload as Section 1. Going from `LAYOUT=false → true` adds **+8,043 ms (+34.8%)** total GPU kernel time, all in one kernel:

| kernel | calls | LAYOUT=false ms | LAYOUT=true ms | Δ |
|---|--:|--:|--:|--:|
| `flash::flash_fwd_splitkv_kernel` (KV-read) | 3948 | 14,666 | 22,879 | **+8,213 (+56%)** |
| `vllm::reshape_and_cache_flash_kernel` (KV-write) | 3948 | 302 | 271 | -32 (-11%) |
| everything else | — | ~8,000 | ~8,000 | ~0 |

That one kernel exceeds the entire gap. Worth noting: the KV-*write* kernel isn't affected — only the multi-block read path is. Writes are sequential per-position so they never hit the cross-page stride.

### Scales with working set

Per-call attention time:

- 100 prompts: 1163 vs 851 μs (**+37%**)
- 500 prompts: 5795 vs 3715 μs (**+56%**)

More concurrent requests → larger working set → more distinct 2 MB pages touched → worse TLB/L2 hit rate. `LAYOUT=false` stays flat because 32 blocks share one page. Deeper models (Llama2-7B at 32 layers, Llama3-70B at 80) cross the page boundary even harder.

## 3. Where `LAYOUT=true` still wins

Three things to put on the other side of the scale.

**Hybrid linear / mamba: required.** Mamba state shares the KV buffer and indexes by virtual block across layers. `interfaces.py:138` outright refuses non-contiguous for hybrid-linear configs.

**Init time: ~1.4 s faster at server boot.** Contiguous reserves one big VM range; non-contiguous reserves `num_layers` separate ones. Measured `alloc_kv_cache` (16 layers, 1 GB/layer): 635 ms vs 2055 ms. ~99% of that 1.4 s is `FTensor::init_with_zero_()` mapping the zero-page over the entire VM range — contiguous uses a 64 MB compound page so it makes 1947 `cuMemMap` calls (~325 μs each); non-contiguous uses 2 MB pages and makes 62,304 calls (~33 μs each). CUDA driver per-call overhead is the dominant cost, and bigger pages amortise it better.

The gap stays roughly flat across `num_layers ∈ {8..80}` (1.3–1.5 s), one-shot at startup.

**Alloc/free hot path: ~2× faster.** Each page mapping under contiguous = 1 `cuMemMap`; under non-contiguous = `num_layers × (K+V)` FTensor `map()` calls. Cold path (`RESERVED=0`) shows a consistent 2.1× ratio; steady-state at small `k` is similar, collapsing to ~1× at `k=256`.

### When does the trade-off flip?

Attention overhead hits every decode step. Startup hits once. For the Section 1 workload:

- `LAYOUT=true` startup advantage: ~1.4 s
- `LAYOUT=false` throughput advantage: 14.17 vs 9.83 req/s, ≈ 31 ms/req

Break-even at **~45 requests**. Above that, non-contiguous wins on total wall-clock; below, contiguous's faster boot wins. Deeper models shift the break-even down further.

So contiguous still wins for: smoke tests, single-shot inference, request-level autoscaling, boot-SLA workloads, hybrid linear/mamba (forced). Everything else: non-contiguous.

## Summary

The kvcached default `CONTIGUOUS_LAYOUT=true` costs ~30% e2e throughput on standard MHA/GQA/MLA because every FlashAttention block read crosses a fresh 2 MB VMM page. Flipping to `LAYOUT=false` closes the gap entirely, at the price of ~1.4 s extra startup that's paid off in tens of requests.

The default should flip to `false` for non-hybrid models; `interfaces.py:138` already handles the hybrid-linear case where contiguous is mandatory.
101 changes: 101 additions & 0 deletions benchmarks/bench_layout/diff_nsys_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright contributors to the kvcached project
# SPDX-License-Identifier: Apache-2.0
"""Diff per-kernel GPU time between two nsys traces.

Usage:
python diff_nsys_kernels.py <baseline.nsys-rep> <variant.nsys-rep>

Prints kernels ordered by absolute delta (variant - baseline). Useful for
identifying which kernels regress between KVCACHED_CONTIGUOUS_LAYOUT=true and
false.
"""
import csv
import subprocess
import sys
from collections import defaultdict
from io import StringIO


def kernel_sums(nsys_rep_path: str) -> dict:
"""Returns {kernel_name: (total_ns, instances)} for the given trace."""
result = subprocess.run(
[
"nsys",
"stats",
"--report",
"cuda_gpu_kern_sum",
"--format",
"csv",
"--force-overwrite=true",
nsys_rep_path,
],
capture_output=True,
text=True,
check=True,
)
# nsys stats emits header text, then a CSV table. Find the CSV.
text = result.stdout
# Find a line that looks like a header row.
lines = text.splitlines()
start = None
for i, line in enumerate(lines):
if line.startswith("Time (%)") or line.startswith("\"Time (%)"):
start = i
break
if start is None:
raise RuntimeError(
f"No CSV header found in nsys stats output for {nsys_rep_path}.\n"
f"First 500 chars:\n{text[:500]}")

csv_text = "\n".join(lines[start:])
reader = csv.DictReader(StringIO(csv_text))
sums: dict = defaultdict(lambda: [0, 0])
for row in reader:
name = row["Name"]
total_ns = int(row["Total Time (ns)"])
instances = int(row["Instances"])
sums[name][0] += total_ns
sums[name][1] += instances
return {k: tuple(v) for k, v in sums.items()}


def main():
if len(sys.argv) != 3:
print(__doc__)
sys.exit(1)
baseline_path, variant_path = sys.argv[1], sys.argv[2]
print(f"baseline = {baseline_path}", file=sys.stderr)
print(f"variant = {variant_path}", file=sys.stderr)

base = kernel_sums(baseline_path)
var = kernel_sums(variant_path)

all_kernels = set(base) | set(var)
rows = []
for k in all_kernels:
b_ns, b_n = base.get(k, (0, 0))
v_ns, v_n = var.get(k, (0, 0))
delta = v_ns - b_ns
rows.append((k, b_ns, b_n, v_ns, v_n, delta))

rows.sort(key=lambda r: r[5]) # by delta ascending (negative first = sped up)

total_b = sum(r[1] for r in rows)
total_v = sum(r[3] for r in rows)
print(
f"\nTotal kernel time: baseline={total_b/1e6:,.1f} ms "
f"variant={total_v/1e6:,.1f} ms "
f"delta={ (total_v-total_b)/1e6:+,.1f} ms ({(total_v-total_b)/total_b*100:+.1f}%)\n"
)

print(f"{'kernel':<80} {'base ms':>10} {'var ms':>10} {'delta ms':>10} {'delta %':>8} base_n var_n")
for k, b_ns, b_n, v_ns, v_n, d_ns in rows:
if abs(d_ns) < 1_000_000: # < 1 ms delta, skip noise
continue
pct = (d_ns / b_ns * 100) if b_ns else float("inf")
print(
f"{k[:80]:<80} {b_ns/1e6:>10.2f} {v_ns/1e6:>10.2f} {d_ns/1e6:>+10.2f} {pct:>+7.1f}% {b_n:>6} {v_n:>6}")


if __name__ == "__main__":
main()
Loading
Loading