diff --git a/Sources/WhisperKit/Core/LogitsFilter.swift b/Sources/WhisperKit/Core/LogitsFilter.swift index 724785c3..244ebeb1 100644 --- a/Sources/WhisperKit/Core/LogitsFilter.swift +++ b/Sources/WhisperKit/Core/LogitsFilter.swift @@ -278,3 +278,37 @@ open class LanguageLogitsFilter: LogitsFiltering { return indexes } } + +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +open class SilenceLogitsFilter: LogitsFiltering { + let silenceToken: Int + let logitsDim: Int + let sampleBegin: Int + let nonSilenceTokenIndexes: [[NSNumber]] + + public init(silenceToken: Int, logitsDim: Int, sampleBegin: Int) { + self.silenceToken = silenceToken + self.logitsDim = logitsDim + self.sampleBegin = sampleBegin + self.nonSilenceTokenIndexes = SilenceLogitsFilter.getNonSilenceTokenIndexes(logitsDim: self.logitsDim, silenceToken: self.silenceToken) + } + + /// Retain the logits that correspond to silence tokens and suppress non-silence tokens + public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray { + guard tokens.count == sampleBegin else { + return logits + } + logits.fill(indexes: nonSilenceTokenIndexes, with: -FloatType.infinity) + return logits + } + + private static func getNonSilenceTokenIndexes(logitsDim: Int, silenceToken: Int) -> [[NSNumber]] { + var indexes: [[NSNumber]] = [] + for i in 0.. DecodingResult + + func detectSilence( + from encoderOutput: MLMultiArray, + using decoderInputs: DecodingInputs, + sampler tokenSampler: TokenSampling, + options: DecodingOptions, + temperature: FloatType + ) async throws -> Float @available(*, deprecated, message: "Subject to removal in a future version. Use `detectLanguage(from:using:sampler:options:temperature:) async throws -> DecodingResult` instead.") @_disfavoredOverload @@ -340,6 +348,58 @@ public class TextDecoderContextPrefill: WhisperMLModel { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) open class TextDecoder: TextDecoding, WhisperMLModel { + func softmax(_ logits: MLMultiArray) -> [Float] { + let count = logits.count + var expValues = [Float](repeating: 0.0, count: count) + var sumExpValues: Float = 0.0 + + for i in 0.. Float { + let softmaxProbs = softmax(logits) + let noSpeechProb = softmaxProbs[noSpeechTokenIndex] + + return noSpeechProb + } + + public func detectSilence( + from encoderOutput: MLMultiArray, + using decoderInputs: DecodingInputs, + sampler tokenSampler: TokenSampling, + options: DecodingOptions, + temperature: FloatType + ) async throws -> Float { + let noSpeechTokenIndex = 50362 + + let predictedLogits = try await self.predictLogits( + inputIds: decoderInputs.inputIds, + cacheLength: decoderInputs.cacheLength, + keyCache: decoderInputs.keyCache, + valueCache: decoderInputs.valueCache, + kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, + encoderOutputEmbeds: encoderOutput, + decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask + ) + + guard let logitsArray = predictedLogits?.logits else { + throw WhisperError.decodingLogitsFailed("Unable to decode logits") + } + + let noSpeechProb = calculateNoSpeechProb(logits: logitsArray, noSpeechTokenIndex: noSpeechTokenIndex) + + return noSpeechProb + + } + + public var model: MLModel? public var tokenizer: WhisperTokenizer? public var prefillData: WhisperMLModel? @@ -549,6 +609,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { var currentTokens: [Int] = decoderInputs.initialPrompt var nextToken: Int = decoderInputs.initialPrompt.last! var logProbs: [Float] = Array(repeating: 0, count: currentTokens.count) + var noSpeechProb: Float = 0.0 // Logits filters var logitsFilters: [any LogitsFiltering] = [] @@ -641,6 +702,34 @@ open class TextDecoder: TextDecoding, WhisperMLModel { for filter in logitsFilters { logits = filter.filterLogits(logits, withTokens: currentTokens) } + + if tokenIndex == intialPromptIndex { + //print(tokenizer.specialTokens.noSpeechToken) //it prints 50257 + let noSpeechTokenIndex = 50362 // I think from models index for the "no speech" token is 50362? + noSpeechProb = calculateNoSpeechProb(logits: logits, noSpeechTokenIndex: noSpeechTokenIndex) + + let avgLogProb = logProbs.reduce(0, +) / Float(logProbs.count) + + if let threshold = options.noSpeechThreshold, noSpeechProb > threshold { + if options.logProbThreshold == nil || avgLogProb < options.logProbThreshold! { + print("Detected silence with noSpeechProb \(noSpeechProb) and avgLogProb \(avgLogProb), skipping segment.") + return DecodingResult( + language: Constants.defaultLanguageCode, + languageProbs: [:], + tokens: [], + tokenLogProbs: [], + text: "", + avgLogProb: avgLogProb, + noSpeechProb: noSpeechProb, + temperature: 0.0, + compressionRatio: 0.0, + cache: nil, + timings: TranscriptionTimings(), + fallback: nil + ) + } + } + } let filteringTime = Date().timeIntervalSince(nonInferenceStartTime) timings.decodingFiltering += filteringTime @@ -794,8 +883,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel { temperature = Float(sampler.temperature).rounded(3) } - let noSpeechProb: Float = 0 // TODO: implement no speech prob - // If language is still nil here, check language can be inferred from tokens var language = options.language ?? Constants.defaultLanguageCode var languageProbs = [String: Float]() diff --git a/Sources/WhisperKit/Core/TranscribeTask.swift b/Sources/WhisperKit/Core/TranscribeTask.swift index 4c0fd02f..54a1a514 100644 --- a/Sources/WhisperKit/Core/TranscribeTask.swift +++ b/Sources/WhisperKit/Core/TranscribeTask.swift @@ -158,6 +158,11 @@ final class TranscribeTask { // Send to decoder to predict text tokens with fallback let decodingResult = try await decodeWithFallback(encoderSegment: encoderOutput, decodingOptions: options, callback: decodingCallback) + if decodingResult.noSpeechProb > (options.noSpeechThreshold ?? 0.6) && decodingResult.avgLogProb < (options.logProbThreshold ?? -1.0) { + seek += segmentSize + continue + } + // MARK: Windowing // At this point we have a completed window aka segment @@ -269,6 +274,7 @@ final class TranscribeTask { let tokenSampler = GreedyTokenSampler(temperature: temp, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) var currentDecodingOptions = options + // For a multilingual model, if language is not passed and detectLanguage is true, detect language and set in options if textDecoder.isModelMultilingual, options.language == nil, options.detectLanguage { let languageDecodingResult: DecodingResult? = try? await textDecoder.detectLanguage( diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index d3d6c5bc..6352bc81 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -434,6 +434,50 @@ open class WhisperKit { return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs) } + + + /// Detects silence in the audio samples in the provided array. + /// + /// - Parameter audioArray: An array of audio samples. + /// - Returns: The probability of silence in the audio. + public func detectSilence(audioArray: [Float]) async throws -> Float { + if modelState != .loaded { + try await loadModels() + } + + guard let tokenizer else { + throw WhisperError.tokenizerUnavailable() + } + + let options = DecodingOptions() + let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken]) + decoderInputs.kvCacheUpdateMask[0] = 1.0 + decoderInputs.decoderKeyPaddingMask[0] = 0.0 + + guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else { + throw WhisperError.transcriptionFailed("Audio samples are nil") + } + + guard let melOutput = try await featureExtractor.logMelSpectrogram(fromAudio: audioSamples) else { + throw WhisperError.transcriptionFailed("Mel output is nil") + } + + guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else { + throw WhisperError.transcriptionFailed("Encoder output is nil") + } + + let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) + let noSpeechProb = try await textDecoder.detectSilence( + from: encoderOutput, + using: decoderInputs, + sampler: tokenSampler, + options: options, + temperature: 0 + ) + + return noSpeechProb + } + // MARK: - Transcribe multiple audio files diff --git a/Tests/WhisperKitTests/Resources/continuous_speech.wav b/Tests/WhisperKitTests/Resources/continuous_speech.wav new file mode 100644 index 00000000..f4a62a1e Binary files /dev/null and b/Tests/WhisperKitTests/Resources/continuous_speech.wav differ diff --git a/Tests/WhisperKitTests/Resources/initial_silence_speech.m4a b/Tests/WhisperKitTests/Resources/initial_silence_speech.m4a new file mode 100644 index 00000000..26d9bc4b Binary files /dev/null and b/Tests/WhisperKitTests/Resources/initial_silence_speech.m4a differ diff --git a/Tests/WhisperKitTests/Resources/silent_audio.mp3 b/Tests/WhisperKitTests/Resources/silent_audio.mp3 new file mode 100644 index 00000000..acb5f988 Binary files /dev/null and b/Tests/WhisperKitTests/Resources/silent_audio.mp3 differ diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index f14029a1..579cf9ab 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -666,7 +666,37 @@ final class UnitTests: XCTestCase { XCTAssertEqual(result.language, language) } } - + + func testDetectSilenceHelperMethod() async throws { + let whisperKit = try await WhisperKit( + modelFolder: tinyModelPath(), + verbose: true, + logLevel: .debug + ) + + let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000) // 1 second of silence at 16kHz + let jfkAudioSamples = try XCTUnwrap(loadAudioSamples(forResource: "ted_60", withExtension: "m4a")) + + let testAudioFiles: [(String, [Float], Bool)] = [ + ("silent_clip", silentAudioSamples, false), // Not expecting speech + ("non_silent_clip", jfkAudioSamples, true) // Expecting speech + ] + + for (audioFileName, audioSamples, expectingSpeech) in testAudioFiles { + let silenceProbability = try await whisperKit.detectSilence(audioArray: audioSamples) + + //print("Test case: \(audioFileName), Expecting speech: \(expectingSpeech), Calculated silence probability: \(silenceProbability)") + // calculated noSpeechProb values for silent and non-silent clips are 0.002598221 and 0.26186648. + // Given these values, a threshold of 0.6 might be too high to accurately distinguish between + // silence and speech.Based on the debug values, here I picked a threshold of 0.2 + if expectingSpeech { + XCTAssertGreaterThan(silenceProbability, 0.2, "Expected speech, but detected silence for \(audioFileName) with probability \(silenceProbability)") + } else { + XCTAssertLessThanOrEqual(silenceProbability, 0.2, "Expected silence, but detected speech for \(audioFileName) with probability \(silenceProbability)") + } + } + } + func testNoTimestamps() async throws { let options = DecodingOptions(withoutTimestamps: true) @@ -708,7 +738,70 @@ final class UnitTests: XCTestCase { XCTAssertNotNil(result.text) } + + + func testSilentAudio() async throws { + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + + let silentAudioSamples: [Float] = [Float](repeating: 0.0, count: 16000) + + let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false) + + let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: silentAudioSamples, decodeOptions: options) + + XCTAssertTrue(result.first?.segments.isEmpty ?? false, "Expected no segments for silent audio") + } + func testInitialSilenceFollowedBySpeech() async throws { + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + + let initialSilenceSpeechSamples: [Float] = loadAudioSamples(forResource: "initial_silence_speech", withExtension: "m4a") + + let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false, noSpeechThreshold: 0.8) + + let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: initialSilenceSpeechSamples, decodeOptions: options) + + if let transcription = result.first?.segments.first?.text { + print("Transcription: \(transcription)") + } else { + print("No transcription found.") + } + + let transcription = result.first?.segments.first?.text + XCTAssertNotNil(transcription, "Expected transcription for audio with initial silence followed by speech") + + XCTAssertTrue(transcription?.contains("Hey") ?? false, "Expected 'Hey' in transcription") + } + func testContinuousSpeechAudio() async throws { + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + + let continuousSpeechSamples: [Float] = loadAudioSamples(forResource: "continuous_speech", withExtension: "wav") + let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false) + + let result: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: continuousSpeechSamples, decodeOptions: options) + + let transcription = result.first?.segments.first?.text + XCTAssertNotNil(transcription, "Expected transcription for continuous speech audio") + XCTAssertFalse(transcription?.isEmpty ?? true, "Expected non-empty transcription for continuous speech audio") + } + + // MARK: - Helper Function + + func loadAudioSamples(forResource resource: String, withExtension ext: String) -> [Float] { + guard let audioFileURL = Bundle.module.url(forResource: resource, withExtension: ext) else { + XCTFail("Audio file not found") + return [] + } + + do { + let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFileURL.path) + return AudioProcessor.convertBufferToArray(buffer: audioBuffer) + } catch { + XCTFail("Failed to load audio samples: \(error.localizedDescription)") + return [] + } + } + func testSilence() async throws { let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) let audioSamples = [Float](repeating: 0.0, count: 30 * 16000) @@ -920,7 +1013,7 @@ final class UnitTests: XCTestCase { let result2 = tokensFilter2.filterLogits(logits2, withTokens: [1]) XCTAssertEqual(result2.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) } - + func testTimestampRulesFilter() throws { // NOTE: for non-multilingual models we supress tokens immediately let tokensFilter1 = TimestampRulesFilter(