From 2066ffbf393bdda533bbec91761c2553acc38af3 Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 23 Apr 2025 04:55:42 +0000 Subject: [PATCH 01/14] add jit for single prefill fp8 fa3 --- csrc/single_prefill_fp8_sm90.cu | 101 ++++++++++++++++++ .../single_prefill_fp8_sm90_kernel_inst.jinja | 11 ++ flashinfer/jit/attention/pytorch.py | 43 ++++++-- flashinfer/prefill.py | 96 ++++++++++++++--- 4 files changed, 225 insertions(+), 26 deletions(-) create mode 100644 csrc/single_prefill_fp8_sm90.cu create mode 100644 csrc/single_prefill_fp8_sm90_kernel_inst.jinja diff --git a/csrc/single_prefill_fp8_sm90.cu b/csrc/single_prefill_fp8_sm90.cu new file mode 100644 index 000000000..2e4eb9133 --- /dev/null +++ b/csrc/single_prefill_fp8_sm90.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "pytorch_extension_utils.h" +#include "single_prefill_sm90_config.inc" + +namespace flashinfer { + +template +cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS) { + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = v.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + const MaskMode mask_mode = static_cast(mask_mode_code); + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { + Params params; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? (static_cast(maybe_lse->data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + } else { + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + } + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.causal = mask_mode == MaskMode::kCausal; + params.group_size = params.num_qo_heads / params.num_kv_heads; + + // Note(Yilong): this should be checked on Python Side + // Only support window_left == 0 for now + params.window_left = window_left; + + // Note(Yilong): all quantization parameters are set in additional_params + ADDITIONAL_PARAMS_SETTER + + // Not support various head_dim for now + static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same"); + // Currently only support same quantization precision + static_assert(std::is_same_v); + + cudaError_t status = + SingleFP8PrefillWithKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, "single_prefill_with_kv_cache_sm90 failed with error: " + + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/csrc/single_prefill_fp8_sm90_kernel_inst.jinja b/csrc/single_prefill_fp8_sm90_kernel_inst.jinja new file mode 100644 index 000000000..be0178dbd --- /dev/null +++ b/csrc/single_prefill_fp8_sm90_kernel_inst.jinja @@ -0,0 +1,11 @@ +#include +#include "single_prefill_sm90_config.inc" + +using namespace flashinfer; + +namespace flashinfer { + +template cudaError_t SingleFP8PrefillWithKVCacheDispatched + <{{ head_dim_qk }}, {{ mask_mode }}, /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, {{ variant_name }}, Params>( + Params& params, cudaStream_t stream); +}; diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 660f34649..fc497377c 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -452,7 +452,12 @@ def gen_single_prefill_module( use_logits_soft_cap, use_fp16_qk_reduction, ) + + # use `fp8_enabled` flag to use separate kernel template + fp8_enabled = "e4m3" in uri or "e5m2" in uri + if backend == "fa2": + assert not fp8_enabled, "fp8 is not supported in fa2 backend" additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] additional_tensor_dtypes = ["uint8_t", "float"] additional_scalar_names = [ @@ -465,12 +470,20 @@ def gen_single_prefill_module( variant_name = f"DefaultAttention" variant_decl = f"#include" else: - additional_tensor_names = [] - additional_tensor_dtypes = [] - additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["double", "double"] - variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" - variant_decl = f"#include" + if not fp8_enabled: + additional_tensor_names = [] + additional_tensor_dtypes = [] + additional_scalar_names = ["logits_soft_cap", "sm_scale"] + additional_scalar_dtypes = ["double", "double"] + variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" + variant_decl = f"#include" + else: + additional_tensor_names = ["scale_q", "scale_k", "scale_v"] + additional_tensor_dtypes = ["float", "float", "float"] + additional_scalar_names = ["sm_scale"] + additional_scalar_dtypes = ["double"] + variant_name = f"DefaultFP8Attention" + variant_decl = f"#include" return gen_customize_single_prefill_module( backend, @@ -490,6 +503,7 @@ def gen_single_prefill_module( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=use_fp16_qk_reduction, + fp8_enabled=fp8_enabled, ) @@ -884,6 +898,7 @@ def gen_customize_single_prefill_module( use_sliding_window: bool = False, use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, + fp8_enabled: bool = False, ): kwargs = { "variant_decl": variant_decl, @@ -967,12 +982,18 @@ def gen_customize_single_prefill_module( ) ) - with open( - FLASHINFER_CSRC_DIR / "single_prefill_sm90_customize_config.jinja" - ) as f: + _file_config = "single_prefill_sm90_customize_config.jinja" + if fp8_enabled: + _file_kernel_inst = "single_prefill_fp8_sm90_kernel_inst.jinja" + _file_csrc = "single_prefill_fp8_sm90.cu" + else: + _file_kernel_inst = "single_prefill_sm90_kernel_inst.jinja" + _file_csrc = "single_prefill_sm90.cu" + + with open(FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open(FLASHINFER_CSRC_DIR / "single_prefill_sm90_kernel_inst.jinja") as f: + with open(FLASHINFER_CSRC_DIR / _file_kernel_inst) as f: kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -998,7 +1019,7 @@ def gen_customize_single_prefill_module( write_if_different(dest_path, source) for filename in [ - "single_prefill_sm90.cu", + _file_csrc, "single_prefill_sm90_jit_pybind.cu", ]: src_path = FLASHINFER_CSRC_DIR / filename diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index aa2481f8f..69ce84d21 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -99,23 +99,44 @@ def run_single_prefill( maybe_alibi_slopes: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, ) -> None: if backend == "fa3": - run_func( - q, - k, - v, - tmp, - o, - maybe_lse, - mask_mode, - layout, - window_left, - logits_soft_cap, - sm_scale, - ) + if not is_float8(q): + run_func( + q, + k, + v, + tmp, + o, + maybe_lse, + mask_mode, + layout, + window_left, + logits_soft_cap, + sm_scale, + ) + else: + # FP8 enabled + run_func( + q, + k, + v, + tmp, + o, + maybe_lse, + mask_mode, + layout, + window_left, + scale_q, + scale_k, + scale_v, + sm_scale, + ) else: run_func( q, @@ -621,6 +642,10 @@ def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, + o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, @@ -642,6 +667,10 @@ def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, + o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, @@ -662,6 +691,10 @@ def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, + o_dtype: Optional[torch.dtype] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, @@ -691,6 +724,18 @@ def single_prefill_with_kv_cache( The key tensor, shape: ``[kv_len, num_kv_heads, head_dim_vo]`` if :attr:`kv_layout` is ``NHD``, ``[num_kv_heads, kv_len, head_dim_vo]`` if :attr:`kv_layout` is ``HND``. + scale_q : Optional[torch.Tensor] + The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. + scale_k : Optional[torch.Tensor] + The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. + scale_v : Optional[torch.Tensor] + The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. + o_dtype : Optional[torch.dtype] + The output tensor data type, if not provided, will be set to the same as the q. + This is necessary as output dtype cannot be automatically inferred in quant. custom_mask : Optional[torch.Tensor] The custom boolean mask tensor, shape: ``[qo_len, kv_len]``. The elements in the mask tensor should be either ``True`` or ``False``, @@ -814,6 +859,20 @@ def single_prefill_with_kv_cache( if return_lse: lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) + if is_float8(q): + # FP8 quant enabled, do sanity check: + # 1. unsupported feature + # 2. dtype check + assert window_left == -1 + assert q.dtype == k.dtype == v.dtype + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + if scale_q is None: + scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) + if scale_k is None: + scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) + if scale_v is None: + scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) + if backend == "auto": backend = determine_attention_backend( q.device, @@ -825,11 +884,15 @@ def single_prefill_with_kv_cache( ) module_getter = get_single_prefill_module(backend) - out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device) + # o_dtype should be provided for FP8 attention + if o_dtype is None: + o_dtype = q.dtype + out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device) + module_getter( q.dtype, k.dtype, - q.dtype, + out.dtype, q.shape[-1], # head_dim_qk v.shape[-1], # head_dim_vo PosEncodingMode[pos_encoding_mode].value, @@ -850,6 +913,9 @@ def single_prefill_with_kv_cache( _get_cache_alibi_slopes_buf(q.shape[1], q.device), logits_soft_cap, sm_scale, + scale_q, + scale_k, + scale_v, rope_scale, rope_theta, ) From 1c5c9cc33052ad8ee9ea1b5faebafcd7557dab6f Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 23 Apr 2025 05:24:07 +0000 Subject: [PATCH 02/14] upd benchmarks --- benchmarks/bench_hopper_fp8_attention.py | 58 +++++++++++++++++++ .../hopper/quantization/prefill_sm90.cuh | 2 +- 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 benchmarks/bench_hopper_fp8_attention.py diff --git a/benchmarks/bench_hopper_fp8_attention.py b/benchmarks/bench_hopper_fp8_attention.py new file mode 100644 index 000000000..a59ef8bb4 --- /dev/null +++ b/benchmarks/bench_hopper_fp8_attention.py @@ -0,0 +1,58 @@ +import torch +import triton + +import flashinfer + + +def bench_single_prefill(seq_len, num_heads, causal, head_dim): + num_qo_heads = num_kv_heads = num_heads + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + sm80_ms, sm90_ms = ( + triton.testing.do_bench( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, backend=backend + ), + warmup=100, + rep=1000, + ) + for backend in ["fa2", "fa3"] + ) + + q = torch.randn( + seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype=torch.float8_e4m3fn) + k = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype=torch.float8_e4m3fn) + v = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype=torch.float8_e4m3fn) + + fp8_sm90_ms = triton.testing.do_bench( + lambda: flashinfer.single_prefill_with_kv_cache_return_lse( + q, k, v, causal=causal, backend="fa3", o_dtype=torch.half + ), + warmup=100, + rep=1000, + ) + + def flops(ms): + if causal: + return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + else: + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + + print( + f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" + ) + + +if __name__ == "__main__": + for seq_len in [4096, 8192, 16384]: + for num_heads in [24, 32]: + for causal in [True, False]: + for head_dim in [64, 128, 256]: + bench_single_prefill(seq_len, num_heads, causal, head_dim) diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 4d2dd4978..458b7ac17 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -332,7 +332,7 @@ cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t s SingleFP8PrefillWithKVCacheKernelTraitsDispatched< FP8AttentionKernelTraits, From 65170bda82db2e18ff3200a60ecce2b66690b3f3 Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 23 Apr 2025 06:07:49 +0000 Subject: [PATCH 03/14] add test mse --- tests/test_hopper_fp8_attention.py | 82 ++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/test_hopper_fp8_attention.py diff --git a/tests/test_hopper_fp8_attention.py b/tests/test_hopper_fp8_attention.py new file mode 100644 index 000000000..2e90cf6d2 --- /dev/null +++ b/tests/test_hopper_fp8_attention.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch + +import flashinfer + + +def per_head_symmetric_quant( + x: torch.Tensor, quant_dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + # x: [seq_len, num_heads, head_dim] + assert quant_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + + def get_dtype_minmax(dtype: torch.dtype) -> Tuple[float, float]: + if dtype == torch.float8_e4m3fn: + return -448.0, 448.0 + elif dtype == torch.float8_e5m2: + return -57344, 57344 + else: + raise ValueError(f"Unsupported quantization dtype: {dtype}") + + o_min_val, o_max_val = get_dtype_minmax(quant_dtype) + x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32) + + s_out = torch.clamp(x_max_val / o_max_val, min=1e-6) + s_out_broadcast = s_out.view(1, -1, 1) + + q_x_out = torch.clamp( + x / s_out_broadcast, + min=o_min_val, + max=o_max_val, + ).to(dtype=quant_dtype) + + assert not torch.any(torch.isnan(q_x_out)) + assert not torch.any(torch.isnan(s_out)) + + return q_x_out, s_out + + +def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): + o_dtype = torch.half + num_qo_heads = num_kv_heads = num_heads + + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + o_ref = flashinfer.single_prefill_with_kv_cache( + q, k, v, causal=causal, backend="fa3" + ) + + q_fp8, s_q = per_head_symmetric_quant(q, quant_dtype=dtype) + k_fp8, s_k = per_head_symmetric_quant(k, quant_dtype=dtype) + v_fp8, s_v = per_head_symmetric_quant(v, quant_dtype=dtype) + o_fp8 = flashinfer.single_prefill_with_kv_cache( + q_fp8, + k_fp8, + v_fp8, + s_q, + s_k, + s_v, + causal=causal, + backend="fa3", + o_dtype=o_dtype, + ) + + assert not torch.any(torch.isnan(o_fp8)) + assert not torch.any(torch.isnan(o_ref)) + + # MSE + mse = torch.mean((o_ref - o_fp8) ** 2) + print( + f"test_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}, dtype={dtype}), MSE: {mse:.3f}" + ) + + +if __name__ == "__main__": + for seq_len in [512, 1024, 4096, 8192]: + for num_heads in [24, 32]: + for causal in [False]: + for head_dim in [64, 128, 256]: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + test_single_prefill(seq_len, num_heads, causal, head_dim, dtype) From 6ab01634f7f601f963aa9d52085d3bdad71c0f48 Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 23 Apr 2025 06:08:12 +0000 Subject: [PATCH 04/14] clean code --- src/fp8-dev/CMakeLists.txt | 105 ------ src/fp8-dev/bench_single_prefill_sm90.cu | 363 -------------------- src/fp8-dev/cpu_reference.h | 195 ----------- src/fp8-dev/flashattention_ops.h | 326 ------------------ src/fp8-dev/flashinfer_ops.cu | 155 --------- src/fp8-dev/test_single_prefill_fa3_sm90.cu | 154 --------- src/fp8-dev/utils.h | 206 ----------- 7 files changed, 1504 deletions(-) delete mode 100644 src/fp8-dev/CMakeLists.txt delete mode 100644 src/fp8-dev/bench_single_prefill_sm90.cu delete mode 100644 src/fp8-dev/cpu_reference.h delete mode 100644 src/fp8-dev/flashattention_ops.h delete mode 100644 src/fp8-dev/flashinfer_ops.cu delete mode 100644 src/fp8-dev/test_single_prefill_fa3_sm90.cu delete mode 100644 src/fp8-dev/utils.h diff --git a/src/fp8-dev/CMakeLists.txt b/src/fp8-dev/CMakeLists.txt deleted file mode 100644 index 8fdc6c1d5..000000000 --- a/src/fp8-dev/CMakeLists.txt +++ /dev/null @@ -1,105 +0,0 @@ -cmake_minimum_required(VERSION 3.23.1) -project(flashinfer CUDA CXX) - -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_ARCHITECTURES 90a) - -# ########### Modify the following paths to your own paths ############ -set(FA3_INCLUDE_DIR /home/ylzhao/flash-attention/hopper) -set(LIBTORCH_INCLUDE_DIR /home/ylzhao/libtorch) -# ########### Modify the above paths to your own paths ############ - -find_package(Python3 REQUIRED) -if(NOT Python3_FOUND) - message(FATAL_ERROR "Python3 not found.") -endif() - -list(APPEND CMAKE_PREFIX_PATH ${LIBTORCH_INCLUDE_DIR}) -find_package(Torch REQUIRED) -find_package(Thrust REQUIRED) -find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - -add_subdirectory(${CMAKE_SOURCE_DIR}/../../3rdparty/nvbench - ${CMAKE_BINARY_DIR}/nvbench_build) -add_subdirectory(${CMAKE_SOURCE_DIR}/../../3rdparty/googletest - ${CMAKE_BINARY_DIR}/googletest_build) - -set(FLASHINFER_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/../../include) -set(CUTLASS_INCLUDE_DIR - ${CMAKE_SOURCE_DIR}/../../3rdparty/cutlass/include - ${CMAKE_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include) -set(NVBENCH_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/../../3rdparty/nvbench) - -# build flashinfer library -add_library(FLASHINFER_LIB STATIC ${CMAKE_SOURCE_DIR}/flashinfer_ops.cu) -set_target_properties(FLASHINFER_LIB PROPERTIES CMAKE_CUDA_ARCHITECTURES "90a") -target_compile_options( - FLASHINFER_LIB - PRIVATE $<$: - --expt-extended-lambda - --use_fast_math - --compiler-options - -fPIC - --generate-code=arch=compute_90a,code=sm_90a - >) -target_include_directories(FLASHINFER_LIB PRIVATE ${FLASHINFER_INCLUDE_DIR}) -target_include_directories(FLASHINFER_LIB PRIVATE ${CUTLASS_INCLUDE_DIR}) - -# build FA3 library -file(GLOB FA3_IMPL_FILES ${FA3_INCLUDE_DIR}/flash_fwd_*.cu) -add_library(FA3_LIB STATIC ${FA3_IMPL_FILES}) -set_target_properties(FA3_LIB PROPERTIES CMAKE_CUDA_ARCHITECTURES "90a") -target_compile_options( - FA3_LIB - PRIVATE $<$: - --expt-extended-lambda - --use_fast_math - --compiler-options - -fPIC - --generate-code=arch=compute_90a,code=sm_90a - >) -target_include_directories(FA3_LIB PRIVATE ${CUTLASS_INCLUDE_DIR} - ${FA3_INCLUDE_DIR}) - -# build benchmark FA3/FlashInfer/FP16 -add_executable(bench_single_prefill - ${CMAKE_SOURCE_DIR}/bench_single_prefill_sm90.cu) -target_compile_options( - bench_single_prefill - PRIVATE $<$:--disable-warnings> # Disables warnings in - # CUDA - $<$:-w> # Disables warnings in C++ - $<$: - --use_fast_math> - $<$:--expt-relaxed-constexpr - --generate-code=arch=compute_90a,code=sm_90a>) -target_include_directories( - bench_single_prefill PRIVATE ${FLASHINFER_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} - ${Python3_INCLUDE_DIRS}) -target_include_directories(bench_single_prefill PRIVATE ${CUTLASS_INCLUDE_DIR}) -target_include_directories(bench_single_prefill PRIVATE ${NVBENCH_INCLUDE_DIR}) -target_include_directories(bench_single_prefill PRIVATE ${FA3_INCLUDE_DIR}) -target_link_libraries(bench_single_prefill PRIVATE nvbench::main - ${TORCH_LIBRARIES} FA3_LIB) - -# build test FA3 / FlashInfer / FP16 test cases -add_executable(test_single_prefill_fa3_sm90 - ${CMAKE_SOURCE_DIR}/test_single_prefill_fa3_sm90.cu) -target_compile_options( - test_single_prefill_fa3_sm90 - PRIVATE $<$:--disable-warnings> # Disables warnings in - # CUDA - $<$:-w> # Disables warnings in C++ - $<$:--expt-relaxed-constexpr - --generate-code=arch=compute_90a,code=sm_90a>) -target_include_directories(test_single_prefill_fa3_sm90 - PRIVATE ${CUTLASS_INCLUDE_DIR}) -target_include_directories( - test_single_prefill_fa3_sm90 - PRIVATE ${FLASHINFER_INCLUDE_DIR} ${TORCH_INCLUDE_DIRS} - ${Python3_INCLUDE_DIRS}) -target_include_directories(test_single_prefill_fa3_sm90 - PRIVATE ${FA3_INCLUDE_DIR}) -target_link_libraries(test_single_prefill_fa3_sm90 - PRIVATE ${TORCH_LIBRARIES} FA3_LIB FLASHINFER_LIB) diff --git a/src/fp8-dev/bench_single_prefill_sm90.cu b/src/fp8-dev/bench_single_prefill_sm90.cu deleted file mode 100644 index 90ff2c446..000000000 --- a/src/fp8-dev/bench_single_prefill_sm90.cu +++ /dev/null @@ -1,363 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "flashattention_ops.h" -#include "utils.h" - -namespace flashinfer { - -template -cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t stream); - -template -cudaError_t SinglePrefillWithKVCacheDispatched(Params& params, cudaStream_t stream); -} // namespace flashinfer - -using namespace flashinfer; - -void single_fp8_prefill_with_kv_cache_sm90(nvbench::state& state) { - size_t qo_len = state.get_int64("seq_len"); - size_t kv_len = state.get_int64("seq_len"); - size_t num_qo_heads = state.get_int64("num_qo_heads"); - size_t num_kv_heads = state.get_int64("num_kv_heads"); - size_t head_dim = state.get_int64("head_dim"); - QKVLayout kv_layout = QKVLayout(state.get_int64("kv_layout")); - MaskMode mask_mode = MaskMode(state.get_int64("mask_mode")); - - if (qo_len > kv_len) { - state.skip("qo_len should be less than kv_len"); - } - - using DTypeQ = cutlass::float_e4m3_t; - using DTypeKV = cutlass::float_e4m3_t; - using DTypeO = cutlass::half_t; - using IdType = int32_t; - - constexpr auto USE_SLIDING_WINDOW = false; - - using Params = SinglePrefillParams; - using AttentionVariant = DefaultFP8Attention; - - thrust::device_vector q(qo_len * num_qo_heads * head_dim); - thrust::device_vector k(kv_len * num_kv_heads * head_dim); - thrust::device_vector v(kv_len * num_kv_heads * head_dim); - thrust::device_vector o(qo_len * num_qo_heads * head_dim); - - thrust::device_vector scale_q(num_qo_heads); - thrust::device_vector scale_k(num_kv_heads); - thrust::device_vector scale_v(num_kv_heads); - - Params params; - params.q_ptr = static_cast(thrust::raw_pointer_cast(q.data())); - params.k_ptr = static_cast(thrust::raw_pointer_cast(k.data())); - params.v_ptr = static_cast(thrust::raw_pointer_cast(v.data())); - params.o_ptr = static_cast(thrust::raw_pointer_cast(o.data())); - params.lse_ptr = nullptr; - // q NHD - params.q_stride_n = num_qo_heads * head_dim; - params.q_stride_h = head_dim; - params.o_stride_n = num_qo_heads * head_dim; - params.o_stride_h = head_dim; - if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = num_kv_heads * head_dim; - params.k_stride_h = head_dim; - params.v_stride_n = num_kv_heads * head_dim; - params.v_stride_h = head_dim; - } else { - // k HND - params.k_stride_h = kv_len * head_dim; - params.k_stride_n = head_dim; - params.v_stride_h = kv_len * head_dim; - params.v_stride_n = head_dim; - } - params.qo_len = qo_len; - params.kv_len = kv_len; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.causal = mask_mode == MaskMode::kCausal; - params.group_size = params.num_qo_heads / params.num_kv_heads; - params.window_left = 0; - - params.additional_params.scale_q = thrust::raw_pointer_cast(scale_q.data()); - params.additional_params.scale_k = thrust::raw_pointer_cast(scale_k.data()); - params.additional_params.scale_v = thrust::raw_pointer_cast(scale_v.data()); - params.additional_params.sm_scale = 1.f / std::sqrt(float(head_dim)); - - state.add_global_memory_reads( - (qo_len * num_qo_heads + 2 * kv_len * num_kv_heads) * sizeof(DTypeQ) * head_dim, "Read"); - state.add_global_memory_writes(qo_len * num_qo_heads * head_dim, "Write"); - - state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { - timer.start(); - - cudaError_t status; - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - status = SingleFP8PrefillWithKVCacheDispatched( - params, launch.get_stream()); - }); - }); - if (status != cudaSuccess) { - state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); - } - timer.stop(); - cudaDeviceSynchronize(); - }); - - const auto measured_mean = static_cast( - state.get_summary("nv/cold/time/gpu/mean").get_float64("value")); - auto& summ = state.add_summary("nv/tflops"); - summ.set_string("description", "Achieved TFlops/s"); - summ.set_string("name", "TFlops/s"); - float tflops; - if (params.causal) { - tflops = qo_len * (2 * kv_len - qo_len) * 2 * num_kv_heads * head_dim / measured_mean / 1e12; - } else { - tflops = qo_len * kv_len * 4 * num_kv_heads * head_dim / measured_mean / 1e12; - } - summ.set_float64("value", tflops); -} - -void single_fp16_prefill_with_kv_cache_sm90(nvbench::state& state) { - size_t qo_len = state.get_int64("seq_len"); - size_t kv_len = state.get_int64("seq_len"); - size_t num_qo_heads = state.get_int64("num_qo_heads"); - size_t num_kv_heads = state.get_int64("num_kv_heads"); - size_t head_dim = state.get_int64("head_dim"); - QKVLayout kv_layout = QKVLayout(state.get_int64("kv_layout")); - MaskMode mask_mode = MaskMode(state.get_int64("mask_mode")); - - if (qo_len > kv_len) { - state.skip("qo_len should be less than kv_len"); - } - - using DTypeQ = cutlass::half_t; - using DTypeKV = cutlass::half_t; - using DTypeO = cutlass::half_t; - using IdType = int32_t; - - constexpr auto USE_SLIDING_WINDOW = false; - - using Params = SinglePrefillParams; - using AttentionVariant = StandardAttention; - - thrust::device_vector q(qo_len * num_qo_heads * head_dim); - thrust::device_vector k(kv_len * num_kv_heads * head_dim); - thrust::device_vector v(kv_len * num_kv_heads * head_dim); - thrust::device_vector o(qo_len * num_qo_heads * head_dim); - - Params params; - params.q_ptr = static_cast(thrust::raw_pointer_cast(q.data())); - params.k_ptr = static_cast(thrust::raw_pointer_cast(k.data())); - params.v_ptr = static_cast(thrust::raw_pointer_cast(v.data())); - params.o_ptr = static_cast(thrust::raw_pointer_cast(o.data())); - params.lse_ptr = nullptr; - // q NHD - params.q_stride_n = num_qo_heads * head_dim; - params.q_stride_h = head_dim; - params.o_stride_n = num_qo_heads * head_dim; - params.o_stride_h = head_dim; - if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = num_kv_heads * head_dim; - params.k_stride_h = head_dim; - params.v_stride_n = num_kv_heads * head_dim; - params.v_stride_h = head_dim; - } else { - // k HND - params.k_stride_h = kv_len * head_dim; - params.k_stride_n = head_dim; - params.v_stride_h = kv_len * head_dim; - params.v_stride_n = head_dim; - } - params.qo_len = qo_len; - params.kv_len = kv_len; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.causal = mask_mode == MaskMode::kCausal; - params.group_size = params.num_qo_heads / params.num_kv_heads; - params.window_left = 0; - params.additional_params.sm_scale = 1.f / std::sqrt(float(head_dim)); - - state.add_global_memory_reads( - (qo_len * num_qo_heads + 2 * kv_len * num_kv_heads) * sizeof(DTypeQ) * head_dim, "Read"); - state.add_global_memory_writes(qo_len * num_qo_heads * head_dim, "Write"); - - state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { - timer.start(); - - cudaError_t status; - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - status = SinglePrefillWithKVCacheDispatched( - params, launch.get_stream()); - }); - }); - - if (status != cudaSuccess) { - state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); - } - timer.stop(); - }); - - const auto measured_mean = static_cast( - state.get_summary("nv/cold/time/gpu/mean").get_float64("value")); - auto& summ = state.add_summary("nv/tflops"); - summ.set_string("description", "Achieved TFlops/s"); - summ.set_string("name", "TFlops/s"); - float tflops; - if (params.causal) { - tflops = qo_len * (2 * kv_len - qo_len) * 2 * num_kv_heads * head_dim / measured_mean / 1e12; - } else { - tflops = qo_len * kv_len * 4 * num_kv_heads * head_dim / measured_mean / 1e12; - } - summ.set_float64("value", tflops); -} - -void single_fp8_fa3_prefill_with_kv_cache_sm90(nvbench::state& state) { - size_t qo_len = state.get_int64("seq_len"); - size_t kv_len = state.get_int64("seq_len"); - size_t num_qo_heads = state.get_int64("num_qo_heads"); - size_t num_kv_heads = state.get_int64("num_kv_heads"); - size_t head_dim = state.get_int64("head_dim"); - MaskMode mask_mode = MaskMode(state.get_int64("mask_mode")); - - bool is_causal = mask_mode == MaskMode::kCausal; - float sm_scale = 1.f / std::sqrt(float(head_dim)); - - if (qo_len > kv_len) { - state.skip("qo_len should be less than kv_len"); - } - - using DTypeQ = cutlass::float_e4m3_t; - using DTypeKV = cutlass::float_e4m3_t; - using DTypeO = cutlass::half_t; - using IdType = int32_t; - using DTypeScale = float; - - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector scale_q(1); - std::vector scale_k(1); - std::vector scale_v(1); - - utils::vec_normal_(q); - utils::vec_normal_(k); - utils::vec_normal_(v); - - utils::vec_zero_(o); - utils::vec_normal_(scale_q); - utils::vec_normal_(scale_k); - utils::vec_normal_(scale_v); - - auto device = torch::Device(torch::kCUDA, 0); - auto q_t = - torch::from_blob(q.data(), {1, uint32_t(qo_len), uint32_t(num_qo_heads), uint32_t(head_dim)}, - torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto k_t = - torch::from_blob(k.data(), {1, uint32_t(kv_len), uint32_t(num_kv_heads), uint32_t(head_dim)}, - torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto v_t = - torch::from_blob(v.data(), {1, uint32_t(kv_len), uint32_t(num_kv_heads), uint32_t(head_dim)}, - torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto scale_q_t = std::optional( - torch::from_blob(scale_q.data(), {1}, torch::kFloat).clone().to(device)); - auto scale_k_t = std::optional( - torch::from_blob(scale_k.data(), {1}, torch::kFloat).clone().to(device)); - auto scale_v_t = std::optional( - torch::from_blob(scale_v.data(), {1}, torch::kFloat).clone().to(device)); - auto out_t = std::optional{}; - - state.add_global_memory_reads( - (qo_len * num_qo_heads + 2 * kv_len * num_kv_heads) * sizeof(DTypeQ) * head_dim, "Read"); - state.add_global_memory_writes(qo_len * num_qo_heads * head_dim, "Write"); - - state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { - timer.start(); - - auto o_t_ref = - mha_fwd(q_t, k_t, v_t, out_t, sm_scale, scale_q_t, scale_k_t, scale_v_t, is_causal)[0]; - - timer.stop(); - cudaDeviceSynchronize(); - }); - - const auto measured_mean = static_cast( - state.get_summary("nv/cold/time/gpu/mean").get_float64("value")); - auto& summ = state.add_summary("nv/tflops"); - summ.set_string("description", "Achieved TFlops/s"); - summ.set_string("name", "TFlops/s"); - float tflops; - if (is_causal) { - tflops = qo_len * (2 * kv_len - qo_len) * 2 * num_kv_heads * head_dim / measured_mean / 1e12; - } else { - tflops = qo_len * kv_len * 4 * num_kv_heads * head_dim / measured_mean / 1e12; - } - summ.set_float64("value", tflops); -} - -NVBENCH_BENCH(single_fp8_prefill_with_kv_cache_sm90) - .set_name(("single_fp8_prefill_with_kv_cache_sm90")) - .add_int64_axis("seq_len", {2048, 4096, 8192, 16384}) - .add_int64_axis("num_qo_heads", {32}) - .add_int64_axis("num_kv_heads", {32}) - .add_int64_axis("head_dim", {64, 128, 256}) - .add_int64_axis("mask_mode", {0, 1}) - .add_int64_axis("kv_layout", {0}); - -NVBENCH_BENCH(single_fp16_prefill_with_kv_cache_sm90) - .set_name(("single_fp16_prefill_with_kv_cache_sm90")) - .add_int64_axis("seq_len", {2048, 4096, 8192, 16384}) - .add_int64_axis("num_qo_heads", {32}) - .add_int64_axis("num_kv_heads", {32}) - .add_int64_axis("head_dim", {64, 128, 256}) - .add_int64_axis("mask_mode", {0, 1}) - .add_int64_axis("kv_layout", {0}); - -NVBENCH_BENCH(single_fp8_fa3_prefill_with_kv_cache_sm90) - .set_name(("single_fp8_fa3_prefill_with_kv_cache_sm90")) - .add_int64_axis("seq_len", {2048, 4096, 8192, 16384}) - .add_int64_axis("num_qo_heads", {32}) - .add_int64_axis("num_kv_heads", {32}) - .add_int64_axis("head_dim", {64, 128, 256}) - .add_int64_axis("mask_mode", {0, 1}); diff --git a/src/fp8-dev/cpu_reference.h b/src/fp8-dev/cpu_reference.h deleted file mode 100644 index 279b4cc9d..000000000 --- a/src/fp8-dev/cpu_reference.h +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include - -#include "utils.h" - -namespace cpu_reference { - -using namespace flashinfer; - -template -void sym_quant_per_head(const std::vector& x_in, std::vector& x_out, - std::vector& s_out, size_t len, size_t num_heads, - size_t head_dim, QKVLayout kv_layout, bool is_q = false) { - assert(x_in.size() == x_out.size()); - assert(s_out.size() == num_heads); - assert(x_in.size() == len * num_heads * head_dim); - - float o_max_val = std::numeric_limits::max(); - float o_min_val = std::numeric_limits::lowest(); - - tensor_info_t info(len, len, num_heads, num_heads, kv_layout, head_dim); - auto offset = [&](size_t token_idx, size_t head_idx, size_t feat_idx) { - if (is_q) { - return info.get_q_elem_offset(token_idx, head_idx, feat_idx); - } else { - return info.get_kv_elem_offset(token_idx, head_idx, feat_idx); - } - }; - for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { - float max_val = 0; - for (size_t token_idx = 0; token_idx < len; ++token_idx) { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - max_val = std::max(max_val, std::abs(x_in[offset(token_idx, head_idx, feat_idx)])); - } - } - s_out[head_idx] = dtype_scale(max_val / o_max_val); - for (size_t token_idx = 0; token_idx < len; ++token_idx) { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float q_x = float(x_in[offset(token_idx, head_idx, feat_idx)]) / float(s_out[head_idx]); - q_x = std::clamp(q_x, o_min_val, o_max_val); - x_out[offset(token_idx, head_idx, feat_idx)] = dtype_out(q_x); - } - } - } -} - -template -std::vector single_mha(const std::vector& q, const std::vector& k, - const std::vector& v, size_t qo_len, size_t kv_len, - size_t num_q_heads, size_t num_kv_heads, size_t head_dim, - float sm_scale, bool causal = true, - QKVLayout kv_layout = QKVLayout::kHND, float rope_scale = 1.f, - float rope_theta = 1e4) { - assert(qo_len <= kv_len); - assert(num_q_heads % num_kv_heads == 0); - - size_t group_size = num_q_heads / num_kv_heads; - std::vector o(qo_len * num_q_heads * head_dim); - std::vector att(kv_len); - - tensor_info_t info(qo_len, kv_len, num_q_heads, num_kv_heads, kv_layout, head_dim); - for (size_t qo_head_idx = 0; qo_head_idx < num_q_heads; ++qo_head_idx) { - const size_t kv_head_idx = qo_head_idx / group_size; - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += float(q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * - float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); - } - // apply mask - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4; - } - max_val = std::max(max_val, att[kv_idx]); - } - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] * sm_scale - max_val * sm_scale); - denom += att[kv_idx]; - } - - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } - - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float o_float = 0.; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - o_float += att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); - } - o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); - } - } - } - return std::move(o); -} - -template -std::vector single_fp8_mha( - const std::vector& q, const std::vector& k, const std::vector& v, - const std::vector& q_scale, const std::vector& k_scale, - const std::vector& v_scale, size_t qo_len, size_t kv_len, size_t num_q_heads, - size_t num_kv_heads, size_t head_dim, float sm_scale, bool causal = true, - QKVLayout kv_layout = QKVLayout::kHND, float rope_scale = 1.f, float rope_theta = 1e4) { - static_assert(sizeof(dtype_in) == 1); - float p_fp8_scale = std::numeric_limits::max(); - - assert(qo_len <= kv_len); - assert(num_q_heads % num_kv_heads == 0); - assert(q_scale.size() == num_q_heads); - assert(k_scale.size() == num_kv_heads); - assert(v_scale.size() == num_kv_heads); - - size_t group_size = num_q_heads / num_kv_heads; - std::vector o(qo_len * num_q_heads * head_dim); - std::vector att(kv_len); - - tensor_info_t info(qo_len, kv_len, num_q_heads, num_kv_heads, kv_layout, head_dim); - for (size_t qo_head_idx = 0; qo_head_idx < num_q_heads; ++qo_head_idx) { - const size_t kv_head_idx = qo_head_idx / group_size; - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += float(q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * - float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); - } - - // apply mask - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4; - } - max_val = std::max(max_val, att[kv_idx]); - } - // exp minus max - float denom = 0; - float sm_scale_fused_dequantize_log2 = - sm_scale * float(q_scale[qo_head_idx]) * float(k_scale[kv_head_idx]); - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] * sm_scale_fused_dequantize_log2 - - max_val * sm_scale_fused_dequantize_log2); - denom += att[kv_idx]; - } - - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } - - // Requantize - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] *= float(p_fp8_scale); - att[kv_idx] = float(dtype_in(att[kv_idx])); - } - - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float o_float = 0.; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - o_float += att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); - } - // Dequantize - o_float *= float(v_scale[kv_head_idx]) / float(p_fp8_scale); - o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); - } - } - } - - return std::move(o); -} - -} // namespace cpu_reference diff --git a/src/fp8-dev/flashattention_ops.h b/src/fp8-dev/flashattention_ops.h deleted file mode 100644 index 3aa90a1fe..000000000 --- a/src/fp8-dev/flashattention_ops.h +++ /dev/null @@ -1,326 +0,0 @@ -#include "flash.h" -#include "static_switch.h" - -// Ref from FlashAttention3-official Repo -// https://github.com/Dao-AILab/flash-attention/blob/bdf733be55f0b323a8cf7cc6745a81c3f43cd7f0/hopper/flash_api.cpp - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) \ - TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ - #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -void set_params_fprop(Flash_fwd_params& params, - // sizes - const size_t b, const size_t seqlen_q, const size_t seqlen_k, - const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, - const size_t h_k, const size_t d, const size_t d_rounded, - // device pointers - const at::Tensor q, const at::Tensor k, const at::Tensor v, at::Tensor out, - void* cu_seqlens_q_d, void* cu_seqlens_k_d, void* seqused_k, void* p_d, - void* softmax_lse_d, float p_dropout, float softmax_scale, - int window_size_left, int window_size_right, - bool seqlenq_ngroups_swapped = false, bool unpadded_lse = false) { - // Reset the parameters - params = {}; - - params.is_bf16 = q.dtype() == torch::kBFloat16; - params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - if (cu_seqlens_q_d == nullptr) { - params.q_batch_stride = q.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - params.o_batch_stride = out.stride(0); - if (seqlenq_ngroups_swapped) { - params.q_batch_stride *= seqlen_q; - params.o_batch_stride *= seqlen_q; - } - } - - params.cu_seqlens_q = static_cast(cu_seqlens_q_d); - params.cu_seqlens_k = static_cast(cu_seqlens_k_d); - params.seqused_k = static_cast(seqused_k); - - TORCH_CHECK(bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k), - "cu_seqlens_q and cu_seqlens_k must be both null or non-null"); - - // P = softmax(QK^T) - params.p_ptr = p_d; - - // Softmax sum - params.softmax_lse_ptr = softmax_lse_d; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - - // Set the different scale values. - constexpr float log2e = 1.44269504088896340736f; - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * log2e; - __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2); - __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half); - params.scale_softmax_log2_half2 = reinterpret_cast(scale_softmax_log2_half2); - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - TORCH_CHECK(p_dropout < 1.f); -#ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); -#endif - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - - if (window_size_left < 0 && window_size_right >= 0) { - window_size_left = seqlen_k; - } - if (window_size_left >= 0 && window_size_right < 0) { - window_size_right = seqlen_k; - } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - -#ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), - "This flash attention build does not support local attention."); -#endif - - params.is_seqlens_k_cumulative = true; - -#ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, - "This flash attention build does not support headdim not being a multiple of 32."); -#endif - - params.unpadded_lse = unpadded_lse; -} - -void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { - // HEADDIM_SWITCH(params.d, [&] { - // run_mha_fwd_(params, stream); - // }); - if (!params.is_e4m3) { - if (params.is_bf16) { - if (params.d == 64) { - run_mha_fwd_(params, stream); - } else if (params.d == 128) { - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_(params, stream); - } - } else { - if (params.d == 64) { - run_mha_fwd_(params, stream); - } else if (params.d == 128) { - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_(params, stream); - } - } - } else { - if (params.d == 64) { - run_mha_fwd_(params, stream); - } else if (params.d == 128) { - run_mha_fwd_(params, stream); - } else if (params.d == 256) { - run_mha_fwd_(params, stream); - } - } -} - -std::vector mha_fwd( - at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional& out_, // batch_size x seqlen_q x num_heads x head_size - const float softmax_scale, - c10::optional& descale_q_, // 1 - c10::optional& descale_k_, // 1 - c10::optional& descale_v_, // 1 - bool is_causal) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); - - auto q_dtype = q.dtype(); - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - // "FlashAttention only support fp16 and bf16 data type for now"); - // TODO: will add e4m3 later - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, - // "FlashAttention only support fp16 and bf16 data type"); - // "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - int seqlen_q = sizes[1]; - int num_heads = sizes[2]; - const int head_size_og = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og <= 256, - "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, - "Number of heads in key/value must divide number of heads in query"); - - TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, - "Only support head size 64, 128, and 256 for now"); - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - - at::Tensor q_padded, k_padded, v_padded; - if (head_size_og % 8 != 0) { - q_padded = torch::nn::functional::pad( - q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - k_padded = torch::nn::functional::pad( - k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad( - v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - } else { - q_padded = q; - k_padded = k; - v_padded = v; - } - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn ? (out.dtype() == at::kHalf) - : (out.dtype() == q_dtype), - "Output must have the same dtype as input dtype if dtype is " - "not fp8, or fp16 for fp8 input."); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { - out = torch::empty_like(q_padded); - } - } else { - if (q_dtype == at::ScalarType::Float8_e4m3fn) - out = torch::empty_like(q_padded, at::kHalf); - else - out = torch::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor p; - - Flash_fwd_params params; - set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, head_size, head_size_rounded, q_padded, k_padded, - v_padded, out, - /*cu_seqlens_q_d=*/nullptr, - /*cu_seqlens_k_d=*/nullptr, - /*seqused_k=*/nullptr, nullptr, softmax_lse.data_ptr(), - /*p_dropout=*/0.f, softmax_scale, - /*window_size_left=*/-1, - /*window_size_right=*/is_causal ? 0 : -1); - - auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) - : torch::empty({1}, opts.dtype(torch::kInt32)); - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - - if (q_dtype == at::ScalarType::Float8_e4m3fn) { - at::Tensor descale_q, descale_k, descale_v; - if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) { - descale_q = descale_q_.value(); - descale_k = descale_k_.value(); - descale_v = descale_v_.value(); - CHECK_DEVICE(descale_q); - CHECK_DEVICE(descale_k); - CHECK_DEVICE(descale_v); - CHECK_SHAPE(descale_q, 1); - CHECK_SHAPE(descale_k, 1); - CHECK_SHAPE(descale_v, 1); - } else { - descale_q = torch::ones({1}, opts.dtype(at::kFloat)); - descale_k = torch::ones({1}, opts.dtype(at::kFloat)); - descale_v = torch::ones({1}, opts.dtype(at::kFloat)); - } - params.descale_q_ptr = descale_q.data_ptr(); - params.descale_k_ptr = descale_k.data_ptr(); - params.descale_v_ptr = descale_v.data_ptr(); - } else { - params.descale_q_ptr = nullptr; - params.descale_k_ptr = nullptr; - params.descale_v_ptr = nullptr; - } - - if (seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - } else { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - at::Tensor out_padded = out; - if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - if (out_.has_value()) { - out_.value().copy_(out); - } - } - - return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; -} diff --git a/src/fp8-dev/flashinfer_ops.cu b/src/fp8-dev/flashinfer_ops.cu deleted file mode 100644 index caf2d383c..000000000 --- a/src/fp8-dev/flashinfer_ops.cu +++ /dev/null @@ -1,155 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace flashinfer; - -template -void run_fwd_flashinfer(thrust::device_vector& q_d, thrust::device_vector& k_d, - thrust::device_vector& v_d, thrust::device_vector& o_d, - thrust::device_vector& scale_q_d, - thrust::device_vector& scale_k_d, - thrust::device_vector& scale_v_d, int32_t qo_len, int32_t kv_len, - int32_t num_qo_heads, int32_t num_kv_heads, int32_t head_dim, - float sm_scale, MaskMode mask_mode, QKVLayout kv_layout) { - using IdType = int32_t; - using DTypeScale = float; - - constexpr bool USE_SLIDING_WINDOW = false; - using Params = SinglePrefillParams; - using AttentionVariant = DefaultFP8Attention; - - Params params; - params.q_ptr = static_cast(thrust::raw_pointer_cast(q_d.data())); - params.k_ptr = static_cast(thrust::raw_pointer_cast(k_d.data())); - params.v_ptr = static_cast(thrust::raw_pointer_cast(v_d.data())); - params.o_ptr = static_cast(thrust::raw_pointer_cast(o_d.data())); - params.lse_ptr = nullptr; - // q NHD - params.q_stride_n = num_qo_heads * head_dim; - params.q_stride_h = head_dim; - params.o_stride_n = num_qo_heads * head_dim; - params.o_stride_h = head_dim; - if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = num_kv_heads * head_dim; - params.k_stride_h = head_dim; - params.v_stride_n = num_kv_heads * head_dim; - params.v_stride_h = head_dim; - } else { - // k HND - params.k_stride_h = kv_len * head_dim; - params.k_stride_n = head_dim; - params.v_stride_h = kv_len * head_dim; - params.v_stride_n = head_dim; - } - params.qo_len = qo_len; - params.kv_len = kv_len; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.causal = (mask_mode == MaskMode::kCausal); - params.group_size = params.num_qo_heads / params.num_kv_heads; - params.window_left = 0; - - params.additional_params.scale_q = thrust::raw_pointer_cast(scale_q_d.data()); - params.additional_params.scale_k = thrust::raw_pointer_cast(scale_k_d.data()); - params.additional_params.scale_v = thrust::raw_pointer_cast(scale_v_d.data()); - params.additional_params.sm_scale = sm_scale; - - cudaError_t status; - cudaStream_t stream = 0; - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - status = SingleFP8PrefillWithKVCacheDispatched(params, stream); - }); - }); - if (status != cudaSuccess) { - throw std::runtime_error("Failed to run SingleFP8PrefillWithKVCacheDispatched"); - } -} - -void run_fwd(thrust::device_vector& q_d, - thrust::device_vector& k_d, - thrust::device_vector& v_d, - thrust::device_vector& o_d, thrust::device_vector& scale_q_d, - thrust::device_vector& scale_k_d, thrust::device_vector& scale_v_d, - int32_t qo_len, int32_t kv_len, int32_t num_qo_heads, int32_t num_kv_heads, - int32_t head_dim, float sm_scale, MaskMode mask_mode, QKVLayout kv_layout) { - run_fwd_flashinfer( - q_d, k_d, v_d, o_d, scale_q_d, scale_k_d, scale_v_d, qo_len, kv_len, num_qo_heads, - num_kv_heads, head_dim, sm_scale, mask_mode, kv_layout); -} - -template -void run_fwd_flashinfer(thrust::device_vector& q_d, thrust::device_vector& k_d, - thrust::device_vector& v_d, thrust::device_vector& o_d, - int32_t qo_len, int32_t kv_len, int32_t num_qo_heads, int32_t num_kv_heads, - int32_t head_dim, float sm_scale, MaskMode mask_mode, QKVLayout kv_layout) { - using IdType = int32_t; - - constexpr bool USE_SLIDING_WINDOW = false; - using Params = SinglePrefillParams; - using AttentionVariant = StandardAttention; - - Params params; - params.q_ptr = static_cast(thrust::raw_pointer_cast(q_d.data())); - params.k_ptr = static_cast(thrust::raw_pointer_cast(k_d.data())); - params.v_ptr = static_cast(thrust::raw_pointer_cast(v_d.data())); - params.o_ptr = static_cast(thrust::raw_pointer_cast(o_d.data())); - params.lse_ptr = nullptr; - // q NHD - params.q_stride_n = num_qo_heads * head_dim; - params.q_stride_h = head_dim; - params.o_stride_n = num_qo_heads * head_dim; - params.o_stride_h = head_dim; - if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = num_kv_heads * head_dim; - params.k_stride_h = head_dim; - params.v_stride_n = num_kv_heads * head_dim; - params.v_stride_h = head_dim; - } else { - // k HND - params.k_stride_h = kv_len * head_dim; - params.k_stride_n = head_dim; - params.v_stride_h = kv_len * head_dim; - params.v_stride_n = head_dim; - } - params.qo_len = qo_len; - params.kv_len = kv_len; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.causal = (mask_mode == MaskMode::kCausal); - params.group_size = params.num_qo_heads / params.num_kv_heads; - params.window_left = 0; - params.additional_params.sm_scale = sm_scale; - - cudaError_t status; - cudaStream_t stream = 0; - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - status = SinglePrefillWithKVCacheDispatched(params, stream); - }); - }); - if (status != cudaSuccess) { - throw std::runtime_error("Failed to run SingleFP16PrefillWithKVCacheDispatched"); - } -} - -void run_fwd(thrust::device_vector& q_d, - thrust::device_vector& k_d, - thrust::device_vector& v_d, - thrust::device_vector& o_d, int32_t qo_len, int32_t kv_len, - int32_t num_qo_heads, int32_t num_kv_heads, int32_t head_dim, float sm_scale, - MaskMode mask_mode, QKVLayout kv_layout) { - run_fwd_flashinfer( - q_d, k_d, v_d, o_d, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, sm_scale, mask_mode, - kv_layout); -} diff --git a/src/fp8-dev/test_single_prefill_fa3_sm90.cu b/src/fp8-dev/test_single_prefill_fa3_sm90.cu deleted file mode 100644 index 4f8d3c616..000000000 --- a/src/fp8-dev/test_single_prefill_fa3_sm90.cu +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "cpu_reference.h" -#include "flashattention_ops.h" -#include "utils.h" -using namespace flashinfer; - -void run_fwd(thrust::device_vector& q_d, - thrust::device_vector& k_d, - thrust::device_vector& v_d, - thrust::device_vector& o_d, thrust::device_vector& scale_q_d, - thrust::device_vector& scale_k_d, thrust::device_vector& scale_v_d, - int32_t qo_len, int32_t kv_len, int32_t num_qo_heads, int32_t num_kv_heads, - int32_t head_dim, float sm_scale, MaskMode mask_mode, QKVLayout kv_layout); - -template -void _TestSingleFP8PrefillKernelCorrectness(int32_t qo_len, int32_t kv_len, int32_t num_qo_heads, - int32_t num_kv_heads, int32_t head_dim, - MaskMode mask_mode, QKVLayout kv_layout, size_t seed, - float rtol = 1e-3, float atol = 1e-3) { - using DTypeScale = float; - float sm_scale = 1.f / std::sqrt(float(head_dim)); - - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector scale_q(num_qo_heads); - std::vector scale_k(num_kv_heads); - std::vector scale_v(num_kv_heads); - - std::vector q_fp32(qo_len * num_qo_heads * head_dim); - std::vector k_fp32(kv_len * num_kv_heads * head_dim); - std::vector v_fp32(kv_len * num_kv_heads * head_dim); - - utils::vec_normal_(q_fp32, 0, 16, seed); - utils::vec_normal_(k_fp32, 0, 16, seed); - utils::vec_normal_(v_fp32, 0, 16, seed); - utils::vec_zero_(o); - - cpu_reference::sym_quant_per_head(q_fp32, q, scale_q, qo_len, num_qo_heads, head_dim, kv_layout, - true); - cpu_reference::sym_quant_per_head(k_fp32, k, scale_k, kv_len, num_kv_heads, head_dim, kv_layout, - false); - cpu_reference::sym_quant_per_head(v_fp32, v, scale_v, kv_len, num_kv_heads, head_dim, kv_layout, - false); - - thrust::device_vector q_d(q); - thrust::device_vector k_d(k); - thrust::device_vector v_d(v); - thrust::device_vector o_d(o); - thrust::device_vector scale_q_d(scale_q); - thrust::device_vector scale_k_d(scale_k); - thrust::device_vector scale_v_d(scale_v); - - run_fwd(q_d, k_d, v_d, o_d, scale_q_d, scale_k_d, scale_v_d, qo_len, kv_len, num_qo_heads, - num_kv_heads, head_dim, sm_scale, mask_mode, kv_layout); - - thrust::host_vector o_flashinfer_copy(o_d); - std::vector o_flashinfer(o_flashinfer_copy.begin(), o_flashinfer_copy.end()); - - /* - Below is FA3 implementation API call - */ - auto device = torch::Device(torch::kCUDA, 0); - auto q_t = torch::from_blob(q.data(), {1, qo_len, num_qo_heads, head_dim}, torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto k_t = torch::from_blob(k.data(), {1, kv_len, num_kv_heads, head_dim}, torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto v_t = torch::from_blob(v.data(), {1, kv_len, num_kv_heads, head_dim}, torch::kFloat8_e4m3fn) - .clone() - .to(device); - auto scale_q_t = std::optional( - torch::from_blob(scale_q.data(), {1}, torch::kFloat).clone().to(device)); - auto scale_k_t = std::optional( - torch::from_blob(scale_k.data(), {1}, torch::kFloat).clone().to(device)); - auto scale_v_t = std::optional( - torch::from_blob(scale_v.data(), {1}, torch::kFloat).clone().to(device)); - auto out_t = std::optional{}; - - bool is_causal = mask_mode == MaskMode::kCausal; - auto o_t_ref = - mha_fwd(q_t, k_t, v_t, out_t, sm_scale, scale_q_t, scale_k_t, scale_v_t, is_causal)[0].to( - torch::Device(torch::kCPU)); - DTypeO* o_ref_ptr = static_cast(o_t_ref.data_ptr()); - std::vector o_fa3(o_ref_ptr, o_ref_ptr + o_t_ref.numel()); - - std::vector o_cpu = cpu_reference::single_fp8_mha( - q, k, v, scale_q, scale_k, scale_v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, - sm_scale, is_causal, kv_layout); - - std::vector o_gold = cpu_reference::single_mha( - q_fp32, k_fp32, v_fp32, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, sm_scale, - is_causal, kv_layout); - - float fa3_mse = utils::vec_cal_mse_(o_fa3, o_gold); - float finfer_mse = utils::vec_cal_mse_(o_flashinfer, o_gold); - float fcpu_mse = utils::vec_cal_mse_(o_cpu, o_gold); - printf("%d,%d,%d,%d,%d,%d,%f,%f,%f\n", num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, - int(mask_mode), fcpu_mse, fa3_mse, finfer_mse); -} - -template -void TestSingleFP8PrefillKernelLongContextCorrectness() { - for (size_t seq_len : {128, 256, 512, 1024, 2048, 4096}) { - for (size_t num_qo_heads : {1}) { - for (size_t num_kv_heads : {1}) { - for (size_t head_dim : {64, 128, 256}) { - for (size_t kv_layout : {0}) { - for (size_t mask_mode : {0, 1}) { - for (size_t seed : {600}) - _TestSingleFP8PrefillKernelCorrectness( - seq_len, seq_len, num_qo_heads, num_kv_heads, head_dim, MaskMode(mask_mode), - QKVLayout(kv_layout), seed); - } - } - } - } - } - } -} - -int main() { - printf( - "num_qo_heads,num_kv_heads,qo_len,kv_len,head_dim,mask_mode,cpu_mse,fa3_mse," - "flashinfer_mse\n"); - TestSingleFP8PrefillKernelLongContextCorrectness(); - return 0; -} diff --git a/src/fp8-dev/utils.h b/src/fp8-dev/utils.h deleted file mode 100644 index 083ae6b8a..000000000 --- a/src/fp8-dev/utils.h +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace utils { - -template -float vec_cal_mse_(std::vector& x, std::vector& y) { - float mse = 0.0; - for (size_t i = 0; i < x.size(); ++i) { - float diff = float(x[i]) - float(y[i]); - mse += diff * diff; - } - return mse / x.size(); -} - -template -void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f, int seed = 327) { - std::mt19937 gen{seed}; - // std::random_device rd{}; - // std::mt19937 gen{rd()}; - - std::normal_distribution d{mean, std}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(cutlass::float_e4m3_t(d(gen))); - } - - if constexpr (sizeof(T) == 1) { - // view vec as uint8_t - uint8_t* data = reinterpret_cast(vec.data()); - // random shuffle - std::shuffle(data, data + vec.size(), gen); - } else { - // random shuffle - std::shuffle(vec.begin(), vec.end(), gen); - } -} - -template -void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_real_distribution d{a, b}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } -} - -template -void vec_zero_(std::vector& vec) { - std::fill(vec.begin(), vec.end(), T(0)); -} - -template -void vec_fill_(std::vector& vec, T val) { - std::fill(vec.begin(), vec.end(), T(val)); -} - -template -void vec_randint_(std::vector& vec, int low, int high) { - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_int_distribution d{low, high}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } -} - -template -void vec_randbin_(std::vector& vec) { - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_int_distribution d{0, 1024}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen) % 2 ? -2 : 2); - } -} - -template -size_t vec_bytes(const T& vec) { - return vec.size() * sizeof(typename T::value_type); -} - -template -bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { - return fabs(a - b) <= (atol + rtol * fabs(b)); -} - -template -std::tuple>, std::vector>> -create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_length, - size_t unique_kv_length, size_t qo_append_length, - size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, - size_t page_size) { - uint32_t num_pages = ((shared_prefix_length + unique_kv_length * batch_size) / page_size); - std::vector shared_k_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector shared_v_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector q_h((batch_size * qo_append_length) * num_qo_heads * head_dim); - - utils::vec_normal_(shared_k_h); - utils::vec_normal_(shared_v_h); - utils::vec_normal_(q_h); - - std::vector qo_indptr{0}; - std::vector kv_indptr_combined_h{0}; - std::vector kv_indptr_unique_h{0}; - std::vector kv_last_page_len_combined_h; - std::vector kv_last_page_len_unique_h; - - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - qo_indptr.push_back(qo_indptr.back() + qo_append_length); - kv_indptr_combined_h.push_back(kv_indptr_combined_h.back() + - (shared_prefix_length + unique_kv_length) / page_size); - kv_indptr_unique_h.push_back(kv_indptr_unique_h.back() + unique_kv_length / page_size); - kv_last_page_len_combined_h.push_back(page_size); - kv_last_page_len_unique_h.push_back(page_size); - } - - std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); - std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); - - std::vector kv_data_h(num_pages * 2 * num_kv_heads * page_size * head_dim); - uint32_t page_id = 0; - - for (; page_id < (shared_prefix_length / page_size); page_id++) { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - std::copy( - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, - kv_data_h.begin() + - (((page_id * 2 + 0) * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); - std::copy( - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, - kv_data_h.begin() + - (((page_id * 2 + 1) * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); - } - } - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + - page_id] = page_id; - } - } - - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - for (uint32_t page_iter = 0; page_iter < (unique_kv_length / page_size); - ++page_iter, ++page_id) { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - std::vector k(head_dim), v(head_dim); - utils::vec_normal_(k); - utils::vec_normal_(v); - std::copy(k.begin(), k.end(), - kv_data_h.begin() + - (((page_id * 2 + 0) * num_kv_heads + head_idx) * page_size + entry_idx) * - head_dim); - std::copy(v.begin(), v.end(), - kv_data_h.begin() + - (((page_id * 2 + 1) * num_kv_heads + head_idx) * page_size + entry_idx) * - head_dim); - } - } - kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + - (shared_prefix_length / page_size) + page_iter] = page_id; - kv_indices_unique_h[request_id * (unique_kv_length / page_size) + page_iter] = page_id; - } - } - return std::make_tuple>, std::vector>>( - {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), std::move(kv_data_h)}, - {std::move(qo_indptr), std::move(kv_indices_combined_h), std::move(kv_indices_unique_h), - std::move(kv_indptr_combined_h), std::move(kv_indptr_unique_h), - std::move(kv_last_page_len_combined_h), std::move(kv_last_page_len_unique_h)}); -} - -} // namespace utils From 903622518309988e1d3275e7b95ab3daa93f54a6 Mon Sep 17 00:00:00 2001 From: happierpig Date: Wed, 23 Apr 2025 23:10:42 +0000 Subject: [PATCH 05/14] fix: use stsm to directly do column permutation --- .../attention/hopper/quantization/epilogue.cuh | 15 ++++++++------- .../hopper/quantization/kernel_traits.cuh | 2 +- .../hopper/quantization/mainloop_load.cuh | 4 ++-- tests/test_hopper_fp8_attention.py | 8 ++++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/flashinfer/attention/hopper/quantization/epilogue.cuh b/include/flashinfer/attention/hopper/quantization/epilogue.cuh index 501e84b23..88063d1cc 100644 --- a/include/flashinfer/attention/hopper/quantization/epilogue.cuh +++ b/include/flashinfer/attention/hopper/quantization/epilogue.cuh @@ -113,18 +113,20 @@ struct FP8CollectiveEpilogue { auto [qo_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; - TiledCopyrO rmem_tiled_copy_O; - Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{}); - auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx); + // No need for FP8 column permutation + // as it has been done in the Transpose Phase. + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc); Tensor tOrO_out = convert_type(tOrO); - Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO)); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Make sure all WGs have finished reading V cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, /*id=*/static_cast(NamedBarriers::kValueEmpty)); - cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.lse_ptr), epilogue_params.layout_LSE); Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape>{}, qo_head_idx, qo_indptr, @@ -151,7 +153,6 @@ struct FP8CollectiveEpilogue { int write_warp_idx = NUM_WARPS - 1; TiledCopyO gmem_tiled_copy_O; - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, select<0, 2>(TileShape_QKD{}), sO, thread_idx, qo_tile_idx, qo_head_idx, qo_indptr, qo_len, write_warp_idx); diff --git a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh index 527424457..9bb09f320 100644 --- a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh @@ -57,7 +57,7 @@ struct TranposeTraits_64x64 { static_assert(cutlass::sizeof_bits_v == 8); using SmemShapeLDSM = Shape, Shape<_16, _4>>; - using SmemShapeSTSM = Shape, Shape<_8, _8>>; + using SmemShapeSTSM = Shape, Shape<_16, _4>>; using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtom_{})); diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh index d909ac584..82142d488 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh @@ -48,8 +48,8 @@ struct SmemTransposeFP8_64x64 { TiledCopyLDSM tiled_copy_ldsm; using stsm_thread_shape = Shape<_4, _1, _8, _4>; - using stsm_value_shape = Shape<_4, _4, _1, _2>; - using stsm_value_stride = Stride<_1, _8, _0, _4>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, diff --git a/tests/test_hopper_fp8_attention.py b/tests/test_hopper_fp8_attention.py index 2e90cf6d2..124db3d03 100644 --- a/tests/test_hopper_fp8_attention.py +++ b/tests/test_hopper_fp8_attention.py @@ -67,16 +67,16 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): assert not torch.any(torch.isnan(o_ref)) # MSE - mse = torch.mean((o_ref - o_fp8) ** 2) + mse = torch.mean((o_ref.float() - o_fp8.float()) ** 2) print( - f"test_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}, dtype={dtype}), MSE: {mse:.3f}" + f"test_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}, dtype={dtype}), MSE: {mse:.5f}" ) if __name__ == "__main__": - for seq_len in [512, 1024, 4096, 8192]: + for seq_len in [117, 509, 1011, 2372, 7777]: for num_heads in [24, 32]: - for causal in [False]: + for causal in [True, False]: for head_dim in [64, 128, 256]: for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: test_single_prefill(seq_len, num_heads, causal, head_dim, dtype) From a7874a4e2fcf2a972cb99b132d9e183b7bbcbe36 Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 00:16:05 +0000 Subject: [PATCH 06/14] format --- .../hopper/quantization/kernel_traits.cuh | 76 ++++ .../hopper/quantization/mainloop_load.cuh | 77 +---- .../quantization/mainloop_sparse_load.cuh | 326 ++++++++++++++++++ 3 files changed, 403 insertions(+), 76 deletions(-) create mode 100644 include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh diff --git a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh index 9bb09f320..da5b3da96 100644 --- a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh @@ -87,6 +87,82 @@ struct TranposeTraits_64x64 { decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); }; +/* + In-kernel Transpose of smemV into smemVt with ldmatrix.trans & stmatrix. + Note that all magic number corresponds to the /quantization/kernel_traits.cuh setup. + This transpose is not a general transpose, but a specific one for the FP8 MMA_PV: + 1. K-dimension: (2,2,4,4):(1,8,2,16), which adheres to the accum_P's layout + 2. N-dimension: (8,2,4):(2,1,16), which needs repermutation when rmemO -> smemO +*/ +template +struct SmemTransposeFP8_64x64 { + using Element = typename Ktraits::DTypeKV; + using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; + using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; + static_assert(cutlass::sizeof_bits_v == 8); + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + // use trans to do 16bits transpose + // which needs permutation to separate 8bits row and column + using TiledCopyLDSM = + decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = + decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void _tranpose(SmemTensor&& s_in, SmemTensorOut&& s_out) { + using namespace cute; + + auto tid = threadIdx.x; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t* data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + // select row-major elements. + // from (0 1 16 17) (128 129 144 145) to (0 16 128 144) (1 17 129 145) + // which is (0 1 8 9) + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } + + template + CUTLASS_DEVICE void do_transpose(SmemTensor& s_in, SmemTensorOut& s_out, int stage_idx) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(SmemLayoutVTransposeSrc{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(SmemLayoutVTransposeSrc{}); ++i) { + this->_tranpose(flatten(s_in(_, i, j, stage_idx)), flatten(s_out(_, i, j, stage_idx))); + } + } + // For FP8 kernel, all WG threads will arrive for issuing ldmatrix + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_PRODUCER_THREADS, + static_cast(NamedBarriers::kProducerWG) /*id*/); + } +}; + template diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh index 82142d488..c5211ee5e 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh @@ -18,87 +18,12 @@ #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/pipeline/pipeline.hpp" +#include "kernel_traits.cuh" namespace flashinfer { using namespace cute; -/* - In-kernel Transpose of smemV into smemVt with ldmatrix.trans & stmatrix. - Note that all magic number corresponds to the /quantization/kernel_traits.cuh setup. - This transpose is not a general transpose, but a specific one for the FP8 MMA_PV: - 1. K-dimension: (2,2,4,4):(1,8,2,16), which adheres to the accum_P's layout - 2. N-dimension: (8,2,4):(2,1,16), which needs repermutation when rmemO -> smemO -*/ -template -struct SmemTransposeFP8_64x64 { - using Element = typename Ktraits::DTypeKV; - using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; - using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; - static_assert(cutlass::sizeof_bits_v == 8); - - using ldsm_thread_shape = Shape<_4, _1, _8, _4>; - using ldsm_value_shape = Shape<_2, _8, _2, _1>; - using ldsm_value_stride = Stride<_2, _4, _1, _0>; - // use trans to do 16bits transpose - // which needs permutation to separate 8bits row and column - using TiledCopyLDSM = - decltype(make_tiled_copy(Copy_Atom{}, Layout{}, - Layout{})); - TiledCopyLDSM tiled_copy_ldsm; - - using stsm_thread_shape = Shape<_4, _1, _8, _4>; - using stsm_value_shape = Shape<_4, _4, _2, _1>; - using stsm_value_stride = Stride<_1, _8, _4, _0>; - - using TiledCopySTSM = - decltype(make_tiled_copy(Copy_Atom{}, Layout{}, - Layout{})); - TiledCopySTSM tiled_copy_stsm; - - template - CUTLASS_DEVICE void _tranpose(SmemTensor&& s_in, SmemTensorOut&& s_out) { - using namespace cute; - - auto tid = threadIdx.x; - auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); - auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); - - auto tXsX = thr_copy_ldsm.partition_S(s_in); - auto tXrX = make_tensor(shape(tXsX)); - auto tXsX_out = thr_copy_stsm.partition_D(s_out); - - cute::copy(tiled_copy_ldsm, tXsX, tXrX); - auto data = tXrX.data(); - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size(tXrX); n += 8) { - uint32_t* data_32bit = reinterpret_cast(&data[n]); - auto upper = data_32bit[0]; - auto lower = data_32bit[1]; - // select row-major elements. - // from (0 1 16 17) (128 129 144 145) to (0 16 128 144) (1 17 129 145) - // which is (0 1 8 9) - data_32bit[0] = __byte_perm(upper, lower, 0x6420); - data_32bit[1] = __byte_perm(upper, lower, 0x7531); - } - cute::copy(tiled_copy_stsm, tXrX, tXsX_out); - } - - template - CUTLASS_DEVICE void do_transpose(SmemTensor& s_in, SmemTensorOut& s_out, int stage_idx) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < shape<2>(SmemLayoutVTransposeSrc{}); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < shape<1>(SmemLayoutVTransposeSrc{}); ++i) { - this->_tranpose(flatten(s_in(_, i, j, stage_idx)), flatten(s_out(_, i, j, stage_idx))); - } - } - // For FP8 kernel, all WG threads will arrive for issuing ldmatrix - cutlass::arch::NamedBarrier::sync(Ktraits::NUM_PRODUCER_THREADS, - static_cast(NamedBarriers::kProducerWG) /*id*/); - } -}; - template struct FP8CollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh new file mode 100644 index 000000000..6a99a8c73 --- /dev/null +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_FP8_SPARSE_MAINLOOP_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_FP8_SPARSE_MAINLOOP_CUH_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../math.cuh" +#include "../block_sparse_gather.cuh" +#include "../named_barrier.cuh" +#include "../utils.cuh" +#include "kernel_traits.cuh" + +namespace flashinfer { + +using namespace cute; + +template +struct SparseCollectiveMainloop { + using DTypeQ = typename Ktraits::DTypeQ; + using DTypeKV = typename Ktraits::DTypeKV; + using IdType = typename Ktraits::IdType; + using TileShape_QKD = typename Ktraits::TileShape_QKD; + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + + static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; + static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; + static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits::value; + using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; + + using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; + using GmemTiledCopyKV = + decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< + GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, + cutlass::detail::TagToStrideB_t, + decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); + + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + + using ShapeT = cute::Shape; + using StrideT = cute::Shape; // (N, D, H) + using LayoutT = cute::Layout; + + using ShapeLseT = cute::Shape; + using StrideLseT = cute::Shape<_1, int64_t>; + using LayoutLseT = cute::Layout; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), StrideT{}), + SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + + static constexpr bool USE_TMA_LOAD_KV = false; + static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + static constexpr uint32_t TmaTransactionBytesQ = + static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + + static constexpr bool UseSchedulerBarrier = + cutlass::sizeof_bits_v == 8 ? HEAD_DIM >= 128 : HEAD_DIM <= 128; + using WarpScheduler = WarpScheduler; + + // Host side kernel arguments + struct Arguments { + DTypeQ const* Q_ptr; + LayoutT layout_Q; + DTypeKV const* K_ptr; + LayoutT layout_K; + DTypeKV const* V_ptr; + LayoutT layout_V; + IdType const* kv_indices; + int window_left; + AdditionalParams additional_params; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + TMA_Q tma_load_Q; + DTypeKV* K_ptr; + DTypeKV* V_ptr; + IdType* kv_indices; + int window_left; + AdditionalParams additional_params; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); + TMA_Q tma_load_Q = + make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{}); + return {args.layout_Q, + args.layout_K, + args.layout_V, + tma_load_Q, + const_cast(args.K_ptr), + const_cast(args.V_ptr), + const_cast(args.kv_indices), + args.window_left, + args.additional_params}; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, + const int kv_len) { + static constexpr int CTA_Q = get<0>(TileShape_QKD{}); + static constexpr int CTA_KV = get<1>(TileShape_QKD{}); + int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); + if constexpr (CAUSAL) { + num_kv_tiles = std::min(num_kv_tiles, + cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); + } + + return num_kv_tiles; + } + + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, + Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + BlockCoord const& block_coord, int work_idx) { + int thread_idx = threadIdx.x; + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; + + // Prepare the TMA loads + Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, + qo_len)(_, _, q_tile_idx); // (Q, D) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = + tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + + int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + int kv_tile_idx = num_kv_tiles - 1; + int swa_begin_kv_tile_idx = 0; + if constexpr (LEFT_SLIDING_WINDOW) { + swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, + q_tile_idx, qo_len, kv_len); + } + + constexpr int HEAD_DIM = get<2>(TileShape_QKD{}); + constexpr int CTA_KV = get<1>(TileShape_QKD{}); + auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); + + Tensor mK = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_K), indexed_gather); + Tensor mV = make_block_sparse_tensor( // (kv_len, D) + make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), + make_shape(kv_len, HEAD_DIM), stride<0>(mainloop_params.layout_V), indexed_gather); + + Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor gV = local_tile(mV, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D, kv) + Tensor cKV = cute::make_identity_tensor(gK.shape()); + + GmemTiledCopyKV gmem_tiled_copy_kv; + auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); + + Tensor tKgK = gmem_thr_copy_kv.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) + Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tVgV = gmem_thr_copy_kv.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) + Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) + Tensor tKVcKV = gmem_thr_copy_kv.partition_D(cKV); // (CPY, CPY_KV, CPY_D) + Tensor tKVcKVGroup = flatten_1(tKVcKV); // (CPY, (CPY_KV, CPY_D)) + + int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); + auto predicate_fn = [&](auto coords) { + auto s_coords = tKVcKVGroup(_0{}, coords); + return elem_less(get<0>(s_coords), valid_last_kv_tile_size); + }; + + // load last k-tile + { + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tKsKiGroup = + flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + } + + // load Q tile + if (warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, + static_cast(NamedBarriers::kQueryEmpty)); + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); + } + } + + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (kv_tile_idx == swa_begin_kv_tile_idx) { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } else { + // load second last k-tile and last v-tile + pipeline_k.producer_acquire(smem_pipe_write_k); + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsViGroup = + flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + --kv_tile_idx; + ++smem_pipe_write_v; + + // load remaining k/v tiles +#pragma unroll 2 + for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { + pipeline_k.producer_acquire(smem_pipe_write_k); + + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + + pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_k; + + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + + // load first v tile + { + pipeline_v.producer_acquire(smem_pipe_write_v); + Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) + copy(gmem_tiled_copy_kv, tVgVi, tVsVi); + pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); + ++smem_pipe_write_v; + } + } + + scheduler.broadcast_next_work(work_tile_info); + } + + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_FP8_SPARSE_MAINLOOP_CUH_ From 3ba5c57c035f60109b528814710bb331ac9139b9 Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 05:34:25 +0000 Subject: [PATCH 07/14] add fp8 v tranpose into sparse mainloop --- ...h_prefill_fp8_paged_sm90_kernel_inst.jinja | 15 ++ ..._prefill_fp8_ragged_sm90_kernel_inst.jinja | 1 + csrc/batch_prefill_fp8_sm90.cu | 179 ++++++++++++++++++ flashinfer/jit/attention/pytorch.py | 50 +++-- flashinfer/prefill.py | 65 +++++-- flashinfer/sparse.py | 47 ++++- .../quantization/mainloop_sparse_load.cuh | 167 +++++++++------- .../hopper/quantization/prefill_sm90.cuh | 120 +++++++++++- 8 files changed, 535 insertions(+), 109 deletions(-) create mode 100644 csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja create mode 100644 csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja create mode 100644 csrc/batch_prefill_fp8_sm90.cu diff --git a/csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja b/csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja new file mode 100644 index 000000000..d7ead0bdf --- /dev/null +++ b/csrc/batch_prefill_fp8_paged_sm90_kernel_inst.jinja @@ -0,0 +1,15 @@ +#include +#include "batch_prefill_sm90_config.inc" + +namespace flashinfer { + +{% for same_scheduler_for_all_heads in ["true", "false"] %} +template cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched + <{{ head_dim_qk }}, + {{ mask_mode }}, + /*USE_SLIDING_WINDOW=*/{{ use_sliding_window }}, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }}, + {{ variant_name }}, PagedParams>(PagedParams& params, cudaStream_t stream); +{% endfor %} + +}; // namespace flashinfer diff --git a/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja new file mode 100644 index 000000000..8225edbb0 --- /dev/null +++ b/csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja @@ -0,0 +1 @@ +// TODO: Not implemented yet diff --git a/csrc/batch_prefill_fp8_sm90.cu b/csrc/batch_prefill_fp8_sm90.cu new file mode 100644 index 000000000..6bf8643dd --- /dev/null +++ b/csrc/batch_prefill_fp8_sm90.cu @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "batch_prefill_sm90_config.inc" +#include "pytorch_conversion_utils.h" +#include "pytorch_extension_utils.h" + +namespace flashinfer { + +template +cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream); + +} // namespace flashinfer + +using namespace flashinfer; + +at::Tensor BatchPrefillWithKVCacheSM90Plan( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + flashinfer::PrefillPlanSM90Info plan_info; + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + + cudaError_t status = + PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_len_arr.data_ptr(), total_num_rows, + batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, + causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + +void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, + at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, + std::optional maybe_lse, + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS) { + return; // TODO: Implement this function +} + +void BatchPrefillWithPagedKVCacheSM90Run( + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, + at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, + at::Tensor o, std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS) { + PrefillPlanSM90Info plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + } + QKVLayout kv_layout = static_cast(layout); + int64_t num_kv_heads, page_size; + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = paged_v_cache.size(3); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + } else { + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); + } + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = paged_k_cache.scalar_type(); + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + const MaskMode mask_mode = static_cast(mask_mode_code); + bool use_swa = window_left != -1; + + DISPATCH_context( + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { + PagedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) { + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); + } else { + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); + } + params.nnz_qo = q.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.page_size = page_size; + params.window_left = window_left; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); + + ADDITIONAL_PARAMS_SETTER + + // Not support various head_dim for now + static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same"); + // Currently only support same quantization precision + static_assert(std::is_same_v); + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { + cudaError_t status = + BatchFP8PrefillWithPagedKVCacheDispatched(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }); + }); +} diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index fc497377c..39cc6cc75 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -746,7 +746,11 @@ def gen_batch_prefill_module( use_fp16_qk_reduction, ) + # use `fp8_enabled` flag to use separate kernel template + fp8_enabled = "e4m3" in uri or "e5m2" in uri + if backend == "fa2": + assert not fp8_enabled, "fp8 is not supported in fa2 backend" additional_tensor_names = [ "maybe_custom_mask", "maybe_mask_indptr", @@ -767,12 +771,20 @@ def gen_batch_prefill_module( variant_name = f"DefaultAttention" variant_decl = f"#include" else: - additional_tensor_names = [] - additional_tensor_dtypes = [] - additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["double", "double"] - variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" - variant_decl = f"#include" + if not fp8_enabled: + additional_tensor_names = [] + additional_tensor_dtypes = [] + additional_scalar_names = ["logits_soft_cap", "sm_scale"] + additional_scalar_dtypes = ["double", "double"] + variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" + variant_decl = f"#include" + else: + additional_tensor_names = ["scale_q", "scale_k", "scale_v"] + additional_tensor_dtypes = ["float", "float", "float"] + additional_scalar_names = ["sm_scale"] + additional_scalar_dtypes = ["double"] + variant_name = f"DefaultFP8Attention" + variant_decl = f"#include" return gen_customize_batch_prefill_module( backend, @@ -793,6 +805,7 @@ def gen_batch_prefill_module( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=use_fp16_qk_reduction, + fp8_enabled=fp8_enabled, ) @@ -1142,6 +1155,7 @@ def gen_customize_batch_prefill_module( use_sliding_window: bool = False, use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, + fp8_enabled: bool = False, ): kwargs = { "variant_decl": variant_decl, @@ -1241,19 +1255,23 @@ def gen_customize_batch_prefill_module( ) ) - with open( - FLASHINFER_CSRC_DIR / "batch_prefill_sm90_customize_config.jinja" - ) as f: + _file_config = "batch_prefill_sm90_customize_config.jinja" + if fp8_enabled: + _file_paged_kernel_inst = "batch_prefill_fp8_paged_sm90_kernel_inst.jinja" + _file_ragged_kernel_inst = "batch_prefill_fp8_ragged_sm90_kernel_inst.jinja" + _file_csrc = "batch_prefill_fp8_sm90.cu" + else: + _file_paged_kernel_inst = "batch_prefill_paged_sm90_kernel_inst.jinja" + _file_ragged_kernel_inst = "batch_prefill_ragged_sm90_kernel_inst.jinja" + _file_csrc = "batch_prefill_sm90.cu" + + with open(FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open( - FLASHINFER_CSRC_DIR / "batch_prefill_paged_sm90_kernel_inst.jinja" - ) as f: + with open(FLASHINFER_CSRC_DIR / _file_paged_kernel_inst) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open( - FLASHINFER_CSRC_DIR / "batch_prefill_ragged_sm90_kernel_inst.jinja" - ) as f: + with open(FLASHINFER_CSRC_DIR / _file_ragged_kernel_inst) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) kwargs |= { @@ -1284,7 +1302,7 @@ def gen_customize_batch_prefill_module( write_if_different(dest_path, source) for filename in [ - "batch_prefill_sm90.cu", + _file_csrc, "batch_prefill_sm90_jit_pybind.cu", ]: src_path = FLASHINFER_CSRC_DIR / filename diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 69ce84d21..bb463cf03 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -358,10 +358,14 @@ def paged_run( maybe_alibi_slopes: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, ) -> None: if backend == "fa2": + assert not is_float8(q) paged_run_func( float_workspace_buffer, int_workspace_buffer, @@ -387,25 +391,48 @@ def paged_run( 1.0 / rope_theta, # rope_rcp_theta ) else: - paged_run_func( - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - paged_k_cache, - paged_v_cache, - qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - o, - maybe_lse, - mask_mode, - layout, - window_left, - logits_soft_cap, - sm_scale, - ) + if not is_float8(q): + paged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + o, + maybe_lse, + mask_mode, + layout, + window_left, + logits_soft_cap, + sm_scale, + ) + else: + paged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + o, + maybe_lse, + mask_mode, + layout, + window_left, + scale_q, + scale_k, + scale_v, + sm_scale, + ) return o @register_fake_op(f"flashinfer::{uri}_paged_run") diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 1215be38b..891593d1c 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -32,6 +32,7 @@ _get_cache_alibi_slopes_buf, canonicalize_torch_dtype, determine_attention_backend, + is_float8, ) @@ -206,6 +207,7 @@ def plan( rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, + o_data_type: Union[str, torch.dtype] = "float16", non_blocking: bool = True, ) -> None: r"""Create auxiliary data structures for block sparse attention. @@ -268,6 +270,9 @@ def plan( The data type of the query tensor. kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. + o_data_type : str, optional + The data type of the output tensor. Default is ``half``. As output dtype cannot + be inferred by input dtype in quantization non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. @@ -284,6 +289,7 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) + self._o_dtype = canonicalize_torch_dtype(o_data_type) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -352,7 +358,7 @@ def plan( self._cached_module = get_batch_decode_module( q_data_type, kv_data_type, - q_data_type, + o_data_type, indptr.dtype, head_dim, head_dim, @@ -395,7 +401,7 @@ def plan( get_module_args = ( q_data_type, kv_data_type, - q_data_type, + self._o_dtype, indptr.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo @@ -459,6 +465,9 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, @@ -473,13 +482,16 @@ def forward( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta - return self.run(q, k, v) + return self.run(q, k, v, scale_q, scale_k, scale_v) def run( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale_q: Optional[torch.Tensor] = None, + scale_k: Optional[torch.Tensor] = None, + scale_v: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -494,6 +506,15 @@ def run( The key tensor with shape ``(N, num_kv_heads, head_dim)``. v : torch.Tensor The value tensor with shape ``(N, num_kv_heads, head_dim)``. + scale_q : Optional[torch.Tensor] + The scale tensor for query, per-head quantization with shape: ``[num_qo_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. + scale_k : Optional[torch.Tensor] + The scale tensor for key, per-head quantization with shape: ``[num_kv_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. + scale_v : Optional[torch.Tensor] + The scale tensor for value, per-head quantization with shape: ``[num_kv_heads]``. + Used with FP8 Quantization. If not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] @@ -541,9 +562,22 @@ def run( ) if out is None: - out = torch.empty_like(q) + out = torch.empty_like(q, dtype=self._o_dtype) else: - _check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out") + _check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") + + if is_float8(q): + assert q.dtype == k.dtype == v.dtype + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert self._backend == "fa3" and self._use_tensor_cores + + if scale_q is None: + scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) + if scale_k is None: + scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) + if scale_v is None: + scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) + if self._use_tensor_cores: if self._backend == "fa3": sparse_indices = block_sparse_indices_to_vector_sparse_offsets( @@ -582,6 +616,9 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], self.device), logits_soft_cap, sm_scale, + scale_q, + scale_k, + scale_v, rope_scale, rope_theta, ) diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index 6a99a8c73..d48af2e12 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -36,7 +36,7 @@ namespace flashinfer { using namespace cute; template -struct SparseCollectiveMainloop { +struct FP8SparseCollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; @@ -46,16 +46,17 @@ struct SparseCollectiveMainloop { static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; static constexpr int HEAD_DIM = Ktraits::HEAD_DIM; - static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; using GmemTiledCopyQ = cute::SM90_TMA_LOAD; static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits::value; using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; + // Use ZFILL for out-of-bound V loading (avoid nan) using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< - GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, + GmemCopyAtomKV, Ktraits::NUM_PRODUCER_THREADS, AlignmentKV, cutlass::detail::TagToStrideB_t, decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); @@ -78,11 +79,13 @@ struct SparseCollectiveMainloop { repeat_like(StrideT{}, int32_t(0)), StrideT{}), SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q + // for sparse loading, we use cp.async static constexpr bool USE_TMA_LOAD_KV = false; - static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA; + using PipelineParamsVt = typename MainloopPipelineVt::Params; static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); @@ -115,6 +118,7 @@ struct SparseCollectiveMainloop { IdType* kv_indices; int window_left; AdditionalParams additional_params; + using DTypeKV = typename Ktraits::DTypeKV; }; static Params to_underlying_arguments(Arguments const& args) { @@ -154,19 +158,33 @@ struct SparseCollectiveMainloop { template CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, - PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, - Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, + MainloopPipeline pipeline_v, MainloopPipelineVt pipeline_vt, + PipelineState& smem_pipe_write, PipelineState& smem_pipe_read, + SharedStorage& shared_storage, Scheduler& scheduler, + typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, BlockCoord const& block_coord, int work_idx) { int thread_idx = threadIdx.x; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); + bool issue_tma_thread = (warp_idx_in_warpgroup == 0) && (elect_one_sync() == 1); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + // *** Prepare In-kernel V Transpose *** + using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; + using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; + + Tensor sV_src = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVTransposeSrc{})); + Tensor sVt_tgt = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVtTransposeTgt{})); + auto v_tranposer = SmemTransposeFP8_64x64(); + /* ----- V Transpose ---- */ + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = block_coord; // Prepare the TMA loads @@ -219,94 +237,112 @@ struct SparseCollectiveMainloop { }; // load last k-tile + // all threads are issuing as TMA is disabled { - pipeline_k.producer_acquire(smem_pipe_write_k); + pipeline_k.producer_acquire(smem_pipe_write); Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tKsKiGroup = - flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) + flatten_1(tKsK(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_kv, predicate_fn, tKgKiGroup, tKsKiGroup); - - pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_k; + pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } + // Wait for the MMA warpgroups to say that smem_q is ready + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, + static_cast(NamedBarriers::kQueryEmpty)); // load Q tile - if (warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, - static_cast(NamedBarriers::kQueryEmpty)); - - int lane_predicate = cute::elect_one_sync(); - if (lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with( - reinterpret_cast( - shared_storage.barrier_Q), - /*mcast_mask=*/0), - tQgQ, tQsQ); - } + if (issue_tma_thread) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast( + shared_storage.barrier_Q), + /*mcast_mask=*/0), + tQgQ, tQsQ); } shared_storage.barrier_O.wait((work_idx + 1) % 2); if (kv_tile_idx == swa_begin_kv_tile_idx) { - pipeline_v.producer_acquire(smem_pipe_write_v); + // first tile is the last tile + pipeline_v.producer_acquire(smem_pipe_write); Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); - - pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_v; + pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Transpose V + pipeline_v.consumer_wait(smem_pipe_read); + v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer + pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer + ++smem_pipe_read; + ++smem_pipe_write; // update state, as K is loaded 1 step faster } else { // load second last k-tile and last v-tile - pipeline_k.producer_acquire(smem_pipe_write_k); - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); - - pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_k; - - pipeline_v.producer_acquire(smem_pipe_write_v); + pipeline_v.producer_acquire(smem_pipe_write); Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = - flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) + flatten_1(tVsV(_, _, _, smem_pipe_write.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_kv, predicate_fn, tVgViGroup, tVsViGroup); + pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Transpose V + pipeline_v.consumer_wait(smem_pipe_read); + v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer + pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer + ++smem_pipe_read; + ++smem_pipe_write; // update state, as K is loaded 1 step faster + + pipeline_k.producer_acquire(smem_pipe_write); + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); - pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); --kv_tile_idx; - ++smem_pipe_write_v; // load remaining k/v tiles #pragma unroll 2 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { - pipeline_k.producer_acquire(smem_pipe_write_k); - - Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) - Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) - copy(gmem_tiled_copy_kv, tKgKi, tKsKi); - - pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_k; - - pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) + pipeline_v.producer_acquire(smem_pipe_write); + Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_kv, tVgVi, tVsVi); - - pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_v; + pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Transpose V + pipeline_v.consumer_wait(smem_pipe_read); + v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer + pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer + ++smem_pipe_read; + ++smem_pipe_write; // update state, as K is loaded 1 step faster + + pipeline_k.producer_acquire(smem_pipe_write); + Tensor tKgKi = tKgK(_, _, _, kv_tile_idx - 1); // (CPY, CPY_KV, CPY_D) + Tensor tKsKi = tKsK(_, _, _, smem_pipe_write.index()); // (CPY, CPY_KV, CPY_D) + copy(gmem_tiled_copy_kv, tKgKi, tKsKi); + pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } scheduler.prefetch_next_work(scheduler_params, work_tile_info); // load first v tile { - pipeline_v.producer_acquire(smem_pipe_write_v); - Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) - Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) + pipeline_v.producer_acquire(smem_pipe_write); + Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) + Tensor tVsVi = tVsV(_, _, _, smem_pipe_write.index()); // (CPY, (CPY_KV, CPY_D)) copy(gmem_tiled_copy_kv, tVgVi, tVsVi); - pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); - ++smem_pipe_write_v; + pipeline_v.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Transpose V + pipeline_v.consumer_wait(smem_pipe_read); + v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer + pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer + ++smem_pipe_read; + ++smem_pipe_write; // update state, as K is loaded 1 step faster } } @@ -314,10 +350,9 @@ struct SparseCollectiveMainloop { } CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, - PipelineState& smem_pipe_write_k, - PipelineState& smem_pipe_write_v) { - pipeline_k.producer_tail(smem_pipe_write_k); - pipeline_v.producer_tail(smem_pipe_write_v); + PipelineState& smem_pipe_write) { + pipeline_k.producer_tail(smem_pipe_write); + pipeline_v.producer_tail(smem_pipe_write); } }; diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 458b7ac17..450c63191 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -31,6 +31,7 @@ #include "kernel_traits.cuh" #include "mainloop_load.cuh" #include "mainloop_mma.cuh" +#include "mainloop_sparse_load.cuh" namespace flashinfer { @@ -55,7 +56,7 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; // We always assign one WG as producer // For FP8 kernel, all 4 warps collectively process ldmatrix with ldmatrix - static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; + static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS; static constexpr int CTA_Q = Ktraits::CTA_Q; static constexpr int CTA_KV = Ktraits::CTA_KV; @@ -112,19 +113,21 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp }(); MainloopPipeline pipeline_v = [&] { + // specialized for shared memory of V transpose pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; - pipeline_params.num_consumers = Ktraits::NUM_PRODUCER_THREADS; if constexpr (use_tma_load_kv) { + pipeline_params.num_consumers = NUM_COPY_THREADS; return MainloopPipeline(shared_storage.pipeline_v, pipeline_params, /*cluster_shape=*/Shape<_1, _1, _1>{}); } else { + pipeline_params.consumer_arv_count = NUM_COPY_THREADS; return MainloopPipeline(shared_storage.pipeline_v, pipeline_params); } }(); // Init pipeline_vt for transpose and consumed by mma PipelineParamsVt pipeline_params_vt; - pipeline_params_vt.producer_arv_count = Ktraits::NUM_PRODUCER_THREADS; + pipeline_params_vt.producer_arv_count = NUM_COPY_THREADS; pipeline_params_vt.consumer_arv_count = NUM_MMA_THREADS; MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt); @@ -309,6 +312,74 @@ cudaError_t SingleFP8PrefillWithKVCacheKernelTraitsDispatched(Params& params, cu return cudaSuccess; } +template +cudaError_t BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched(Params& params, + cudaStream_t stream) { + using DTypeQ = typename KernelTraits::DTypeQ; + using DTypeKV = typename KernelTraits::DTypeKV; + using DTypeO = typename KernelTraits::DTypeO; + using IdType = typename KernelTraits::IdType; + + using CollectiveMainloop = + FP8SparseCollectiveMainloop; + using CollectiveEpilogue = FP8CollectiveEpilogue; + using Scheduler = + std::conditional_t, + BatchPrefillPersistentTileScheduler>; + + typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( + {params.q_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM, + params.q_stride_n, + params.q_stride_h), // layout_Q + params.k_ptr, + // NOTE(Zihao): nnz was useless here, we can just pass 0 + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.k_stride_n, + params.k_stride_h), // layout_K + params.v_ptr, + get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM, params.v_stride_n, + params.v_stride_h), // layout_V + params.kv_indices, params.window_left, params.additional_params}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + params.o_ptr, + get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM, + params.o_stride_n, + params.o_stride_h), // layout_O + params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE + }); + + typename Scheduler::Arguments scheduler_args = { + params.work_indptr, params.head_indices, + params.qo_tile_indices, params.qo_indptr, + params.kv_indptr, params.qo_lens, + params.kv_lens, cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads), + params.num_qo_heads}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + auto kernel = + (void*)FP8PrefillWithKVCacheKernel; + int smem_size = sizeof(typename KernelTraits::SharedStorage); + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; + dim3 block_dims(ctaSize); + void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); + + return cudaSuccess; +} + template cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t stream) { @@ -352,6 +423,49 @@ cudaError_t SingleFP8PrefillWithKVCacheDispatched(Params& params, cudaStream_t s return status; } +template +cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); + if (MASK_MODE == MaskMode::kCustom) { + return cudaErrorNotSupported; // Not supported yet. + } + constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; + if constexpr (HEAD_DIM == 64) { + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later + BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else if constexpr (HEAD_DIM == 128) { + BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } else { + // HEAD_DIM == 256; + // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later + BatchFP8PrefillWithPagedKVCacheKernelTraitsDispatched< + FP8AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + } + cudaError_t status = cudaGetLastError(); + return status; +}; + } // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_FP8_PREFILL_SM90_CUH_ From 1211b4c06e22c3d9045a3330f45ae170919c9fba Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 06:02:24 +0000 Subject: [PATCH 08/14] fix: fix deadlock --- .../quantization/mainloop_sparse_load.cuh | 2 +- .../hopper/quantization/prefill_sm90.cuh | 60 +++++++++---------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index d48af2e12..b7865d3d0 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -248,7 +248,7 @@ struct FP8SparseCollectiveMainloop { } // Wait for the MMA warpgroups to say that smem_q is ready - cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, static_cast(NamedBarriers::kQueryEmpty)); // load Q tile if (issue_tma_thread) { diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 450c63191..522183f6e 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -145,38 +145,38 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp cutlass::arch::warpgroup_reg_dealloc<72>(); } - if constexpr (use_tma_load_kv) { // Load Q, K, V - PipelineState smem_pipe_write = cutlass::make_producer_start_state(); - PipelineState smem_pipe_read; - - int work_idx = 0; - - TileScheduler scheduler; - for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); - work_tile_info.is_valid(scheduler_params); - work_tile_info = scheduler.template get_next_work( - scheduler_params, work_tile_info)) { - auto block_coord = work_tile_info.get_block_coord(scheduler_params); - auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = - block_coord; - - if (q_tile_idx * CTA_Q >= qo_len) { - continue; - } - int num_kv_tiles = - collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); - if (num_kv_tiles <= 0) { - scheduler.prefetch_next_work(scheduler_params, work_tile_info); - scheduler.broadcast_next_work(work_tile_info); - continue; - } - collective_mainloop.load( - mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read, - shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); - ++work_idx; + // Here no condition as the entire warp group is used as producer + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + PipelineState smem_pipe_read; + + int work_idx = 0; + + TileScheduler scheduler; + for (auto work_tile_info = scheduler.get_initial_work(scheduler_params); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, + work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len] = + block_coord; + + if (q_tile_idx * CTA_Q >= qo_len) { + continue; } - collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write); + int num_kv_tiles = + collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); + if (num_kv_tiles <= 0) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + collective_mainloop.load( + mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + ++work_idx; } + collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write); + } else { // Consumer if constexpr (use_tma_load_kv) { cutlass::arch::warpgroup_reg_alloc(); From 23022e2e23feaf2101ad599a18b5eec5398d8b81 Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 17:40:49 +0000 Subject: [PATCH 09/14] upd test cases --- flashinfer/sparse.py | 1 + tests/test_hopper_fp8_attention.py | 129 +++++++++++++++++++++++++++-- 2 files changed, 123 insertions(+), 7 deletions(-) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 891593d1c..cea948f5f 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -352,6 +352,7 @@ def plan( if ( R * (num_qo_heads // num_kv_heads) < 4 and mask_mode != MaskMode.CUSTOM.value + and not q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2] ): # If the operation is not compute-bound, we use the cuda-core implementation self._use_tensor_cores = False diff --git a/tests/test_hopper_fp8_attention.py b/tests/test_hopper_fp8_attention.py index 124db3d03..3df19aaf7 100644 --- a/tests/test_hopper_fp8_attention.py +++ b/tests/test_hopper_fp8_attention.py @@ -1,5 +1,8 @@ from typing import Tuple +import numpy as np +import pytest +import scipy as sp import torch import flashinfer @@ -37,20 +40,52 @@ def get_dtype_minmax(dtype: torch.dtype) -> Tuple[float, float]: return q_x_out, s_out +def bsr_attention_ref( + q, + k, + v, + indptr, + indices, + mask_data, +): + M = q.shape[0] + N = k.shape[0] + bsr = sp.sparse.bsr_matrix( + (mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()), + shape=(M, N), + ) + dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) + o = flashinfer.prefill.single_prefill_with_kv_cache( + q, k, v, custom_mask=dense_mask, backend="fa2" + ) + return o + + +# Test single_prefill correctness: MSE should be below threshold +@pytest.mark.parametrize("seq_len", [117, 509, 1011, 2372, 7777]) +@pytest.mark.parametrize("num_heads", [24, 32]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): + # Prepare inputs o_dtype = torch.half num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + + # Reference output o_ref = flashinfer.single_prefill_with_kv_cache( q, k, v, causal=causal, backend="fa3" ) + # Quantize q_fp8, s_q = per_head_symmetric_quant(q, quant_dtype=dtype) k_fp8, s_k = per_head_symmetric_quant(k, quant_dtype=dtype) v_fp8, s_v = per_head_symmetric_quant(v, quant_dtype=dtype) + + # FP8 output o_fp8 = flashinfer.single_prefill_with_kv_cache( q_fp8, k_fp8, @@ -63,14 +98,74 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): o_dtype=o_dtype, ) - assert not torch.any(torch.isnan(o_fp8)) - assert not torch.any(torch.isnan(o_ref)) - - # MSE + # Compute MSE and assert + # NOTE: This is not a strict correctness guarantee mse = torch.mean((o_ref.float() - o_fp8.float()) ** 2) - print( - f"test_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}, dtype={dtype}), MSE: {mse:.5f}" + assert mse < 1.0, f"MSE too high: {mse.item()}" + + +# Test block sparse attention correctness: MSE should be below threshold +@pytest.mark.parametrize("R", [1, 4, 16]) +@pytest.mark.parametrize("C", [1, 4, 16]) +@pytest.mark.parametrize("M", [256, 512, 1024]) +@pytest.mark.parametrize("N", [256, 512, 1024]) +@pytest.mark.parametrize("num_heads", [24, 32]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("mask_inside_block", [False]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +def test_block_sparse_attention( + R, C, M, N, num_heads, head_dim, mask_inside_block, dtype +): + # Build sparse mask + MB = M // R + NB = N // C + rng = np.random.default_rng() + S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr() + indptr = torch.from_numpy(S.indptr).cuda() + indices = torch.from_numpy(S.indices).cuda() + nnz = S.nnz + if mask_inside_block: + data_mask = (torch.rand((nnz, R, C)) > 0.5).to(torch.bool).cuda() + else: + data_mask = torch.ones((nnz, R, C), dtype=torch.bool, device="cuda") + + # Random inputs + q = torch.randn((M, num_heads, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((N, num_heads, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((N, num_heads, head_dim), dtype=torch.float16, device="cuda") + + # Reference output via dense mask + o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask) + + # Plan and run BlockSparseAttention + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda") + sparse_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper( + workspace_buffer, backend="fa3" + ) + sparse_wrapper.plan( + indptr, + indices, + M, + N, + R, + C, + num_heads, + num_heads, + head_dim, + mask=data_mask if mask_inside_block else None, + q_data_type=dtype, + kv_data_type=dtype, + o_data_type=torch.float16, ) + q_fp8, s_q = per_head_symmetric_quant(q, quant_dtype=dtype) + k_fp8, s_k = per_head_symmetric_quant(k, quant_dtype=dtype) + v_fp8, s_v = per_head_symmetric_quant(v, quant_dtype=dtype) + o = sparse_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v) + + # Compute MSE and assert + # NOTE: This is not a strict correctness guarantee + mse = torch.mean((o_ref.float() - o.float()) ** 2) + assert mse < 1.0, f"Block sparse MSE too high: {mse.item()}" if __name__ == "__main__": @@ -80,3 +175,23 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): for head_dim in [64, 128, 256]: for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: test_single_prefill(seq_len, num_heads, causal, head_dim, dtype) + + for R in [1, 4, 16]: + for C in [1, 4, 16]: + for M in [64, 128, 256]: + for N in [64, 128, 256]: + for num_heads in [32]: + for head_dim in [128, 256]: + for mask_inside_block in [False]: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + test_block_sparse_attention( + R, + C, + M, + N, + num_heads, + num_heads, + head_dim, + mask_inside_block, + dtype, + ) From 7fa3a82f27770e92138d448067008e15066e0daa Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 19:28:02 +0000 Subject: [PATCH 10/14] fix: fix deadlock by adding vt pipeline producer sync --- .../attention/hopper/quantization/mainloop_load.cuh | 4 ++-- .../attention/hopper/quantization/mainloop_sparse_load.cuh | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh index c5211ee5e..0fe91760f 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh @@ -244,7 +244,7 @@ struct FP8CollectiveMainloop { shared_storage.barrier_O.wait((work_idx + 1) % 2); pipeline_v.consumer_wait(smem_pipe_read); - // pipeline_vt.producer_acquire(smem_pipe_write); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); @@ -269,7 +269,7 @@ struct FP8CollectiveMainloop { } pipeline_v.consumer_wait(smem_pipe_read); - // pipeline_vt.producer_acquire(smem_pipe_write); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index b7865d3d0..49ab44590 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -273,6 +273,7 @@ struct FP8SparseCollectiveMainloop { // Transpose V pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer @@ -289,6 +290,7 @@ struct FP8SparseCollectiveMainloop { // Transpose V pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer @@ -314,6 +316,7 @@ struct FP8SparseCollectiveMainloop { // Transpose V pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer @@ -338,6 +341,7 @@ struct FP8SparseCollectiveMainloop { // Transpose V pipeline_v.consumer_wait(smem_pipe_read); + pipeline_vt.producer_acquire(smem_pipe_write); v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); // ping MMA consumer pipeline_v.consumer_release(smem_pipe_read); // release V loading consumer From 082cf2402fe08332a147f447ef54b75332a20f57 Mon Sep 17 00:00:00 2001 From: happierpig Date: Thu, 24 Apr 2025 21:25:24 +0000 Subject: [PATCH 11/14] fix: add memory barrier before WG_MMA write_o to allow STSM finish. --- .../hopper/quantization/epilogue.cuh | 6 ++- tests/test_hopper_fp8_attention.py | 38 +++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/include/flashinfer/attention/hopper/quantization/epilogue.cuh b/include/flashinfer/attention/hopper/quantization/epilogue.cuh index 88063d1cc..0d4883fc5 100644 --- a/include/flashinfer/attention/hopper/quantization/epilogue.cuh +++ b/include/flashinfer/attention/hopper/quantization/epilogue.cuh @@ -151,8 +151,12 @@ struct FP8CollectiveEpilogue { } } - int write_warp_idx = NUM_WARPS - 1; + // make sure all WG finish STSM o + cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + TiledCopyO gmem_tiled_copy_O; + int write_warp_idx = NUM_WARPS - 1; write_O(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O, select<0, 2>(TileShape_QKD{}), sO, thread_idx, qo_tile_idx, qo_head_idx, qo_indptr, qo_len, write_warp_idx); diff --git a/tests/test_hopper_fp8_attention.py b/tests/test_hopper_fp8_attention.py index 3df19aaf7..08ecf26df 100644 --- a/tests/test_hopper_fp8_attention.py +++ b/tests/test_hopper_fp8_attention.py @@ -62,7 +62,7 @@ def bsr_attention_ref( # Test single_prefill correctness: MSE should be below threshold -@pytest.mark.parametrize("seq_len", [117, 509, 1011, 2372, 7777]) +@pytest.mark.parametrize("seq_len", [117, 509, 1011, 2372, 7777, 12315]) @pytest.mark.parametrize("num_heads", [24, 32]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @@ -107,19 +107,27 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): # Test block sparse attention correctness: MSE should be below threshold @pytest.mark.parametrize("R", [1, 4, 16]) @pytest.mark.parametrize("C", [1, 4, 16]) -@pytest.mark.parametrize("M", [256, 512, 1024]) -@pytest.mark.parametrize("N", [256, 512, 1024]) -@pytest.mark.parametrize("num_heads", [24, 32]) +@pytest.mark.parametrize("M", [256, 512, 1024, 4096]) +@pytest.mark.parametrize("N", [256, 512, 1024, 4096]) +@pytest.mark.parametrize("num_heads", [1, 8, 24, 32]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("mask_inside_block", [False]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) def test_block_sparse_attention( R, C, M, N, num_heads, head_dim, mask_inside_block, dtype ): + # print args + print( + f"Testing block sparse attention with R={R}, C={C}, M={M}, N={N}, num_heads={num_heads}, " + f"head_dim={head_dim}, mask_inside_block={mask_inside_block}, dtype={dtype}" + ) + # setup random seed for reproducibility + torch.manual_seed(0) + np.random.seed(0) # Build sparse mask MB = M // R NB = N // C - rng = np.random.default_rng() + rng = np.random.default_rng(seed=0) S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr() indptr = torch.from_numpy(S.indptr).cuda() indices = torch.from_numpy(S.indices).cuda() @@ -169,19 +177,12 @@ def test_block_sparse_attention( if __name__ == "__main__": - for seq_len in [117, 509, 1011, 2372, 7777]: - for num_heads in [24, 32]: - for causal in [True, False]: - for head_dim in [64, 128, 256]: - for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - test_single_prefill(seq_len, num_heads, causal, head_dim, dtype) - - for R in [1, 4, 16]: - for C in [1, 4, 16]: - for M in [64, 128, 256]: - for N in [64, 128, 256]: - for num_heads in [32]: - for head_dim in [128, 256]: + for R in [4]: + for C in [1]: + for M in [1024]: + for N in [512]: + for num_heads in [8]: + for head_dim in [256]: for mask_inside_block in [False]: for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: test_block_sparse_attention( @@ -190,7 +191,6 @@ def test_block_sparse_attention( M, N, num_heads, - num_heads, head_dim, mask_inside_block, dtype, From 3bd14cc422614af63cec9411db533529e7a708d9 Mon Sep 17 00:00:00 2001 From: happierpig Date: Tue, 29 Apr 2025 18:34:05 +0000 Subject: [PATCH 12/14] fix typo --- flashinfer/sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index cea948f5f..44e1bb23b 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -359,7 +359,7 @@ def plan( self._cached_module = get_batch_decode_module( q_data_type, kv_data_type, - o_data_type, + self._o_dtype, indptr.dtype, head_dim, head_dim, From 90469a12424a632ab87164bb3e46bf2739eea14a Mon Sep 17 00:00:00 2001 From: happierpig Date: Tue, 29 Apr 2025 19:56:22 +0000 Subject: [PATCH 13/14] upd --- flashinfer/jit/attention/pytorch.py | 12 ++++++++---- flashinfer/prefill.py | 3 +++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 39cc6cc75..5f4a8a9bc 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -454,10 +454,12 @@ def gen_single_prefill_module( ) # use `fp8_enabled` flag to use separate kernel template - fp8_enabled = "e4m3" in uri or "e5m2" in uri + # this is used for fp8 tensor core computation + # KV-only quant is not influenced by this flag + fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2] if backend == "fa2": - assert not fp8_enabled, "fp8 is not supported in fa2 backend" + assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] additional_tensor_dtypes = ["uint8_t", "float"] additional_scalar_names = [ @@ -747,10 +749,12 @@ def gen_batch_prefill_module( ) # use `fp8_enabled` flag to use separate kernel template - fp8_enabled = "e4m3" in uri or "e5m2" in uri + # this is used for fp8 tensor core computation + # KV-only quant is not influenced by this flag + fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2] if backend == "fa2": - assert not fp8_enabled, "fp8 is not supported in fa2 backend" + assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" additional_tensor_names = [ "maybe_custom_mask", "maybe_mask_indptr", diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bb463cf03..21c5edf65 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1759,6 +1759,9 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], q.device), logits_soft_cap, sm_scale, + None, # scale_q, not supported yet + None, # scale_k + None, # scale_v rope_scale, rope_theta, ] From 589fa3e822e19475c3de3a1133d25a8584ca76c8 Mon Sep 17 00:00:00 2001 From: happierpig Date: Tue, 29 Apr 2025 20:20:45 +0000 Subject: [PATCH 14/14] upd --- flashinfer/decode.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index e66fde478..32cda4c91 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -530,6 +530,9 @@ def single_decode_with_kv_cache( _get_cache_alibi_slopes_buf(num_qo_heads, q.device), logits_soft_cap, sm_scale, + None, # scale_q, not supported yet + None, # scale_k + None, # scale_v rope_scale, rope_theta, ) @@ -1169,6 +1172,9 @@ def run( _get_cache_alibi_slopes_buf(q.shape[1], q.device), logits_soft_cap, sm_scale, + None, # scale_q, not supported yet + None, # scale_k + None, # scale_v rope_scale, rope_theta, ]