Skip to content

Whisper Redesigned Solution #1229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c2c8745
Rename Whisper encoder input to audio features
kunal-vaishnavi Sep 30, 2024
1d5f4f0
Initial commit for new export
kunal-vaishnavi Oct 22, 2024
5bf4628
Fix KV cache initialization and runtime bugs
kunal-vaishnavi Nov 2, 2024
3cb936e
Add another check for alignment heads input
kunal-vaishnavi Nov 5, 2024
b648f58
Dump logits in ORT GenAI
kunal-vaishnavi Nov 7, 2024
2a5b762
Fix cross QK update
kunal-vaishnavi Nov 14, 2024
e24db74
Fix finalize cross QK
kunal-vaishnavi Nov 15, 2024
e4c838e
Save checkpoint for working solution
kunal-vaishnavi Nov 15, 2024
3a548a1
Clean up code
kunal-vaishnavi Nov 17, 2024
4d9af67
Remove unneeded template instantiations
kunal-vaishnavi Nov 21, 2024
1d9161d
Fixes: update crossQK copy for first step;
mindest Nov 27, 2024
97be76a
Enable getting model inputs to user
kunal-vaishnavi Dec 4, 2024
1bcd264
Add additional check for cache indirection
kunal-vaishnavi Dec 6, 2024
c35a73d
Add audio processing unit test
kunal-vaishnavi Dec 18, 2024
1d5da61
Fix Whisper GenAI config
kunal-vaishnavi Dec 18, 2024
efd0199
Save checkpoint for working solution
kunal-vaishnavi Dec 21, 2024
fbebe68
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Dec 21, 2024
ef955e7
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Feb 5, 2025
e869d02
Squashed commit of the following:
kunal-vaishnavi Feb 6, 2025
32c48d2
Initial changes to work with main
kunal-vaishnavi Feb 17, 2025
e4a8b5f
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Feb 17, 2025
323028a
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Mar 24, 2025
7756a86
Resolving build errors after merging main
kunal-vaishnavi Mar 25, 2025
a167add
Fix prompt length and get input
kunal-vaishnavi Mar 28, 2025
8782b47
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 163 additions & 69 deletions src/config.cpp

Large diffs are not rendered by default.

71 changes: 46 additions & 25 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,32 @@ struct Config {
static constexpr std::string_view PresentKeyName = "present.%d.key";
static constexpr std::string_view PresentValueName = "present.%d.value";

static constexpr std::string_view InputsEmbedsName = "inputs_embeds";
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view promptTemplate = "{Content}";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
// Speech encoder names
static constexpr std::string_view AudioAttentionMaskName = "audio_attention_mask";
static constexpr std::string_view AudioSizesName = "audio_sizes";
static constexpr std::string_view AudioProjectionModeName = "audio_projection_mode";
static constexpr std::string_view AudioFeaturesName = "audio_features";
static constexpr std::string_view EncoderHiddenStatesName = "encoder_hidden_states";
static constexpr std::string_view NumAudioTokens = "num_audio_tokens";

// Vision names
// Vision encoder names
static constexpr std::string_view PixelValuesName = "pixel_values";
static constexpr std::string_view ImageSizesName = "image_sizes";
static constexpr std::string_view ImageFeaturesName = "image_features";
static constexpr std::string_view ImageAttentionMaskName = "image_attention_mask";
static constexpr std::string_view ImageFeaturesName = "image_features";
static constexpr std::string_view NumImageTokens = "num_image_tokens";

// Speech names
static constexpr std::string_view InputFeaturesName = "encoder_input_ids";
// Embedding names
static constexpr std::string_view AudioEmbedsName = "audio_embeds";
static constexpr std::string_view AudioAttentionMaskName = "audio_attention_mask";
static constexpr std::string_view AudioSizesName = "audio_sizes";
static constexpr std::string_view AudioProjectionModeName = "audio_projection_mode";
static constexpr std::string_view AudioFeaturesName = "audio_features";
static constexpr std::string_view NumAudioTokens = "num_audio_tokens";
static constexpr std::string_view InputsEmbedsName = "inputs_embeds";

// Generation names
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
static constexpr std::string_view CacheIndirectionName = "cache_indirection";

static constexpr std::string_view PromptTemplateName = "{Content}";
};

