Skip to content

A custom implementation of [AdaKV](https://arxiv.org/abs/2407.11550) under the open-source project NVIDIA/kvpress. Original AdaKV code can be found in [repo](https://github.com/FFY0/AdaKV). Additionally, [NVIDIA/kvpress](https://github.com/NVIDIA/kvpress) also offers a simpler method to simulate the performance of AdaKV.

License

Notifications You must be signed in to change notification settings

FFY0/AdaKV-in-NVIDIA-kvpress

 
 

Repository files navigation

PyPI version License Colab example notebook

kvpress

Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using 🤗 transformers, aiming to simplify the development of new methods for researchers and developers in this field.

A custom implementation of AdaKV under NVIDIA/kvpress open-source project!

In this fork, we have implemented AdaKV under KVPress with a custom CUDA kernel, enabling easy customization of head-specific compression. Additionally, the official (NVIDIA/KVPress)[https://github.com/NVIDIA/kvpress] repository provides a simpler way to simulate AdaKV's performance. The key difference lies in whether actual compression is achieved. The official code offers a fast and convenient starting point, and this repository allows you to test the practical compression benefits likes peak memory usage and decoding latency. Additionally, there are other implementations of AdaKV available. For example, Cloudflare provides an AdaKV implementation integrated into VLLM, alongside the (original AdaKV code)[https://github.com/FFY0/AdaKV]. We encourage everyone to explore these versions, and we hope they can be helpful to your work.

Custom Evaluation

RULER

Install

pip install kvpress

We recommend using flash attention if possible:

pip install flash-attn --no-build-isolation

Usage

This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a compression_ratio parameter that measures the compression of the cache. The easiest way to use a press is through our custom KVPressTextGenerationPipeline that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you:

from kvpress import ExpectedAttentionPress
from transformers import pipeline

device = "cuda:0"
model= "microsoft/Phi-3.5-mini-instruct"
pipe = pipeline("kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":"flash_attention_2"})

context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context" # optional
    
press = ExpectedAttentionPress(compression_ratio=0.4)
answer = pipe(context, question=question, press=press)["answer"]

In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the Wikipedia notebook demo for a more detailed example.

Important

We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.

Note

To use the ObservedAttentionPress, use model_kwargs={"attn_implementation":"eager"} in order to materialize the attention weights (this method is not compatible with flash attention).

About

A custom implementation of [AdaKV](https://arxiv.org/abs/2407.11550) under the open-source project NVIDIA/kvpress. Original AdaKV code can be found in [repo](https://github.com/FFY0/AdaKV). Additionally, [NVIDIA/kvpress](https://github.com/NVIDIA/kvpress) also offers a simpler method to simulate the performance of AdaKV.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 93.9%
  • Cuda 2.7%
  • Shell 1.6%
  • C 1.1%
  • Makefile 0.7%