3.5-bit KV cache compression for vLLM, combining PolarQuant (MSE-optimal scalar quantization) with QJL (1-bit Quantized Johnson-Lindenstrauss residual correction).
Reference: TurboQuant paper (arxiv 2504.19874)
TurboQuant compresses the transformer KV cache from 16 bits/channel (FP16) → ~3.5 bits/channel with near-zero accuracy loss on LongBench benchmarks. The speedup comes from replacing FP32 Q×K matrix multiply with integer table lookups (LUT).
| Stage | Bits | Purpose |
|---|---|---|
| PolarQuant | 2 | MSE-optimal scalar quantization after randomized Hadamard rotation |
| QJL | 1 | Unbiased residual correction via sign quantization |
| Total | 3 | ~4.9× compression vs FP16 |
git clone https://github.com/OnlyTerp/turboquant.git
cd turboquant
pip install -e .Or with vLLM as an extra dependency:
pip install -e ".[vllm]"- Python ≥ 3.10
- PyTorch ≥ 2.1
- vLLM ≥ 0.4.0 (optional for standalone testing)
- CUDA-capable GPU (recommended)
# Install the plugin
cd scripts/turboquant && pip install -e .
# Serve a model with TurboQuant KV cache compression
vllm serve meta-llama/Llama-3-8B-Instruct --attention-backend turboquantThe plugin is auto-discovered by vLLM through the vllm.platform_plugins entry point. No manual imports needed.
All parameters can be set via TQ_* environment variables before launching vLLM:
# Core parameters
export TQ_B_MSE=2 # PolarQuant bits per coordinate (default: 2)
export TQ_B_QJL=1 # QJL bits per coordinate (default: 1)
export TQ_FLUSH_INTERVAL=128 # Raw buffer size before compression (default: 128)
# Model-specific (auto-detected when vLLM provides model config)
export TQ_NUM_LAYERS=32
export TQ_NUM_HEADS=32
export TQ_NUM_KV_HEADS=8 # For GQA models (e.g., Llama-3)
export TQ_HEAD_DIM=128
# Device
export TQ_DEVICE=cuda # "cuda" or "cpu" (default: cuda)
# Then launch
vllm serve meta-llama/Llama-3-8B-Instruct --attention-backend turboquantfrom vllm_plugin import TurboQuantConfig, TurboQuantAttentionImpl
# Configure
config = TurboQuantConfig(
num_layers=32,
num_heads=32,
num_kv_heads=8, # Llama-3 GQA
head_dim=128,
flush_interval=64,
b_mse=2,
b_qjl=1,
device="cuda",
)
# Create attention implementation
impl = TurboQuantAttentionImpl(
num_heads=32,
head_size=128,
scale=1.0 / (128 ** 0.5),
num_kv_heads=8,
tq_config=config,
layer_idx=0,
)
# Forward pass (same interface as vLLM AttentionImpl)
output = impl.forward(query, key, value)vllm_plugin/
├── __init__.py # Package init, version, public API
├── config.py # TurboQuantConfig dataclass with env-var overrides
├── attention.py # TurboQuantAttentionBackend + TurboQuantAttentionImpl
├── platform.py # TurboQuantPlatform plugin (entry point target)
└── README.md # This file
setup.py # Entry point registration
src/
└── cache.py # TurboQuantCache (PolarQuant + QJL encode/decode)
- Prefill: Query, Key, Value tensors arrive → all KV tokens are batch-encoded into TurboQuantCache → attention computed via PQ scores + QJL correction
- Decode: Single new KV token buffered in raw FP → when buffer hits
flush_interval, batch-compress into TQ cache → attention computed over compressed + buffered tokens with unified softmax - GQA: Query heads are mapped to KV heads using
heads_per_kvratio (num_heads // num_kv_heads)
| Env Variable | Parameter | Default | Description |
|---|---|---|---|
TQ_NUM_LAYERS |
num_layers |
32 | Transformer layer count |
TQ_NUM_HEADS |
num_heads |
32 | Query attention heads |
TQ_NUM_KV_HEADS |
num_kv_heads |
32 | KV heads (≤ num_heads) |
TQ_HEAD_DIM |
head_dim |
128 | Head dimension (power of 2) |
TQ_MAX_SEQ_LEN |
max_seq_len |
4096 | Max sequence length |
TQ_FLUSH_INTERVAL |
flush_interval |
128 | Buffer size before TQ flush |
TQ_B_MSE |
b_mse |
2 | PolarQuant bits/coord |
TQ_B_QJL |
b_qjl |
1 | QJL bits/coord |
TQ_DEVICE |
device |
cuda | Torch device |
# Install with dev dependencies
pip install -e ".[dev]"
# Run tests (standalone — no vLLM required)
python -m pytest tests/ -v
# Type check
mypy vllm_plugin/
# Lint
ruff check vllm_plugin/The plugin uses try/except for vLLM imports, so tests can run without vLLM installed using mock stubs.
- Triton kernels not yet integrated: Current implementation uses pure PyTorch. Triton kernels for FWHT and fused dequant+attention are planned.
- Accuracy in progress: The reference kernels recently fixed the QJL projection transpose bug; end-to-end quality still needs broader validation.
- Paged cache bypassed: TurboQuant manages its own compressed storage; vLLM's paged KV cache allocation is not used for KV data.
- Attention sinks: First few tokens are not yet preserved in higher precision (planned).
MIT