Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Redesigned Solution #23549

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f314287
Add support for creating optimized whisper ONNX models without beam s…
kunal-vaishnavi Apr 26, 2024
6a44f72
Fix incorrect dynamic axes labels
kunal-vaishnavi Apr 26, 2024
58ec5eb
Fix fusion breaks for OpenAI implementation of Whisper
kunal-vaishnavi May 3, 2024
4c228ea
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Jun 12, 2024
dd20876
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Jul 23, 2024
b13cb22
Comment out DMMHA case temporarily
kunal-vaishnavi Jul 23, 2024
31db1a0
Replace MHA with DMMHA
kunal-vaishnavi Jul 29, 2024
3b92432
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Aug 26, 2024
7bb79f3
Debugging beam search output
kunal-vaishnavi Sep 6, 2024
14b7e77
Initial commit for new export
kunal-vaishnavi Oct 22, 2024
fa345fe
Add parity check after export and optimization
kunal-vaishnavi Oct 22, 2024
e050dea
Fix multiple attention kernel invocations
kunal-vaishnavi Nov 2, 2024
bf87062
Make output Q*K values optional
kunal-vaishnavi Nov 4, 2024
17fa0ab
Fix batch size check for cache indirection
kunal-vaishnavi Nov 6, 2024
52aeb58
Save checkpoint for working solution
kunal-vaishnavi Nov 15, 2024
240fe3b
Clean up code
kunal-vaishnavi Nov 17, 2024
ae98085
Fix string dumping
kunal-vaishnavi Nov 20, 2024
3d2c8fe
Fix out_qk dtype issue for half input case.
mindest Nov 20, 2024
287151f
Remove type cast for output QK
kunal-vaishnavi Nov 21, 2024
0805d1d
Enable release mode build
kunal-vaishnavi Dec 4, 2024
b629903
Make QK output dtype independent of attention dtype
kunal-vaishnavi Dec 9, 2024
648b389
Add batched jump times export
kunal-vaishnavi Dec 9, 2024
a6c6ee8
Get batched jump times ONNX model with parity check
kunal-vaishnavi Dec 12, 2024
c0a6ce4
Save checkpoint for working solution
kunal-vaishnavi Dec 21, 2024
008eeb9
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Dec 22, 2024
158d0a8
Fix build after merge
kunal-vaishnavi Dec 22, 2024
02cb5be
Fix model with beam search op
kunal-vaishnavi Dec 23, 2024
2acd593
Get model impl and beam search op export combinations working
kunal-vaishnavi Dec 25, 2024
612eb0c
Enable separate export of encoder and decoder init
kunal-vaishnavi Dec 25, 2024
f2d78fd
Add tests for multiple export types to CIs
kunal-vaishnavi Dec 25, 2024
cb93517
Update folder and file names in Whisper README
kunal-vaishnavi Dec 25, 2024
6da11ec
Add FP32 CPU DMMHA support
kunal-vaishnavi Dec 28, 2024
9640736
Add unit tests
kunal-vaishnavi Jan 8, 2025
75a342a
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Jan 24, 2025
7fe6b05
Change debug message for PrepareQkv
kunal-vaishnavi Jan 25, 2025
8620168
Fix seqlens_k after merge
kunal-vaishnavi Jan 29, 2025
b0a732b
Merge branch 'main' into kvaishnavi/whisper-separate-export
kunal-vaishnavi Jan 29, 2025
23808f7
Add changes suggested by linter
kunal-vaishnavi Jan 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* output_qk */,
batch_size, sequence_length, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
attention_bias, context);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cpu/utils/dump_tensor.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not need. I did not see code in this file that dumps tensor.

#include "core/providers/common.h"

namespace onnxruntime {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cpu/bert/attention_parameters.h"

namespace onnxruntime {
namespace contrib {
Expand Down
144 changes: 5 additions & 139 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,148 +49,10 @@ enum AttentionKernelType {
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_LeanAttention,
AttentionKernel_FtCausalAttention,
AttentionKernel_Default
};

// Parameters deduced from node attributes and inputs/outputs.
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;
float mask_filter_value;
float scale;
bool use_tf32;
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};

struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Whether to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

void* k = nullptr;
void* k_bias = nullptr;

void* v = nullptr;
void* v_bias = nullptr;

void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;

void* out = nullptr;
void* out_qk = nullptr;

const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
int sequence_length;
int input_hidden_size; // hidden size of input
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
float scale;
int token_count;
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;
bool use_tf32;
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length; // sequence length of input query, key, value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
bool is_first_prompt; // indicates whether this is first decoding step
bool do_rotary;
bool rotary_interleaved;
bool use_smooth_softmax;
float scale;
float softcap;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};

// Parameters for sparse attention.
struct SparseAttentionParameters {
int batch_size; // batch size
int sequence_length; // sequence length of input query, key, value
int hidden_size; // hidden size of query
int num_heads; // number of heads of query
int head_size; // hidden size per head of query, key or value
int kv_hidden_size; // hidden size of key or value
int kv_num_heads; // number of heads of key or value
bool do_rotary; // whether to use rotary embedding
bool rotary_interleaved; // whether to use interleaved rotary embedding
int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
int num_sparse_layout; // number of sparse layout
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
float scale; // scaling factor applied prior to softmax
bool is_packed_qkv; // whether qkv is packed
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
int max_sequence_length; // max sequence length for sparse layout
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
int max_cache_sequence_length; // max sequence length for kv cache buffer
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
};

constexpr bool LAYOUT_BSNH = false;
constexpr bool LAYOUT_BNSH = true;

Expand All @@ -215,6 +77,7 @@ enum class AttentionBackend : int {

// Experimental kernels
LEAN_ATTENTION = 256,
FT_CAUSAL_ATTENTION = 512, // FasterTransformer's decoder masked multihead attention
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
Expand Down Expand Up @@ -245,6 +108,9 @@ constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
// Environment variable to enable or disable lean attention. Default is 0 (disabled).
constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION";

// Environment variable to enable or disable FasterTransformer's decoder masked multi-head attention. Default is 0 (enabled).
constexpr const char* kDisableFtCausalAttention = "ORT_DISABLE_FT_CAUSAL_ATTENTION";

// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

Expand Down
Loading
Loading