diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index e19ea45..14e5b68 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -2,7 +2,7 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import Accelerate -import AVFoundation +@preconcurrency import AVFoundation import CoreAudio import CoreML diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index 8d71051..57b08ed 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -6,7 +6,7 @@ import CoreML import Foundation public protocol TokenSampling { - func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult + func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult } @@ -39,7 +39,7 @@ open class GreedyTokenSampler: TokenSampling { #if canImport(CoreML.MLState) @available(macOS 15, iOS 18, watchOS 11, visionOS 2, *) - private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) { + private func sampleWithMLTensor(logits: MLMultiArray) async -> (token: Int, logprob: Float) { // Use MLTensor operations if available for sampling // Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift var logitsTensor = MLTensor(MLShapedArray(logits)).cast(to: Float.self) @@ -76,9 +76,11 @@ open class GreedyTokenSampler: TokenSampling { nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log() } + async let nextTokenArray = nextTokenTensor.asIntArray() + async let nextLogprobArray = nextLogprobTensor.asFloatArray() return ( - token: nextTokenTensor.asIntArray()[0], - logprob: nextLogprobTensor.asFloatArray()[0] + token: await nextTokenArray[0], + logprob: await nextLogprobArray[0] ) } #endif @@ -212,7 +214,7 @@ open class GreedyTokenSampler: TokenSampling { return (token: nextToken!, logprob: nextLogprob) } - public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult { var nextTokens = tokens var nextLogprobs = logProbs var completed = false @@ -220,7 +222,7 @@ open class GreedyTokenSampler: TokenSampling { var result: (token: Int, logprob: Float) #if canImport(CoreML.MLState) if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { - result = sampleWithMLTensor(logits: logits) + result = await sampleWithMLTensor(logits: logits) } else { result = sampleWithBNNS(logits: logits) } @@ -278,7 +280,7 @@ open class BeamSearchTokenSampler: TokenSampling { finishedSequences = [] } - public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index efdacec..6fcedcf 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -686,7 +686,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let samplingStartTime = Date() - let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) + let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) nextToken = sampleResult.tokens.last! logProbs = sampleResult.logProbs @@ -838,7 +838,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let samplingStartTime = Date() - let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) + let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) nextToken = sampleResult.tokens.last! let nextTokenLogProb = sampleResult.logProbs.last! diff --git a/Sources/WhisperKit/Utilities/Extensions+Public.swift b/Sources/WhisperKit/Utilities/Extensions+Public.swift index ae04c79..e636ff6 100644 --- a/Sources/WhisperKit/Utilities/Extensions+Public.swift +++ b/Sources/WhisperKit/Utilities/Extensions+Public.swift @@ -2,7 +2,8 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import AVFoundation -import CoreML +// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16 and macOS 13. +@preconcurrency import CoreML public extension Array where Element == TranscriptionSegment { func contains(segment: TranscriptionSegment) -> Bool { @@ -159,69 +160,38 @@ public extension MLMultiArray { #if canImport(CoreML.MLState) @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) public extension MLTensor { - func asIntArray() -> [Int] { - let semaphore = DispatchSemaphore(value: 0) - var result: [Int] = [] - - Task(priority: .high) { - result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) } - semaphore.signal() - } - - semaphore.wait() - return result + func asIntArray() async -> [Int] { + await shapedArray(of: Int32.self).scalars.map { Int($0) } } - func asFloatArray() -> [Float] { - let semaphore = DispatchSemaphore(value: 0) - let tensorType = self.scalarType - - var result: [Float] = [] - - Task(priority: .high) { - switch tensorType { - case is Float32.Type: - result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) } - case is FloatType.Type: - result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) } - case is Float.Type: - result = await self.shapedArray(of: Float.self).scalars.map { Float($0) } - case is Int32.Type: - result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) } - default: - fatalError("Unsupported data type") - } - semaphore.signal() + func asFloatArray() async -> [Float] { + switch scalarType { + case is Float32.Type: + await shapedArray(of: Float32.self).scalars.map { Float($0) } + case is FloatType.Type: + await shapedArray(of: FloatType.self).scalars.map { Float($0) } + case is Float.Type: + await shapedArray(of: Float.self).scalars.map { Float($0) } + case is Int32.Type: + await shapedArray(of: Int32.self).scalars.map { Float($0) } + default: + fatalError("Unsupported data type") } - - semaphore.wait() - return result } - func asMLMultiArray() -> MLMultiArray { - let semaphore = DispatchSemaphore(value: 0) - let tensorType = self.scalarType - - var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0) - - Task(priority: .high) { - switch tensorType { - case is Float32.Type: - result = MLMultiArray(await self.shapedArray(of: Float32.self)) - case is FloatType.Type: - result = MLMultiArray(await self.shapedArray(of: FloatType.self)) - case is Float.Type: - result = MLMultiArray(await self.shapedArray(of: Float.self)) - case is Int32.Type: - result = MLMultiArray(await self.shapedArray(of: Int32.self)) - default: - fatalError("Unsupported data type") - } - semaphore.signal() + func asMLMultiArray() async -> MLMultiArray { + switch scalarType { + case is Float32.Type: + MLMultiArray(await shapedArray(of: Float32.self)) + case is FloatType.Type: + MLMultiArray(await shapedArray(of: FloatType.self)) + case is Float.Type: + MLMultiArray(await shapedArray(of: Float.self)) + case is Int32.Type: + MLMultiArray(await shapedArray(of: Int32.self)) + default: + fatalError("Unsupported data type") } - - semaphore.wait() - return result } } #endif diff --git a/Tests/WhisperKitTests/MLTensorExtensionsTests.swift b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift new file mode 100644 index 0000000..39599af --- /dev/null +++ b/Tests/WhisperKitTests/MLTensorExtensionsTests.swift @@ -0,0 +1,77 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +#if canImport(CoreML.MLState) +import CoreML +@testable import WhisperKit +import XCTest + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +final class MLTensorExtensionsTests: XCTestCase { + func testAsIntArrayReturnsExpectedScalars() async { + let tensor = MLTensor(MLShapedArray(scalars: [1, -2, 42], shape: [3])) + + let result = await tensor.asIntArray() + + XCTAssertEqual(result, [1, -2, 42]) + } + + func testAsFloatArraySupportsFloat32Tensor() async { + let tensor = MLTensor(MLShapedArray(scalars: [0.25, -1.5, 2.0], shape: [3])) + + let result = await tensor.asFloatArray() + + assertEqual(result, [0.25, -1.5, 2.0], accuracy: 0.0001) + } + + func testAsFloatArraySupportsFloatTypeTensor() async { + let expected = [FloatType(0.125), FloatType(-0.75), FloatType(3.5)] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.asFloatArray() + let expectedFloats: [Float] = expected.map { Float($0) } + + assertEqual(result, expectedFloats, accuracy: 0.0001) + } + + func testAsFloatArraySupportsInt32Tensor() async { + let tensor = MLTensor(MLShapedArray(scalars: [-3, 0, 7], shape: [3])) + + let result = await tensor.asFloatArray() + + assertEqual(result, [-3, 0, 7], accuracy: 0.0001) + } + + func testAsMLMultiArrayRoundTripsFloatTypeTensor() async { + let expected = [FloatType(1.25), FloatType(-0.5), FloatType(3.75)] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.asMLMultiArray() + let shapedArray = MLShapedArray(result) + let resultFloats: [Float] = shapedArray.scalars.map { Float($0) } + let expectedFloats: [Float] = expected.map { Float($0) } + + XCTAssertEqual(result.shape, [3]) + XCTAssertEqual(shapedArray.scalars.count, expected.count) + assertEqual(resultFloats, expectedFloats, accuracy: 0.0001) + } + + func testAsMLMultiArrayRoundTripsInt32Tensor() async { + let expected: [Int32] = [-9, 4, 12] + let tensor = MLTensor(MLShapedArray(scalars: expected, shape: [3])) + + let result = await tensor.asMLMultiArray() + let shapedArray = MLShapedArray(result) + + XCTAssertEqual(result.shape, [3]) + XCTAssertEqual(shapedArray.scalars, expected) + } + + private func assertEqual(_ lhs: [Float], _ rhs: [Float], accuracy: Float) { + XCTAssertEqual(lhs.count, rhs.count) + for (actual, expected) in zip(lhs, rhs) { + XCTAssertEqual(actual, expected, accuracy: accuracy) + } + } +} +#endif