Skip to content

Commit 116d97d

Browse files
happierpighappierpig
and
happierpig
authored
feat: add functional per-head FP8 quantization for FA3 (#1033)
This PR adds FP8 support in FA3 to speed up compute-bound prefill kernels. It follows up on #869. ## 1. Bug fixes - Fixed deadlock, illegal memory access, and wrong results for varied `num_heads` and `seq_len`. - Covered by [unit tests](https://github.com/happierpig/flashinfer-ai/blob/082cf2402fe08332a147f447ef54b75332a20f57/tests/test_hopper_fp8_attention.py). ## 2. New features - Enabled FP8 in-kernel transpose logic of `mainloop_sparse.cuh`. - FP8 now works in: - `BatchPrefillWithPagedKVCache` - `BlockSparseAttentionWrapper`: support _**sparse**_ and _**quantized**_ attention ## 3. Python JIT interface - Exposed kernels to Python: - `BlockSparseAttentionWrapper` - `single_prefill_with_kv_cache` - Migrated tests and benchmarks to Python scripts: - [tests/test_hopper_fp8_attention.py](https://github.com/happierpig/flashinfer-ai/blob/082cf2402fe08332a147f447ef54b75332a20f57/tests/test_hopper_fp8_attention.py) - [benchmarks/bench_hopper_fp8_attention.py](https://github.com/happierpig/flashinfer-ai/blob/082cf2402fe08332a147f447ef54b75332a20f57/benchmarks/bench_hopper_fp8_attention.py) **Note:** Performance is on par with #869. Need tuning. cc @yzh119 --------- Co-authored-by: happierpig <[email protected]>
1 parent 3f76969 commit 116d97d

23 files changed

+1417
-1691
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import triton
3+
4+
import flashinfer
5+
6+
7+
def bench_single_prefill(seq_len, num_heads, causal, head_dim):
8+
num_qo_heads = num_kv_heads = num_heads
9+
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
10+
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
11+
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
12+
13+
sm80_ms, sm90_ms = (
14+
triton.testing.do_bench(
15+
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
16+
q, k, v, causal=causal, backend=backend
17+
),
18+
warmup=100,
19+
rep=1000,
20+
)
21+
for backend in ["fa2", "fa3"]
22+
)
23+
24+
q = torch.randn(
25+
seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
26+
).to(dtype=torch.float8_e4m3fn)
27+
k = torch.randn(
28+
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
29+
).to(dtype=torch.float8_e4m3fn)
30+
v = torch.randn(
31+
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
32+
).to(dtype=torch.float8_e4m3fn)
33+
34+
fp8_sm90_ms = triton.testing.do_bench(
35+
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
36+
q, k, v, causal=causal, backend="fa3", o_dtype=torch.half
37+
),
38+
warmup=100,
39+
rep=1000,
40+
)
41+
42+
def flops(ms):
43+
if causal:
44+
return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
45+
else:
46+
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
47+
48+
print(
49+
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
for seq_len in [4096, 8192, 16384]:
55+
for num_heads in [24, 32]:
56+
for causal in [True, False]:
57+
for head_dim in [64, 128, 256]:
58+
bench_single_prefill(seq_len, num_heads, causal, head_dim)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
2+
#include "batch_prefill_sm90_config.inc"
3+
4+
namespace flashinfer {
5+
6+
{% for same_scheduler_for_all_heads in ["true", "false"] %}
7+
template cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched
8+
<{{ head_dim_qk }},
9+
{{ mask_mode }},
10+
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
11+
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
12+
{{ variant_name }}, PagedParams>(PagedParams& params, cudaStream_t stream);
13+
{% endfor %}
14+
15+
}; // namespace flashinfer
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
// TODO: Not implemented yet

csrc/batch_prefill_fp8_sm90.cu

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) 2023 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <flashinfer/attention/mask.cuh>
18+
#include <flashinfer/attention/scheduler.cuh>
19+
#include <flashinfer/layout.cuh>
20+
#include <flashinfer/math.cuh>
21+
#include <optional>
22+
23+
#include "batch_prefill_sm90_config.inc"
24+
#include "pytorch_conversion_utils.h"
25+
#include "pytorch_extension_utils.h"
26+
27+
namespace flashinfer {
28+
29+
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
30+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
31+
cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream);
32+
33+
} // namespace flashinfer
34+
35+
using namespace flashinfer;
36+
37+
at::Tensor BatchPrefillWithKVCacheSM90Plan(
38+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
39+
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
40+
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
41+
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
42+
int64_t head_dim_vo, bool causal) {
43+
size_t float_workspace_size_in_bytes =
44+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
45+
size_t int_workspace_size_in_bytes =
46+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
47+
48+
flashinfer::PrefillPlanSM90Info plan_info;
49+
50+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
51+
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
52+
53+
cudaError_t status =
54+
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
55+
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
56+
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
57+
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
58+
batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
59+
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
60+
61+
TORCH_CHECK(status == cudaSuccess,
62+
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
63+
64+
return vec_to_tensor(plan_info.ToVector());
65+
}
66+
67+
void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
68+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
69+
at::Tensor q, at::Tensor k, at::Tensor v,
70+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
71+
std::optional<at::Tensor> maybe_lse,
72+
int64_t mask_mode_code, int64_t layout,
73+
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
74+
return; // TODO: Implement this function
75+
}
76+
77+
void BatchPrefillWithPagedKVCacheSM90Run(
78+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
79+
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
80+
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
81+
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
82+
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
83+
PrefillPlanSM90Info plan_info;
84+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
85+
86+
if (maybe_lse) {
87+
const auto& lse = *maybe_lse;
88+
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
89+
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
90+
}
91+
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
92+
int64_t num_kv_heads, page_size;
93+
int64_t head_dim_qk = q.size(2);
94+
int64_t head_dim_vo = paged_v_cache.size(3);
95+
if (kv_layout == QKVLayout::kHND) {
96+
num_kv_heads = paged_k_cache.size(1);
97+
page_size = paged_k_cache.size(2);
98+
} else {
99+
page_size = paged_k_cache.size(1);
100+
num_kv_heads = paged_k_cache.size(2);
101+
}
102+
103+
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
104+
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
105+
106+
auto q_scalar_type = q.scalar_type();
107+
auto kv_scalar_type = paged_k_cache.scalar_type();
108+
109+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
110+
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
111+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
112+
bool use_swa = window_left != -1;
113+
114+
DISPATCH_context(
115+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
116+
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] {
117+
PagedParams params;
118+
119+
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
120+
params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
121+
params.v_ptr = static_cast<DTypeKV*>(paged_v_cache.data_ptr());
122+
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
123+
params.lse_ptr = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
124+
params.q_stride_n = q.stride(0);
125+
params.q_stride_h = q.stride(1);
126+
params.o_stride_n = o.stride(0);
127+
params.o_stride_h = o.stride(1);
128+
if (kv_layout == QKVLayout::kNHD) {
129+
// (num_pages, page_size, num_heads, head_dim)
130+
params.k_stride_n = paged_k_cache.stride(1);
131+
params.k_stride_h = paged_k_cache.stride(2);
132+
params.v_stride_n = paged_v_cache.stride(1);
133+
params.v_stride_h = paged_v_cache.stride(2);
134+
} else {
135+
// (num_pages, num_heads, page_size, head_dim)
136+
params.k_stride_h = paged_k_cache.stride(1);
137+
params.k_stride_n = paged_k_cache.stride(2);
138+
params.v_stride_h = paged_v_cache.stride(1);
139+
params.v_stride_n = paged_v_cache.stride(2);
140+
}
141+
params.nnz_qo = q.size(0);
142+
params.num_qo_heads = q.size(1);
143+
params.num_kv_heads = num_kv_heads;
144+
params.group_size = params.num_qo_heads / num_kv_heads;
145+
params.page_size = page_size;
146+
params.window_left = window_left;
147+
params.causal = mask_mode_code == 1;
148+
params.qo_tile_indices =
149+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_tile_indices_offset);
150+
params.qo_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_indptr_offset);
151+
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
152+
params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset);
153+
params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
154+
params.head_indices =
155+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
156+
params.work_indptr =
157+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
158+
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
159+
160+
ADDITIONAL_PARAMS_SETTER
161+
162+
// Not support various head_dim for now
163+
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
164+
// Currently only support same quantization precision
165+
static_assert(std::is_same_v<DTypeQ, DTypeKV>);
166+
167+
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
168+
DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
169+
cudaError_t status =
170+
BatchFP8PrefillWithPagedKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
171+
SAME_SCHEDULER_FOR_ALL_HEADS,
172+
AttentionVariant>(params, stream);
173+
TORCH_CHECK(status == cudaSuccess,
174+
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
175+
cudaGetErrorString(status));
176+
return true;
177+
});
178+
});
179+
}

