Skip to content

Commit 957cea0

Browse files
author
Nathan
committed
fix
1 parent 59751da commit 957cea0

20 files changed

+6952
-2
lines changed

README.md

+79-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,79 @@
1-
# qfilters
2-
Repository for the Q-Filters method
1+
# Q-Filters: Leveraging Query-Key Geometry for Efficient Key-Value Cache Compression
2+
[![arXiv](https://img.shields.io/badge/arXiv-1234.56789-b31b1b.svg)](https://arxiv.org/abs/1234.56789)
3+
4+
![Q-Filters Demo GIF](assets/qfilters_demo.gif)
5+
6+
## Setup
7+
1. Install required libraries in a virtual environment:
8+
```bash
9+
python -m virtualenv venv
10+
source venv/bin/activate
11+
pip install -r requirements.txt
12+
````
13+
2. Configure HuggingFace\'s environment:
14+
```bash
15+
export HF_DATASETS_CACHE=<path_to_hf_cache>
16+
export HF_HOME=<path_to_hf_cache>
17+
export HF_TOKEN=<hf_token>
18+
```
19+
20+
## Generate with Q-Filters
21+
Here is an example of how to use Q-Filters in a generation setup:
22+
```python
23+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
24+
from src.hf_cache import QFiltersCache
25+
from datasets import load_dataset
26+
27+
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
28+
model = AutoModelForCausalLM.from_pretrained(
29+
model_name,
30+
device_map="auto",
31+
low_cpu_mem_usage=True,
32+
torch_dtype="bfloat16"
33+
)
34+
35+
tokenizer = AutoTokenizer.from_pretrained(model_name)
36+
streamer = TextStreamer(tokenizer)
37+
38+
question = """What is the probability of two integers selected at random having a greatest common divisor of 1."""
39+
input_text = f"<|User|>{question}<|Assistant|><think>\n"
40+
41+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
42+
43+
past_key_values = QFiltersCache(
44+
window_length=64,
45+
max_length=128,
46+
model_name=model_name
47+
)
48+
49+
out = model.generate(
50+
**inputs,
51+
do_sample=True,
52+
temperature=0.5,
53+
max_new_tokens=4096,
54+
past_key_values=past_key_values,
55+
streamer=streamer
56+
)
57+
```
58+
59+
## Compute Q-Filters for a new model
60+
1. Verify that the target model does not already have [pre-computed Q-Filters](https://huggingface.co/collections/nthngdy/q-filters-67a4994dcb302a3d37f3d119).
61+
2. Use the `make_filters.py` script to generate the filters. For instance:
62+
```bash
63+
python make_filters.py \
64+
--model_name deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
65+
--model_cls Qwen2ForCausalLM \
66+
--max_seq_len 2048 \
67+
--num_sequences 10 \
68+
--num_svd_samples 3000 \
69+
--dataset_name PatrickHaller/fineweb-1B \
70+
--save_mode disk \
71+
# --save_mode hub \
72+
# --save_mode hub+disk \
73+
# --hf_user_id nthngdy \
74+
--save_dir ../filters
75+
```
76+
3. For Q-Filters saved on disk, you can upload them later using this command:
77+
```bash
78+
huggingface-cli upload path_to_hf_repo path_to_local_qfilters .
79+
```

example.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
2+
from src.hf_cache import QFiltersCache, KNormCache
3+
from datasets import load_dataset
4+
5+
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
6+
model = AutoModelForCausalLM.from_pretrained(
7+
model_name,
8+
device_map="auto",
9+
low_cpu_mem_usage=True,
10+
torch_dtype="bfloat16"
11+
)
12+
13+
tokenizer = AutoTokenizer.from_pretrained(model_name)
14+
streamer = TextStreamer(tokenizer)
15+
16+
question = """What is the probability of two integers selected at random having a greatest common divisor of 1."""
17+
input_text = f"<|User|>{question}<|Assistant|><think>\n"
18+
19+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
20+
21+
# past_key_values = QFiltersCache(
22+
# window_length=64,
23+
# max_length=128,
24+
# model_name=model_name
25+
# )
26+
past_key_values = KNormCache(
27+
window_length=64,
28+
max_length=128,
29+
)
30+
31+
32+
out = model.generate(
33+
**inputs,
34+
do_sample=True,
35+
temperature=0.5,
36+
max_new_tokens=4096,
37+
past_key_values=past_key_values,
38+
streamer=streamer
39+
)

make_filters.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from tqdm import tqdm
2+
import pickle
3+
import os
4+
import argparse
5+
6+
from transformers import AutoTokenizer, AutoModelForCausalLM
7+
from datasets import load_dataset
8+
import torch
9+
10+
import modeling as M
11+
from utils import QFilters
12+
13+
14+
device = "cuda" if torch.cuda.is_available() else "cpu"
15+
16+
parser = argparse.ArgumentParser()
17+
18+
parser.add_argument("--model_name")
19+
parser.add_argument("--model_cls")
20+
parser.add_argument("--max_seq_len", type=int, default=2048)
21+
parser.add_argument("--num_sequences", type=int, default=20)
22+
parser.add_argument("--num_svd_samples", type=int, default=3000)
23+
parser.add_argument("--filter_suffix", default="")
24+
parser.add_argument("--torch_dtype", default="bfloat16")
25+
26+
parser.add_argument("--dataset_name")
27+
parser.add_argument("--dataset_config", default="default")
28+
parser.add_argument("--dataset_split", default="train[:1000]")
29+
30+
parser.add_argument("--save_mode", default="disk")
31+
parser.add_argument("--save_dir", default="")
32+
parser.add_argument("--hf_user_id", default="")
33+
34+
35+
args = parser.parse_args()
36+
37+
model_name = args.model_name
38+
model_cls = getattr(M, args.model_cls)
39+
max_seq_len = args.max_seq_len
40+
num_sequences = args.num_sequences
41+
num_svd_samples = args.num_svd_samples
42+
filter_suffix = args.filter_suffix
43+
torch_dtype = args.torch_dtype
44+
45+
dataset_name = args.dataset_name
46+
dataset_config = args.dataset_config
47+
dataset_split = args.dataset_split
48+
49+
save_mode = args.save_mode
50+
save_dir = args.save_dir
51+
hf_user_id = args.hf_user_id
52+
53+
if "disk" in save_mode and not save_dir:
54+
raise ValueError("In 'disk' or 'disk+hub' save modes, a '--save_dir' must be provided.")
55+
56+
if "hub" in save_mode and not hf_user_id:
57+
raise ValueError("In 'hub' or 'disk+hub' save modes, a '--hf_user_id' must be provided.")
58+
59+
60+
61+
tokenizer = AutoTokenizer.from_pretrained(model_name)
62+
model = model_cls.from_pretrained(
63+
model_name, attn_implementation="flash_attention_2", device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch_dtype)
64+
65+
model = model.eval()
66+
67+
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split)
68+
69+
70+
with torch.no_grad():
71+
decoder = getattr(model, "gpt_neox", getattr(model, 'model', None))
72+
svd_filters = [[] for _ in range(len(decoder.layers))]
73+
sample_count = 0
74+
num_k_heads = None
75+
76+
for i, sample in tqdm(enumerate(dataset)):
77+
78+
tokens = tokenizer(sample["text"], return_tensors="pt")
79+
if tokens.input_ids.shape[-1] < max_seq_len:
80+
continue
81+
sample_count+=1
82+
input_ids = tokens.input_ids[:, :max_seq_len].to(device)
83+
if sample_count < num_sequences:
84+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
85+
out_repr = model(input_ids).past_key_values
86+
for j, (query, key) in enumerate(out_repr):
87+
num_k_heads = key.shape[1]
88+
svd_filters[j].append(query.flatten(0, 1).cpu())
89+
else:
90+
break
91+
92+
del model
93+
94+
for f_id, el in enumerate(svd_filters):
95+
stacked_el = torch.stack(el, 1).flatten(1, 2)
96+
idx = torch.argsort(torch.rand(stacked_el.shape[1], device=stacked_el.device))[:num_svd_samples]
97+
stacked_el = stacked_el[:, idx].cuda()
98+
u,s,vh = torch.linalg.svd(stacked_el.float())
99+
svd_sign = ((u[..., 0]>0).float().mean(-1) > 0.5).float()*2-1
100+
svd_filter_q = -svd_sign[:, None] * vh[..., 0, :]
101+
svd_filters[f_id] = svd_filter_q.reshape(num_k_heads, -1, svd_filter_q.shape[-1]).mean(-2)
102+
103+
svd_filters = torch.nn.Parameter(torch.stack(svd_filters))
104+
q_filters = QFilters(*svd_filters.shape)
105+
q_filters.q_filters = svd_filters
106+
107+
model_suffix = model_name.split("/")[-1]
108+
filter_savename = f"{model_suffix}_qfilt{'_' + filter_suffix if filter_suffix else ''}"
109+
if "disk" in save_mode:
110+
q_filters.save_pretrained(f"{save_dir}/{filter_savename}")
111+
if "hub" in save_mode:
112+
q_filters.push_to_hub(f"{hf_user_id}/{filter_savename}")

modeling/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from modeling.modeling_llama import LlamaForCausalLM
2+
from modeling.modeling_qwen2 import Qwen2ForCausalLM
3+
from modeling.modeling_olmo2 import Olmo2ForCausalLM
4+
from modeling.modeling_phi3 import Phi3ForCausalLM
5+
from modeling.modeling_mistral import MistralForCausalLM
585 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)