-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper Redesigned Solution #23549
base: main
Are you sure you want to change the base?
Whisper Redesigned Solution #23549
Conversation
return model | ||
|
||
|
||
def fix_past_sequence_length(model: ModelProto): |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py
Dismissed
Show dismissed
Hide dismissed
diff = np.abs(pt_outputs[i] - ort_outputs[i]) | ||
logger.warning(f"Comparing {output_name}...") | ||
logger.warning(f"Max diff: {np.max(diff)}") | ||
except: # noqa: E722 |
Check notice
Code scanning / CodeQL
Except block handles 'BaseException' Note
onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
Dismissed
Show dismissed
Hide dismissed
onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
Dismissed
Show dismissed
Hide dismissed
@@ -0,0 +1,195 @@ | |||
import numpy as np | |||
|
|||
import onnxruntime as ort |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
@@ -3,6 +3,7 @@ | |||
|
|||
#include "contrib_ops/cpu/bert/attention_base.h" | |||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h" | |||
#include "contrib_ops/cpu/utils/dump_tensor.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might not need. I did not see code in this file that dumps tensor.
|
||
auto* tp = context->GetOperatorThreadPool(); | ||
|
||
int total_sequence_length = past_sequence_length + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_sequence_length = past_sequence_length + sequence_length. So this function assumes sequence_length is 1? Maybe add a comment that it is for decoder masked mha.
if (output_qk != nullptr) { | ||
const ptrdiff_t attention_probs_size = SafeInt<ptrdiff_t>(batch_size * num_heads_ * sequence_length * total_sequence_length); | ||
const ptrdiff_t attention_probs_bytes = attention_probs_size * sizeof(T); | ||
memcpy(output_qk, attention_probs, attention_probs_bytes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it duplicated since line 316 has copied the data?
// If optional outputs aren't needed, present_key, present_value, and output_qk will be null | ||
std::vector<int64_t> present_key_shape({static_cast<int64_t>(batch_size), | ||
static_cast<int64_t>(num_heads_), | ||
static_cast<int64_t>(parameters.max_sequence_length), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that it is to support shared buffer for past and present.
previously, this dim is total sequence length (when no shared buffer). Is this change compatible with that (or max_sequence_length is set to total sequence length when no shared buffer)?
) | ||
if reshape_qkv_2_path is None: | ||
if reshape_qkv_path[-1].input[0] != matmul_qkv.output[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape_qkv_2_path might be None.
@@ -541,6 +541,7 @@ def create_multihead_attention_node( | |||
output: str, | |||
key_padding_mask: str = "", | |||
add_qk: str = "", | |||
unidirectional: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check whether this function is called in derived classes, and update the reference if needed.
# From wheel: | ||
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp16 --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk | ||
``` | ||
|
||
## Exporting Whisper with Beam Search |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supported still, deprecated, or should this variant be removed?
rm -rf wtiny-fp16-cuda-oai ; \ | ||
popd ; \ | ||
' | ||
displayName: 'Test Whisper export flag combinations' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have/plan to add any functional tests that verify FW pass output for these converted models actually match expected?
Description
This PR re-designs how Whisper is created and supported in ONNX Runtime. The new solution leverages previous optimization work, and it is designed to be used in conjunction with this work in ONNX Runtime GenAI.
Some of the added changes include:
WhisperBeamSearch
opWhisperBeamSearch
op to chain the encoder and decoder subgraphsWhisperBeamSearch
op created an encoder-decoder-init model and decoder-with-past model. The decoder was duplicated twice, one in each.DUMP_STRING
to enable easy logging of intermediate information when running in debug mode to debug a problem. This info is not printed in release mode so it will not impact performance.DecoderMaskedMultiHeadAttention
intoMultiHeadAttention
MultiHeadAttention
op for improved performancecache_indirection
andpast_sequence_length
as new optional inputs toMultiHeadAttention
output_qk
as new optional output toMultiHeadAttention
output_qk
tensor with FP16 or FP32 precision, regardless of the model's precisionThe existing solutions are still available if desired.
Known Issues
WhisperBeamSearch
op and output QK is currently disabled. This is because ONNX Runtime doesn't currently support output QK kernels on CPU, only on CUDA.Neg --> Shape
in the jump times model when exporting the model to contain theWhisperBeamSearch
op.DecoderMaskedMultiHeadAttention
CPU kernel has a parity mismatch with theDecoderMaskedMultiHeadAttention
CUDA kernel.DecoderMaskedMultiHeadAttention
for the FP32 CPU model is not enabled. Currently, it usesMultiHeadAttention
to avoid the parity mismatch issue.Motivation and Context
Using the beam search op has made it more difficult to debug and fix errors that are encountered. This new approach is more flexible and more customizable for users (e.g. by running with ONNX Runtime GenAI). It also helps this issue.