csrc/single_prefill_fp8_sm90.cu

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/mask.cuh>
17+
#include <flashinfer/layout.cuh>
18+
#include <flashinfer/math.cuh>
19+
#include <optional>
20+
21+
#include "pytorch_extension_utils.h"
22+
#include "single_prefill_sm90_config.inc"
23+
24+
namespace flashinfer {
25+
26+
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
27+
typename AttentionVariant, typename Params>
28+
cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t stream);
29+
30+
} // namespace flashinfer
31+
32+
using namespace flashinfer;
33+
34+
void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp,
35+
at::Tensor o, std::optional<at::Tensor> maybe_lse,
36+
int64_t mask_mode_code, int64_t layout,
37+
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
38+
unsigned int head_dim_qk = q.size(2);
39+
unsigned int head_dim_vo = v.size(2);
40+
unsigned int num_qo_heads = q.size(1);
41+
unsigned int qo_len = q.size(0);
42+
43+
auto q_scalar_type = q.scalar_type();
44+
auto kv_scalar_type = k.scalar_type();
45+
46+
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
47+
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
48+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
49+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
50+
51+
DISPATCH_context(
52+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
53+
USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] {
54+
Params params;
55+
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
56+
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
57+
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
58+
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
59+
params.lse_ptr = maybe_lse ? (static_cast<float*>(maybe_lse->data_ptr())) : nullptr;
60+
params.q_stride_n = q.stride(0);
61+
params.q_stride_h = q.stride(1);
62+
params.o_stride_n = o.stride(0);
63+
params.o_stride_h = o.stride(1);
64+
if (kv_layout == QKVLayout::kNHD) {
65+
params.k_stride_n = k.stride(0);
66+
params.k_stride_h = k.stride(1);
67+
params.v_stride_n = v.stride(0);
68+
params.v_stride_h = v.stride(1);
69+
} else {
70+
params.k_stride_h = k.stride(0);
71+
params.k_stride_n = k.stride(1);
72+
params.v_stride_h = v.stride(0);
73+
params.v_stride_n = v.stride(1);
74+
}
75+
params.qo_len = q.size(0);
76+
params.kv_len = k.size(0);
77+
params.num_qo_heads = q.size(1);
78+
params.num_kv_heads = k.size(1);
79+
params.causal = mask_mode == MaskMode::kCausal;
80+
params.group_size = params.num_qo_heads / params.num_kv_heads;
81+
82+
// Note(Yilong): this should be checked on Python Side
83+
// Only support window_left == 0 for now
84+
params.window_left = window_left;
85+
86+
// Note(Yilong): all quantization parameters are set in additional_params
87+
ADDITIONAL_PARAMS_SETTER
88+
89+
// Not support various head_dim for now
90+
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
91+
// Currently only support same quantization precision
92+
static_assert(std::is_same_v<DTypeQ, DTypeKV>);
93+
94+
cudaError_t status =
95+
SingleFP8PrefillWithKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
96+
AttentionVariant>(params, stream);
97+
TORCH_CHECK(status == cudaSuccess, "single_prefill_with_kv_cache_sm90 failed with error: " +
98+
std::string(cudaGetErrorString(status)));
99+
return true;
100+
});
101+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
2+
#include "single_prefill_sm90_config.inc"
3+
4+
using namespace flashinfer;
5+
6+
namespace flashinfer {
7+
8+
template cudaError_t SingleFP8PrefillWithKVCacheDispatched
9+
<{{ head_dim_qk }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>(
10+
Params& params, cudaStream_t stream);
11+
};

flashinfer/decode.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ def single_decode_with_kv_cache(
530530
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
531531
logits_soft_cap,
532532
sm_scale,
533+
None, # scale_q, not supported yet
534+
None, # scale_k
535+
None, # scale_v
533536
rope_scale,
534537
rope_theta,
535538
)
@@ -1169,6 +1172,9 @@ def run(
11691172
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
11701173
logits_soft_cap,
11711174
sm_scale,
1175+
None, # scale_q, not supported yet
1176+
None, # scale_k
1177+
None, # scale_v
11721178
rope_scale,
11731179
rope_theta,
11741180
]

0 commit comments

Comments
 (0)