Skip to content

Commit 7942fa7

Browse files
Whisper Redesigned Solution (#23549)
### Description This PR re-designs how Whisper is created and supported in ONNX Runtime. The new solution leverages [previous optimization work](#15473), and it is designed to be used in conjunction with [this work](microsoft/onnxruntime-genai#1229) in ONNX Runtime GenAI. Some of the added changes include: - Re-designed export that creates new ONNX models without needing a `WhisperBeamSearch` op - Creates one encoder model that also pre-computes the cross-attention KV caches (since they only need to be calculated once) - Creates one decoder model that can be used during pre-fill and token generation - Creates one jump-times model that can be used for word-level timestamps - Removes need for a `WhisperBeamSearch` op to chain the encoder and decoder subgraphs - Removes need to duplicate decoder's weights in memory - Previous solution with the `WhisperBeamSearch` op created an encoder-decoder-init model and decoder-with-past model. The decoder was duplicated twice, one in each. - Removes need for separate logic to export the PyTorch model coming from OpenAI vs. the PyTorch model coming from Hugging Face - Re-factors common parameters and logic used in CPU and CUDA attention kernels - Adds `DUMP_STRING` to enable easy logging of intermediate information when running in debug mode to debug a problem. This info is not printed in release mode so it will not impact performance. - Integrates `DecoderMaskedMultiHeadAttention` into `MultiHeadAttention` - Enables past-present buffer sharing in the `MultiHeadAttention` op for improved performance - Adds `cache_indirection` and `past_sequence_length` as new optional inputs to `MultiHeadAttention` - Adds `output_qk` as new optional output to `MultiHeadAttention` - Enables calculating `output_qk` tensor with FP16 or FP32 precision, regardless of the model's precision - CI tests that run end-to-end across various flag combinations that are used by many customers internally and externally The existing solutions are still available if desired. ### Known Issues - The FP32 CPU model with the `WhisperBeamSearch` op and output QK is currently disabled. This is because ONNX Runtime doesn't currently support output QK kernels on CPU, only on CUDA. - The `DecoderMaskedMultiHeadAttention` CPU kernel has a parity mismatch with the `DecoderMaskedMultiHeadAttention` CUDA kernel. - Using `DecoderMaskedMultiHeadAttention` for the FP32 CPU model is not enabled. Currently, it uses `MultiHeadAttention` to avoid the parity mismatch issue. ### Motivation and Context Using the beam search op has made it more difficult to debug and fix errors that are encountered. This new approach is more flexible and more customizable for users (e.g. by running with ONNX Runtime GenAI). It also helps [this issue](#18216). --------- Co-authored-by: mindest <[email protected]>
1 parent ae501ee commit 7942fa7

File tree

95 files changed

+6352
-2674
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+6352
-2674
lines changed

docs/ContribOperators.md

+17-9
Original file line numberDiff line numberDiff line change
@@ -1191,17 +1191,17 @@ This version of the operator has been available since version 1 of the 'com.micr
11911191
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
11921192
<dt><tt>present_value</tt> (optional) : T</dt>
11931193
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
1194-
<dt><tt>qk</tt> (optional) : V</dt>
1194+
<dt><tt>qk</tt> (optional) : QK</dt>
11951195
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). </dd>
11961196
</dl>
11971197

11981198
#### Type Constraints
11991199

12001200
<dl>
1201-
<dt><tt>V</tt> : tensor(float)</dt>
1202-
<dd>Constrain qk output types to float32 tensors.</dd>
12031201
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
12041202
<dd>Constrain input and output types to float tensors.</dd>
1203+
<dt><tt>QK</tt> : tensor(float), tensor(float16)</dt>
1204+
<dd>Constrain QK output to float32 or float16 tensors, independent of input type or output type.</dd>
12051205
<dt><tt>M</tt> : tensor(int32)</dt>
12061206
<dd>Constrain mask index to integer types</dd>
12071207
</dl>
@@ -3203,7 +3203,7 @@ This version of the operator has been available since version 1 of the 'com.micr
32033203
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
32043204
</dl>
32053205

