Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import AVFoundation
@preconcurrency import AVFoundation
import CoreAudio
import CoreML

Expand Down
16 changes: 9 additions & 7 deletions Sources/WhisperKit/Core/Text/TokenSampler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import CoreML
import Foundation

public protocol TokenSampling {
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult
func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult
}

Expand Down Expand Up @@ -39,7 +39,7 @@ open class GreedyTokenSampler: TokenSampling {

#if canImport(CoreML.MLState)
@available(macOS 15, iOS 18, watchOS 11, visionOS 2, *)
private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) {
private func sampleWithMLTensor(logits: MLMultiArray) async -> (token: Int, logprob: Float) {
// Use MLTensor operations if available for sampling
// Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
var logitsTensor = MLTensor(MLShapedArray<FloatType>(logits)).cast(to: Float.self)
Expand Down Expand Up @@ -76,9 +76,11 @@ open class GreedyTokenSampler: TokenSampling {
nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log()
}

async let nextTokenArray = nextTokenTensor.asIntArray()
async let nextLogprobArray = nextLogprobTensor.asFloatArray()
return (
token: nextTokenTensor.asIntArray()[0],
logprob: nextLogprobTensor.asFloatArray()[0]
token: await nextTokenArray[0],
logprob: await nextLogprobArray[0]
)
}
#endif
Expand Down Expand Up @@ -212,15 +214,15 @@ open class GreedyTokenSampler: TokenSampling {
return (token: nextToken!, logprob: nextLogprob)
}

public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
var nextTokens = tokens
var nextLogprobs = logProbs
var completed = false

var result: (token: Int, logprob: Float)
#if canImport(CoreML.MLState)
if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
result = sampleWithMLTensor(logits: logits)
result = await sampleWithMLTensor(logits: logits)
} else {
result = sampleWithBNNS(logits: logits)
}
Expand Down Expand Up @@ -278,7 +280,7 @@ open class BeamSearchTokenSampler: TokenSampling {
finishedSequences = []
}

public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
// TODO: Implement
fatalError("Not implemented: \(#function)")
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

let samplingStartTime = Date()

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
logProbs = sampleResult.logProbs
Expand Down Expand Up @@ -838,7 +838,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

let samplingStartTime = Date()

let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)

nextToken = sampleResult.tokens.last!
let nextTokenLogProb = sampleResult.logProbs.last!
Expand Down
86 changes: 28 additions & 58 deletions Sources/WhisperKit/Utilities/Extensions+Public.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.

import AVFoundation
import CoreML
// TODO: Should be able to remove `@preconcurrency` once we drop support for iOS 16, macOS 14.
@preconcurrency import CoreML

public extension Array where Element == TranscriptionSegment {
func contains(segment: TranscriptionSegment) -> Bool {
Expand Down Expand Up @@ -159,69 +160,38 @@ public extension MLMultiArray {
#if canImport(CoreML.MLState)
@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
public extension MLTensor {
func asIntArray() -> [Int] {
let semaphore = DispatchSemaphore(value: 0)
var result: [Int] = []

Task(priority: .high) {
result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) }
semaphore.signal()
}

semaphore.wait()
return result
func asIntArray() async -> [Int] {
await shapedArray(of: Int32.self).scalars.map { Int($0) }
}

func asFloatArray() -> [Float] {
let semaphore = DispatchSemaphore(value: 0)
let tensorType = self.scalarType

var result: [Float] = []

Task(priority: .high) {
switch tensorType {
case is Float32.Type:
result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) }
case is FloatType.Type:
result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) }
case is Float.Type:
result = await self.shapedArray(of: Float.self).scalars.map { Float($0) }
case is Int32.Type:
result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) }
default:
fatalError("Unsupported data type")
}
semaphore.signal()
func asFloatArray() async -> [Float] {
switch scalarType {
case is Float32.Type:
await shapedArray(of: Float32.self).scalars.map { Float($0) }
case is FloatType.Type:
await shapedArray(of: FloatType.self).scalars.map { Float($0) }
case is Float.Type:
await shapedArray(of: Float.self).scalars.map { Float($0) }
case is Int32.Type:
await shapedArray(of: Int32.self).scalars.map { Float($0) }
default:
fatalError("Unsupported data type")
}

semaphore.wait()
return result
}

