From a128819048f119db275903b0f137a44091c9367c Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 11 Sep 2024 14:03:49 +0800 Subject: [PATCH] First-phase: add sdp kernel without mask --- .clang-format | 2 +- csrc/xpu/attention_xpu.cpp | 2314 ++++++++++++++------------ csrc/xpu/pybind.cpp | 3 + csrc/xpu/xpu_ops.h | 7 + vllm/attention/backends/ipex_attn.py | 127 +- vllm/worker/xpu_model_runner.py | 15 +- 6 files changed, 1291 insertions(+), 1177 deletions(-) diff --git a/.clang-format b/.clang-format index 7f9e6d720fa..aac76168d28 100644 --- a/.clang-format +++ b/.clang-format @@ -1,6 +1,6 @@ BasedOnStyle: Google UseTab: Never -IndentWidth: 2 +IndentWidth: 4 ColumnLimit: 80 # Force pointers to the type for C++. diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 833f46eaaf7..98c796871f8 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -4,7 +4,8 @@ #endif #include #include - +#include +using namespace sycl::ext::intel::esimd; // clang-format on #include #include @@ -18,1237 +19,1332 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) template struct Float_Trait { - using Type = T; + using Type = T; }; template <> struct Float_Trait { - using Type = uint16_t; + using Type = uint16_t; }; template <> struct Float_Trait { - using Type = sycl::ext::oneapi::bfloat16; + using Type = sycl::ext::oneapi::bfloat16; }; namespace vllm { +// Handle context attention with pure k/v. Currently, we do not have any +// context. +template +void context_attention_kernel_kv( + void* query, void* key, void* value, void* seq_lens, const float scale, + void* out, // output + const void* query_start_loc, const int batch_size, const int num_heads, + const int max_input_length, const int num_queries_per_group, + const int query_stride_tokens, const int query_stride_head, + const int key_stride_tokens, const int key_stride_head, + const int value_stride_tokens, const int value_stride_head, + const int out_stride_tokens, const int out_stride_head) { + static_assert(GS * HD * sizeof(scalar_t) * 2 < 64 * 1024); + + const size_t key_slm_offset = 0; + const size_t value_slm_offset = GS * HD * sizeof(scalar_t); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + sycl::range<3> global_size(batch_size, num_heads, + (max_input_length + GS - 1) / GS * GS); + sycl::range<3> local_size(1, 1, GS); + + auto cgf = [&](sycl::handler& handle) { + handle.parallel_for( + sycl::nd_range<3>(global_size, local_size), + [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { + slm_init(); + + const size_t bsz_idx = item.get_global_id(0); + const size_t head_idx = item.get_global_id(1); + // Assuming we have 32 query head and 8 kv_heads. Then + // num_queries_per_group should be 4 For head_idx 13, then + // kv_head_idx = 13 / 4 = 3, which is correct + const size_t kv_head_idx = head_idx / num_queries_per_group; + const int32_t seq_idx = item.get_global_id(2); + const size_t gid = item.get_group(2); + const size_t tid = item.get_local_id(2); + + // const int64_t * seq_len = (const int64_t *) seq_lens; + const int32_t* seq_len = (const int32_t*)seq_lens; + int32_t seq_bound = seq_len[bsz_idx]; + + // TODO: check if this token_idx is correct or not... + const int32_t* query_loc = (const int32_t*)query_start_loc; + // There is a possibility that the current token index pass over + // the seq_len, therefore: + int32_t token_idx = + query_loc[bsz_idx] + std::min(seq_idx, seq_bound - 1); + const scalar_t* query_head = (const scalar_t*)query + + token_idx * query_stride_tokens + + head_idx * query_stride_head; + const scalar_t* key_head = + (const scalar_t*)key + + query_loc[bsz_idx] * key_stride_tokens + + kv_head_idx * key_stride_head; + const scalar_t* value_head = + (const scalar_t*)value + + query_loc[bsz_idx] * value_stride_tokens + + kv_head_idx * value_stride_head; + // TODO: check out_head validness + // WARNING: out_head may be out of bound + scalar_t* out_head = + (scalar_t*)out + + (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + + head_idx * out_stride_head; + + simd query_row = + block_load(query_head) * scale; + simd accv = 0; + simd softmaxv = 0; + scalar_t max_attn = -sycl::detail::max_v(); + + for (size_t group = 0; group < gid; ++group) { + simd key_row = block_load( + key_head + (group * GS + tid) * key_stride_tokens); + slm_block_store( + key_slm_offset + tid * HD * sizeof(scalar_t), key_row); + simd value_row = block_load( + value_head + (group * GS + tid) * value_stride_tokens); + slm_block_store( + value_slm_offset + tid * HD * sizeof(scalar_t), + value_row); + barrier(); + simd attnv; +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd key_row = + slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = sycl::ext::intel::esimd::detail::sum< + scalar_t, scalar_t, HD>(query_row * key_row); + attnv[r] = attn; + } + + scalar_t new_max_attn = + std::max(hmax(attnv), max_attn); + scalar_t attn_exp = exp(max_attn - new_max_attn); + accv = accv * attn_exp; + + softmaxv = softmaxv * attn_exp; + max_attn = new_max_attn; + const simd attn_expv = exp(attnv - max_attn); +#pragma unorll + for (size_t r = 0; r < GS; ++r) { + simd value_row = + slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + accv += value_row * attn_expv[r]; + } + softmaxv += attn_expv; + barrier(); + } + + scalar_t softmax = + sycl::ext::intel::esimd::detail::sum(softmaxv); + + if (seq_idx < seq_bound) { + // key_head has already included the query_loc[bsz_idx] + // part, should only included seq_idx part + simd key_row = block_load( + key_head + seq_idx * key_stride_tokens); + slm_block_store( + key_slm_offset + tid * HD * sizeof(scalar_t), key_row); + simd value_row = block_load( + value_head + seq_idx * value_stride_tokens); + slm_block_store( + value_slm_offset + tid * HD * sizeof(scalar_t), + value_row); + } + barrier(); + + if (seq_idx < seq_bound) { + // handle last a few of tokens + for (size_t r = 0; r <= tid; ++r) { + simd key_row = + slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + simd value_row = + slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = sycl::ext::intel::esimd::detail::sum< + scalar_t, scalar_t, HD>(query_row * key_row); + if (attn <= max_attn) { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(attn - max_attn); + accv += value_row * attn_exp; + softmax += attn_exp; + } else { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(max_attn - attn); + accv = accv * attn_exp + value_row; + softmax = softmax * attn_exp + 1; + max_attn = attn; + } + } + + if (softmax > 0) { + simd result = accv / softmax; + block_store(out_head, result); + } else { + simd result = 0; + block_store(out_head, result); + } + } + }); + }; + queue.submit(cgf); +} + // Q*K^T operation. template -inline float qk_dot_( - const Vec* q, - const Vec* k, - const sycl::nd_item<3>& item_ct1) { - using A_vec = typename FloatVec::Type; - // Compute the parallel products for Q*K^T (treat vector lanes separately). - A_vec qk_vec = mul(q[0], k[0]); +inline float qk_dot_(const Vec* q, const Vec* k, + const sycl::nd_item<3>& item_ct1) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); #pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } - // Finalize the reduction across lanes. - float qk = sum(qk_vec); + // Finalize the reduction across lanes. + float qk = sum(qk_vec); #pragma unroll - for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - - qk += dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), qk, mask); - } - return qk; + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), qk, mask); + } + return qk; } template struct Qk_dot { - template - static inline float dot( - const Vec* q, - const Vec* k, - const sycl::nd_item<3>& item_ct1) { - return qk_dot_(q, k, item_ct1); - } + template + static inline float dot(const Vec* q, const Vec* k, + const sycl::nd_item<3>& item_ct1) { + return qk_dot_(q, k, item_ct1); + } }; template -inline float block_sum( - float* red_smem, - float sum, - const sycl::nd_item<3>& item_ct1) { - // Decompose the thread index into warp / lane. - int warp = item_ct1.get_local_id(2) / WARP_SIZE; - int lane = item_ct1.get_local_id(2) % WARP_SIZE; - - // Compute the sum per warp. +inline float block_sum(float* red_smem, float sum, + const sycl::nd_item<3>& item_ct1) { + // Decompose the thread index into warp / lane. + int warp = item_ct1.get_local_id(2) / WARP_SIZE; + int lane = item_ct1.get_local_id(2) % WARP_SIZE; + + // Compute the sum per warp. #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - - /* - DPCT1096:42: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - sum += dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), sum, mask); - } + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:42: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + sum += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask); + } - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. - // Make sure the data is in shared memory. - - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::local_space); - // The warps compute the final sums. - if (lane < NUM_WARPS) { - sum = red_smem[lane]; - } + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } - // Parallel reduction inside the warp. + // Parallel reduction inside the warp. #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:43: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + sum += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask); + } + + // Broadcast to other threads. + /* - DPCT1096:43: The right-most dimension of the work-group used in the SYCL + DPCT1096:44: The right-most dimension of the work-group used in the SYCL kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU + "dpct::select_from_sub_group" may return an unexpected result on the CPU device. Modify the size of the work-group to ensure that the value of the right-most dimension is a multiple of "32". */ - sum += dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), sum, mask); - } - - // Broadcast to other threads. - - /* - DPCT1096:44: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::select_from_sub_group" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - return dpct::select_from_sub_group( - item_ct1.get_sub_group(), sum, 0); + return dpct::select_from_sub_group(item_ct1.get_sub_group(), sum, 0); } -template < - typename scalar_t, - typename Q_Vec_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int VEC_SIZE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const sycl::nd_item<3>& item_ct1, - uint8_t* dpct_local, - Q_Vec_t* q_vecs, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const sycl::nd_item<3>& item_ct1, uint8_t* dpct_local, Q_Vec_t* q_vecs, float* red_smem) { - const int seq_idx = item_ct1.get_group(1); - const int partition_idx = item_ct1.get_group(0); - const int max_num_partitions = item_ct1.get_group_range(0); - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = - USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = - USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = - MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = - MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - const int num_tokens = end_token_idx - start_token_idx; - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = - NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE - // divides NUM_THREADS - assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = - DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = item_ct1.get_local_id(2); - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = item_ct1.get_group(2); - const int num_heads = item_ct1.get_group_range(2); - const int num_queries_per_kv = num_heads / num_kv_heads; - - const int kv_head_idx = head_idx / num_queries_per_kv; - ; - const float alibi_slope = - alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread - // group fetch or compute 16 bytes at a time. For example, if the size of a - // thread group is 4 and the data type is half, then the vector size is 16 / - // (4 * sizeof(half)) == 2. - - // constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), - // 1); - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the query, and the second thread - // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because - // q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const int seq_idx = item_ct1.get_group(1); + const int partition_idx = item_ct1.get_group(0); + const int max_num_partitions = item_ct1.get_group_range(0); + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } -#pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; - i += NUM_THREAD_GROUPS) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = - *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - /* - DPCT1065:5: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); // TODO(naed90): possible speedup if this is replaced with - // a memory wall right before we use q_vecs - - // Memory planning. - auto shared_mem = (char*)dpct_local; - // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / + THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = item_ct1.get_local_id(2); + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = item_ct1.get_group(2); + const int num_heads = item_ct1.get_group_range(2); + const int num_queries_per_kv = num_heads / num_kv_heads; + + const int kv_head_idx = head_idx / num_queries_per_kv; + ; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + + // constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), + // 1); + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the key, and the second thread - // has 1, 5, 9, ... th vectors of the key, and so on. - - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): + // Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - Q_Vec_t k_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + /* + DPCT1065:5: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better + performance if there is no access to global memory. + */ + item_ct1.barrier( + sycl::access::fence_space:: + local_space); // TODO(naed90): possible speedup if this is replaced + // with a memory wall right before we use q_vecs + + // Memory planning. + auto shared_mem = (char*)dpct_local; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast + // it to int64 because int32 can lead to overflow when this variable is + // multiplied by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread + // in the group has 0, 4, 8, ... th vectors of the key, and the second + // thread has 1, 5, 9, ... th vectors of the key, and so on. + + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = + block_idx * BLOCK_SIZE + physical_block_offset; + + Q_Vec_t k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + - physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride + physical_block_offset * x; - - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - // Q_Vec_t q_vec_[NUM_VECS_PER_THREAD] = q_vecs + thread_group_offset * - // THREAD_GROUP_SIZE; - float qk = scale * - Qk_dot:: - template dot( - q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, - k_vecs, - item_ct1); - // Add the ALiBi bias if slopes are given. - qk += - (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); - } + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread + // group. Q_Vec_t q_vec_[NUM_VECS_PER_THREAD] = q_vecs + + // thread_group_offset * THREAD_GROUP_SIZE; + float qk = + scale * Qk_dot::template dot< + Q_Vec_t, NUM_VECS_PER_THREAD>( + q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, + k_vecs, item_ct1); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) + ? alibi_slope * (token_idx - context_len + 1) + : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); + } + } } - } - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - - /* - DPCT1096:38: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - qk_max = sycl::fmax( - qk_max, - dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + /* + DPCT1096:38: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:39: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + /* - DPCT1096:39: The right-most dimension of the work-group used in the SYCL + DPCT1096:40: The right-most dimension of the work-group used in the SYCL kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU + "dpct::select_from_sub_group" may return an unexpected result on the CPU device. Modify the size of the work-group to ensure that the value of the right-most dimension is a multiple of "32". */ - qk_max = sycl::fmax( - qk_max, - dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), qk_max, mask)); - } - // Broadcast the max qk value to all threads. - - /* - DPCT1096:40: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::select_from_sub_group" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - qk_max = dpct::select_from_sub_group( - item_ct1.get_sub_group(), qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = sycl::exp(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); - - // Compute softmax. - const float inv_sum = 1.f / (exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *exp_sums_ptr = exp_sum; - } - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename Vec::Type; - using Float_L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = - DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); - - // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - scalar_t zero_value; - zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec; - vllm::from_float( - logits_vec, - *reinterpret_cast(logits + token_idx - start_token_idx)); - - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride; + qk_max = dpct::select_from_sub_group(item_ct1.get_sub_group(), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = sycl::exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); + + // Compute softmax. + const float inv_sum = 1.f / (exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_context_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the - // context, we should explicitly zero out the values since they may - // contain NaNs. See - // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast + // it to int64 because int32 can lead to overflow when this variable is + // multiplied by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = + (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + vllm::from_float(logits_vec, *reinterpret_cast( + logits + token_idx - start_token_idx)); + + const scalar_t* v_ptr = v_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out + // of the context, we should explicitly zero out the values + // since they may contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll - for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = - token_idx + j < context_len ? v_vec_ptr[j] : zero_value; - } + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < context_len + ? v_vec_ptr[j] + : zero_value; + } + } + accs[i] += vllm::dot(logits_vec, v_vec); + } } - accs[i] += vllm::dot(logits_vec, v_vec); - } } - } - // Perform reduction within each warp. + // Perform reduction within each warp. #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; #pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - - /* - DPCT1096:41: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value of - the right-most dimension is a multiple of "32". - */ - acc += dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), acc, mask); + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:41: The right-most dimension of the work-group used in the + SYCL kernel that calls this function may be less than "32". The + function "dpct::permute_sub_group_by_xor" may return an unexpected + result on the CPU device. Modify the size of the work-group to + ensure that the value of the right-most dimension is a multiple of + "32". + */ + acc += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), acc, + mask); + } + accs[i] = acc; } - accs[i] = acc; - } - // NOTE(woosuk): A barrier is required because the shared memory space for - // logits is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(sycl::access::fence_space::local_space); - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); #pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } } - } - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } } - } + + item_ct1.barrier(sycl::access::fence_space::local_space); } - - item_ct1.barrier(sycl::access::fence_space::local_space); - } - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + - seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - vllm::from_float(*(out_ptr + row_idx), accs[i]); - } + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + vllm::from_float(*(out_ptr + row_idx), accs[i]); + } + } } - } } // Grid: (num_heads, num_seqs, 1). -template < - typename scalar_t, - typename Q_Vec_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int VEC_SIZE> +template void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const sycl::nd_item<3>& item_ct1, - uint8_t* dpct_local, - Q_Vec_t* q_vecs, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const sycl::nd_item<3>& item_ct1, uint8_t* dpct_local, Q_Vec_t* q_vecs, float* red_smem) { - paged_attention_kernel< - scalar_t, - Q_Vec_t, - HEAD_SIZE, - BLOCK_SIZE, - NUM_THREADS, - VEC_SIZE>( - /* exp_sums */ nullptr, - /* max_logits */ nullptr, - out, - q, - k_cache, - v_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - max_num_blocks_per_seq, - alibi_slopes, - q_stride, - kv_block_stride, - kv_head_stride, - item_ct1, - dpct_local, - q_vecs, - red_smem); + paged_attention_kernel( + /* exp_sums */ nullptr, + /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride, item_ct1, dpct_local, q_vecs, + red_smem); } -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_xpu_v1_impl::call( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - num_seqs, \ - num_heads, \ - num_blocks); - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - queue.submit([&](sycl::handler& cgh) { \ - sycl::local_accessor dpct_local_acc_ct1( \ - sycl::range<1>(shared_mem_size), cgh); \ - sycl::local_accessor q_vecs_acc_ct1( \ - sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ - sycl::local_accessor red_smem_acc_ct1( \ - sycl::range<1>(2 * NUM_WARPS), cgh); \ - \ - auto out_ptr_ct0 = out_ptr; \ - auto query_ptr_ct1 = query_ptr; \ - auto key_cache_ptr_ct2 = key_cache_ptr; \ - auto value_cache_ptr_ct3 = value_cache_ptr; \ - auto scale_ct5 = scale; \ - auto block_tables_ptr_ct6 = block_tables_ptr; \ - auto context_lens_ptr_ct7 = context_lens_ptr; \ - auto max_num_blocks_per_seq_ct8 = max_num_blocks_per_seq; \ - auto alibi_slopes_ptr_ct9 = alibi_slopes_ptr; \ - auto q_stride_ct10 = q_stride; \ - auto kv_block_stride_ct11 = kv_block_stride; \ - auto kv_head_stride_ct12 = kv_head_stride; \ - \ - cgh.parallel_for( \ - sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - paged_attention_v1_kernel< \ - sycl_t, \ - Q_Vec, \ - HEAD_SIZE, \ - BLOCK_SIZE, \ - NUM_THREADS, \ - VEC_SIZE>( \ - out_ptr_ct0, \ - query_ptr_ct1, \ - key_cache_ptr_ct2, \ - value_cache_ptr_ct3, \ - num_kv_heads, \ - scale_ct5, \ - block_tables_ptr_ct6, \ - context_lens_ptr_ct7, \ - max_num_blocks_per_seq_ct8, \ - alibi_slopes_ptr_ct9, \ - q_stride_ct10, \ - kv_block_stride_ct11, \ - kv_head_stride_ct12, \ - item_ct1, \ - dpct_local_acc_ct1.get_pointer(), \ - q_vecs_acc_ct1.get_pointer(), \ - red_smem_acc_ct1.get_pointer()); \ - }); \ - }); +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_xpu_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads, num_blocks); + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + queue.submit([&](sycl::handler& cgh) { \ + sycl::local_accessor dpct_local_acc_ct1( \ + sycl::range<1>(shared_mem_size), cgh); \ + sycl::local_accessor q_vecs_acc_ct1( \ + sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ + sycl::local_accessor red_smem_acc_ct1( \ + sycl::range<1>(2 * NUM_WARPS), cgh); \ + \ + auto out_ptr_ct0 = out_ptr; \ + auto query_ptr_ct1 = query_ptr; \ + auto key_cache_ptr_ct2 = key_cache_ptr; \ + auto value_cache_ptr_ct3 = value_cache_ptr; \ + auto scale_ct5 = scale; \ + auto block_tables_ptr_ct6 = block_tables_ptr; \ + auto context_lens_ptr_ct7 = context_lens_ptr; \ + auto max_num_blocks_per_seq_ct8 = max_num_blocks_per_seq; \ + auto alibi_slopes_ptr_ct9 = alibi_slopes_ptr; \ + auto q_stride_ct10 = q_stride; \ + auto kv_block_stride_ct11 = kv_block_stride; \ + auto kv_head_stride_ct12 = kv_head_stride; \ + \ + cgh.parallel_for( \ + sycl::nd_range<3>(grid * block, block), \ + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size( \ + 32)]] { \ + paged_attention_v1_kernel( \ + out_ptr_ct0, query_ptr_ct1, key_cache_ptr_ct2, \ + value_cache_ptr_ct3, num_kv_heads, scale_ct5, \ + block_tables_ptr_ct6, context_lens_ptr_ct7, \ + max_num_blocks_per_seq_ct8, alibi_slopes_ptr_ct9, \ + q_stride_ct10, kv_block_stride_ct11, kv_head_stride_ct12, \ + item_ct1, dpct_local_acc_ct1.get_pointer(), \ + q_vecs_acc_ct1.get_pointer(), \ + red_smem_acc_ct1.get_pointer()); \ + }); \ + }); template void paged_attention_xpu_v1_impl_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); - using sycl_t = vllm::xpu::SyclTypeTrait::Type; - using Q_Vec = typename Vec::Type; - - int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; - assert(head_size % THREAD_GROUP_SIZE == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); - sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); - sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = - DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - - int logits_size = padded_max_context_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len - // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); - - sycl::range<3> grid(1, num_seqs, num_heads); - sycl::range<3> block(1, 1, NUM_THREADS); - sycl::queue& queue = vllm::xpu::vllmGetQueue(); - - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } - // queue.wait(); + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using sycl_t = vllm::xpu::SyclTypeTrait::Type; + using Q_Vec = typename Vec::Type; + + int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; + assert(head_size % THREAD_GROUP_SIZE == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); + sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); + sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = + DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + sycl::range<3> grid(1, num_seqs, num_heads); + sycl::range<3> block(1, 1, NUM_THREADS); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend + // this to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } + // queue.wait(); } -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - vllm::paged_attention_xpu_v1_impl_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_KERNEL_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_KERNEL_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + vllm::paged_attention_xpu_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + context_lens, max_context_len, alibi_slopes); + +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } // Grid: (num_heads, num_seqs). -template < - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions, - const sycl::nd_item<3>& item_ct1, - uint8_t* dpct_local, - float* red_smem) { - const int num_heads = item_ct1.get_group_range(2); - const int head_idx = item_ct1.get_group(2); - const int seq_idx = item_ct1.get_group(1); - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + - seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions, const sycl::nd_item<3>& item_ct1, + uint8_t* dpct_local, float* red_smem) { + const int num_heads = item_ct1.get_group_range(2); + const int head_idx = item_ct1.get_group(2); + const int seq_idx = item_ct1.get_group(1); + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; + i += item_ct1.get_local_range(2)) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane = item_ct1.get_local_id(2) % WARP_SIZE; + + // Size: 2 * num_partitions. + auto shared_mem = (char*)dpct_local; + // Workspace for reduction. + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = item_ct1.get_local_id(2); i < num_partitions; i += item_ct1.get_local_range(2)) { - out_ptr[i] = tmp_out_ptr[i]; + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = sycl::fmax(max_logit, (float)l); } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = item_ct1.get_local_id(2) / WARP_SIZE; - const int lane = item_ct1.get_local_id(2) % WARP_SIZE; - - // Size: 2 * num_partitions. - auto shared_mem = (char*)dpct_local; - // Workspace for reduction. - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = item_ct1.get_local_id(2); i < num_partitions; - i += item_ct1.get_local_range(2)) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = sycl::fmax(max_logit, (float)l); - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - - // Get the global max logit. - // Reduce within the warp. + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Get the global max logit. + // Reduce within the warp. #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - - /* - DPCT1096:45: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - max_logit = sycl::fmax( - max_logit, - dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:45: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + max_logit = sycl::fmax( + max_logit, dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), + max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:46: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + max_logit = sycl::fmax( + max_logit, dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), + max_logit, mask)); + } + // Broadcast the max value to all threads. + /* - DPCT1096:46: The right-most dimension of the work-group used in the SYCL + DPCT1096:47: The right-most dimension of the work-group used in the SYCL kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU + "dpct::select_from_sub_group" may return an unexpected result on the CPU device. Modify the size of the work-group to ensure that the value of the right-most dimension is a multiple of "32". */ - max_logit = sycl::fmax( - max_logit, - dpct::permute_sub_group_by_xor( - item_ct1.get_sub_group(), max_logit, mask)); - } - // Broadcast the max value to all threads. - - /* - DPCT1096:47: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::select_from_sub_group" may return an unexpected result on the CPU - device. Modify the size of the work-group to ensure that the value of the - right-most dimension is a multiple of "32". - */ - max_logit = dpct::select_from_sub_group( - item_ct1.get_sub_group(), max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = - reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = item_ct1.get_local_id(2); i < num_partitions; - i += item_ct1.get_local_range(2)) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * sycl::exp(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - - item_ct1.barrier(sycl::access::fence_space::local_space); - global_exp_sum = - block_sum(&red_smem[NUM_WARPS], global_exp_sum, item_ct1); - const float inv_global_exp_sum = 1.0f / (global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + - seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + max_logit = + dpct::select_from_sub_group(item_ct1.get_sub_group(), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = item_ct1.get_local_id(2); i < num_partitions; + i += item_ct1.get_local_range(2)) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * sycl::exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + global_exp_sum = + block_sum(&red_smem[NUM_WARPS], global_exp_sum, item_ct1); + const float inv_global_exp_sum = 1.0f / (global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll - for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * - inv_global_exp_sum; + for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * + shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); } - from_float(out_ptr[i], acc); - } } // Grid: (num_heads, num_seqs, max_num_partitions). -template < - typename scalar_t, - typename Q_Vec_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int VEC_SIZE, - int PARTITION_SIZE> +template void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const sycl::nd_item<3>& item_ct1, - uint8_t* dpct_local, - Q_Vec_t* q_vecs, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const sycl::nd_item<3>& item_ct1, uint8_t* dpct_local, Q_Vec_t* q_vecs, float* red_smem) { - paged_attention_kernel< - scalar_t, - Q_Vec_t, - HEAD_SIZE, - BLOCK_SIZE, - NUM_THREADS, - VEC_SIZE, - PARTITION_SIZE>( - exp_sums, - max_logits, - tmp_out, - q, - k_cache, - v_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - max_num_blocks_per_seq, - alibi_slopes, - q_stride, - kv_block_stride, - kv_head_stride, - item_ct1, - dpct_local, - q_vecs, - red_smem); + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride, item_ct1, dpct_local, q_vecs, + red_smem); } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - queue.submit([&](sycl::handler& cgh) { \ - sycl::local_accessor dpct_local_acc_ct1( \ - sycl::range<1>(shared_mem_size), cgh); \ - sycl::local_accessor q_vecs_acc_ct1( \ - sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ - sycl::local_accessor red_smem_acc_ct1( \ - sycl::range<1>(2 * NUM_WARPS), cgh); \ - \ - auto exp_sums_ptr_ct0 = exp_sums_ptr; \ - auto max_logits_ptr_ct1 = max_logits_ptr; \ - auto tmp_out_ptr_ct2 = tmp_out_ptr; \ - auto query_ptr_ct3 = query_ptr; \ - auto key_cache_ptr_ct4 = key_cache_ptr; \ - auto value_cache_ptr_ct5 = value_cache_ptr; \ - auto scale_ct7 = scale; \ - auto block_tables_ptr_ct8 = block_tables_ptr; \ - auto context_lens_ptr_ct9 = context_lens_ptr; \ - auto max_num_blocks_per_seq_ct10 = max_num_blocks_per_seq; \ - auto alibi_slopes_ptr_ct11 = alibi_slopes_ptr; \ - auto q_stride_ct12 = q_stride; \ - auto kv_block_stride_ct13 = kv_block_stride; \ - auto kv_head_stride_ct14 = kv_head_stride; \ - \ - cgh.parallel_for( \ - sycl::nd_range<3>(grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - vllm::paged_attention_v2_kernel< \ - sycl_t, \ - Q_Vec, \ - HEAD_SIZE, \ - BLOCK_SIZE, \ - NUM_THREADS, \ - VEC_SIZE, \ - PARTITION_SIZE>( \ - exp_sums_ptr_ct0, \ - max_logits_ptr_ct1, \ - tmp_out_ptr_ct2, \ - query_ptr_ct3, \ - key_cache_ptr_ct4, \ - value_cache_ptr_ct5, \ - num_kv_heads, \ - scale_ct7, \ - block_tables_ptr_ct8, \ - context_lens_ptr_ct9, \ - max_num_blocks_per_seq_ct10, \ - alibi_slopes_ptr_ct11, \ - q_stride_ct12, \ - kv_block_stride_ct13, \ - kv_head_stride_ct14, \ - item_ct1, \ - dpct_local_acc_ct1.get_pointer(), \ - q_vecs_acc_ct1.get_pointer(), \ - red_smem_acc_ct1.get_pointer()); \ - }); \ - }); \ - queue.submit([&](sycl::handler& cgh) { \ - sycl::local_accessor dpct_local_acc_ct1( \ - sycl::range<1>(reduce_shared_mem_size), cgh); \ - sycl::local_accessor red_smem_acc_ct1( \ - sycl::range<1>(2 * NUM_WARPS), cgh); \ - \ - auto out_ptr_ct0 = out_ptr; \ - auto exp_sums_ptr_ct1 = exp_sums_ptr; \ - auto max_logits_ptr_ct2 = max_logits_ptr; \ - auto tmp_out_ptr_ct3 = tmp_out_ptr; \ - auto context_lens_ptr_ct4 = context_lens_ptr; \ - auto max_num_partitions_ct5 = max_num_partitions; \ - \ - cgh.parallel_for( \ - sycl::nd_range<3>(reduce_grid * block, block), \ - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ - vllm::paged_attention_v2_reduce_kernel< \ - sycl_t, \ - HEAD_SIZE, \ - NUM_THREADS, \ - PARTITION_SIZE>( \ - out_ptr_ct0, \ - exp_sums_ptr_ct1, \ - max_logits_ptr_ct2, \ - tmp_out_ptr_ct3, \ - context_lens_ptr_ct4, \ - max_num_partitions_ct5, \ - item_ct1, \ - dpct_local_acc_ct1.get_pointer(), \ - red_smem_acc_ct1.get_pointer()); \ - }); \ - }); - -template < - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 512, - int PARTITION_SIZE = 512> +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + queue.submit([&](sycl::handler& cgh) { \ + sycl::local_accessor dpct_local_acc_ct1( \ + sycl::range<1>(shared_mem_size), cgh); \ + sycl::local_accessor q_vecs_acc_ct1( \ + sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ + sycl::local_accessor red_smem_acc_ct1( \ + sycl::range<1>(2 * NUM_WARPS), cgh); \ + \ + auto exp_sums_ptr_ct0 = exp_sums_ptr; \ + auto max_logits_ptr_ct1 = max_logits_ptr; \ + auto tmp_out_ptr_ct2 = tmp_out_ptr; \ + auto query_ptr_ct3 = query_ptr; \ + auto key_cache_ptr_ct4 = key_cache_ptr; \ + auto value_cache_ptr_ct5 = value_cache_ptr; \ + auto scale_ct7 = scale; \ + auto block_tables_ptr_ct8 = block_tables_ptr; \ + auto context_lens_ptr_ct9 = context_lens_ptr; \ + auto max_num_blocks_per_seq_ct10 = max_num_blocks_per_seq; \ + auto alibi_slopes_ptr_ct11 = alibi_slopes_ptr; \ + auto q_stride_ct12 = q_stride; \ + auto kv_block_stride_ct13 = kv_block_stride; \ + auto kv_head_stride_ct14 = kv_head_stride; \ + \ + cgh.parallel_for( \ + sycl::nd_range<3>(grid * block, block), \ + [=](sycl::nd_item<3> item_ct1) \ + [[intel::reqd_sub_group_size(32)]] { \ + vllm::paged_attention_v2_kernel( \ + exp_sums_ptr_ct0, max_logits_ptr_ct1, tmp_out_ptr_ct2, \ + query_ptr_ct3, key_cache_ptr_ct4, value_cache_ptr_ct5, \ + num_kv_heads, scale_ct7, block_tables_ptr_ct8, \ + context_lens_ptr_ct9, max_num_blocks_per_seq_ct10, \ + alibi_slopes_ptr_ct11, q_stride_ct12, \ + kv_block_stride_ct13, kv_head_stride_ct14, item_ct1, \ + dpct_local_acc_ct1.get_pointer(), \ + q_vecs_acc_ct1.get_pointer(), \ + red_smem_acc_ct1.get_pointer()); \ + }); \ + }); \ + queue.submit([&](sycl::handler& cgh) { \ + sycl::local_accessor dpct_local_acc_ct1( \ + sycl::range<1>(reduce_shared_mem_size), cgh); \ + sycl::local_accessor red_smem_acc_ct1( \ + sycl::range<1>(2 * NUM_WARPS), cgh); \ + \ + auto out_ptr_ct0 = out_ptr; \ + auto exp_sums_ptr_ct1 = exp_sums_ptr; \ + auto max_logits_ptr_ct2 = max_logits_ptr; \ + auto tmp_out_ptr_ct3 = tmp_out_ptr; \ + auto context_lens_ptr_ct4 = context_lens_ptr; \ + auto max_num_partitions_ct5 = max_num_partitions; \ + \ + cgh.parallel_for( \ + sycl::nd_range<3>(reduce_grid * block, block), \ + [=](sycl::nd_item<3> item_ct1) \ + [[intel::reqd_sub_group_size(32)]] { \ + vllm::paged_attention_v2_reduce_kernel< \ + sycl_t, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>( \ + out_ptr_ct0, exp_sums_ptr_ct1, max_logits_ptr_ct2, \ + tmp_out_ptr_ct3, context_lens_ptr_ct4, \ + max_num_partitions_ct5, item_ct1, \ + dpct_local_acc_ct1.get_pointer(), \ + red_smem_acc_ct1.get_pointer()); \ + }); \ + }); + +template void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % THREAD_GROUP_SIZE == 0); - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); - using sycl_t = vllm::xpu::SyclTypeTrait::Type; - using Q_Vec = typename Vec::Type; - - int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; - assert(head_size % THREAD_GROUP_SIZE == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - sycl_t* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); - sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - sycl::range<3> grid(max_num_partitions, num_seqs, num_heads); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - sycl::range<3> reduce_grid(1, num_seqs, num_heads); - - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - - sycl::range<3> block(1, 1, NUM_THREADS); - sycl::queue& queue = vllm::xpu::vllmGetQueue(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % THREAD_GROUP_SIZE == 0); + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using sycl_t = vllm::xpu::SyclTypeTrait::Type; + using Q_Vec = typename Vec::Type; + + int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; + assert(head_size % THREAD_GROUP_SIZE == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + sycl_t* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); + sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + sycl::range<3> grid(max_num_partitions, num_seqs, num_heads); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + sycl::range<3> reduce_grid(1, num_seqs, num_heads); + + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + sycl::range<3> block(1, 1, NUM_THREADS); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend + // this to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - vllm::paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -} // namespace vllm +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + vllm::paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +} // namespace vllm void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - const float kv_scale) { - VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( - query.scalar_type(), "paged_attention_xpu_v1_impl", [&] { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - }); + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, int block_size, + int max_context_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, const float kv_scale) { + VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( + query.scalar_type(), "paged_attention_xpu_v1_impl", + [&] { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); }); } void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - const float kv_scale) { - VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( - query.scalar_type(), "paged_attention_xpu_v2_impl", [&] { - CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); - }); + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, int block_size, + int max_context_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, const float kv_scale) { + VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( + query.scalar_type(), "paged_attention_xpu_v2_impl", + [&] { CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); }); +} + +template +auto dispatch_context_attention(at::ScalarType it) { + switch (it) { + case at::ScalarType::Float: + throw std::runtime_error("Unsupported dtype float32"); + case at::ScalarType::Half: + return vllm::context_attention_kernel_kv; + default: + throw std::runtime_error( + "unsupported dtype, only fp16 are supported"); + } +} + +torch::Tensor context_attention_forward( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor query_start_loc, torch::Tensor seq_lens, + int max_input_length) { + // TODO: Dispatch to different query.scalar_type() if needed. + int64_t num_tokens = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_dim = query.size(2); + int64_t batch_size = seq_lens.size(0); + int num_kv_heads = value.size(1); + + int key_dimension = key.dim(); + auto output = at::empty({query.size(0), query.size(1), query.size(2)}, + at::device(query.device()).dtype(query.dtype())); + + // key should be in shape: + // 1. [num_tokens, num_kv_head, head_dim] + assert(key_dimension == 3); + assert(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type()); + assert(head_dim == 128); + assert(query.scalar_type() == at::ScalarType::Half); + + int query_stride_token = query.stride(0); + int query_stride_head = query.stride(1); + int query_stride_dim = query.stride(2); + const float attn_scale = 1 / std::sqrt((float)head_dim); + + assert(num_heads % num_kv_heads == 0); + int num_queries_per_kv = num_heads / num_kv_heads; + + assert(num_tokens == key.size(0) && key.size(0) == value.size(0)); + int key_stride_token = key.stride(0); + int key_stride_head = key.stride(1); + // Probably the same with query/value + int key_stride_dim = key.stride(2); + int value_stride_bs = value.stride(0); + int value_stride_head = value.stride(1); + int value_stride_dim = value.stride(2); + + + // if (head_dim == 128) { + // vllm::context_attention_kernel_kv( + // query.data_ptr(), key.data_ptr(), value.data_ptr(), + // seq_lens.data_ptr(), attn_scale, output.data_ptr(), + // query_start_loc.data_ptr(), batch_size, num_heads, max_input_length, + // num_queries_per_kv, query.stride(0), query.stride(1), key.stride(0), + // key.stride(1), value.stride(0), value.stride(1), output.stride(0), + // output.stride(1)); + // } else if (head_dim == 64) { + // vllm::context_attention_kernel_kv( + // query.data_ptr(), key.data_ptr(), value.data_ptr(), + // seq_lens.data_ptr(), attn_scale, output.data_ptr(), + // query_start_loc.data_ptr(), batch_size, num_heads, max_input_length, + // num_queries_per_kv, query.stride(0), query.stride(1), key.stride(0), + // key.stride(1), value.stride(0), value.stride(1), output.stride(0), + // output.stride(1)); + // } else { + // throw std::runtime_error( + // "unsupported head_dim, only 128, and 64 are supported"); + // } + + auto func = [&]() { + switch (head_dim) { + case 128: + return vllm::context_attention_kernel_kv; + case 64: + return vllm::context_attention_kernel_kv; + default: + throw std::runtime_error( + "unsupported head_dim, only 128, and 64 are supported"); + } + }(); + func(query.data_ptr(), key.data_ptr(), value.data_ptr(), + seq_lens.data_ptr(), attn_scale, output.data_ptr(), + query_start_loc.data_ptr(), batch_size, num_heads, max_input_length, + num_queries_per_kv, query.stride(0), query.stride(1), key.stride(0), + key.stride(1), value.stride(0), value.stride(1), output.stride(0), + output.stride(1)); + return output; } \ No newline at end of file diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp index 4e7f2fa6bd8..62dfecf289b 100644 --- a/csrc/xpu/pybind.cpp +++ b/csrc/xpu/pybind.cpp @@ -75,4 +75,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "awq_dequantize", &awq_dequantize, "dequant method for awq"); + + ops.def("context_attention_forward", &context_attention_forward, + "Context attention forward"); } diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h index 6125b19ac80..891521522c6 100644 --- a/csrc/xpu/xpu_ops.h +++ b/csrc/xpu/xpu_ops.h @@ -50,6 +50,13 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, torch::Tensor &slot_mapping, const std::string& kv_cache_dtype, const float kv_scale); +torch::Tensor context_attention_forward( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor query_start_loc, torch::Tensor seq_lens, + int max_input_length); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 23b3e6ec8c0..bbf6e79ac5b 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -67,6 +67,7 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] seqlen_q: Optional[torch.Tensor] max_seqlen: Optional[int] + query_start_loc: Optional[torch.Tensor] def __post_init__(self): # Set during the execution of the first attention op. @@ -127,6 +128,16 @@ def use_sdp_causal(head_dim, query_states): ) +def use_vllm_sdp(head_dim, query_states): + if head_dim in [-1, 80, 96]: + print("WARNING: Encounter case with non-supported sdp kernel") + return ( + head_dim in [64, 128] # for now + and query_states.device.type == "xpu" # GPU + and query_states.dtype in [torch.half] # fp32/fp16 + ) + + class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): def __init__( @@ -196,6 +207,7 @@ def split_kv_cache( value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache + def forward( self, query: torch.Tensor, @@ -226,9 +238,9 @@ def forward( "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + query = query.view(-1, self.num_heads, self.head_size).contiguous() + key = key.view(-1, self.num_kv_heads, self.head_size).contiguous() + value = value.view(-1, self.num_kv_heads, self.head_size).contiguous() if kv_cache is not None: key_cache, value_cache = self.split_kv_cache( @@ -252,71 +264,54 @@ def forward( value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: - if self.alibi_slopes is not None: - att_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore - elif self.sliding_window is not None: - att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore - else: - att_masks = [None] * len(attn_metadata.seq_lens) - attn_metadata.attn_bias = att_masks - - # output = torch.empty( - # (num_tokens, self.num_heads, self.head_size), - # dtype=query.dtype, - # device=query.device) - # ipex_ops.varlen_attention(query, - # key, - # value, - # output, - # attn_metadata.seqlen_q, - # attn_metadata.seqlen_q, - # attn_metadata.max_seqlen, - # attn_metadata.max_seqlen, - # pdropout=0.0, - # softmax_scale=self.scale, - # zero_tensors=False, - # is_causal=True, - # return_softmax=False, - # gen_=None) - - output = torch.empty( + if use_vllm_sdp(self.head_size, query): + import vllm._C.ops + output = vllm._C.ops.context_attention_forward(query, key, value, attn_metadata.query_start_loc, attn_metadata.seq_lens_tensor, max(attn_metadata.seq_lens)) + else: + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype, device=query.device) - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - start = 0 - for seq_len, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): - end = start + seq_len - if use_sdp_causal(self.head_size, query): - import xe_addons - if mask is not None: - mask = mask.unsqueeze(0) - sub_out = xe_addons.sdp_causal( - query[None, :, start:end, :].contiguous(), - key[None, :, start:end, :].contiguous(), - value[None, :, start:end, :].contiguous(), - mask).squeeze(0).movedim( - query.dim() - 2, 0) - else: - sub_out = torch.nn.functional.scaled_dot_product_attention( - query[None, :, start:end, :], - key[None, :, start:end, :], - value[None, :, start:end, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=not self.need_mask, - scale=self.scale).squeeze(0).movedim( - query.dim() - 2, 0) - output[start:end, :, :] = sub_out - start = end + # Prepare attention_mask + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = [None] * len(attn_metadata.seq_lens) + attn_metadata.attn_bias = att_masks + + start = 0 + for seq_len, mask in zip(attn_metadata.seq_lens, + attn_metadata.attn_bias): + end = start + seq_len + if use_sdp_causal(self.head_size, query): + import xe_addons + if mask is not None: + mask = mask.unsqueeze(0) + sub_out = xe_addons.sdp_causal( + query[None, :, start:end, :].contiguous(), + key[None, :, start:end, :].contiguous(), + value[None, :, start:end, :].contiguous(), + mask).squeeze(0).movedim(query.dim() - 2, 0) + else: + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=not self.need_mask, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end else: # prefix-enabled attention raise RuntimeError( diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 025449cfe48..454b50be38f 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -343,6 +343,7 @@ def _prepare_decode( num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, + query_start_loc=None ) return ( input_tokens, @@ -469,6 +470,17 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + # Prepare query_start_loc + query_start_loc = torch.zeros(len(seq_lens) + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) max_seqlen = max(seq_lens) tmp = [0] @@ -482,12 +494,13 @@ def _prepare_prompt( seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, - seq_lens_tensor=None, + seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), + query_start_loc=query_start_loc, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)