|
| 1 | +/* |
| 2 | + * Copyright (c) 2023 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include <flashinfer/attention/mask.cuh> |
| 18 | +#include <flashinfer/attention/scheduler.cuh> |
| 19 | +#include <flashinfer/layout.cuh> |
| 20 | +#include <flashinfer/math.cuh> |
| 21 | +#include <optional> |
| 22 | + |
| 23 | +#include "batch_prefill_sm90_config.inc" |
| 24 | +#include "pytorch_conversion_utils.h" |
| 25 | +#include "pytorch_extension_utils.h" |
| 26 | + |
| 27 | +namespace flashinfer { |
| 28 | + |
| 29 | +template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW, |
| 30 | + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params> |
| 31 | +cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t stream); |
| 32 | + |
| 33 | +} // namespace flashinfer |
| 34 | + |
| 35 | +using namespace flashinfer; |
| 36 | + |
| 37 | +at::Tensor BatchPrefillWithKVCacheSM90Plan( |
| 38 | + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, |
| 39 | + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, |
| 40 | + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, |
| 41 | + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, |
| 42 | + int64_t head_dim_vo, bool causal) { |
| 43 | + size_t float_workspace_size_in_bytes = |
| 44 | + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); |
| 45 | + size_t int_workspace_size_in_bytes = |
| 46 | + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); |
| 47 | + |
| 48 | + flashinfer::PrefillPlanSM90Info plan_info; |
| 49 | + |
| 50 | + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); |
| 51 | + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); |
| 52 | + |
| 53 | + cudaError_t status = |
| 54 | + PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, |
| 55 | + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), |
| 56 | + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(), |
| 57 | + kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows, |
| 58 | + batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, |
| 59 | + causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); |
| 60 | + |
| 61 | + TORCH_CHECK(status == cudaSuccess, |
| 62 | + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); |
| 63 | + |
| 64 | + return vec_to_tensor(plan_info.ToVector()); |
| 65 | +} |
| 66 | + |
| 67 | +void BatchPrefillWithRaggedKVCacheSM90Run(at::Tensor float_workspace_buffer, |
| 68 | + at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, |
| 69 | + at::Tensor q, at::Tensor k, at::Tensor v, |
| 70 | + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, |
| 71 | + std::optional<at::Tensor> maybe_lse, |
| 72 | + int64_t mask_mode_code, int64_t layout, |
| 73 | + int64_t window_left ADDITIONAL_FUNC_PARAMS) { |
| 74 | + return; // TODO: Implement this function |
| 75 | +} |
| 76 | + |
| 77 | +void BatchPrefillWithPagedKVCacheSM90Run( |
| 78 | + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec, |
| 79 | + at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, |
| 80 | + at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, |
| 81 | + at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout, |
| 82 | + int64_t window_left ADDITIONAL_FUNC_PARAMS) { |
| 83 | + PrefillPlanSM90Info plan_info; |
| 84 | + plan_info.FromVector(tensor_to_vec(plan_info_vec)); |
| 85 | + |
| 86 | + if (maybe_lse) { |
| 87 | + const auto& lse = *maybe_lse; |
| 88 | + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); |
| 89 | + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); |
| 90 | + } |
| 91 | + QKVLayout kv_layout = static_cast<QKVLayout>(layout); |
| 92 | + int64_t num_kv_heads, page_size; |
| 93 | + int64_t head_dim_qk = q.size(2); |
| 94 | + int64_t head_dim_vo = paged_v_cache.size(3); |
| 95 | + if (kv_layout == QKVLayout::kHND) { |
| 96 | + num_kv_heads = paged_k_cache.size(1); |
| 97 | + page_size = paged_k_cache.size(2); |
| 98 | + } else { |
| 99 | + page_size = paged_k_cache.size(1); |
| 100 | + num_kv_heads = paged_k_cache.size(2); |
| 101 | + } |
| 102 | + |
| 103 | + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); |
| 104 | + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); |
| 105 | + |
| 106 | + auto q_scalar_type = q.scalar_type(); |
| 107 | + auto kv_scalar_type = paged_k_cache.scalar_type(); |
| 108 | + |
| 109 | + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); |
| 110 | + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); |
| 111 | + const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code); |
| 112 | + bool use_swa = window_left != -1; |
| 113 | + |
| 114 | + DISPATCH_context( |
| 115 | + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, |
| 116 | + USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { |
| 117 | + PagedParams params; |
| 118 | + |
| 119 | + params.q_ptr = static_cast<DTypeQ*>(q.data_ptr()); |
| 120 | + params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr()); |
| 121 | + params.v_ptr = static_cast<DTypeKV*>(paged_v_cache.data_ptr()); |
| 122 | + params.o_ptr = static_cast<DTypeO*>(o.data_ptr()); |
| 123 | + params.lse_ptr = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr; |
| 124 | + params.q_stride_n = q.stride(0); |
| 125 | + params.q_stride_h = q.stride(1); |
| 126 | + params.o_stride_n = o.stride(0); |
| 127 | + params.o_stride_h = o.stride(1); |
| 128 | + if (kv_layout == QKVLayout::kNHD) { |
| 129 | + // (num_pages, page_size, num_heads, head_dim) |
| 130 | + params.k_stride_n = paged_k_cache.stride(1); |
| 131 | + params.k_stride_h = paged_k_cache.stride(2); |
| 132 | + params.v_stride_n = paged_v_cache.stride(1); |
| 133 | + params.v_stride_h = paged_v_cache.stride(2); |
| 134 | + } else { |
| 135 | + // (num_pages, num_heads, page_size, head_dim) |
| 136 | + params.k_stride_h = paged_k_cache.stride(1); |
| 137 | + params.k_stride_n = paged_k_cache.stride(2); |
| 138 | + params.v_stride_h = paged_v_cache.stride(1); |
| 139 | + params.v_stride_n = paged_v_cache.stride(2); |
| 140 | + } |
| 141 | + params.nnz_qo = q.size(0); |
| 142 | + params.num_qo_heads = q.size(1); |
| 143 | + params.num_kv_heads = num_kv_heads; |
| 144 | + params.group_size = params.num_qo_heads / num_kv_heads; |
| 145 | + params.page_size = page_size; |
| 146 | + params.window_left = window_left; |
| 147 | + params.causal = mask_mode_code == 1; |
| 148 | + params.qo_tile_indices = |
| 149 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_tile_indices_offset); |
| 150 | + params.qo_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_indptr_offset); |
| 151 | + params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset); |
| 152 | + params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset); |
| 153 | + params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset); |
| 154 | + params.head_indices = |
| 155 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset); |
| 156 | + params.work_indptr = |
| 157 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset); |
| 158 | + params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr()); |
| 159 | + |
| 160 | + ADDITIONAL_PARAMS_SETTER |
| 161 | + |
| 162 | + // Not support various head_dim for now |
| 163 | + static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same"); |
| 164 | + // Currently only support same quantization precision |
| 165 | + static_assert(std::is_same_v<DTypeQ, DTypeKV>); |
| 166 | + |
| 167 | + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; |
| 168 | + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { |
| 169 | + cudaError_t status = |
| 170 | + BatchFP8PrefillWithPagedKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW, |
| 171 | + SAME_SCHEDULER_FOR_ALL_HEADS, |
| 172 | + AttentionVariant>(params, stream); |
| 173 | + TORCH_CHECK(status == cudaSuccess, |
| 174 | + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", |
| 175 | + cudaGetErrorString(status)); |
| 176 | + return true; |
| 177 | + }); |
| 178 | + }); |
| 179 | +} |
0 commit comments