You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
This PR re-designs how Whisper is created and supported in ONNX Runtime.
The new solution leverages [previous optimization
work](#15473), and it is
designed to be used in conjunction with [this
work](microsoft/onnxruntime-genai#1229) in ONNX
Runtime GenAI.
Some of the added changes include:
- Re-designed export that creates new ONNX models without needing a
`WhisperBeamSearch` op
- Creates one encoder model that also pre-computes the cross-attention
KV caches (since they only need to be calculated once)
- Creates one decoder model that can be used during pre-fill and token
generation
- Creates one jump-times model that can be used for word-level
timestamps
- Removes need for a `WhisperBeamSearch` op to chain the encoder and
decoder subgraphs
- Removes need to duplicate decoder's weights in memory
- Previous solution with the `WhisperBeamSearch` op created an
encoder-decoder-init model and decoder-with-past model. The decoder was
duplicated twice, one in each.
- Removes need for separate logic to export the PyTorch model coming
from OpenAI vs. the PyTorch model coming from Hugging Face
- Re-factors common parameters and logic used in CPU and CUDA attention
kernels
- Adds `DUMP_STRING` to enable easy logging of intermediate information
when running in debug mode to debug a problem. This info is not printed
in release mode so it will not impact performance.
- Integrates `DecoderMaskedMultiHeadAttention` into `MultiHeadAttention`
- Enables past-present buffer sharing in the `MultiHeadAttention` op for
improved performance
- Adds `cache_indirection` and `past_sequence_length` as new optional
inputs to `MultiHeadAttention`
- Adds `output_qk` as new optional output to `MultiHeadAttention`
- Enables calculating `output_qk` tensor with FP16 or FP32 precision,
regardless of the model's precision
- CI tests that run end-to-end across various flag combinations that are
used by many customers internally and externally
The existing solutions are still available if desired.
### Known Issues
- The FP32 CPU model with the `WhisperBeamSearch` op and output QK is
currently disabled. This is because ONNX Runtime doesn't currently
support output QK kernels on CPU, only on CUDA.
- The `DecoderMaskedMultiHeadAttention` CPU kernel has a parity mismatch
with the `DecoderMaskedMultiHeadAttention` CUDA kernel.
- Using `DecoderMaskedMultiHeadAttention` for the FP32 CPU model is not
enabled. Currently, it uses `MultiHeadAttention` to avoid the parity
mismatch issue.
### Motivation and Context
Using the beam search op has made it more difficult to debug and fix
errors that are encountered. This new approach is more flexible and more
customizable for users (e.g. by running with ONNX Runtime GenAI). It
also helps [this
issue](#18216).
---------
Co-authored-by: mindest <[email protected]>
Copy file name to clipboardexpand all lines: docs/ContribOperators.md
+17-9
Original file line number
Diff line number
Diff line change
@@ -1191,17 +1191,17 @@ This version of the operator has been available since version 1 of the 'com.micr
1191
1191
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
1192
1192
<dt><tt>present_value</tt> (optional) : T</dt>
1193
1193
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dd>Constrain QK output to float32 or float16 tensors, independent of input type or output type.</dd>
1205
1205
<dt><tt>M</tt> : tensor(int32)</dt>
1206
1206
<dd>Constrain mask index to integer types</dd>
1207
1207
</dl>
@@ -3203,7 +3203,7 @@ This version of the operator has been available since version 1 of the 'com.micr
3203
3203
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
3204
3204
</dl>
3205
3205
3206
-
#### Inputs (1 - 8)
3206
+
#### Inputs (1 - 10)
3207
3207
3208
3208
<dl>
3209
3209
<dt><tt>query</tt> : T</dt>
@@ -3219,27 +3219,35 @@ This version of the operator has been available since version 1 of the 'com.micr
3219
3219
<dt><tt>attention_bias</tt> (optional) : T</dt>
3220
3220
<dd>bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
3221
3221
<dt><tt>past_key</tt> (optional) : T</dt>
3222
-
<dd>past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
3222
+
<dd>past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
3223
3223
<dt><tt>past_value</tt> (optional) : T</dt>
3224
-
<dd>past state for self attention value with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
3224
+
<dd>past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
<dd>A buffer of shape [batch_size, beam_width, max_sequence_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
3225
3229
</dl>
3226
3230
3227
-
#### Outputs (1 - 3)
3231
+
#### Outputs (1 - 4)
3228
3232
3229
3233
<dl>
3230
3234
<dt><tt>output</tt> : T</dt>
3231
3235
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
3232
3236
<dt><tt>present_key</tt> (optional) : T</dt>
3233
-
<dd>present state for cross attention key with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention key with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
3237
+
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size)or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
3234
3238
<dt><tt>present_value</tt> (optional) : T</dt>
3235
-
<dd>present state for cross attention value with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
3239
+
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size) or (batch_size, num_heads, max_sequence_length, head_size) when buffer sharing is used</dd>
0 commit comments