Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions Tests/WhisperKitTests/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,19 @@ extension XCTestCase {
}

func trackForMemoryLeaks(on instance: AnyObject, file: StaticString = #filePath, line: UInt = #line) {
addTeardownBlock { [weak instance] in
XCTAssertNil(instance, "Detected potential memory leak", file: file, line: line)
/// Stores only a weak reference for teardown leak assertions.
/// `XCTestCase.addTeardownBlock` uses a sending closure, so this wrapper must be Sendable.
final class LeakRefWrapper: @unchecked Sendable {
weak var object: AnyObject?

init(object: AnyObject) {
self.object = object
}
}

let wrapper = LeakRefWrapper(object: instance)
addTeardownBlock { [wrapper, file, line] in
XCTAssertNil(wrapper.object, "Detected potential memory leak", file: file, line: line)
}
}

Expand Down
25 changes: 15 additions & 10 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Combine
import CoreML
import Hub
import NaturalLanguage
import os.lock
import Tokenizers
@testable import WhisperKit
import XCTest
Expand Down Expand Up @@ -1735,7 +1736,12 @@ final class UnitTests: XCTestCase {
guard #available(macOS 15, iOS 18, watchOS 11, visionOS 2, *) else {
throw XCTSkip("Disabled on macOS 14 and below due to swift concurrency flakiness")
}


let audioFilePath = try XCTUnwrap(
Bundle.current(for: self).path(forResource: "jfk", ofType: "wav"),
"Audio file not found"
)

let callbackTestTask = Task(priority: .userInitiated) {
let config = WhisperKitConfig(
model: "tiny",
Expand All @@ -1746,10 +1752,6 @@ final class UnitTests: XCTestCase {
let whisperKit = try await WhisperKit(config)

try await whisperKit.loadModels()
let audioFilePath = try XCTUnwrap(
Bundle.current(for: self).path(forResource: "jfk", ofType: "wav"),
"Audio file not found"
)

let earlyStopTokenCount = 10
let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in
Expand Down Expand Up @@ -1808,12 +1810,14 @@ final class UnitTests: XCTestCase {
segmentExpectation.expectedFulfillmentCount = 3 // Expect at least 3 segment callbacks

// Keep track of discovered segments
var discoveredSegments = [[TranscriptionSegment]]()
let discoveredSegments = OSAllocatedUnfairLock(initialState: [[TranscriptionSegment]]())

// Set up segment discovery callback
whisperKit.segmentDiscoveryCallback = { segments in
Logging.debug("Segments discovered with ids: \(segments.map { $0.id }) and seek: \(segments.map { $0.seek })")
discoveredSegments.append(segments)
discoveredSegments.withLock { state in
state.append(segments)
}
segmentExpectation.fulfill()
}

Expand All @@ -1834,10 +1838,11 @@ final class UnitTests: XCTestCase {
await fulfillment(of: [segmentExpectation], timeout: 1)

// Verify segments were discovered across multiple chunks
XCTAssertGreaterThanOrEqual(discoveredSegments.count, 3, "Should have discovered segments in multiple chunks")
let discoveredSegmentsSnapshot = discoveredSegments.withLock { $0 }
XCTAssertGreaterThanOrEqual(discoveredSegmentsSnapshot.count, 3, "Should have discovered segments in multiple chunks")

// Verify that segments have different seek positions
let allSeekPositions = Set(discoveredSegments.flatMap { $0.map { $0.seek } })
let allSeekPositions = Set(discoveredSegmentsSnapshot.flatMap { $0.map { $0.seek } })
XCTAssertGreaterThanOrEqual(allSeekPositions.count, 3, "Should have segments with different seek positions")

// Verify that the seek positions are sensible (not negative)
Expand All @@ -1846,7 +1851,7 @@ final class UnitTests: XCTestCase {
}

// Verify that segment times are reasonable
let allSegments = discoveredSegments.flatMap { $0 }
let allSegments = discoveredSegmentsSnapshot.flatMap { $0 }
for segment in allSegments {
XCTAssertGreaterThanOrEqual(segment.start, 0, "Segment start time should not be negative")
XCTAssertGreaterThan(segment.end, segment.start, "Segment end time should be greater than start time")
Expand Down