fs::path config_path; // Path of the config directory
Expand Down Expand Up @@ -86,14 +91,28 @@ struct Config {
int vocab_size{};
int context_length{};

// For models like whisper
struct EncoderDecoderInit {
struct Encoder {
std::string filename;
SessionOptions session_options;

int hidden_size{};
int num_attention_heads{};
int num_hidden_layers{};
int head_size{};

struct Inputs {
std::string input_features{Defaults::InputFeaturesName};
std::string input_ids{Defaults::InputIdsName};
std::string embeddings{Defaults::InputsEmbedsName};
std::string attention_mask{Defaults::AttentionMaskName};
std::string position_ids{Defaults::PositionIdsName};
std::string audio_features{Defaults::AudioFeaturesName};
} inputs;
} encoder_decoder_init;

struct Outputs {
std::string hidden_states{Defaults::EncoderHiddenStatesName};
std::string cross_present_key_names{"present_key_cross_%d"}, cross_present_value_names{"present_value_cross_%d"};
} outputs;
} encoder;

struct Embedding {
std::string filename;
Expand Down Expand Up @@ -163,23 +182,25 @@ struct Config {
struct Inputs {
std::string input_ids{Defaults::InputIdsName};
std::string embeddings{Defaults::InputsEmbedsName};
std::string position_ids{Defaults::PositionIdsName};
std::string attention_mask{Defaults::AttentionMaskName};
std::string position_ids{Defaults::PositionIdsName};
std::string past_key_names{Defaults::PastKeyName};
std::string past_value_names{Defaults::PastValueName};
std::string past_names; // When key/value pairs are combined
std::string cross_past_key_names, cross_past_value_names;
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};

std::string past_sequence_length{Defaults::PastSequenceLengthName};
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
std::string total_sequence_length{Defaults::TotalSequenceLengthName};
std::string cache_indirection{Defaults::CacheIndirectionName};
} inputs;

struct Outputs {
std::string logits{Defaults::LogitsName};
std::string present_key_names{Defaults::PresentKeyName};
std::string present_value_names{Defaults::PresentValueName};
std::string present_names; // When key/value pairs are combined
std::string cross_present_key_names, cross_present_value_names;
std::string output_cross_qk_names{"output_cross_qk_%d"};
} outputs;

struct PipelineModel {
Expand All @@ -203,10 +224,10 @@ struct Config {
} decoder;

struct PromptTemplates {
std::string assistant{Defaults::promptTemplate};
std::string prompt{Defaults::promptTemplate};
std::string system{Defaults::promptTemplate};
std::string user{Defaults::promptTemplate};
std::string assistant{Defaults::PromptTemplateName};
std::string prompt{Defaults::PromptTemplateName};
std::string system{Defaults::PromptTemplateName};
std::string user{Defaults::PromptTemplateName};
};
std::optional<PromptTemplates> prompt_templates;
} model;
Expand Down
20 changes: 14 additions & 6 deletions src/cuda/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,31 @@ struct CudaInterfaceImpl final : DeviceInterface {
return true;
}

void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) override {
void HandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) override {
cuda::LaunchHandleEOSArray(batch_logits, batch_beam_size, vocab_size, eos_token_ids, eos_token_ids_count, GetStream());
}

void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length) override {
void UpdateCacheIndirection(int32_t* tgt_indir_cache, const int32_t* src_indir_cache, const int32_t* beam_ids, int batch_size, int beam_width, int input_seq_length, int max_seq_length, int current_length) override {
cuda::UpdateCacheIndirectionKernelLauncher(tgt_indir_cache, src_indir_cache, beam_ids, batch_size, beam_width, input_seq_length, max_seq_length, current_length, GetStream());
}

void ReorderPastStatesKernelLauncher(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size) override {
void ReorderPastStates(void* out_buffer, const void* in_buffer, int batch_size, int num_heads, int max_length, int head_size, int chunk_size) override {
cuda::ReorderPastStatesKernelLauncher(out_buffer, in_buffer, batch_size, num_heads, max_length, head_size, chunk_size, GetStream());
}

void LaunchCopyCrossQKSingleDecodeStep(float* cross_qk_buffer_data, float** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length) override {
cuda::LaunchCopyCrossQKSingleDecodeStep(GetStream(), cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length);
void CopyCrossQK(float* cross_qk_buffer_data, void** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length, int sequence_length) override {
cuda::LaunchCopyCrossQKSingleDecodeStep(GetStream(), cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length, sequence_length);
}

void LaunchFinalizeCrossQK(int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override {
void CopyCrossQK(Ort::Float16_t* cross_qk_buffer_data, void** qk_layer_pointers, int token_index, int batch_beam_size, int num_layers, int num_heads, int num_alignment_heads, const int* alignment_heads, int frames, int max_length, int sequence_length) override {
cuda::LaunchCopyCrossQKSingleDecodeStep(GetStream(), cross_qk_buffer_data, qk_layer_pointers, token_index, batch_beam_size, num_layers, num_heads, num_alignment_heads, alignment_heads, frames, max_length, sequence_length);
}

void FinalizeCrossQK(int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const float* cross_qk_buffer_data, float* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override {
cuda::LaunchFinalizeCrossQK(GetStream(), iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data);
}

void FinalizeCrossQK(int iteration_number, int context_decoding_len, int batch_size, int num_beams, int max_length, int num_alignment_heads, int frames_of_k, const Ort::Float16_t* cross_qk_buffer_data, Ort::Float16_t* cross_qk_output, int num_return_sequences, const int* cache_indir_data) override {
cuda::LaunchFinalizeCrossQK(GetStream(), iteration_number, context_decoding_len, batch_size, num_beams, max_length, num_alignment_heads, frames_of_k, cross_qk_buffer_data, cross_qk_output, num_return_sequences, cache_indir_data);
}
};
Expand Down
7 changes: 4 additions & 3 deletions src/cuda/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ void UpdateCacheIndirectionKernelLauncher(int32_t* tgt_indir_cache,
template <typename T>
void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream,
T* cross_qk_buffer_data,
T** qk_layer_pointers,
void** qk_layer_pointers,
int token_index,
int batch_beam_size,
int num_layers,
int num_heads,
int num_alignment_heads,
const int* alignment_heads,
int frames,
int max_length);
int max_length,
int sequence_length);

template <typename T>
void LaunchFinalizeCrossQK(cudaStream_t stream,
int iteration_number,
int context_decoding_len,
int batch_size,
int batch_beam_size,
int num_beams,
int max_length,
int num_alignment_heads,
Expand Down
59 changes: 46 additions & 13 deletions src/cuda/model_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ __global__ void CopyCrossQKSingleDecodeStepKernel(T* target, // shape [batch_be
int num_heads,
const int* alignment_heads,
int frames,
int max_length) {
int max_length,
int sequence_length) {
const int pair = blockIdx.x;
const int num_alignment_heads = gridDim.x;
const int bbm = blockIdx.y;
Expand All @@ -280,49 +281,68 @@ __global__ void CopyCrossQKSingleDecodeStepKernel(T* target, // shape [batch_be
const int head = *(alignment_heads + 1);

target += ((int64_t)bbm * num_alignment_heads + pair) * max_length * frames + ((int64_t)token_index * frames);
T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames;
T* src = reinterpret_cast<T*>(qk_layer_pointers[layer]) + ((int64_t)bbm * num_heads + head) * sequence_length * frames;

for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) {
target[tid] = src[tid]; // use vectorized read write in future if needed
for (int i = 1; i < sequence_length; i++) {
target[i * frames + tid] = src[i * frames + tid];
}
}
}

template <typename T>
void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream,
T* cross_qk_buffer_data,
T** qk_layer_pointers,
void** qk_layer_pointers,
int token_index,
int batch_beam_size,
int num_layers,
int num_heads,
int num_alignment_heads,
const int* alignment_heads,
int frames,
int max_length) {
int max_length,
int sequence_length) {
dim3 block(512);
dim3 grid(num_alignment_heads, batch_beam_size);

CopyCrossQKSingleDecodeStepKernel<<<grid, block, 0, stream>>>(cross_qk_buffer_data,
qk_layer_pointers,
reinterpret_cast<T**>(qk_layer_pointers),
token_index,
num_layers,
num_heads,
alignment_heads,
frames,
max_length);
max_length,
sequence_length);
}

template void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream,
float* cross_qk_buffer_data,
float** qk_layer_pointers,
void** qk_layer_pointers,
int token_index,
int batch_beam_size,
int num_layers,
int num_heads,
int num_alignment_heads,
const int* alignment_heads,
int frames,
int max_length,
int sequence_length);

template void LaunchCopyCrossQKSingleDecodeStep(cudaStream_t stream,
half* cross_qk_buffer_data,
void** qk_layer_pointers,
int token_index,
int batch_beam_size,
int num_layers,
int num_heads,
int num_alignment_heads,
const int* alignment_heads,
int frames,
int max_length);
int max_length,
int sequence_length);

template <typename T>
__global__ void CopyDecoderCrossQKAllStepsKernel(int context_decoding_len,
Expand All @@ -341,7 +361,7 @@ __global__ void CopyDecoderCrossQKAllStepsKernel(int context_decoding_len,
const int batch = br / num_return_sequences;
const int ret_seq_id = br % num_return_sequences;

const int64_t offset_in_cache = ((int64_t)batch * num_return_sequences + ret_seq_id) * max_length + token_decoding_index + context_decoding_len;
const int64_t offset_in_cache = ((int64_t)batch * num_return_sequences + ret_seq_id) * max_length + token_decoding_index;
int bi_src = batch * num_beams + cache_indir_data[offset_in_cache];

T* target = cross_qk_output + (((int64_t)br * num_alignment_heads + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k;
Expand All @@ -355,7 +375,7 @@ template <typename T>
void LaunchFinalizeCrossQK(cudaStream_t stream,
int iteration_number,
int context_decoding_len,
int batch_size,
int batch_beam_size,
int num_beams,
int max_length,
int num_alignment_heads,
Expand All @@ -364,10 +384,10 @@ void LaunchFinalizeCrossQK(cudaStream_t stream,
T* cross_qk_output,
int num_return_sequences,
const int* cache_indir_data) {
int64_t br = (int64_t)batch_size * num_return_sequences;
int64_t br = (int64_t)batch_beam_size;
assert(br < 65536L && num_alignment_heads < 65536);

const int total_decoding_length = iteration_number - 1;
const int total_decoding_length = iteration_number;
dim3 block(512);
dim3 grid(total_decoding_length, num_alignment_heads, (unsigned)br);

Expand All @@ -384,7 +404,7 @@ void LaunchFinalizeCrossQK(cudaStream_t stream,
template void LaunchFinalizeCrossQK(cudaStream_t stream,
int iteration_number,
int context_decoding_len,
int batch_size,
int batch_beam_size,
int num_beams,
int max_length,
int num_alignment_heads,
Expand All @@ -394,5 +414,18 @@ template void LaunchFinalizeCrossQK(cudaStream_t stream,
int num_return_sequences,
const int* cache_indir_data);

template void LaunchFinalizeCrossQK(cudaStream_t stream,
int iteration_number,
int context_decoding_len,
int batch_beam_size,
int num_beams,
int max_length,
int num_alignment_heads,
int frames_of_k,
const half* cross_qk_buffer_data,
half* cross_qk_output,
int num_return_sequences,
const int* cache_indir_data);

} // namespace cuda
} // namespace Generators
Loading
Loading