@@ -32,6 +32,7 @@ struct chunk_prefill_args_t {
32
32
void * num_blocks_per_seq;
33
33
void * cu_seqlens_q;
34
34
void * cu_seqlens_k;
35
+ void * cu_seqlens_k_zeros;
35
36
int max_queries;
36
37
int max_keys;
37
38
int total_seqlen_q;
@@ -88,8 +89,8 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct KernelLauncher {
88
89
args.num_heads_q ,
89
90
args.num_heads_k ,
90
91
cutlass::fmha::collective::VariableLength{args.max_queries }, // cu_q
91
- cutlass::fmha::collective::VariableLength{args. max_keys }, // cu_k
92
- cutlass::fmha::collective::VariableLength{args.max_keys }, // cu_v
92
+ cutlass::fmha::collective::VariableLength{0 }, // cu_kv
93
+ cutlass::fmha::collective::VariableLength{args.max_keys }, // cu_kv_cache
93
94
args.head_size ,
94
95
args.head_size );
95
96
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
@@ -106,7 +107,7 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct KernelLauncher {
106
107
stride_O = cutlass::make_cute_packed_stride (StrideO{}, cute::make_shape (seq_len_qo * group_q_size, group_q_num * head_size_vo, batch));
107
108
108
109
get<3 >(problem_shape_out).cumulative_length = reinterpret_cast <int *>(args.cu_seqlens_q );
109
- get<4 >(problem_shape_out).cumulative_length = reinterpret_cast <int *>(args.cu_seqlens_k );
110
+ get<4 >(problem_shape_out).cumulative_length = reinterpret_cast <int *>(args.cu_seqlens_k_zeros );
110
111
get<5 >(problem_shape_out).cumulative_length = reinterpret_cast <int *>(args.cu_seqlens_k );
111
112
112
113
return problem_shape_out;
@@ -256,7 +257,7 @@ template<typename chunk_policy>
256
257
void policy_dispatch (
257
258
CutlassType cuType,
258
259
const chunk_prefill_args_t & args) {
259
- const int PipelineStages = 2 ;
260
+ const int PipelineStages = 0 ;
260
261
if (cuType == CutlassType::half) {
261
262
FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
262
263
typename chunk_policy::ShapeOutPut, typename chunk_policy::SubgroupLayout, PipelineStages,
@@ -303,6 +304,7 @@ void cutlass_chunk_prefill_impl(
303
304
int total_seqlen_q = query.size (0 );
304
305
int total_seqlen_k = num_block * block_size;
305
306
at::Tensor num_blocks_per_seq = torch::div (cu_seqlens_k, block_size);
307
+ at::Tensor cu_seqlens_k_zeros = torch.zeros_like (cu_seqlens_k);
306
308
307
309
chunk_prefill_args_t args = {
308
310
query.data_ptr (),
@@ -313,6 +315,7 @@ void cutlass_chunk_prefill_impl(
313
315
num_blocks_per_seq.data_ptr (),
314
316
cu_seqlens_q.data_ptr (),
315
317
cu_seqlens_k.data_ptr (),
318
+ cu_seqlens_k_zeros.data_ptr (),
316
319
max_seqlen_q,
317
320
max_seqlen_k,
318
321
total_seqlen_q,
0 commit comments