diff --git a/src/whisper.cpp b/src/whisper.cpp index cb887d4593b..af33e5478e7 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -868,6 +868,11 @@ struct whisper_aheads_masks { ggml_backend_buffer_t buffer = nullptr; }; +struct vad_time_mapping { + int64_t processed_time; // Time in processed (VAD) audio + int64_t original_time; // Corresponding time in original audio +}; + struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -957,13 +962,15 @@ struct whisper_state { whisper_vad_context * vad_context = nullptr; struct vad_segment_info { - float orig_start; - float orig_end; - float vad_start; - float vad_end; + int64_t orig_start; + int64_t orig_end; + int64_t vad_start; + int64_t vad_end; }; std::vector vad_segments; bool has_vad_segments = false; + + std::vector vad_mapping_table; }; struct whisper_context { @@ -4420,8 +4427,8 @@ struct whisper_vad_model { }; struct whisper_vad_segment { - float start; // Start time in seconds - float end; // End time in seconds + int64_t start; + int64_t end; }; struct whisper_vad_segments { @@ -4469,6 +4476,15 @@ struct whisper_vad_params whisper_vad_default_params(void) { return result; } +// Time conversion utility functions for whisper VAD +static int cs_to_samples(int64_t cs) { + return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5); +} + +static int64_t samples_to_cs(int samples) { + return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5); +} + static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { bool op_supported = true; @@ -5413,12 +5429,12 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( (speeches[i].end + speech_pad_samples) : audio_length_samples; } - // Convert from samples to seconds and copy to final segments - segments[i].start = (float)speeches[i].start / sample_rate; - segments[i].end = (float)speeches[i].end / sample_rate; + // Convert from samples to centiseconds + segments[i].start = samples_to_cs(speeches[i].start); + segments[i].end = samples_to_cs(speeches[i].end); WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n", - __func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start); + __func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0); } whisper_vad_segments * vad_segments = new whisper_vad_segments; @@ -6615,10 +6631,13 @@ static bool whisper_vad( struct whisper_full_params params, const float * samples, int n_samples, - std::vector & filtered_samples, - int & filtered_n_samples) { - WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__); - filtered_n_samples = 0; + std::vector & filtered_samples) { + WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); + int filtered_n_samples = 0; + + // Clear any existing mapping table + state->vad_mapping_table.clear(); + state->has_vad_segments = false; if (state->vad_context == nullptr) { struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params(); @@ -6640,13 +6659,17 @@ static bool whisper_vad( ctx->state->vad_segments.clear(); ctx->state->vad_segments.reserve(vad_segments->data.size()); + // Initialize the time mapping table + state->vad_mapping_table.clear(); + state->vad_mapping_table.reserve(vad_segments->data.size() * 4); + WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size()); float overlap_seconds = vad_params.samples_overlap; int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE; for (int i = 0; i < (int)vad_segments->data.size(); i++) { - int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE; - int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE; + int segment_start_samples = cs_to_samples(vad_segments->data[i].start); + int segment_end_samples = cs_to_samples(vad_segments->data[i].end); if (i < (int)vad_segments->data.size() - 1) { segment_end_samples += overlap_samples; @@ -6655,9 +6678,9 @@ static bool whisper_vad( filtered_n_samples += (segment_end_samples - segment_start_samples); WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n", - __func__, i, vad_segments->data[i].start, - vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0), - (vad_segments->data[i].end - vad_segments->data[i].start) + + __func__, i, vad_segments->data[i].start/100.0, + (vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)), + (vad_segments->data[i].end - vad_segments->data[i].start)/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)); } @@ -6679,8 +6702,8 @@ static bool whisper_vad( int offset = 0; for (int i = 0; i < (int)vad_segments->data.size(); i++) { - int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE; - int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE; + int segment_start_samples = cs_to_samples(vad_segments->data[i].start); + int segment_end_samples = cs_to_samples(vad_segments->data[i].end); if (i < (int)vad_segments->data.size() - 1) { segment_end_samples += overlap_samples; @@ -6689,18 +6712,47 @@ static bool whisper_vad( segment_start_samples = std::min(segment_start_samples, n_samples - 1); segment_end_samples = std::min(segment_end_samples, n_samples); int segment_length = segment_end_samples - segment_start_samples; - if (segment_length > 0) { whisper_state::vad_segment_info segment; segment.orig_start = vad_segments->data[i].start; segment.orig_end = vad_segments->data[i].end; - segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE; - segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE; + segment.vad_start = samples_to_cs(offset); + segment.vad_end = samples_to_cs(offset + segment_length); + + // Add segment boundaries to mapping table + vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start}; + vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end}; + + state->vad_mapping_table.push_back(start_mapping); + state->vad_mapping_table.push_back(end_mapping); + + // Add intermediate points for longer segments to improve interpolation accuracy + const int64_t min_segment_length = 100; // 1 second + const int64_t point_interval = 20; // Add a point every 200ms + + if (segment.vad_end - segment.vad_start > min_segment_length) { + int64_t segment_duration = segment.vad_end - segment.vad_start; + int num_points = (int)(segment_duration / point_interval) - 1; + + for (int j = 1; j <= num_points; j++) { + int64_t vad_time = segment.vad_start + j * point_interval; + + if (vad_time >= segment.vad_end) continue; + + int64_t vad_elapsed = vad_time - segment.vad_start; + int64_t vad_total = segment.vad_end - segment.vad_start; + int64_t orig_total = segment.orig_end - segment.orig_start; + int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total; + + vad_time_mapping intermediate_mapping = {vad_time, orig_time}; + state->vad_mapping_table.push_back(intermediate_mapping); + } + } WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", - __func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end); + __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0); ctx->state->vad_segments.push_back(segment); // Copy this speech segment @@ -6709,6 +6761,17 @@ static bool whisper_vad( // Add silence after this segment (except after the last segment) if (i < (int)vad_segments->data.size() - 1) { + // Calculate the start and end time of the silence gap in processed audio + int64_t silence_start_vad = samples_to_cs(offset); + int64_t silence_end_vad = samples_to_cs(offset + silence_samples); + // Calculate the corresponding original times + int64_t orig_silence_start = segment.orig_end; + int64_t orig_silence_end = vad_segments->data[i+1].start; + + // Add mapping points for silence boundaries + state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start}); + state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end}); + // Fill with zeros (silence) memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float)); offset += silence_samples; @@ -6716,6 +6779,24 @@ static bool whisper_vad( } } + // Sort the mapping table by processed time + std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), + [](const vad_time_mapping& a, const vad_time_mapping& b) { + return a.processed_time < b.processed_time; + }); + + // Remove any duplicate processed times to ensure monotonicity which is + // needed for binary search and interpolation later. + if (!state->vad_mapping_table.empty()) { + auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), + [](const vad_time_mapping& a, const vad_time_mapping& b) { + return a.processed_time == b.processed_time; + }); + state->vad_mapping_table.erase(last, state->vad_mapping_table.end()); + } + + WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size()); + filtered_n_samples = offset; WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n", __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); @@ -6735,27 +6816,9 @@ int whisper_full_with_state( result_all.clear(); - const float * process_samples = samples; - int n_process_samples = n_samples; - std::vector vad_samples; - - if (params.vad) { - WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - int vad_n_samples; - if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) { - WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); - return -1; - } - if (vad_n_samples == 0) { - return 0; - } - process_samples = vad_samples.data(); - n_process_samples = vad_n_samples; - } - - if (n_process_samples > 0) { + if (n_samples > 0) { // compute log mel spectrogram - if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } @@ -7665,6 +7728,20 @@ int whisper_full( struct whisper_full_params params, const float * samples, int n_samples) { + + std::vector vad_samples; + if (params.vad) { + WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); + if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { + WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); + return -1; + } + if (vad_samples.empty()) { + return 0; + } + samples = vad_samples.data(); + n_samples = vad_samples.size(); + } return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); } @@ -7674,9 +7751,24 @@ int whisper_full_parallel( const float * samples, int n_samples, int n_processors) { + if (n_processors == 1) { return whisper_full(ctx, params, samples, n_samples); } + + std::vector vad_samples; + if (params.vad) { + WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); + if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { + WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); + return -1; + } + if (vad_samples.empty()) { + return 0; + } + samples = vad_samples.data(); + n_samples = vad_samples.size(); + } int ret = 0; // prepare separate states for each thread @@ -7799,130 +7891,89 @@ int whisper_full_lang_id(struct whisper_context * ctx) { return ctx->state->lang_id; } -int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { - // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_segments.empty()) { - return state->result_all[i_segment].t0; +static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector & mapping_table) { + if (mapping_table.empty()) { + return processed_time; } - // Get the start timestamp produced by whisper_full. whisper_full processes - // only the speech segments in this case so we need to map these timestamps - // back to the original audio. - float t0 = state->result_all[i_segment].t0 / 100.0f; + if (processed_time <= mapping_table.front().processed_time) { + return mapping_table.front().original_time; // Before first mapping point + } - // Find which VAD segment this timestamp belongs. - // TODO(danbev) This could be optimized by using a binary search if the number - // of segments exceed a certain limit. Also we might be able to assume that - // the access pattern is sequential and optimized for that too. - for (size_t i = 0; i < state->vad_segments.size(); i++) { - const auto & segment = state->vad_segments[i]; + if (processed_time >= mapping_table.back().processed_time) { + return mapping_table.back().original_time; // After last mapping point + } - // Check if the timestamp falls within this segment. - if (t0 >= segment.vad_start && t0 <= segment.vad_end) { - float proportion = 0.0f; - if (segment.vad_end > segment.vad_start) { - proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start); - } - float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start); - return (int64_t)(orig_t0 * 100); + // Binary search over the time map that finds the first entry that has a + // processed time greater than or equal to the current processed time. + auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time, + [](const vad_time_mapping & entry, int64_t time) { + return entry.processed_time < time; } + ); + + // If exact match found + if (upper->processed_time == processed_time) { + return upper->original_time; } - // Check if the timestamp falls between two segments. - for (size_t i = 0; i < state->vad_segments.size() - 1; i++) { - const auto & curr = state->vad_segments[i]; - const auto & next = state->vad_segments[i + 1]; + // Need to interpolate between two points + auto lower = upper - 1; - if (t0 > curr.vad_end && t0 < next.vad_start) { - // Calculate how far we are through the gap as a proportion - float gap_proportion = 0.0f; - if (next.vad_start > curr.vad_end) { - gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end); - } - float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end); - return (int64_t)(orig_t0 * 100); - } - } + int64_t processed_diff = upper->processed_time - lower->processed_time; + int64_t original_diff = upper->original_time - lower->original_time; + int64_t offset = processed_time - lower->processed_time; - // Handle the case where the timestamp is after the last segment. - if (t0 > state->vad_segments.back().vad_end) { - // For timestamps after the last segment, add the extra time to the end of the last segment - const auto& last = state->vad_segments.back(); - // Calculate how far beyond the last segment - float extra_time = t0 - last.vad_end; - // Add this extra time to the original end time - float orig_t0 = last.orig_end + extra_time; - return (int64_t)(orig_t0 * 100); + if (processed_diff == 0) { + return lower->original_time; } - WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0); - return t0; + // Perform linear interpolation + return lower->original_time + (offset * original_diff) / processed_diff; } -int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { - return whisper_full_get_segment_t0_from_state(ctx->state, i_segment); +// Function to get the starting timestamp of a segment +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { + // If VAD wasn't used, return the original timestamp + if (!state->has_vad_segments || state->vad_mapping_table.empty()) { + return state->result_all[i_segment].t0; + } + + // Get the processed timestamp + int64_t t0 = state->result_all[i_segment].t0; + + // Map to original time using the mapping table + return map_processed_to_original_time(t0, state->vad_mapping_table); } +// Function to get the ending timestamp of a segment int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_segments.empty()) { + if (!state->has_vad_segments || state->vad_mapping_table.empty()) { return state->result_all[i_segment].t1; } - // Get the end timestamp produced by whisper_full. whisper_full processes - // only the speech segments in this case so we need to map these timestamps - // back to the original audio. - float t1 = state->result_all[i_segment].t1 / 100.0f; - - // Find which VAD segment this timestamp belongs. - // TODO(danbev) This could be optimized by using a binary search if the number - // of segments exceed a certain limit. Also we might be able to assume that - // the access pattern is sequential and optimized for that too. - for (size_t i = 0; i < state->vad_segments.size(); i++) { - const auto& segment = state->vad_segments[i]; - - // Check if the timestamp falls within this segment. - if (t1 >= segment.vad_start && t1 <= segment.vad_end) { - // Calculate the proportion through the filtered segment. - float proportion = 0.0f; - if (segment.vad_end > segment.vad_start) { - proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start); - } - float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start); - return (int64_t)(orig_t1 * 100); - } - } + // Get the processed timestamp + int64_t t1 = state->result_all[i_segment].t1; - // Check if the timestamp falls between two segments. - for (size_t i = 0; i < state->vad_segments.size() - 1; i++) { - const auto & curr = state->vad_segments[i]; - const auto & next = state->vad_segments[i + 1]; + // Map to original time using the mapping table + int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table); - if (t1 > curr.vad_end && t1 < next.vad_start) { - // Calculate how far we are through the gap as a proportion - float gap_proportion = 0.0f; - if (next.vad_start > curr.vad_end) { - gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end); - } - // Map to the corresponding position in the original gap - float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end); - return (int64_t)(orig_t1 * 100); - } - } + // Get the corresponding t0 for this segment + int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment); - // Handle the case where the timestamp is after the last segment - if (t1 > state->vad_segments.back().vad_end) { - // For the last segment, use the end of the last VAD segment - const auto& last = state->vad_segments.back(); - // Calculate how far beyond the last segment - float extra_time = t1 - last.vad_end; - // Add this extra time to the original end time - float orig_t1 = last.orig_end + extra_time; - return (int64_t)(orig_t1 * 100); + // Ensure minimum duration to prevent zero-length segments + const int64_t min_duration = 10; // 10ms minimum + if (orig_t1 - orig_t0 < min_duration) { + orig_t1 = orig_t0 + min_duration; } - WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1); - return t1; + return orig_t1; +} + + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return whisper_full_get_segment_t0_from_state(ctx->state, i_segment); } int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {