Skip to content

feat: add functional per-head FP8 quantization for FA3 #1033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions benchmarks/bench_hopper_fp8_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import triton

import flashinfer


def bench_single_prefill(seq_len, num_heads, causal, head_dim):
num_qo_heads = num_kv_heads = num_heads
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")

sm80_ms, sm90_ms = (
triton.testing.do_bench(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend=backend
),
warmup=100,
rep=1000,
)
for backend in ["fa2", "fa3"]
)

q = torch.randn(
seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
k = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
v = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)

fp8_sm90_ms = triton.testing.do_bench(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend="fa3", o_dtype=torch.half
),
warmup=100,
rep=1000,
)

def flops(ms):
if causal:
return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
else:
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9

print(
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"
)


if __name__ == "__main__":
for seq_len in [4096, 8192, 16384]:
for num_heads in [24, 32]:
for causal in [True, False]:
for head_dim in [64, 128, 256]:
bench_single_prefill(seq_len, num_heads, causal, head_dim)
15 changes: 15 additions & 0 deletions csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"

namespace flashinfer {

{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, PagedParams>(PagedParams& params, cudaStream_t stream);
{% endfor %}

}; // namespace flashinfer
1 change: 1 addition & 0 deletions csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// TODO: Not implemented yet
179 changes: 179 additions & 0 deletions csrc/batch_prefill_fp8_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/math.cuh>
#include <optional>

#include "batch_prefill_sm90_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream);

} // namespace flashinfer

using namespace flashinfer;

at::Tensor BatchPrefillWithKVCacheSM90Plan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

flashinfer::PrefillPlanSM90Info plan_info;

const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

cudaError_t status =
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size,
causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);

TORCH_CHECK(status == cudaSuccess,
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));

return vec_to_tensor(plan_info.ToVector());
}

void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout,
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
return; // TODO: Implement this function
}

void BatchPrefillWithPagedKVCacheSM90Run(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
PrefillPlanSM90Info plan_info;
plan_info.FromVector(tensor_to_vec(plan_info_vec));

if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
}
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
int64_t num_kv_heads, page_size;
int64_t head_dim_qk = q.size(2);
int64_t head_dim_vo = paged_v_cache.size(3);
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
bool use_swa = window_left != -1;

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] {
PagedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
params.v_ptr = static_cast<DTypeKV*>(paged_v_cache.data_ptr());
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
params.lse_ptr = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.q_stride_n = q.stride(0);
params.q_stride_h = q.stride(1);
params.o_stride_n = o.stride(0);
params.o_stride_h = o.stride(1);
if (kv_layout == QKVLayout::kNHD) {
// (num_pages, page_size, num_heads, head_dim)
params.k_stride_n = paged_k_cache.stride(1);
params.k_stride_h = paged_k_cache.stride(2);
params.v_stride_n = paged_v_cache.stride(1);
params.v_stride_h = paged_v_cache.stride(2);
} else {
// (num_pages, num_heads, page_size, head_dim)
params.k_stride_h = paged_k_cache.stride(1);
params.k_stride_n = paged_k_cache.stride(2);
params.v_stride_h = paged_v_cache.stride(1);
params.v_stride_n = paged_v_cache.stride(2);
}
params.nnz_qo = q.size(0);
params.num_qo_heads = q.size(1);
params.num_kv_heads = num_kv_heads;
params.group_size = params.num_qo_heads / num_kv_heads;
params.page_size = page_size;
params.window_left = window_left;
params.causal = mask_mode_code == 1;
params.qo_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_tile_indices_offset);
params.qo_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_indptr_offset);
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset);
params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
params.head_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());

ADDITIONAL_PARAMS_SETTER

// Not support various head_dim for now
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
// Currently only support same quantization precision
static_assert(std::is_same_v<DTypeQ, DTypeKV>);

bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
cudaError_t status =
BatchFP8PrefillWithPagedKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
SAME_SCHEDULER_FOR_ALL_HEADS,
AttentionVariant>(params, stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
cudaGetErrorString(status));
return true;
});
});
}
101 changes: 101 additions & 0 deletions csrc/single_prefill_fp8_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/math.cuh>
#include <optional>

#include "pytorch_extension_utils.h"
#include "single_prefill_sm90_config.inc"

namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
typename AttentionVariant, typename Params>
cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t stream);

} // namespace flashinfer

using namespace flashinfer;

void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout,
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
unsigned int head_dim_qk = q.size(2);
unsigned int head_dim_vo = v.size(2);
unsigned int num_qo_heads = q.size(1);
unsigned int qo_len = q.size(0);

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = k.scalar_type();

QKVLayout kv_layout = static_cast<QKVLayout>(layout);
const c10::cuda::OptionalCUDAGuard device_guard(q.device());
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] {
Params params;
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
params.lse_ptr = maybe_lse ? (static_cast<float*>(maybe_lse->data_ptr())) : nullptr;
params.q_stride_n = q.stride(0);
params.q_stride_h = q.stride(1);
params.o_stride_n = o.stride(0);
params.o_stride_h = o.stride(1);
if (kv_layout == QKVLayout::kNHD) {
params.k_stride_n = k.stride(0);
params.k_stride_h = k.stride(1);
params.v_stride_n = v.stride(0);
params.v_stride_h = v.stride(1);
} else {
params.k_stride_h = k.stride(0);
params.k_stride_n = k.stride(1);
params.v_stride_h = v.stride(0);
params.v_stride_n = v.stride(1);
}
params.qo_len = q.size(0);
params.kv_len = k.size(0);
params.num_qo_heads = q.size(1);
params.num_kv_heads = k.size(1);
params.causal = mask_mode == MaskMode::kCausal;
params.group_size = params.num_qo_heads / params.num_kv_heads;

// Note(Yilong): this should be checked on Python Side
// Only support window_left == 0 for now
params.window_left = window_left;

// Note(Yilong): all quantization parameters are set in additional_params
ADDITIONAL_PARAMS_SETTER

// Not support various head_dim for now
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
// Currently only support same quantization precision
static_assert(std::is_same_v<DTypeQ, DTypeKV>);

cudaError_t status =
SingleFP8PrefillWithKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
AttentionVariant>(params, stream);
TORCH_CHECK(status == cudaSuccess, "single_prefill_with_kv_cache_sm90 failed with error: " +
std::string(cudaGetErrorString(status)));
return true;
});
}
11 changes: 11 additions & 0 deletions csrc/single_prefill_fp8_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "single_prefill_sm90_config.inc"

using namespace flashinfer;

namespace flashinfer {

template cudaError_t SingleFP8PrefillWithKVCacheDispatched
<{{ head_dim_qk }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>(
Params& params, cudaStream_t stream);
};
6 changes: 6 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ def single_decode_with_kv_cache(
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
rope_scale,
rope_theta,
)
Expand Down Expand Up @@ -1169,6 +1172,9 @@ def run(
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
rope_scale,
rope_theta,
]
Expand Down
Loading