diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 1072af100..d2340b5c9 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1484,7 +1484,6 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); - auto block = cg::this_thread_block(); const uint32_t kv_chunk_size = *kv_chunk_size_ptr; const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x, @@ -1492,14 +1491,21 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage if (block_valid_mask && !block_valid_mask[bx]) { return; } - const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; - float alibi_slopes[num_frags_x][2]; const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; + const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx]; + + if (qo_len == 0) { + // Fail fast if query is empty. May happen with CUDA graphs. + return; + } + + const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; + float alibi_slopes[num_frags_x][2]; + constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], - kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx]) + const uint32_t kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx]) ? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) * paged_kv.page_size + paged_kv.last_page_len[request_idx] @@ -1514,6 +1520,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + auto block = cg::this_thread_block(); extern __shared__ uint8_t smem[]; DTypeQKAccum s_frag[num_frags_x][num_frags_z][8];