-
Notifications
You must be signed in to change notification settings - Fork 405
/
Copy pathtriton_merge_attn_states.py
186 lines (168 loc) · 5.88 KB
/
triton_merge_attn_states.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_merge_attn_states.py"""
from typing import Optional
import torch
import triton
import triton.language as tl
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_triton(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = tl.exp(p_lse) + tl.exp(s_lse)
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
out,
mask=head_mask,
)
@triton.autotune(
configs=[
triton.Config({"num_warps": 2}, num_warps=2),
triton.Config({"num_warps": 4}, num_warps=4),
triton.Config({"num_warps": 8}, num_warps=8),
triton.Config({"num_warps": 16}, num_warps=16),
triton.Config({"num_warps": 32}, num_warps=32),
],
key=["PADDED_HEAD_SIZE", "OUTPUT_LSE"],
use_cuda_graph=False,
)
@triton.jit
def merge_attn_states_kernel_opt(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(
prefix_lse + head_idx * num_tokens + token_idx, cache_modifier=".cg"
)
s_lse = tl.load(
suffix_lse + head_idx * num_tokens + token_idx, cache_modifier=".cg"
)
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = p_se + s_se
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
cache_modifier=".cg",
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
cache_modifier=".cg",
)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
out,
mask=head_mask,
)