Skip to content

Commit 0c93a3b

Browse files
committed
zero kv
1 parent ee1b719 commit 0c93a3b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
3737
out = *out_;
3838
}
3939
else {
40-
out = torch::zeros_like(q);
40+
out = torch::zeros_like(q).to(torch::kFloat32);
4141
}
4242

4343
cutlass_chunk_prefill_impl(

csrc/xpu/cutlass_kernels/chunk_prefill.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct chunk_prefill_args_t {
3232
void* num_blocks_per_seq;
3333
void* cu_seqlens_q;
3434
void* cu_seqlens_k;
35+
void* cu_seqlens_k_zeros;
3536
int max_queries;
3637
int max_keys;
3738
int total_seqlen_q;
@@ -88,8 +89,8 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct KernelLauncher {
8889
args.num_heads_q,
8990
args.num_heads_k,
9091
cutlass::fmha::collective::VariableLength{args.max_queries}, // cu_q
91-
cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_k
92-
cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_v
92+
cutlass::fmha::collective::VariableLength{0}, // cu_kv
93+
cutlass::fmha::collective::VariableLength{args.max_keys}, // cu_kv_cache
9394
args.head_size,
9495
args.head_size);
9596
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
@@ -106,7 +107,7 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct KernelLauncher {
106107
stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch));
107108

108109
get<3>(problem_shape_out).cumulative_length = reinterpret_cast<int*>(args.cu_seqlens_q);
109-
get<4>(problem_shape_out).cumulative_length = reinterpret_cast<int*>(args.cu_seqlens_k);
110+
get<4>(problem_shape_out).cumulative_length = reinterpret_cast<int*>(args.cu_seqlens_k_zeros);
110111
get<5>(problem_shape_out).cumulative_length = reinterpret_cast<int*>(args.cu_seqlens_k);
111112

112113
return problem_shape_out;
@@ -256,7 +257,7 @@ template<typename chunk_policy>
256257
void policy_dispatch(
257258
CutlassType cuType,
258259
const chunk_prefill_args_t& args) {
259-
const int PipelineStages = 2;
260+
const int PipelineStages = 0;
260261
if(cuType == CutlassType::half) {
261262
FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
262263
typename chunk_policy::ShapeOutPut, typename chunk_policy::SubgroupLayout, PipelineStages,
@@ -303,6 +304,7 @@ void cutlass_chunk_prefill_impl(
303304
int total_seqlen_q = query.size(0);
304305
int total_seqlen_k = num_block * block_size;
305306
at::Tensor num_blocks_per_seq = torch::div(cu_seqlens_k, block_size);
307+
at::Tensor cu_seqlens_k_zeros = torch.zeros_like(cu_seqlens_k);
306308

307309
chunk_prefill_args_t args = {
308310
query.data_ptr(),
@@ -313,6 +315,7 @@ void cutlass_chunk_prefill_impl(
313315
num_blocks_per_seq.data_ptr(),
314316
cu_seqlens_q.data_ptr(),
315317
cu_seqlens_k.data_ptr(),
318+
cu_seqlens_k_zeros.data_ptr(),
316319
max_seqlen_q,
317320
max_seqlen_k,
318321
total_seqlen_q,

0 commit comments

Comments
 (0)