3206-
#### Inputs (1 - 8)
3206+
#### Inputs (1 - 10)
32073207

32083208
<dl>
32093209
<dt><tt>query</tt> : T</dt>
@@ -3219,27 +3219,35 @@ This version of the operator has been available since version 1 of the 'com.micr
32193219
<dt><tt>attention_bias</tt> (optional) : T</dt>
32203220
<dd>bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
32213221
<dt><tt>past_key</tt> (optional) : T</dt>
3222-
<dd>past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
3222+
<dd>past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
32233223
<dt><tt>past_value</tt> (optional) : T</dt>
3224-
<dd>past state for self attention value with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
3224+
<dd>past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
3225+
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
3226+
<dd>The past_sequence_length buffer sharing is used with</dd>
3227+
<dt><tt>cache_indirection</tt> (optional) : M</dt>
3228+
<dd>A buffer of shape [batch_size, beam_width, max_sequence_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
32253229
</dl>
32263230

3227-
#### Outputs (1 - 3)
3231+
#### Outputs (1 - 4)
32283232

32293233
<dl>
32303234
<dt><tt>output</tt> : T</dt>
32313235
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
32323236
<dt><tt>present_key</tt> (optional) : T</dt>
3233-
<dd>present state for cross attention key with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention key with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
3237+
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
32343238
<dt><tt>present_value</tt> (optional) : T</dt>
3235-
<dd>present state for cross attention value with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
3239+
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
3240+
<dt><tt>qk</tt> (optional) : QK</dt>
3241+
<dd>normalized Q * K, of shape (batch_size, num_heads, sequence_length, total_sequence_length). </dd>
32363242
</dl>
32373243

32383244
#### Type Constraints
32393245

32403246
<dl>
32413247
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
32423248
<dd>Constrain input and output to float tensors.</dd>
3249+
<dt><tt>QK</tt> : tensor(float), tensor(float16)</dt>
3250+
<dd>Constrain QK output to float32 or float16 tensors, independent of input type or output type.</dd>
32433251
<dt><tt>M</tt> : tensor(int32)</dt>
32443252
<dd>Constrain mask to integer types</dd>
32453253
</dl>

docs/OperatorKernels.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ Do not modify directly.*
504504
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
505505
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
506506
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
507-
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**V**|1+|**T** = tensor(float)|
507+
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**T** = tensor(float)|
508508
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float)|
509509
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
510510
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
@@ -528,7 +528,7 @@ Do not modify directly.*
528528
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
529529
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
530530
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
531-
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
531+
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**T** = tensor(float)|
532532
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
533533
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
534534
|NhwcMaxPool|*in* x:**T**<br> *out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
@@ -906,7 +906,7 @@ Do not modify directly.*
906906
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
907907
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
908908
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
909-
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**V**|1+|**T** = tensor(float), tensor(float16)|
909+
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
910910
|DecoderMaskedSelfAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* attention_bias:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
911911
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
912912
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
@@ -929,7 +929,7 @@ Do not modify directly.*
929929
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
930930
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
931931
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
932-
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
932+
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
933933
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
934934
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
935935
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1402,7 +1402,7 @@ Do not modify directly.*
14021402
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
14031403
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
14041404
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
1405-
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
1405+
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
14061406
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
14071407
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
14081408
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|

onnxruntime/contrib_ops/cpu/bert/attention.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
335335

336336
// Compute the attention score and apply the score to V
337337
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
338-
output, nullptr /* present_key */, nullptr /* present_value */,
338+
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* output_qk */,
339339
batch_size, sequence_length, sequence_length,
340340
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
341341
attention_bias, context);

onnxruntime/contrib_ops/cpu/bert/attention_base.h

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "core/common/common.h"
88
#include "core/framework/op_kernel.h"
99
#include "contrib_ops/cpu/bert/attention_common.h"
10+
#include "contrib_ops/cpu/bert/attention_parameters.h"
1011

1112
namespace onnxruntime {
1213
namespace contrib {

0 commit comments

Comments
 (0)