Skip to content

Commit 03dc351

Browse files
committed
Fix FlashAttentionDecodeSplitVx indirect dispatch input ordering
Move SetIndirectDispatchTensor after all AddInput calls to ensure the indirect buffer is the last program input. When head_sink was added after SetIndirectDispatchTensor, the shader variable types were swapped, causing a u32*f16 WGSL compilation error.
1 parent 722743c commit 03dc351

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,18 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
344344
program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
345345
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
346346
if (use_indirect_dispatch) {
347-
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None})
348-
.SetIndirectDispatchTensor(indirect_buffer);
347+
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
349348
} else {
350349
program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
351350
}
352351
if (has_head_sink) {
353352
program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
354353
}
354+
// SetIndirectDispatchTensor must be called after all AddInput calls because it
355+
// appends the indirect buffer as the last program input.
356+
if (use_indirect_dispatch) {
357+
program.SetIndirectDispatchTensor(indirect_buffer);
358+
}
355359
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
356360
.SetWorkgroupSize(64)
357361
.AddUniformVariables({{static_cast<uint32_t>(parameters.total_sequence_length_)},

0 commit comments

Comments
 (0)