func asMLMultiArray() -> MLMultiArray {
let semaphore = DispatchSemaphore(value: 0)
let tensorType = self.scalarType

var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0)

Task(priority: .high) {
switch tensorType {
case is Float32.Type:
result = MLMultiArray(await self.shapedArray(of: Float32.self))
case is FloatType.Type:
result = MLMultiArray(await self.shapedArray(of: FloatType.self))
case is Float.Type:
result = MLMultiArray(await self.shapedArray(of: Float.self))
case is Int32.Type:
result = MLMultiArray(await self.shapedArray(of: Int32.self))
default:
fatalError("Unsupported data type")
}
semaphore.signal()
func asMLMultiArray() async -> MLMultiArray {
switch scalarType {
case is Float32.Type:
MLMultiArray(await shapedArray(of: Float32.self))
case is FloatType.Type:
MLMultiArray(await shapedArray(of: FloatType.self))
case is Float.Type:
MLMultiArray(await shapedArray(of: Float.self))
case is Int32.Type:
MLMultiArray(await shapedArray(of: Int32.self))
default:
fatalError("Unsupported data type")
}

semaphore.wait()
return result
}
}
#endif
Expand Down
74 changes: 74 additions & 0 deletions Tests/WhisperKitTests/MLTensorExtensionsTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

#if canImport(CoreML.MLState)
import CoreML
@testable import WhisperKit
import XCTest

@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
final class MLTensorExtensionsTests: XCTestCase {
func testAsIntArrayReturnsExpectedScalars() async {
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [1, -2, 42], shape: [3]))

let result = await tensor.asIntArray()

XCTAssertEqual(result, [1, -2, 42])
}

func testAsFloatArraySupportsFloat32Tensor() async {
let tensor = MLTensor(MLShapedArray<Float32>(scalars: [0.25, -1.5, 2.0], shape: [3]))

let result = await tensor.asFloatArray()

assertEqual(result, [0.25, -1.5, 2.0], accuracy: 0.0001)
}

func testAsFloatArraySupportsFloatTypeTensor() async {
let expected = [FloatType(0.125), FloatType(-0.75), FloatType(3.5)]
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))

let result = await tensor.asFloatArray()

assertEqual(result, expected.map(Float.init), accuracy: 0.0001)
}

func testAsFloatArraySupportsInt32Tensor() async {
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [-3, 0, 7], shape: [3]))

let result = await tensor.asFloatArray()

assertEqual(result, [-3, 0, 7], accuracy: 0.0001)
}

func testAsMLMultiArrayRoundTripsFloatTypeTensor() async {
let expected = [FloatType(1.25), FloatType(-0.5), FloatType(3.75)]
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))

let result = await tensor.asMLMultiArray()
let shapedArray = MLShapedArray<FloatType>(result)

XCTAssertEqual(result.shape, [3])
XCTAssertEqual(shapedArray.scalars.count, expected.count)
assertEqual(shapedArray.scalars.map(Float.init), expected.map(Float.init), accuracy: 0.0001)
}

func testAsMLMultiArrayRoundTripsInt32Tensor() async {
let expected: [Int32] = [-9, 4, 12]
let tensor = MLTensor(MLShapedArray<Int32>(scalars: expected, shape: [3]))

let result = await tensor.asMLMultiArray()
let shapedArray = MLShapedArray<Int32>(result)

XCTAssertEqual(result.shape, [3])
XCTAssertEqual(shapedArray.scalars, expected)
}

private func assertEqual(_ lhs: [Float], _ rhs: [Float], accuracy: Float) {
XCTAssertEqual(lhs.count, rhs.count)
for (actual, expected) in zip(lhs, rhs) {
XCTAssertEqual(actual, expected, accuracy: accuracy)
}
}
}
#endif
Loading