-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathcache.py
1478 lines (1252 loc) · 54.6 KB
/
cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import regex as re
from abc import ABC, abstractmethod
from collections import Counter
from prompt_compression import get_prompt_compressor_constructor
from quantization_utils import quantize_tensor, dequantize_tensor
import argparse
import torch
import torch.nn as nn
def add_cache_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group("cache_args")
# KV-Cache Kwargs
group.add_argument(
"--max_cache_length",
type=float,
default=[1.0],
nargs="+",
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
)
group.add_argument(
"--cache_bits",
default=None,
type=int,
choices=[2, 4, 8],
help="Quantize the cache to reduce memory usage.",
)
# ScissorHands (https://arxiv.org/abs/2305.17118) recommends large caches at higher levels --> funnel
# Yet PyramidKV (https://arxiv.org/abs/2406.02069) recommends the opposite --> pyramid shaped
group.add_argument(
"--cache_length_pattern",
default="tile",
choices=["tile", "repeat", "funnel", "pyramid"],
)
strategies = [
"full",
"random",
"recent_global",
"heavy_hitter",
"l2",
"hybrid",
"keep_it_odd",
]
debug_strategies = [f"debug_{strategy}" for strategy in strategies]
strategies.extend(debug_strategies)
group.add_argument(
"--cache_strategy",
default=["full"],
nargs="+",
choices=strategies,
)
group.add_argument(
"--cache_strategy_pattern",
default="tile",
choices=["tile", "repeat"],
help="How to apply the cache_strategy across layers.",
)
# Dealing with Long Prompts
parser.add_argument(
"--feed_long_prompts",
default=False,
action="store_true",
help="If True and |prompt| > max_cache_length, prefill with prompt[:max_cache_length], and feed prompt[max_cache_length:] sequentially.",
)
group.add_argument(
"--prompt_compression_strategy", # This doesn't matter if args.feed_long_prompts is True
default=["recent_global"],
nargs="+",
help="If |prompt| exceeds max_cache_length, we need to specify a strategy for compressing it to max_cache_length.",
)
# Optional Cache Kwargs depending on cache_strategy
group.add_argument(
"--global_tokens",
default=1,
type=int,
help="The number of initial tokens to always include in the KV-Cache. \
If using recent_global strategy, the actual window size becomes max_cache_length - global_tokens.",
)
# Locality
group.add_argument(
"--recent_window", # NB: for KVCacheRecentGlobal, recent_window is implicitly set to self.max_cache_length - self.global_tokens.
default=10, # 10 is default specified in ScissorHands paper ("r" in Algorithm 2).
type=float, # If < 1, it is a fraction of max_cache_length.
help="The number of recently generated tokens to always spare from eviction.",
)
# Heavy Hitter Hyperparameters (--cache_strategy == "heavy_hitter")
group.add_argument( ## See Algorithm 2 in ScissorHands arxiv.org/abs/2305.17118
"--history_window_size", # Equivalent to "m" in Algorithm 2. 400 is default specified in paper.
default=1, # If 1, we accumulate the full history in one slot (effectively, a history_window_size of ∞)
type=int,
help="The number of past tokens to consider when computing 'Heavy Hitters' in the KV-Cache.",
)
group.add_argument(
"--attn_thresholding",
default=False,
action="store_true",
help="Whether to accumulate number of times a token was unimportant (binary) versus raw un-normalized probabilities. If true, more memory efficient.",
)
# Hybrid, e.g., FastGen, specific hyperparameters (--cache_strategy == "hybrid")
parser.add_argument(
"--min_recovery_frac",
default=0.9,
type=float,
help="Mininum fraction of recovered attentions (|compressed_attn - uncompressed_attn| < epsilon). The lower the value, the higher the compression.",
)
def cache_compatibility(args):
for length, cache_strat, prompt_strat in zip(
args.max_cache_length, args.cache_strategy, args.prompt_compression_strategy
):
if cache_strat == "heavy_hitter":
assert (
prompt_strat == "heavy_hitter"
), "Heavy Hitter cache strategy currently must be run with --prompt_compression_strategy heavy_hitter to return attention."
if cache_strat == "hybrid":
assert (
not args.compile
), "Hybrid cache strategy is currently not supported with compile=True."
if cache_strat in {"full", "hybrid"}:
assert (
length == 1.0
), f"{cache_strat} cache strategy only supports max_cache_length=1.0."
print("The cache argument values you provided appear compatible with each other!")
def create_window_attention_mask(seq_len, window_size, device, global_tokens: int = 4):
# Initialize the mask tensor with zeros
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
# Add global tokens
mask[:, :global_tokens] = True
for i in range(seq_len):
mask[i, max(0, i + 1 - window_size) : i + 1] = True
return mask
class KVCache(ABC, nn.Module):
# Define which hyperparameters are relevant for the cache.
# Override as needed for sub-classes.
relevant_kwargs = [
"max_cache_length",
"global_tokens",
"max_seq_length",
"cache_bits",
]
def __init__(
self,
max_batch_size,
n_heads,
head_dim,
dtype=torch.bfloat16,
head_specific=False, # IFF True, heads can contain different tokens, e.g., cache evictions are "head_specific".
variable_length=False, # IFF True, the number of tokens inserted can vary across heads. Only true for KVCacheHybrid.
**kwargs,
):
super().__init__()
# Assign each kwarg as an attribute of the class
for key, value in kwargs.items():
setattr(self, key, value)
self.cache_shape = (max_batch_size, n_heads, self.max_cache_length, head_dim)
# Quantization: 2, 4, 8 bits supported.
self.quantize = self.cache_bits is not None
self.n_bit = self.cache_bits
self.quantization_axis = 2 # Quantize the cache along the sequence length axis
k_cache = torch.zeros(self.cache_shape, dtype=dtype)
v_cache = torch.zeros(self.cache_shape, dtype=dtype)
if self.quantize:
k_cache, k_scales, k_zeros = quantize_tensor(
k_cache, n_bit=self.n_bit, axis=self.quantization_axis
)
v_cache, v_scales, v_zeros = quantize_tensor(
v_cache, n_bit=self.n_bit, axis=self.quantization_axis
)
self.register_buffer("k_scales", k_scales)
self.register_buffer("v_scales", v_scales)
self.register_buffer("k_zero_points", k_zeros)
self.register_buffer("v_zero_points", v_zeros)
self.register_buffer("k_cache", k_cache)
self.register_buffer("v_cache", v_cache)
# Can we evict different tokens for different heads?
# If the answer is yes, we need to store self.pos for each head.
self.n_heads = n_heads
self.head_specific = head_specific
self.register_buffer(
"pos", # Track pos to keep track of the original positions of each item in cache.
torch.full(
(
max_batch_size,
n_heads if head_specific else 1,
self.max_cache_length,
),
-1,
dtype=torch.int,
),
)
self.register_buffer(
"cache_cts",
torch.zeros((n_heads if variable_length else 1), dtype=torch.int),
)
# We need to use a mask since not all heads have same number of tokens. We can't simply truncate.
# 1 dimension stands for query dimension, which will always be 1 (next token) for KV cache attention.
kv_mask_shape = (max_batch_size, n_heads, 1, self.max_cache_length)
self.register_buffer("mask", torch.zeros(kv_mask_shape, dtype=torch.bool))
def reset(self):
"""
Resets the cache to its initial state for a new example.
NB: For more performance, don't reset k_cache and v_cache since we overwrite them in update.
"""
self.k_cache.zero_()
self.v_cache.zero_()
self.mask.zero_()
self.cache_cts.zero_()
self.pos.fill_(-1)
def return_attn(self):
"""
Returns whether the cache requires attention weights for cache management.
"""
return False
def memory_usage(self):
tensors = []
for obj in vars(self).values():
if torch.is_tensor(obj):
tensors.append(obj)
elif isinstance(obj, dict):
for vv in obj.values():
if torch.is_tensor(vv):
tensors.append(vv)
return sum([t.element_size() * t.numel() for t in tensors]) / (1024**3)
def compute_statistics(self, seq_len):
"""
Computes statistics about the cache.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
"""
return {
"compression_ratio": self.compression_ratio(seq_len).item(),
"cache_memory_gb": self.memory_usage(),
}
def compression_ratio(self, seq_len):
"""
Returns the compression ratio of the cache.
"""
# Final token isn't passed to cache so must -1 from seq_len
n = seq_len - 1
assert torch.all(self.cache_cts <= self.max_cache_length)
cache_size = self.cache_cts.clone().float()
if self.n_bit is not None:
cache_size *= self.n_bit / 16.0
return ((n - cache_size) / n).mean()
def quantize_cache(self):
if self.quantize:
self.k_cache, self.k_scales, self.k_zero_points = quantize_tensor(
self.k_cache, n_bit=self.n_bit, axis=self.quantization_axis
)
self.v_cache, self.v_scales, self.v_zero_points = quantize_tensor(
self.v_cache, n_bit=self.n_bit, axis=self.quantization_axis
)
def dequantize_cache(self):
if self.quantize:
self.k_cache = dequantize_tensor(
self.k_cache,
self.k_scales,
self.k_zero_points,
self.cache_shape,
n_bit=self.n_bit,
axis=self.quantization_axis,
)
self.v_cache = dequantize_tensor(
self.v_cache,
self.v_scales,
self.v_zero_points,
self.cache_shape,
n_bit=self.n_bit,
axis=self.quantization_axis,
)
def return_kv_cache(self):
return self.k_cache, self.v_cache, self.mask
def update_kv(self, input_pos, k_val, v_val, is_prefill, **kwargs):
"""
Cache update logic.
Takes in the input positions and the corresponding k and v values.
Modifies self.pos, self.k_cache, self.v_cache place.
Returns a tensor indicating the number of tokens inserted - number of tokens evicted.
None is equivalent to 0.
"""
# Dequantize the cache before updating
self.dequantize_cache()
if is_prefill:
num_insertions = self._prefill_update(input_pos, k_val, v_val, **kwargs)
else:
num_insertions = self._decoding_update(input_pos, k_val, v_val, **kwargs)
self.cache_cts += num_insertions[: len(self.cache_cts)]
# [Optional] Update any internal model state
k, v, mask = (
self.return_kv_cache()
) # By default, just returns self.k_cache, self.v_cache, self.mask
# Quantize the cache after updating
self.quantize_cache()
return k, v, mask
def update_state(self, *args, **kwargs):
"""
Optional method to update cache-specific internal state (excludes self.k_cache, self.v_cache, and self.pos).
"""
pass
def _decoding_update(self, input_pos, k_val, v_val, **kwargs):
"""
Decoding logic for the cache.
"""
eviction_idx = self._eviction_idx(input_pos)
# Num insertions means we inserted into an unfilled slot (previous pos == -1)
# They should be all the same unless variable_length = True
num_insertions = (
(self.pos.gather(2, eviction_idx.view(1, -1, 1)).squeeze() == -1)
.int()
.view(-1)
)
self._fill(input_pos, k_val, v_val, fill_idxs=eviction_idx)
return num_insertions
def _eviction_idx(self, input_pos):
scores = self._token_importances(input_pos)
if scores.ndim == 1:
scores = scores.unsqueeze(0)
# Protect global tokens
scores[:, : self.global_tokens] = float("inf")
# Evict unfilled slots (pos == -1)
scores.masked_fill_(self.pos.view(scores.shape) == -1, float("-inf"))
# Evict least important token
return torch.argmin(scores, dim=-1)
def _prefill_update(self, input_pos, k_val, v_val, **kwargs):
input_pos = input_pos.int()
fill_idxs = torch.arange(input_pos.shape[-1], device=input_pos.device)
self._fill_contiguous(input_pos, k_val, v_val, fill_idxs=fill_idxs)
# Saves a fraction of time to return as a tensor rather than integer
return torch.tensor(
[input_pos.shape[-1]], dtype=torch.int, device=input_pos.device
)
def _fill_contiguous(
self, input_pos, k_val, v_val, fill_idxs: torch.Tensor | int, **kwargs
):
"""
A simple utility to fill the cache and pos.
"""
self.pos[:, :, fill_idxs] = input_pos
self.k_cache[:, :, fill_idxs, :] = k_val
self.v_cache[:, :, fill_idxs, :] = v_val
update_mask = kwargs.get("update_mask", True)
if update_mask:
self.mask[:, :, :, fill_idxs] = True
@abstractmethod
def _fill(self, input_pos, k_val, v_val, fill_idxs: torch.Tensor | int, **kwargs):
"""
Modifies the cache in-place with key-value pairs at given fill_indices.
Args:
fill_indices (torch.Tensor): The indices specifying the positions to fill in the cache.
input_pos (torch.Tensor): The input positions corresponding to the fill_indices.
k_val (torch.Tensor): The key values to fill in the fill_indices slots.
v_val (torch.Tensor): The value values to fill in the fill_indices slots.
Returns:
None
"""
raise NotImplementedError
def update_attn_history(self, attn):
"""
Update the attention history with the most recent attention weights.
"""
raise Exception(
f"{self.__class__.__name__} requested return_attn=True but has not yet implemented a update_attn_history function."
)
class KVCacheHeadConstant(KVCache):
def __init__(
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
):
super().__init__(
max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs
)
def _fill(self, input_pos, k_val, v_val, fill_idxs: torch.Tensor | int, **kwargs):
return self._fill_contiguous(input_pos, k_val, v_val, fill_idxs, **kwargs)
class KVCacheHeadSpecific(KVCache):
def __init__(
self,
max_batch_size,
n_heads,
head_dim,
dtype=torch.bfloat16,
variable_length=False,
**kwargs,
):
super().__init__(
max_batch_size,
n_heads,
head_dim,
dtype,
head_specific=True,
variable_length=variable_length,
**kwargs,
)
def _fill(self, input_pos, k_val, v_val, fill_idxs: torch.Tensor | int, **kwargs):
"""
Modifies the cache in-place with key-value pairs at given fill_indices.
Args:
fill_indices (torch.Tensor): The indices specifying the positions to fill in the cache.
input_pos (torch.Tensor): The input positions corresponding to the fill_indices.
k_val (torch.Tensor): The key values to fill in the fill_indices slots.
v_val (torch.Tensor): The value values to fill in the fill_indices slots.
Returns:
None
"""
# fill_indices [num_heads] or [1]
# input_pos [seq_len] or [num_heads, seq_len]
# k_val, v_val [batch_size, n_heads, seq_len, head_dim]
assert input_pos.shape[-1] == k_val.shape[2] == v_val.shape[2]
# input_pos is either [seq_len] or [num_heads, seq_len]
pos_fill_indices = fill_idxs.view(1, -1, 1)
cache_fill_indices = fill_idxs.view(1, len(fill_idxs), 1, 1).expand(
1, k_val.shape[1], 1, k_val.shape[-1]
)
input_pos = input_pos.view(1, -1, 1).expand(1, k_val.shape[1], 1).int()
self.pos.scatter_(2, pos_fill_indices, input_pos.int())
self.k_cache.scatter_(2, cache_fill_indices, k_val)
self.v_cache.scatter_(2, cache_fill_indices, v_val)
update_mask = kwargs.get("update_mask", True)
if update_mask:
self.mask.scatter_(3, fill_idxs.view(1, -1, 1, 1), True)
class KVCacheFull(KVCacheHeadConstant):
def __init__(
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
):
self.global_tokens = 0 # No global tokens for full cache (they are all global)
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
def _eviction_idx(self, input_pos):
# Select the first unfilled slot
return self.pos[0, 0].argmin().view(1)
class KVCacheRandom(KVCacheHeadConstant):
relevant_kwargs = [
"max_cache_length",
"max_seq_length",
"cache_bits",
"global_tokens",
"recent_window",
]
def __init__(
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
):
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
def _token_importances(self, input_pos):
# Assign random importance
scores = torch.rand(self.max_cache_length, device=input_pos.device)
# Protect Recent Tokens
scores[self.pos[0, 0] >= input_pos - self.recent_window] = float("inf")
return scores
class KVCacheRecentGlobal(KVCacheHeadConstant):
relevant_kwargs = [
"max_cache_length",
"max_seq_length",
"cache_bits",
"global_tokens",
# NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens.
]
def __init__(
self,
max_batch_size,
n_heads,
head_dim,
dtype=torch.bfloat16,
**kwargs,
):
super().__init__(
max_batch_size,
n_heads,
head_dim,
dtype,
**kwargs,
)
def _eviction_idx(self, input_pos):
return (
torch.argmin(self.pos[:, :, self.global_tokens :], dim=-1)
+ self.global_tokens
).view(1)
class KVCacheL2(KVCacheHeadSpecific):
relevant_kwargs = [
"max_cache_length",
"max_seq_length",
"cache_bits",
"global_tokens",
"recent_window",
]
def __init__(
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
):
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
key_norm_shape = (max_batch_size, n_heads, self.max_cache_length)
self.register_buffer("key_norm", torch.zeros(key_norm_shape, dtype=dtype))
def reset(self):
super().reset()
self.key_norm.zero_()
def _decoding_update(self, input_pos, k_val, v_val, **kwargs):
# Same as KVCacheHeadSpecific, but we also update the L2 norm of the keys for decoding
fill_indices = self._eviction_idx(input_pos)
num_insertions = (
(self.pos.gather(2, fill_indices.view(1, -1, 1)).squeeze() == -1)
.int()
.view(-1)
)
self._fill(input_pos, k_val, v_val, fill_idxs=fill_indices)
# Custom code for L2 -- store the key vector norms
key_norm = torch.linalg.vector_norm(k_val, ord=2, dim=-1)
self.key_norm.scatter_(2, fill_indices.view(1, -1, 1), key_norm)
return num_insertions
def _token_importances(self, input_pos):
# 1. Lowest l2 norms have high importance (- self.key_norm)
# 2. Lowest score needs to be > -1 : we evict unfilled tokens first (+ max value such that min score is 0)
# 3. Save Recent Window (+ inf)
return (
(self.key_norm.max() - self.key_norm)
.masked_fill(self.pos >= input_pos - self.recent_window, float("inf"))
.squeeze(0)
)
def update_state(self, input_pos, k_val, v_val, is_prefill, attn, **kwargs):
pass
# We will update the L2 norm of the keys for decoding in _decoding_update
# We do this during the update bc/ we have access to the fill indices of the tokens we are inserting
if is_prefill: # For prefill, we cache the norm for the key cache at the time
self.key_norm.copy_(torch.linalg.vector_norm(self.k_cache, ord=2, dim=-1))
class KVCacheHeavyHitter(KVCacheHeadSpecific):
# This class mostly follows the logic in ScissorHands (https://arxiv.org/abs/2305.17118)
# But it is very similar to other Heavy Hitter methods (H20, PyramidKV, etc.)
relevant_kwargs = [
"max_cache_length",
"max_seq_length",
"cache_bits",
"global_tokens",
"history_window_size",
"recent_window",
"attn_thresholding",
]
def __init__(
self,
max_batch_size,
n_heads,
head_dim,
dtype=torch.bfloat16,
variable_length=False,
**kwargs,
):
super().__init__(
max_batch_size,
n_heads,
head_dim,
dtype,
variable_length,
**kwargs,
)
# Initialize a buffer for the attention histories
history_num_shape = (
max_batch_size,
n_heads,
self.max_cache_length,
self.history_window_size,
)
history_denom_shape = (
max_batch_size,
n_heads,
self.max_cache_length,
)
# If attn_thresholding, we store a binary indicator of whether the attention >= uniform attention
# If not, we store the raw attention values
# If history_window_size = 1, we accumulate the full history in one slot so we need a dtype with large range
history_num_dtype = (
torch.bool
if self.attn_thresholding
else torch.float64
if self.history_window_size == 1
else dtype
)
self.register_buffer(
"attn_history_num",
torch.zeros(history_num_shape, dtype=history_num_dtype),
)
# Ideally, we could use the self.pos to track the number of times a token has been attended to
# But any change to cache management or how self.pos is stored would break this.
self.register_buffer(
"attn_history_denom", torch.zeros(history_denom_shape, dtype=torch.int32)
)
self.register_buffer("attn_counter", torch.zeros((1,), dtype=torch.int64))
def reset(self):
super().reset()
self.attn_history_num.zero_()
self.attn_history_denom.zero_()
self.attn_counter.zero_()
def return_attn(self) -> bool:
return True
def update_state(self, input_pos, k_val, v_val, is_prefill, attn, **kwargs):
"""
Insert the most recent attention into the history buffer.
If self.attn_thresholding = True, insert a binary indicator of whether the attention >= uniform attention.
"""
# Resize attn to be max cache length with zero padding if need be
seq_len = attn.shape[-1]
if (
is_prefill and attn.ndim == 4
): # Prefill, we may receive the full attention map and have to average across queries
# Normalize using input_pos to only count non-zero attentions bc/ of causal mask
attn = attn.squeeze(0).sum(dim=1) / (seq_len - input_pos)
attn = attn.view(1, self.n_heads, -1, 1)
attn = (attn >= 1 / self.cache_cts).int() if self.attn_thresholding else attn
# Torch.compile doesn't support dyanmic slicing so we need to zero-pad to full dimension
padding = max(self.max_cache_length - seq_len, 0)
pad_attn = torch.zeros(
1, self.n_heads, padding, 1, dtype=attn.dtype, device=attn.device
)
attn = torch.cat([attn, pad_attn], dim=2)
history_idx = self.attn_counter % self.history_window_size
if self.history_window_size == 1: # We consider the full history
self.attn_history_num[:, :, :, history_idx] += attn
else:
self.attn_history_num[:, :, :, history_idx] = attn
self.attn_history_denom += 1
self.attn_counter += 1
def _eviction_idx(self, input_pos):
# Identify the token with consistently "lowest" attention
numerator = self.attn_history_num.sum(dim=-1).float()
if (
self.history_window_size == 1
): # We use the full history (there is no clamping around a fixed window)
denominator = self.attn_history_denom.clamp_min(1)
else:
# The denominator is the number of times this token's history has been recorded
# We only record most self.history_window_size recent scores so need to clamp it
denominator = self.attn_history_denom.clamp(1, self.history_window_size)
avg_attn = numerator / denominator
# Save the global & most recent tokens from being evicted
avg_attn.masked_fill_(
torch.logical_or(
self.pos < self.global_tokens,
self.pos >= input_pos - self.recent_window,
),
1.0,
)
avg_attn.masked_fill_(self.pos == -1, 0.0)
fill_idxs = avg_attn.argmin(dim=-1).squeeze()
# Zero-out the attention history for these newly inserted slots
num_fill = fill_idxs.view(1, -1, 1, 1).expand(
1, -1, 1, self.attn_history_num.shape[-1]
)
denom_fill = fill_idxs.view(1, -1, 1)
self.attn_history_num.scatter_(
2, num_fill, torch.zeros_like(num_fill, dtype=self.attn_history_num.dtype)
)
self.attn_history_denom.scatter_(
2, denom_fill, torch.zeros_like(denom_fill, dtype=torch.int32)
)
return fill_idxs
class KVCacheHybrid(KVCacheHeavyHitter):
# This class mostly follows the logic in FastGen (https://arxiv.org/abs/2310.01801)
# Yet, it allows for a wider set of hybrid strategies to be considered during profiling.
relevant_kwargs = [
"max_cache_length",
"max_seq_length",
"cache_bits",
"global_tokens",
"token_ids",
"min_recovery_frac",
"hybrid_strategies",
]
def __init__(
self,
max_batch_size,
n_heads,
head_dim,
dtype=torch.bfloat16,
**kwargs,
):
self.attn_thresholding = False
self.history_window_size = 400 # Default value for ScissorHands
self.recent_window = (
None # Dummy value: Recent windows are defined per attention head
)
super().__init__(
max_batch_size,
n_heads,
head_dim,
dtype,
variable_length=True,
**kwargs,
)
self.requires_special = any(
["special" in strat["strategy"] for strat in self.hybrid_strategies]
)
mask_shape = (max_batch_size, n_heads, self.max_cache_length)
if self.requires_special:
special_ids = [torch.tensor(ids) for ids in kwargs["token_ids"]["special"]]
self.register_buffer("special_ids", torch.nested.nested_tensor(special_ids))
# As well as a mask showing where special ids are in the KV cache
# We store this to avoid re-computing the mask every time and having to store all input_ids
self.register_buffer(
"special_mask", torch.zeros(mask_shape, dtype=torch.bool)
)
self.register_buffer("num_special", torch.zeros((1,), dtype=torch.int))
self.requires_punc = any(
["punc" in strat["strategy"] for strat in self.hybrid_strategies]
)
if self.requires_punc:
# Store the punctuation vocabulary ids
punc_ids = torch.Tensor(kwargs["token_ids"]["punctuation"])
self.register_buffer("punc_ids", punc_ids)
# As well as a mask showing where punctuation ids are in the KV cache
# We store this to avoid re-computing the mask every time and having to store input_ids
self.register_buffer("punc_mask", torch.zeros(mask_shape, dtype=torch.bool))
self.register_buffer("num_punc", torch.zeros((1,), dtype=torch.int))
self.requires_heavy_hitter = self._init_requires_heavy_hitter()
# We need to use a mask since not all heads have same number of tokens. We can't simply truncate.
# 1 dimension stands for query dimension, which will always be 1 (next token) for KV cache attention.
kv_mask_shape = (max_batch_size, n_heads, 1, self.max_cache_length)
self.register_buffer("mask", torch.zeros(kv_mask_shape, dtype=torch.bool))
def return_attn(self):
return self.requires_heavy_hitter
def _init_requires_heavy_hitter(self):
return any(
["heavy_hitter" in strat["strategy"] for strat in self.hybrid_strategies]
)
def _eviction_idx_for_head(
self,
head_idx,
input_pos,
recent_window,
apply_heavy_hitter=False,
apply_window=False,
apply_special=False,
apply_punc=False,
):
if apply_heavy_hitter:
numerator = (
self.attn_history_num[:, head_idx, : self.cache_cts[head_idx]]
.sum(dim=-1)
.float()
)
if self.history_window_size == 1: # Use full history
denominator = self.attn_history_denom[
:, head_idx, : self.cache_cts[head_idx]
]
else:
# The denominator is the number of times this token's history has been recorded
# We only record most self.history_window_size recent scores so need to clamp it
denominator = self.attn_history_denom[
:, head_idx, : self.cache_cts[head_idx]
].clamp_max(self.history_window_size)
score = numerator / denominator
else:
score = self.pos[:, head_idx, : self.cache_cts[head_idx]].clone().float()
save_mask = torch.zeros_like(score, dtype=torch.bool)
save_mask[:, : self.global_tokens] = 1
if apply_special:
save_mask |= self.special_mask[:, head_idx, : self.cache_cts[head_idx]]
if apply_punc:
save_mask |= self.punc_mask[:, head_idx, : self.cache_cts[head_idx]]
if apply_window:
window_mask = (
self.pos[:, head_idx, : self.cache_cts[head_idx]]
> input_pos - recent_window
)
save_mask |= window_mask
score.masked_fill_(save_mask, float("inf"))
fill_idx = score.argmin(dim=-1)
return fill_idx
def _select_fill_idx(self, strategy, head_idx, input_pos, is_punc: bool = False):
def _end_idx():
# We need to clone because self.cache_cts will be incremented later and we don't want to have fill_idx as a mutable reference
return min(self.max_cache_length - 1, self.cache_cts[head_idx].clone())
strategy = self.hybrid_strategies[strategy]
name = strategy["strategy"]
# If is punctuation token and we are preserving, we always add it to the end
if "punc" in name and is_punc:
return _end_idx(), False
if name == "full":
return _end_idx(), False
# Every strategy has a budget for global tokens
budget = torch.tensor(
[self.global_tokens], dtype=torch.int, device=input_pos.device
)
if "special" in name:
budget += self.num_special
if "punc" in name:
budget += self.num_punc
if "window" in name:
budget += round(strategy["recent_window"] * self.max_cache_length)
if "heavy_hitter" in name:
budget += round(strategy["heavy_hitter_frac"] * self.max_cache_length)
eviction_required = self.cache_cts[head_idx] >= budget
if not eviction_required:
return _end_idx(), False
if "heavy_hitter" in name or "window" in name:
recent_window = round(
strategy.get("recent_window", 0) * self.max_cache_length
)
fill_idx = self._eviction_idx_for_head(
head_idx,
input_pos,
recent_window=recent_window,
apply_heavy_hitter="heavy_hitter" in name,
apply_window="window" in name,
apply_punc="punc" in name,
apply_special="special" in name,
)
return fill_idx, True # Eviction Required
# If we reach here, we have a hybrid strategy that is not window, heavy hitter, or full
assert "punc" in name or "special" in name, f"Invalid hybrid strategy {name}"
return None, False
def reset(self):
super().reset()
self.cache_strategies.fill = None # Free up memory temporarily
self.requires_heavy_hitter = self._init_requires_heavy_hitter()
if hasattr(self, "special_mask"):
self.special_mask.zero_()
self.num_special.zero_()
if hasattr(self, "punc_mask"):
self.punc_mask.zero_()
self.num_punc.zero_()
def _decoding_update(self, input_pos, k_val, v_val, **kwargs):
input_ids = kwargs.get("input_ids")
n_heads = k_val.shape[1]
is_punc = (
torch.isin(input_ids, self.punc_ids) if hasattr(self, "punc_ids") else False
)
# If fill idx is None we place value at the back (which is truncated for attention calculation anyway)
fill_indices = torch.full(
(n_heads,),
self.max_cache_length - 1,
dtype=torch.int64,
device=k_val.device,
)
cache_ct_incr = torch.zeros_like(fill_indices)
for head_idx, strategy in enumerate(self.cache_strategies):
fill_idx, eviction_required = self._select_fill_idx(
strategy, head_idx, input_pos, is_punc=is_punc
)
if fill_idx is None:
continue
fill_indices[head_idx] = fill_idx
if eviction_required:
if self.requires_heavy_hitter:
# Reset attention history since we've inserted a new token
self.attn_history_num[:, head_idx, fill_idx, :].fill_(0)
self.attn_history_denom[:, head_idx, fill_idx].fill_(0)
else:
# Increment cache_ct_incr for heads that have grown (no eviction)
cache_ct_incr[head_idx] = 1
# We can't use all fill indices to bulk assign mask because some fill_indices are dummies (self.max_cache_length - 1)