diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme
index da13128a..80b47f47 100644
--- a/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme
+++ b/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme
@@ -48,6 +48,20 @@
ReferencedContainer = "container:">
+
+
+
+
+
+
+
+
diff --git a/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
new file mode 100644
index 00000000..083d195a
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -0,0 +1,42 @@
+{
+ "originHash" : "105884455b6729deaf95d3dd16df7029fd7929c924fcab55200562b52b64dbc2",
+ "pins" : [
+ {
+ "identity" : "swift-argument-parser",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/apple/swift-argument-parser.git",
+ "state" : {
+ "revision" : "c5d11a805e765f52ba34ec7284bd4fcd6ba68615",
+ "version" : "1.7.0"
+ }
+ },
+ {
+ "identity" : "swift-collections",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/apple/swift-collections.git",
+ "state" : {
+ "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
+ "version" : "1.3.0"
+ }
+ },
+ {
+ "identity" : "swift-jinja",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/huggingface/swift-jinja.git",
+ "state" : {
+ "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
+ "version" : "2.3.1"
+ }
+ },
+ {
+ "identity" : "swift-transformers",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/huggingface/swift-transformers.git",
+ "state" : {
+ "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
+ "version" : "1.1.6"
+ }
+ }
+ ],
+ "version" : 3
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift b/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift
new file mode 100644
index 00000000..de35b715
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift
@@ -0,0 +1,147 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import AVFoundation
+import Foundation
+
+/// Metadata for a single TTS generation, embedded inside its `.m4a` file.
+///
+/// The full struct is serialized as JSON and stored in the iTunes `©cmt` atom so
+/// the complete generation context is recoverable from the file alone.
+/// Human-readable fields (title, artist, album) are also embedded
+struct AudioMetadata: Codable, Sendable {
+ static let metadataTitleMaxLength = 80
+ let id: UUID
+ let text: String
+ let speaker: String
+ let language: String
+ let instruction: String
+ let modelName: String
+ let realTimeFactor: Double
+ let speedFactor: Double
+ let stepsPerSecond: Double
+ let timeToFirstBuffer: TimeInterval
+ let date: Date
+
+ init(
+ id: UUID = UUID(),
+ text: String,
+ speaker: String,
+ language: String,
+ instruction: String,
+ modelName: String,
+ realTimeFactor: Double,
+ speedFactor: Double,
+ stepsPerSecond: Double,
+ timeToFirstBuffer: TimeInterval,
+ date: Date = Date()
+ ) {
+ self.id = id
+ self.text = text
+ self.speaker = speaker
+ self.language = language
+ self.instruction = instruction
+ self.modelName = modelName
+ self.realTimeFactor = realTimeFactor
+ self.speedFactor = speedFactor
+ self.stepsPerSecond = stepsPerSecond
+ self.timeToFirstBuffer = timeToFirstBuffer
+ self.date = date
+ }
+
+ // MARK: - Filename
+
+ private static let filenameDateFormatter: DateFormatter = {
+ let f = DateFormatter()
+ f.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
+ f.locale = Locale(identifier: "en_US_POSIX")
+ f.timeZone = TimeZone(identifier: "UTC")
+ return f
+ }()
+
+ var suggestedFileName: String {
+ let slug = modelName
+ .lowercased()
+ .replacingOccurrences(of: " ", with: "-")
+ return "\(speaker)_\(slug)_\(Self.filenameDateFormatter.string(from: date))"
+ }
+
+ // MARK: - AVFoundation metadata
+
+ /// Build the `[AVMetadataItem]` array for `AVAssetExportSession`.
+ ///
+ /// All fields use `commonIdentifier` variants so they work across every
+ /// container format AVFoundation supports (M4A, MOV, MP3, AIFF...).
+ ///
+ /// Fields verified to survive `AVAssetExportSession` M4A export (confirmed via ffprobe):
+ /// - `©nam` title - first 80 chars of the generated text
+ /// - `©ART` artist - speaker name
+ /// - `©alb` album - model (e.g. "qwen3_tts_12hz-1.7b-customvoice")
+ /// - `©lyr` lyrics - full untruncated text
+ /// - `©too` encoder - "TTSKit v{appVersion}" (shows as `encoder` in ffprobe)
+ /// - `©cmt` comment - JSON blob; the app reads this back to reconstruct the generation
+ ///
+ /// `.commonIdentifierCreator` and `.commonIdentifierDescription` were tried but do not
+ /// survive M4A export - iTunes-specific atoms are used for those two slots instead.
+ func avMetadataItems() throws -> [AVMetadataItem] {
+ let encoder = JSONEncoder()
+ encoder.dateEncodingStrategy = .iso8601
+ let json = try String(data: encoder.encode(self), encoding: .utf8) ?? ""
+
+ let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "1.0"
+ let maxLen = Self.metadataTitleMaxLength
+ let title = text.count > maxLen ? String(text.prefix(maxLen - 3)) + "..." : text
+ let langTag = Self.bcp47Tag(for: language)
+
+ func item(_ identifier: AVMetadataIdentifier, _ value: String, lang: String = "und") -> AVMetadataItem {
+ let i = AVMutableMetadataItem()
+ i.identifier = identifier
+ i.value = value as NSString
+ i.extendedLanguageTag = lang
+ return i
+ }
+
+ return [
+ item(.commonIdentifierTitle, title, lang: langTag),
+ item(.commonIdentifierArtist, speaker.capitalized),
+ item(.commonIdentifierAlbumName, modelName),
+ item(.iTunesMetadataLyrics, text, lang: langTag),
+ item(.iTunesMetadataEncodingTool, "TTSKit v\(appVersion)"),
+ item(.iTunesMetadataUserComment, json),
+ ]
+ }
+
+ private static func bcp47Tag(for language: String) -> String {
+ switch language.lowercased() {
+ case "english": return "en"
+ case "chinese": return "zh"
+ case "japanese": return "ja"
+ case "korean": return "ko"
+ case "german": return "de"
+ case "french": return "fr"
+ case "russian": return "ru"
+ case "portuguese": return "pt"
+ case "spanish": return "es"
+ case "italian": return "it"
+ default: return "und"
+ }
+ }
+
+ // MARK: - Loading from file
+
+ /// Reconstruct metadata from the `©cmt` atom of an `.m4a` file.
+ /// Returns `nil` if the file has no TTSKit metadata.
+ static func load(from url: URL) async throws -> AudioMetadata? {
+ let asset = AVURLAsset(url: url)
+ let items = try await asset.load(.metadata)
+ guard let commentItem = items.first(where: {
+ $0.identifier == .iTunesMetadataUserComment
+ }),
+ let json = try await commentItem.load(.stringValue),
+ let data = json.data(using: .utf8) else { return nil }
+
+ let decoder = JSONDecoder()
+ decoder.dateDecodingStrategy = .iso8601
+ return try? decoder.decode(AudioMetadata.self, from: data)
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift b/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift
new file mode 100644
index 00000000..457f1fed
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift
@@ -0,0 +1,94 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import CoreML
+import SwiftUI
+import TTSKit
+
+/// Collapsible sidebar section for configuring per-component ML compute units.
+/// Matches the pattern from WhisperAX's computeUnitsView.
+struct ComputeUnitsView: View {
+ @Environment(ViewModel.self) private var viewModel
+ @State private var isExpanded = false
+
+ var body: some View {
+ @Bindable var vm = viewModel
+
+ DisclosureGroup(isExpanded: $isExpanded) {
+ VStack(spacing: 8) {
+ computeRow(
+ label: "Embedders",
+ units: vm.embedderComputeUnits,
+ onChange: { vm.embedderComputeUnits = $0; reloadIfNeeded() }
+ )
+ computeRow(
+ label: "Code Decoder",
+ units: vm.codeDecoderComputeUnits,
+ onChange: { vm.codeDecoderComputeUnits = $0; reloadIfNeeded() }
+ )
+ computeRow(
+ label: "Multi-Code Decoder",
+ units: vm.multiCodeDecoderComputeUnits,
+ onChange: { vm.multiCodeDecoderComputeUnits = $0; reloadIfNeeded() }
+ )
+ computeRow(
+ label: "Speech Decoder",
+ units: vm.speechDecoderComputeUnits,
+ onChange: { vm.speechDecoderComputeUnits = $0; reloadIfNeeded() }
+ )
+ }
+ .padding(.top, 4)
+ } label: {
+ Button {
+ isExpanded.toggle()
+ } label: {
+ Text("Compute Units")
+ .font(.headline)
+ }
+ .buttonStyle(.plain)
+ }
+ .disabled(viewModel.modelState.isBusy)
+ .padding(.horizontal, 12)
+ .padding(.vertical, 6)
+ }
+
+ @ViewBuilder
+ private func computeRow(
+ label: String,
+ units: MLComputeUnits,
+ onChange: @escaping (MLComputeUnits) -> Void
+ ) -> some View {
+ HStack(spacing: 8) {
+ Text(label)
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ .lineLimit(1)
+ .minimumScaleFactor(0.85)
+ .layoutPriority(1)
+
+ Spacer(minLength: 4)
+
+ Picker("", selection: Binding(
+ get: { units },
+ set: { onChange($0) }
+ )) {
+ Text("CPU").tag(MLComputeUnits.cpuOnly)
+ Text("GPU").tag(MLComputeUnits.cpuAndGPU)
+ #if os(iOS)
+ // Abbreviated on iOS to prevent overflow
+ Text("NE").tag(MLComputeUnits.cpuAndNeuralEngine)
+ #else
+ Text("Neural Engine").tag(MLComputeUnits.cpuAndNeuralEngine)
+ #endif
+ }
+ .labelsHidden()
+ .pickerStyle(.menu)
+ .fixedSize()
+ }
+ }
+
+ private func reloadIfNeeded() {
+ guard viewModel.modelState == .loaded else { return }
+ viewModel.reloadModelForComputeUnitChange()
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift b/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift
new file mode 100644
index 00000000..7c5b6d10
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift
@@ -0,0 +1,63 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import SwiftUI
+
+struct ContentView: View {
+ @Environment(ViewModel.self) private var viewModel
+ var body: some View {
+ NavigationSplitView {
+ SidebarView()
+ } detail: {
+ DetailView()
+ }
+ .navigationSplitViewStyle(.balanced)
+ #if os(macOS)
+ // SwiftUI frame constraints don't prevent NavigationSplitView from entering
+ // overlay mode. Setting minSize directly on NSWindow is a reliable fix.
+ .background(WindowMinSizeEnforcer(minSize: CGSize(width: 840, height: 560)))
+ #endif
+ }
+}
+
+// MARK: - macOS window minimum size enforcer
+
+#if os(macOS)
+/// Enforces `NSWindow.minSize` at the AppKit level.
+///
+/// SwiftUI's `.frame(minWidth:)` and `windowResizability` both fail to prevent
+/// `NavigationSplitView` from shrinking into overlay mode. Going through AppKit
+/// directly is a reliable way to clamp the resize handle.
+///
+/// A custom `NSView` subclass is used instead of `DispatchQueue.main.async` because
+/// the view's `window` property is `nil` during `makeNSView`.
+private struct WindowMinSizeEnforcer: NSViewRepresentable {
+ let minSize: CGSize
+
+ func makeNSView(context: Context) -> MinSizeView {
+ let view = MinSizeView()
+ view.minSize = minSize
+ return view
+ }
+
+ func updateNSView(_ nsView: MinSizeView, context: Context) {
+ nsView.minSize = minSize
+ }
+}
+
+private final class MinSizeView: NSView {
+ var minSize: CGSize = .zero {
+ didSet { window?.minSize = minSize }
+ }
+
+ override func viewDidMoveToWindow() {
+ super.viewDidMoveToWindow()
+ window?.minSize = minSize
+ }
+}
+#endif
+
+#Preview {
+ ContentView()
+ .environment(ViewModel())
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift b/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift
new file mode 100644
index 00000000..1aa680e3
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift
@@ -0,0 +1,584 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import SwiftUI
+import TTSKit
+import UniformTypeIdentifiers
+#if canImport(UIKit)
+import UIKit
+#elseif canImport(AppKit)
+import AppKit
+#endif
+
+struct DetailView: View {
+ @Environment(ViewModel.self) private var viewModel
+
+ @State private var isImportingFile = false
+ @State private var fileImportError: String? = nil
+
+ var body: some View {
+ @Bindable var vm = viewModel
+
+ VStack(spacing: 0) {
+ // Waveform + playback controls (~top half)
+ waveformSection
+ .frame(minHeight: 140)
+
+ Divider()
+
+ // Text input + controls (bottom half)
+ VStack(spacing: 16) {
+ textInputSection
+ metricsRow
+ controlsBar
+ }
+ .padding(20)
+ }
+ .frame(maxWidth: .infinity, maxHeight: .infinity)
+ .navigationTitle("TTSKit Example")
+ #if os(iOS)
+ .navigationBarTitleDisplayMode(.inline)
+ #endif
+ .toolbar {
+ toolbarContent
+ }
+ .sheet(isPresented: $vm.showGenerationSettings) {
+ GenerationSettingsView()
+ .environment(viewModel)
+ .presentationDetents([.medium, .large])
+ .presentationBackgroundInteraction(.enabled)
+ .presentationContentInteraction(.scrolls)
+ }
+ .onAppear {
+ viewModel.onAppear()
+ }
+ .onChange(of: viewModel.selectedGenerationID) {
+ guard !viewModel.isStreaming else { return }
+ guard let gen = viewModel.selectedGeneration else { return }
+ viewModel.loadWaveform(for: gen)
+ // Auto-populate input fields so the user can play with the settings and regenerate
+ viewModel.loadInputs(from: gen)
+ }
+ }
+
+ // MARK: - Waveform Section
+
+ @ViewBuilder
+ private var waveformSection: some View {
+ VStack(spacing: 12) {
+ Spacer(minLength: 8)
+
+ // All three states live in the layout simultaneously so the height
+ // never changes - only opacity transitions between them.
+ ZStack {
+ // Idle empty placeholder
+ VStack(spacing: 8) {
+ Image(systemName: "waveform")
+ .font(.system(size: 40))
+ .foregroundStyle(.quaternary)
+ Text("Waveform will appear here")
+ .font(.subheadline)
+ .foregroundStyle(.tertiary)
+ }
+ .frame(maxWidth: .infinity, maxHeight: .infinity)
+ .opacity(viewModel.currentWaveform.isEmpty && viewModel.generationState != .generating ? 1 : 0)
+
+ // Generating progress
+ VStack(spacing: 8) {
+ if viewModel.totalSteps > 0 {
+ ProgressView(
+ value: Double(viewModel.stepsCompleted),
+ total: Double(viewModel.totalSteps)
+ )
+ .progressViewStyle(.linear)
+ .frame(maxWidth: 200)
+
+ let pct = Int(Double(viewModel.stepsCompleted) / Double(viewModel.totalSteps) * 100)
+ let detail = viewModel.chunksTotal > 1 ? " (\(viewModel.chunksTotal) chunks)" : ""
+ Text("Generating... \(pct)%\(detail)")
+ .font(.subheadline)
+ .foregroundStyle(.tertiary)
+ .contentTransition(.numericText())
+ } else {
+ ProgressView()
+ .controlSize(.small)
+ Text(viewModel.statusMessage)
+ .font(.subheadline)
+ .foregroundStyle(.tertiary)
+ }
+ }
+ .frame(maxWidth: .infinity, maxHeight: .infinity)
+ .opacity(viewModel.currentWaveform.isEmpty && viewModel.generationState == .generating ? 1 : 0)
+
+ // Live / completed waveform
+ WaveformView(
+ samples: viewModel.currentWaveform,
+ playbackTime: viewModel.playbackTime,
+ totalDuration: viewModel.currentDuration
+ )
+ .padding(.horizontal, 0)
+ .opacity(viewModel.currentWaveform.isEmpty ? 0 : 1)
+ }
+ .frame(maxWidth: .infinity, maxHeight: .infinity)
+
+ // Time labels + playback controls
+ HStack {
+ HStack(spacing: 4) {
+ Text(formatTime(viewModel.playbackTime))
+ .font(.caption.monospacedDigit())
+ .foregroundStyle(.secondary)
+
+ let bufferRemaining = viewModel.silentBufferRemaining
+ if bufferRemaining > 0.4 {
+ Text("(\(formatTime(bufferRemaining)))")
+ .font(.caption.monospacedDigit())
+ .foregroundStyle(.tertiary)
+ }
+ }
+
+ Spacer()
+
+ playbackControls
+
+ Spacer()
+
+ Text(formatTime(viewModel.currentDuration))
+ .font(.caption.monospacedDigit())
+ .foregroundStyle(.secondary)
+ }
+ .padding(.horizontal, 24)
+
+ Spacer(minLength: 8)
+ }
+ .background(.quaternary.opacity(0.3))
+ }
+
+ private var playbackControls: some View {
+ let hasAudio = viewModel.selectedGeneration?.audioFileName != nil
+ let icon = viewModel.isPlaying ? "stop.fill" : "play.fill"
+ return Button {
+ if viewModel.isPlaying {
+ viewModel.stopPlayback()
+ } else if let gen = viewModel.selectedGeneration {
+ viewModel.playGeneration(gen)
+ }
+ } label: {
+ Image(systemName: icon)
+ .font(.title2)
+ }
+ .buttonStyle(.plain)
+ .keyboardShortcut(.space, modifiers: [])
+ .disabled(!hasAudio || viewModel.generationState == .generating)
+ .opacity(hasAudio ? 1 : 0)
+ .accessibilityLabel(viewModel.isPlaying ? "Stop playback" : "Play audio")
+ .accessibilityHint(viewModel.isPlaying ? "Stops the current audio" : "Plays the selected generation")
+ }
+
+ // MARK: - Text Input
+
+ @ViewBuilder
+ private var textInputSection: some View {
+ @Bindable var vm = viewModel
+
+ VStack(alignment: .leading, spacing: 8) {
+ let promptForInput = viewModel.modelState == .loaded
+ && viewModel.inputText.isEmpty
+ && viewModel.generationState == .idle
+
+ ZStack(alignment: .topLeading) {
+ TextEditor(text: $vm.inputText)
+ .font(.body)
+ .scrollContentBackground(.hidden)
+ .padding(.horizontal, 8)
+ .padding(.vertical, 8)
+ .background(.quaternary.opacity(0.5), in: RoundedRectangle(cornerRadius: 8))
+ .overlay {
+ RoundedRectangle(cornerRadius: 8)
+ .stroke(Color.accentColor, lineWidth: promptForInput ? 1.5 : 0)
+ .animation(.easeInOut(duration: 0.25), value: promptForInput)
+ }
+ .frame(minHeight: 80, maxHeight: 160)
+// .overlay(alignment: .topTrailing) {
+// Button {
+// isImportingFile = true
+// } label: {
+// Image(systemName: "square.and.arrow.down")
+// .font(.caption)
+// .padding(6)
+// .background(.thinMaterial, in: RoundedRectangle(cornerRadius: 6))
+// }
+// .buttonStyle(.plain)
+// .padding(6)
+// .accessibilityLabel("Import text file")
+// .accessibilityHint("Opens a file picker to import a .txt or .pdf file into the text editor")
+// }
+
+ if vm.inputText.isEmpty {
+ Text("Enter text to speak...")
+ .foregroundStyle(.tertiary)
+ .padding(.horizontal, 12)
+ .padding(.vertical, 8)
+ .allowsHitTesting(false)
+ }
+ }
+ .fileImporter(
+ isPresented: $isImportingFile,
+ allowedContentTypes: [.plainText, .pdf],
+ allowsMultipleSelection: false
+ ) { result in
+ switch result {
+ case .success(let urls):
+ guard let url = urls.first else { return }
+ // Security-scoped access is required for files picked outside the sandbox
+ let accessing = url.startAccessingSecurityScopedResource()
+ defer { if accessing { url.stopAccessingSecurityScopedResource() } }
+ if let text = FileUtilities.readTextContent(at: url) {
+ vm.inputText = text
+ } else {
+ fileImportError = "Could not read text from \"\(url.lastPathComponent)\"."
+ }
+ case .failure(let error):
+ fileImportError = error.localizedDescription
+ }
+ }
+ .alert("Import Failed", isPresented: .init(
+ get: { fileImportError != nil },
+ set: { if !$0 { fileImportError = nil } }
+ )) {
+ Button("OK", role: .cancel) { fileImportError = nil }
+ } message: {
+ Text(fileImportError ?? "")
+ }
+
+ HStack(spacing: 8) {
+ Picker("Voice", selection: $vm.selectedSpeaker) {
+ ForEach(Qwen3Speaker.allCases, id: \.self) { speaker in
+ Text("\(speaker.displayName) · \(speaker.nativeLanguage)").tag(speaker)
+ }
+ }
+ .fixedSize()
+ .accessibilityLabel("Voice")
+ .accessibilityHint("Choose the speaker voice for synthesis")
+
+ Picker("Language", selection: $vm.selectedLanguage) {
+ ForEach(Qwen3Language.allCases, id: \.self) { lang in
+ Text(lang.rawValue.capitalized).tag(lang)
+ }
+ }
+ .fixedSize()
+ .accessibilityLabel("Language")
+ .accessibilityHint("Choose the output language for the voice")
+
+ Picker("Playback", selection: $vm.playbackStrategyTag) {
+ Text("Auto").tag("auto")
+ Text("Stream").tag("stream")
+ Text("Generate First").tag("generateFirst")
+ }
+ .fixedSize()
+ .accessibilityLabel("Playback strategy")
+ .accessibilityHint("Auto buffers adaptively; Stream plays frame-by-frame; Generate First waits for the full audio")
+
+ Spacer(minLength: 0)
+
+ Text(vm.inputTokenCount.map { "\($0) tok" } ?? "~\(vm.inputText.unicodeScalars.count / 5) tok")
+ .font(.caption)
+ .foregroundStyle(.tertiary)
+ .lineLimit(1)
+ }
+
+ // Selected speaker description
+ Text(vm.selectedSpeaker.voiceDescription)
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ .frame(maxWidth: .infinity, alignment: .leading)
+ .transition(.opacity)
+ .animation(.easeInOut(duration: 0.2), value: vm.selectedSpeaker)
+
+ // Instruction / style prompt (1.7B only)
+ let instructionSupported = vm.selectedPreset.supportsVoiceDirection
+ if instructionSupported {
+ HStack(spacing: 6) {
+ Image(systemName: "theatermasks")
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ TextField(
+ "Style instruction (e.g. cheerful and energetic)...",
+ text: $vm.instruction
+ )
+ .font(.callout)
+ .textFieldStyle(.plain)
+ }
+ .padding(.horizontal, 10)
+ .padding(.vertical, 7)
+ .background(.quaternary.opacity(0.3), in: RoundedRectangle(cornerRadius: 8))
+ .transition(.opacity.combined(with: .move(edge: .top)))
+ .animation(.easeInOut(duration: 0.2), value: instructionSupported)
+ }
+ }
+ }
+
+ // MARK: - Metrics Row
+
+ private var metricsRow: some View {
+ let hasMetrics = viewModel.currentRTF > 0
+ return HStack(spacing: 0) {
+ metricItem(
+ value: hasMetrics ? String(format: "%.1f×", viewModel.currentSpeedFactor) : "-",
+ label: "Speed Factor"
+ )
+ Divider().frame(height: 24)
+ metricItem(
+ value: hasMetrics ? String(format: "%.0f", viewModel.currentStepsPerSecond) : "-",
+ label: "steps/s"
+ )
+ Divider().frame(height: 24)
+ metricItem(
+ value: hasMetrics ? String(format: "%.2fs", viewModel.currentTimeToFirstBuffer) : "-",
+ label: "First Buffer"
+ )
+ }
+ .frame(maxWidth: .infinity)
+ .padding(.vertical, 6)
+ .background(.quaternary.opacity(0.3), in: RoundedRectangle(cornerRadius: 8))
+ .opacity(hasMetrics ? 1 : 0.4)
+ .animation(.easeInOut(duration: 0.2), value: hasMetrics)
+ }
+
+ private func metricItem(value: String, label: String) -> some View {
+ VStack(spacing: 2) {
+ Text(value)
+ .font(.system(.body, design: .monospaced))
+ .lineLimit(1)
+ Text(label)
+ .font(.caption2)
+ .foregroundStyle(.secondary)
+ }
+ .frame(maxWidth: .infinity)
+ }
+
+ // MARK: - Controls Bar
+
+ @ViewBuilder
+ private var controlsBar: some View {
+ #if os(iOS)
+ // On iOS: generate button spans full width; secondary controls on the row below.
+ // This prevents the button label from being clipped on narrow screens.
+ VStack(spacing: 10) {
+ Button {
+ if viewModel.generationState == .idle {
+ viewModel.startGeneration()
+ } else {
+ viewModel.cancelGeneration()
+ }
+ } label: {
+ Label(generateButtonTitle, systemImage: generateButtonIcon)
+ .frame(maxWidth: .infinity)
+ }
+ .glassButton(prominent: true)
+ .tint(viewModel.generationState == .idle ? .accentColor : .red)
+ .controlSize(.large)
+ .disabled(!viewModel.canGenerate && viewModel.generationState == .idle)
+ .accessibilityLabel(viewModel.generationState == .idle ? generateButtonTitle : "Cancel generation")
+ .accessibilityHint(viewModel.generationState == .idle
+ ? "Synthesizes the entered text using the loaded model"
+ : "Stops the current generation immediately")
+
+ HStack {
+ Button {
+ viewModel.clearInput()
+ } label: {
+ Label("Clear", systemImage: "xmark.circle")
+ }
+ .glassButton()
+ .controlSize(.regular)
+ .disabled(viewModel.inputText.isEmpty)
+ .opacity(viewModel.generationState == .idle ? 1 : 0)
+ .accessibilityLabel("Clear input")
+ .accessibilityHint("Clears the text input and resets the waveform")
+
+ Spacer()
+
+ Text(viewModel.statusMessage)
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ .lineLimit(2)
+ .multilineTextAlignment(.trailing)
+ }
+ }
+ #else
+ HStack(spacing: 12) {
+ // Generate / Cancel button
+ Button {
+ if viewModel.generationState == .idle {
+ viewModel.startGeneration()
+ } else {
+ viewModel.cancelGeneration()
+ }
+ } label: {
+ Label(
+ generateButtonTitle,
+ systemImage: generateButtonIcon
+ )
+ .frame(maxWidth: 200)
+ }
+ .glassButton(prominent: true)
+ .tint(viewModel.generationState == .idle ? .accentColor : .red)
+ .controlSize(.large)
+ .disabled(!viewModel.canGenerate && viewModel.generationState == .idle)
+ .accessibilityLabel(viewModel.generationState == .idle ? generateButtonTitle : "Cancel generation")
+ .accessibilityHint(viewModel.generationState == .idle
+ ? "Synthesizes the entered text using the loaded model"
+ : "Stops the current generation immediately")
+
+ Button {
+ viewModel.clearInput()
+ } label: {
+ Label("Clear", systemImage: "xmark.circle")
+ }
+ .glassButton()
+ .controlSize(.large)
+ .disabled(viewModel.inputText.isEmpty)
+ .opacity(viewModel.generationState == .idle ? 1 : 0)
+ .accessibilityLabel("Clear input")
+ .accessibilityHint("Clears the text input and resets the waveform")
+
+ Spacer()
+
+ // Status
+ Text(viewModel.statusMessage)
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ .lineLimit(1)
+ .truncationMode(.tail)
+ }
+ #endif
+ }
+
+ // MARK: - Toolbar
+
+ @ToolbarContentBuilder
+ private var toolbarContent: some ToolbarContent {
+ #if os(macOS)
+ ToolbarItem {
+ Button(action: newGeneration) {
+ Label("New Generation", systemImage: "plus")
+ }
+ .keyboardShortcut("n")
+ }
+ #endif
+
+ ToolbarItem(placement: .primaryAction) {
+ Button {
+ viewModel.showGenerationSettings = true
+ } label: {
+ Label("Generation Settings", systemImage: "slider.horizontal.3")
+ }
+ }
+
+ // Share / export audio file
+ if let gen = viewModel.selectedGeneration,
+ let url = viewModel.audioFileURL(for: gen)
+ {
+ ToolbarItem {
+ AudioCopyButton(url: url)
+ }
+ }
+ }
+
+ private func newGeneration() {
+ viewModel.clearInput()
+ viewModel.currentWaveform = []
+ viewModel.selectedGenerationID = ViewModel.newGenerationSentinel
+ }
+
+ // MARK: - Helpers
+
+ private var generateButtonTitle: String {
+ switch viewModel.generationState {
+ case .generating: return "Cancel"
+ case .idle:
+ switch viewModel.modelState {
+ case .downloading: return "Downloading..."
+ case .prewarming: return "Specializing..."
+ case .loading: return "Loading..."
+ default: return viewModel.selectedGeneration != nil ? "Regenerate" : "Generate"
+ }
+ }
+ }
+
+ private var generateButtonIcon: String {
+ switch viewModel.generationState {
+ case .generating: return "stop.fill"
+ case .idle: return "play.fill"
+ }
+ }
+
+ private func formatTime(_ seconds: TimeInterval) -> String {
+ let m = Int(seconds) / 60
+ let s = Int(seconds) % 60
+ return String(format: "%d:%02d", m, s)
+ }
+}
+
+// MARK: - Audio file transferable
+
+// MARK: - Copy to clipboard button
+
+/// Copies the audio file to the system clipboard so it can be pasted into Mail, Finder, etc.
+/// Uses `NSPasteboard` on macOS and `UIPasteboard` on iOS - both accept a file URL directly,
+/// letting the receiving app decide how to handle the file.
+struct AudioCopyButton: View {
+ let url: URL
+ @State private var copied = false
+
+ var body: some View {
+ Button {
+ copyToClipboard()
+ copied = true
+ DispatchQueue.main.asyncAfter(deadline: .now() + 1.5) { copied = false }
+ } label: {
+ Label(copied ? "Copied!" : "Copy Audio", systemImage: copied ? "checkmark" : "doc.on.doc")
+ }
+ .accessibilityLabel(copied ? "Copied to clipboard" : "Copy audio file")
+ .accessibilityHint("Copies the audio file to the clipboard so you can paste it into other apps")
+ }
+
+ private func copyToClipboard() {
+ #if os(macOS)
+ NSPasteboard.general.clearContents()
+ NSPasteboard.general.writeObjects([url as NSURL])
+ #else
+ guard let data = try? Data(contentsOf: url) else { return }
+ let uti = UTType(filenameExtension: url.pathExtension) ?? .audio
+ UIPasteboard.general.setData(data, forPasteboardType: uti.identifier)
+ #endif
+ }
+}
+
+// MARK: - Glass Button Style
+
+/// Applies `.glass` on iOS/macOS 26+ and falls back to `.borderedProminent` /
+/// `.bordered` on older OS versions.
+struct GlassButtonModifier: ViewModifier {
+ var prominent: Bool = false
+
+ func body(content: Content) -> some View {
+ if #available(iOS 26, macOS 26, watchOS 26, visionOS 26, *) {
+ content
+ .buttonStyle(.glass)
+ } else if prominent {
+ content
+ .buttonStyle(.borderedProminent)
+ } else {
+ content
+ .buttonStyle(.bordered)
+ }
+ }
+}
+
+extension View {
+ func glassButton(prominent: Bool = false) -> some View {
+ modifier(GlassButtonModifier(prominent: prominent))
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift b/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift
new file mode 100644
index 00000000..3c73c5b8
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift
@@ -0,0 +1,236 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import SwiftUI
+import TTSKit
+
+/// Sheet for configuring GenerationOptions.
+struct GenerationSettingsView: View {
+ @Environment(ViewModel.self) private var viewModel
+ @Environment(\.dismiss) private var dismiss
+
+ var body: some View {
+ #if os(iOS)
+ NavigationStack {
+ settingsForm
+ }
+ #else
+ VStack(spacing: 0) {
+ // Explicit header with title and close button - toolbar items
+ // don't render reliably inside macOS sheet NavigationStacks.
+ HStack {
+ Text("Generation Options")
+ .font(.title2)
+ .fontWeight(.semibold)
+ Spacer()
+ Button {
+ dismiss()
+ } label: {
+ Image(systemName: "xmark.circle.fill")
+ .font(.title2)
+ .symbolRenderingMode(.hierarchical)
+ .foregroundStyle(.secondary)
+ }
+ .buttonStyle(.plain)
+ }
+ .padding(.horizontal)
+ .padding(.vertical, 12)
+
+ Divider()
+
+ settingsForm
+ .frame(minWidth: 480, minHeight: 500)
+ }
+ #endif
+ }
+
+ private var settingsForm: some View {
+ @Bindable var vm = viewModel
+
+ return Form {
+ // MARK: Sampling
+
+ Section("Sampling") {
+ sliderRow(
+ label: "Temperature",
+ info: "Controls the randomness of token selection. Higher values produce more varied speech; lower values are more deterministic.",
+ value: $vm.temperature,
+ range: 0...1,
+ step: 0.05,
+ display: String(format: "%.2f", vm.temperature)
+ )
+
+ sliderRow(
+ label: "Top-K",
+ info: "Limits token selection to the top-K most probable tokens at each step. Lower values make output more focused.",
+ value: $vm.topK,
+ range: 1...100,
+ step: 1,
+ display: Int(vm.topK).formatted()
+ )
+
+ sliderRow(
+ label: "Repetition Penalty",
+ info: "Penalises repeating the same tokens. Values above 1.0 discourage repetition; 1.0 disables the penalty.",
+ value: $vm.repetitionPenalty,
+ range: 1.0...1.5,
+ step: 0.01,
+ display: String(format: "%.2f", vm.repetitionPenalty)
+ )
+
+ sliderRow(
+ label: "Max New Tokens",
+ info: "Upper bound on tokens generated per chunk. Longer values allow longer output but increase max generation time.",
+ value: $vm.maxNewTokens,
+ range: 50...500,
+ step: 10,
+ display: Int(vm.maxNewTokens).formatted()
+ )
+ }
+
+ // MARK: Chunking
+
+ Section("Chunking") {
+ // Strategy row: label + info on the left, segmented picker on the right.
+ // Using LabeledContent avoids macOS Form's two-column layout pushing
+ // a plain Spacer-separated HStack to extreme edges.
+ LabeledContent {
+ Picker("", selection: $vm.chunkingStrategyTag) {
+ Text("None").tag("none")
+ Text("Sentence").tag("sentence")
+ }
+ .pickerStyle(.segmented)
+ .frame(width: 160)
+ } label: {
+ HStack(spacing: 4) {
+ Text("Strategy")
+ InfoButton("Controls how long text is split before generation. Sentence-based chunking produces more natural prosody at chunk boundaries.")
+ }
+ }
+
+ sliderRow(
+ label: "Target Chunk Size",
+ info: "Maximum tokens per chunk when using sentence chunking. Chunks break at the nearest sentence boundary at or before this token count. Default (50).",
+ value: $vm.targetChunkSize,
+ range: 10...100,
+ step: 1,
+ display: "\(Int(vm.targetChunkSize)) tok",
+ displayWidth: 60
+ )
+ .disabled(vm.chunkingStrategyTag == "none")
+
+ sliderRow(
+ label: "Min Chunk Size",
+ info: "Minimum tokens per chunk. Short trailing segments are merged into the previous chunk to avoid tiny segments with poor prosody.",
+ value: $vm.minChunkSize,
+ range: 1...30,
+ step: 1,
+ display: "\(Int(vm.minChunkSize)) tok",
+ displayWidth: 60
+ )
+ .disabled(vm.chunkingStrategyTag == "none")
+ }
+
+ // MARK: Concurrency
+
+ Section("Concurrency") {
+ sliderRow(
+ label: "Concurrent Workers",
+ info: "How many chunks to generate in parallel. 0 = all chunks at once (fastest). 1 = sequential (best for streaming playback).",
+ value: $vm.concurrentWorkerCount,
+ range: 0...16,
+ step: 1,
+ display: vm.concurrentWorkerCount == 0 ? "Max" : Int(vm.concurrentWorkerCount).formatted()
+ )
+ }
+
+ // MARK: Reset
+
+ Section {
+ Button(role: .destructive) {
+ viewModel.resetGenerationSettings()
+ } label: {
+ HStack {
+ Spacer()
+ Text("Reset to Defaults")
+ Spacer()
+ }
+ }
+ }
+ }
+ // Use grouped style on both platforms so macOS renders as a scrollable list
+ // rather than its default two-column layout, which misaligns our custom rows.
+ .formStyle(.grouped)
+ #if os(iOS)
+ .navigationTitle("Generation Options")
+ .navigationBarTitleDisplayMode(.inline)
+ .toolbar {
+ ToolbarItem(placement: .primaryAction) {
+ Button {
+ dismiss()
+ } label: {
+ Label("Done", systemImage: "xmark.circle.fill")
+ .foregroundStyle(.primary)
+ }
+ }
+ }
+ #endif
+ }
+
+ // MARK: - Helpers
+
+ /// Reusable label + info + value display + slider row.
+ private func sliderRow(
+ label: String,
+ info: String,
+ value: Binding,
+ range: ClosedRange,
+ step: Double,
+ display: String,
+ displayWidth: CGFloat = 44
+ ) -> some View {
+ VStack(alignment: .leading, spacing: 4) {
+ HStack {
+ Text(label)
+ InfoButton(info)
+ Spacer()
+ Text(display)
+ .monospacedDigit()
+ .foregroundStyle(.secondary)
+ .frame(width: displayWidth, alignment: .trailing)
+ }
+ Slider(value: value, in: range, step: step)
+ }
+ .padding(.vertical, 2)
+ }
+}
+
+// MARK: - Info Button
+
+/// Small info button that shows a popover with explanatory text.
+/// Matches the InfoButton pattern from WhisperAX.
+struct InfoButton: View {
+ let text: String
+ @State private var isShowing = false
+
+ init(_ text: String) {
+ self.text = text
+ }
+
+ var body: some View {
+ Button {
+ isShowing = true
+ } label: {
+ Image(systemName: "info.circle")
+ .foregroundStyle(.blue)
+ }
+ .buttonStyle(.borderless)
+ .popover(isPresented: $isShowing) {
+ Text(text)
+ .multilineTextAlignment(.leading)
+ .fixedSize(horizontal: false, vertical: true)
+ .padding()
+ .frame(width: 260)
+ }
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift b/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift
new file mode 100644
index 00000000..21f7fe0c
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift
@@ -0,0 +1,190 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import SwiftUI
+import TTSKit
+
+/// Sidebar section for model selection, download, load/unload, and deletion.
+/// Mirrors the model management pattern from WhisperAX / Argmax Playground.
+struct ModelManagementView: View {
+ @Environment(ViewModel.self) private var viewModel
+
+ var body: some View {
+ @Bindable var vm = viewModel
+
+ VStack(alignment: .leading, spacing: 8) {
+ // Status row: dot + state label + accessory buttons
+ HStack(spacing: 6) {
+ Image(systemName: "circle.fill")
+ .font(.system(size: 8))
+ .foregroundStyle(vm.modelState.color)
+ .symbolEffect(.variableColor, isActive: vm.modelState.isBusy)
+
+ Text(vm.modelState.label)
+ .font(.caption)
+ .foregroundStyle(.secondary)
+
+ Spacer()
+
+ #if os(macOS)
+ // On macOS the picker fits in the status row alongside the accessory buttons
+ modelPicker
+ #endif
+
+ accessoryButtons
+ }
+
+ #if os(iOS)
+ // On iOS the picker gets its own row so it isn't squeezed by the accessory buttons
+ modelPicker
+ .frame(maxWidth: .infinity, alignment: .leading)
+ #endif
+
+ // Size estimate
+ Text(viewModel.modelDiskSize(for: vm.selectedPreset) ?? vm.selectedPreset.sizeEstimate)
+ .font(.caption2)
+ .foregroundStyle(.tertiary)
+
+ // Action slot: button OR progress - same height, no layout shift
+ if vm.modelState == .loaded && vm.loadedPreset == vm.selectedPreset {
+ Button { viewModel.unloadModel() } label: {
+ Label("Unload Model", systemImage: "eject")
+ .frame(maxWidth: .infinity)
+ .frame(height: 28)
+ }
+ .glassButton()
+ .controlSize(.small)
+ .accessibilityLabel("Unload model")
+ .accessibilityHint("Releases the model from memory. You can reload it later.")
+ } else if vm.modelState == .downloading {
+ VStack(spacing: 2) {
+ HStack {
+ ProgressView(value: vm.downloadProgress)
+ .progressViewStyle(.linear)
+ Text(String(format: "%.1f%%", vm.downloadProgress * 100))
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ .monospacedDigit()
+ }
+ }
+ .frame(height: 28)
+ } else if vm.modelState == .prewarming || vm.modelState == .loading {
+ VStack(alignment: .leading, spacing: 2) {
+ ProgressView()
+ .progressViewStyle(.linear)
+ Text(vm.modelState == .prewarming
+ ? "Specializing for your device..."
+ : "Loading...")
+ .font(.caption2)
+ .foregroundStyle(.secondary)
+ }
+ .frame(height: 28)
+ } else {
+ Button { Task { await viewModel.loadModel() } } label: {
+ Label(
+ vm.isModelDownloaded ? "Load Model" : "Download & Load",
+ systemImage: vm.isModelDownloaded ? "arrow.up.circle" : "arrow.down.circle"
+ )
+ .frame(maxWidth: .infinity)
+ .frame(height: 28)
+ }
+ .glassButton(prominent: true)
+ .tint(.accentColor)
+ .controlSize(.small)
+ .accessibilityLabel(vm.isModelDownloaded ? "Load model" : "Download and load model")
+ .accessibilityHint(vm.isModelDownloaded
+ ? "Loads the downloaded model into memory"
+ : "Downloads the model (~\(vm.selectedPreset.sizeEstimate)) and loads it")
+ }
+ }
+ .padding(.horizontal, 12)
+ .padding(.vertical, 8)
+ }
+
+ // MARK: - Model Picker
+
+ private var modelPicker: some View {
+ @Bindable var vm = viewModel
+ return Picker("Model", selection: $vm.selectedPreset) {
+ ForEach(TTSModelVariant.allCases, id: \.self) { preset in
+ HStack {
+ let downloaded = vm.localModelPaths[preset] != nil
+ Image(systemName: downloaded ? "checkmark.circle" : "arrow.down.circle.dotted")
+ if preset.isAvailableOnCurrentPlatform {
+ Text(preset.displayName)
+ } else {
+ Text("\(preset.displayName) - Mac only")
+ .foregroundStyle(.tertiary)
+ }
+ }
+ .tag(preset)
+ }
+ }
+ .labelsHidden()
+ .pickerStyle(.menu)
+ .accessibilityLabel("Select model")
+ .accessibilityHint("Choose the TTS model to use for generation")
+ .onChange(of: vm.selectedPreset) {
+ if !vm.selectedPreset.isAvailableOnCurrentPlatform {
+ vm.selectedPreset = .defaultForCurrentPlatform
+ vm.statusMessage = "\(vm.selectedPreset.displayName) requires macOS"
+ return
+ }
+ if vm.loadedPreset == vm.selectedPreset { return }
+ if vm.localModelPaths[vm.selectedPreset] != nil {
+ vm.statusMessage = vm.modelState == .loaded
+ ? "Different model loaded"
+ : "Downloaded"
+ } else {
+ vm.statusMessage = "Not downloaded"
+ }
+ }
+ }
+
+ // MARK: - Accessory Buttons (trash, reveal, HF link)
+
+ @ViewBuilder
+ private var accessoryButtons: some View {
+ let vm = viewModel
+
+ // Delete
+ Button {
+ viewModel.deleteModel()
+ } label: {
+ Image(systemName: "trash")
+ }
+ .buttonStyle(.borderless)
+ .disabled(!vm.isModelDownloaded || vm.modelState.isBusy)
+ .help("Delete downloaded model files")
+ .accessibilityLabel("Delete model")
+ .accessibilityHint("Permanently deletes the downloaded model files from disk")
+
+ #if os(macOS)
+ // Reveal in Finder
+ if let path = vm.localModelPaths[vm.selectedPreset] {
+ Button {
+ NSWorkspace.shared.selectFile(nil, inFileViewerRootedAtPath: path)
+ } label: {
+ Image(systemName: "folder")
+ }
+ .buttonStyle(.borderless)
+ .help("Reveal in Finder")
+ }
+ #endif
+
+ // Open on HuggingFace
+ Button {
+ if let url = viewModel.modelRepoURL {
+ #if os(macOS)
+ NSWorkspace.shared.open(url)
+ #else
+ UIApplication.shared.open(url)
+ #endif
+ }
+ } label: {
+ Image(systemName: "link.circle")
+ }
+ .buttonStyle(.borderless)
+ .help("Open model on HuggingFace")
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png
new file mode 100644
index 00000000..55f3dedf
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json
new file mode 100644
index 00000000..621c8067
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json
@@ -0,0 +1,37 @@
+{
+ "fill-specializations" : [
+ {
+ "value" : "automatic"
+ },
+ {
+ "appearance" : "dark",
+ "value" : "automatic"
+ }
+ ],
+ "groups" : [
+ {
+ "layers" : [
+ {
+ "glass" : false,
+ "hidden" : false,
+ "image-name" : "1024.png",
+ "name" : "1024"
+ }
+ ],
+ "shadow" : {
+ "kind" : "neutral",
+ "opacity" : 0.5
+ },
+ "translucency" : {
+ "enabled" : true,
+ "value" : 0.5
+ }
+ }
+ ],
+ "supported-platforms" : {
+ "circles" : [
+ "watchOS"
+ ],
+ "squares" : "shared"
+ }
+}
\ No newline at end of file
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png
new file mode 100644
index 00000000..95cb96dd
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png
new file mode 100644
index 00000000..e81683a2
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png
new file mode 100644
index 00000000..55f3dedf
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png
new file mode 100644
index 00000000..55f3dedf
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png
new file mode 100644
index 00000000..a084825d
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png
new file mode 100644
index 00000000..275e890f
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png
new file mode 100644
index 00000000..0cc74726
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png
new file mode 100644
index 00000000..4ddd0681
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png
new file mode 100644
index 00000000..4ddd0681
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png
new file mode 100644
index 00000000..f4e54b88
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png
new file mode 100644
index 00000000..7a629226
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png
new file mode 100644
index 00000000..1952cada
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png
new file mode 100644
index 00000000..4b3b3e42
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png
new file mode 100644
index 00000000..bc8da580
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png
new file mode 100644
index 00000000..b3ffac1b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png
new file mode 100644
index 00000000..c4639557
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png
new file mode 100644
index 00000000..81467c9b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png
new file mode 100644
index 00000000..5813d1a9
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png
new file mode 100644
index 00000000..469fde22
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png
new file mode 100644
index 00000000..a5a87744
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png
new file mode 100644
index 00000000..24d36305
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png
new file mode 100644
index 00000000..4c18e6de
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png
new file mode 100644
index 00000000..023c9c32
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png
new file mode 100644
index 00000000..a39f420b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png
new file mode 100644
index 00000000..dad76cac
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png
new file mode 100644
index 00000000..46be1514
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png
new file mode 100644
index 00000000..ade9ec3b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png
new file mode 100644
index 00000000..54ccc013
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png
new file mode 100644
index 00000000..3b67b790
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png
new file mode 100644
index 00000000..5fa76358
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png
new file mode 100644
index 00000000..5fa76358
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png
new file mode 100644
index 00000000..9fc760d4
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png
new file mode 100644
index 00000000..a27515cf
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png
new file mode 100644
index 00000000..7cff860d
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png
new file mode 100644
index 00000000..e23c66d7
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png
new file mode 100644
index 00000000..a4e383f3
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png
new file mode 100644
index 00000000..1a1259ee
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png
new file mode 100644
index 00000000..210fd08b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png
new file mode 100644
index 00000000..210fd08b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png
new file mode 100644
index 00000000..e3744a97
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png
new file mode 100644
index 00000000..e3744a97
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png
new file mode 100644
index 00000000..262c010d
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png
new file mode 100644
index 00000000..8952531b
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json
new file mode 100644
index 00000000..d2c94f1c
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json
@@ -0,0 +1,318 @@
+{
+ "images" : [
+ {
+ "filename" : "40.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "20x20"
+ },
+ {
+ "filename" : "60.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "20x20"
+ },
+ {
+ "filename" : "58 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "29x29"
+ },
+ {
+ "filename" : "87 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "29x29"
+ },
+ {
+ "filename" : "76.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "38x38"
+ },
+ {
+ "filename" : "114.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "38x38"
+ },
+ {
+ "filename" : "80 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "40x40"
+ },
+ {
+ "filename" : "120.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "40x40"
+ },
+ {
+ "filename" : "120 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "60x60"
+ },
+ {
+ "filename" : "180.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "60x60"
+ },
+ {
+ "filename" : "128 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "64x64"
+ },
+ {
+ "filename" : "192.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "3x",
+ "size" : "64x64"
+ },
+ {
+ "filename" : "136.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "68x68"
+ },
+ {
+ "filename" : "152.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "76x76"
+ },
+ {
+ "filename" : "167.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "scale" : "2x",
+ "size" : "83.5x83.5"
+ },
+ {
+ "filename" : "1024 1.png",
+ "idiom" : "universal",
+ "platform" : "ios",
+ "size" : "1024x1024"
+ },
+ {
+ "filename" : "16.png",
+ "idiom" : "mac",
+ "scale" : "1x",
+ "size" : "16x16"
+ },
+ {
+ "filename" : "32.png",
+ "idiom" : "mac",
+ "scale" : "2x",
+ "size" : "16x16"
+ },
+ {
+ "filename" : "32.png",
+ "idiom" : "mac",
+ "scale" : "1x",
+ "size" : "32x32"
+ },
+ {
+ "filename" : "64.png",
+ "idiom" : "mac",
+ "scale" : "2x",
+ "size" : "32x32"
+ },
+ {
+ "filename" : "128.png",
+ "idiom" : "mac",
+ "scale" : "1x",
+ "size" : "128x128"
+ },
+ {
+ "filename" : "256.png",
+ "idiom" : "mac",
+ "scale" : "2x",
+ "size" : "128x128"
+ },
+ {
+ "filename" : "256.png",
+ "idiom" : "mac",
+ "scale" : "1x",
+ "size" : "256x256"
+ },
+ {
+ "filename" : "512.png",
+ "idiom" : "mac",
+ "scale" : "2x",
+ "size" : "256x256"
+ },
+ {
+ "filename" : "512.png",
+ "idiom" : "mac",
+ "scale" : "1x",
+ "size" : "512x512"
+ },
+ {
+ "filename" : "1024.png",
+ "idiom" : "mac",
+ "scale" : "2x",
+ "size" : "512x512"
+ },
+ {
+ "filename" : "44.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "22x22"
+ },
+ {
+ "filename" : "48.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "24x24"
+ },
+ {
+ "filename" : "55.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "27.5x27.5"
+ },
+ {
+ "filename" : "58.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "29x29"
+ },
+ {
+ "filename" : "60 1.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "30x30"
+ },
+ {
+ "filename" : "64 1.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "32x32"
+ },
+ {
+ "filename" : "66.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "33x33"
+ },
+ {
+ "filename" : "80.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "40x40"
+ },
+ {
+ "filename" : "87.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "43.5x43.5"
+ },
+ {
+ "filename" : "88.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "44x44"
+ },
+ {
+ "filename" : "92.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "46x46"
+ },
+ {
+ "filename" : "100.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "50x50"
+ },
+ {
+ "filename" : "102.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "51x51"
+ },
+ {
+ "filename" : "108.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "54x54"
+ },
+ {
+ "filename" : "172.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "86x86"
+ },
+ {
+ "filename" : "196.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "98x98"
+ },
+ {
+ "filename" : "216.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "108x108"
+ },
+ {
+ "filename" : "234.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "117x117"
+ },
+ {
+ "filename" : "258.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "scale" : "2x",
+ "size" : "129x129"
+ },
+ {
+ "filename" : "1024 2.png",
+ "idiom" : "universal",
+ "platform" : "watchos",
+ "size" : "1024x1024"
+ }
+ ],
+ "info" : {
+ "author" : "xcode",
+ "version" : 1
+ }
+}
diff --git a/Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json
similarity index 51%
rename from Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json
rename to Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json
index eb878970..73c00596 100644
--- a/Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json
+++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json
@@ -1,9 +1,4 @@
{
- "colors" : [
- {
- "idiom" : "universal"
- }
- ],
"info" : {
"author" : "xcode",
"version" : 1
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png
new file mode 100644
index 00000000..023c9c32
Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png differ
diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist b/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist
new file mode 100644
index 00000000..0c67376e
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift b/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift
new file mode 100644
index 00000000..179ebd59
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift
@@ -0,0 +1,225 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import SwiftUI
+
+struct SidebarView: View {
+ @Environment(ViewModel.self) private var viewModel
+
+ var body: some View {
+ @Bindable var vm = viewModel
+
+ VStack(spacing: 0) {
+ // Model management at the top of the sidebar
+ ModelManagementView()
+
+ Divider()
+
+ // Compute units configuration
+ ComputeUnitsView()
+ .disabled(vm.modelState.isBusy)
+
+ Divider()
+
+ // Generation history list
+ List(selection: $vm.selectedGenerationID) {
+ if !vm.favoriteGenerations.isEmpty {
+ Section("Favorites") {
+ ForEach(vm.favoriteGenerations) { gen in
+ GenerationRow(generation: gen)
+ .tag(gen.id)
+ .contextMenu { rowContextMenu(for: gen) }
+ #if os(iOS)
+ .swipeActions(edge: .trailing) { rowSwipeActions(for: gen) }
+ #endif
+ }
+ }
+ }
+
+ Section("Recents") {
+ ForEach(vm.generations) { gen in
+ GenerationRow(generation: gen)
+ .tag(gen.id)
+ .contextMenu { rowContextMenu(for: gen) }
+ #if os(iOS)
+ .swipeActions(edge: .trailing) { rowSwipeActions(for: gen) }
+ #endif
+ }
+ }
+ }
+ .overlay {
+ if vm.generations.isEmpty {
+ ContentUnavailableView(
+ "No Generations Yet",
+ systemImage: "waveform",
+ description: Text("Generated speech will appear here.")
+ )
+ }
+ }
+
+ Divider()
+
+ // App & device info footer
+ appInfoFooter
+ }
+ .navigationTitle("TTSKit Example")
+ #if os(macOS)
+ .navigationSplitViewColumnWidth(min: 260, ideal: 280, max: 400)
+ #elseif os(iOS)
+ .toolbar {
+ ToolbarItemGroup(placement: .bottomBar) {
+ Spacer()
+ Button(action: newGeneration) {
+ Image(systemName: "plus")
+ .font(.title2)
+ .bold()
+ }
+ }
+ }
+ #endif
+ }
+
+ // MARK: - New Generation
+
+ private func newGeneration() {
+ viewModel.clearInput()
+ viewModel.currentWaveform = []
+ // Use a sentinel UUID to show the detail in "new generation" mode.
+ // This drives navigation uniformly on all platforms via List(selection:).
+ viewModel.selectedGenerationID = ViewModel.newGenerationSentinel
+ }
+
+ // MARK: - App Info Footer
+
+ private var appInfoFooter: some View {
+ VStack(alignment: .leading, spacing: 4) {
+ let version = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "Unknown"
+ let build = Bundle.main.infoDictionary?["CFBundleVersion"] as? String ?? "Unknown"
+ Text("App Version: \(version) (\(build))")
+ #if os(iOS)
+ Text("Device: \(UIDevice.current.model)")
+ Text("OS: \(UIDevice.current.systemVersion)")
+ #elseif os(macOS)
+ Text("OS: \(ProcessInfo.processInfo.operatingSystemVersionString)")
+ #endif
+ }
+ .font(.system(.caption2, design: .monospaced))
+ .foregroundStyle(.tertiary)
+ .frame(maxWidth: .infinity, alignment: .leading)
+ .padding(.horizontal, 12)
+ .padding(.vertical, 8)
+ }
+
+ // MARK: - Context Menu / Swipe Actions
+
+ @ViewBuilder
+ private func rowContextMenu(for generation: Generation) -> some View {
+ Button {
+ viewModel.playGeneration(generation)
+ } label: {
+ Label("Play", systemImage: "play.fill")
+ }
+
+ Button {
+ viewModel.toggleFavorite(generation.id)
+ } label: {
+ Label(
+ generation.isFavorite ? "Unfavorite" : "Favorite",
+ systemImage: generation.isFavorite ? "star.slash" : "star"
+ )
+ }
+
+ if let url = viewModel.audioFileURL(for: generation) {
+ AudioCopyButton(url: url)
+
+ #if os(macOS)
+ Button {
+ NSWorkspace.shared.activateFileViewerSelecting([url])
+ } label: {
+ Label("Show in Finder", systemImage: "folder")
+ }
+ #endif
+ }
+
+ Divider()
+
+ Button(role: .destructive) {
+ viewModel.deleteGeneration(generation.id)
+ } label: {
+ Label("Delete", systemImage: "trash")
+ }
+ }
+
+ #if os(iOS)
+ @ViewBuilder
+ private func rowSwipeActions(for generation: Generation) -> some View {
+ Button(role: .destructive) {
+ viewModel.deleteGeneration(generation.id)
+ } label: {
+ Label("Delete", systemImage: "trash")
+ }
+ }
+ #endif
+}
+
+// MARK: - Row
+
+struct GenerationRow: View {
+ @Environment(ViewModel.self) private var viewModel
+ let generation: Generation
+ @State private var isHovered = false
+
+ var body: some View {
+ HStack(spacing: 10) {
+ if let samples = generation.waveformSamples, !samples.isEmpty {
+ WaveformThumbnail(samples: samples)
+ } else {
+ Image(systemName: "waveform")
+ .frame(width: 48, height: 24)
+ .foregroundStyle(.tertiary)
+ }
+
+ VStack(alignment: .leading, spacing: 2) {
+ Text(generation.title)
+ .font(.body)
+ .lineLimit(1)
+ .truncationMode(.tail)
+
+ HStack(spacing: 6) {
+ Text(generation.date, style: .date)
+ Text("·")
+ Text(String(format: "%.1fs", generation.audioDuration))
+ Text("·")
+ Text(generation.speaker.capitalized)
+ }
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ }
+
+ Spacer()
+
+ if generation.isFavorite {
+ Image(systemName: "star.fill")
+ .font(.caption)
+ .foregroundStyle(.yellow)
+ }
+
+ #if os(macOS)
+ Button {
+ viewModel.deleteGeneration(generation.id)
+ } label: {
+ Image(systemName: "trash")
+ .font(.caption)
+ .foregroundStyle(.secondary)
+ }
+ .buttonStyle(.plain)
+ .opacity(isHovered ? 1 : 0)
+ .animation(.easeInOut(duration: 0.15), value: isHovered)
+ #endif
+ }
+ .padding(.vertical, 2)
+ #if os(macOS)
+ .onHover { isHovered = $0 }
+ #endif
+ }
+}
diff --git a/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift b/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift
new file mode 100644
index 00000000..bc5e52ed
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift
@@ -0,0 +1,87 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import SwiftUI
+
+@main
+struct SpeakAXApp: App {
+ @State private var viewModel = ViewModel()
+
+ var body: some Scene {
+ WindowGroup {
+ ContentView()
+ .environment(viewModel)
+ #if os(iOS)
+ .onAppear { installKeyboardDismissGesture() }
+ #endif
+ .onReceive(NotificationCenter.default.publisher(
+ for: {
+ #if os(iOS)
+ UIApplication.willResignActiveNotification
+ #else
+ NSApplication.willTerminateNotification
+ #endif
+ }()
+ )) { _ in
+ viewModel.cancelAllTasks()
+ }
+ }
+ .defaultSize(width: 1100, height: 700)
+ }
+}
+
+// MARK: - iOS keyboard dismiss on tap outside
+
+#if os(iOS)
+import UIKit
+
+/// Installs a window-level tap gesture recognizer that dismisses the keyboard whenever
+/// the user taps outside a UITextView or UITextField. Uses `cancelsTouchesInView = false`
+/// and a delegate check so text selection and all other controls continue to work normally.
+private func installKeyboardDismissGesture() {
+ guard let window = UIApplication.shared
+ .connectedScenes
+ .compactMap({ $0 as? UIWindowScene })
+ .flatMap({ $0.windows })
+ .first(where: { $0.isKeyWindow }) else { return }
+
+ // Only install once
+ if window.gestureRecognizers?.contains(where: { $0 is KeyboardDismissTapRecognizer }) == true {
+ return
+ }
+
+ window.addGestureRecognizer(KeyboardDismissTapRecognizer())
+}
+
+private final class KeyboardDismissTapRecognizer: UITapGestureRecognizer, UIGestureRecognizerDelegate {
+ init() {
+ super.init(target: nil, action: nil)
+ cancelsTouchesInView = false
+ delegate = self
+ addTarget(self, action: #selector(handleTap))
+ }
+
+ @objc private func handleTap() {
+ UIApplication.shared.sendAction(
+ #selector(UIResponder.resignFirstResponder),
+ to: nil, from: nil, for: nil
+ )
+ }
+
+ /// Only fire when the touch lands outside a text input view.
+ func gestureRecognizer(
+ _ gestureRecognizer: UIGestureRecognizer,
+ shouldReceive touch: UITouch
+ ) -> Bool {
+ !(touch.view is UITextView || touch.view is UITextField)
+ }
+
+ /// Always allow simultaneous recognition with every other gesture (buttons, scrolls, etc.)
+ func gestureRecognizer(
+ _ gestureRecognizer: UIGestureRecognizer,
+ shouldRecognizeSimultaneouslyWith other: UIGestureRecognizer
+ ) -> Bool {
+ true
+ }
+}
+#endif
diff --git a/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift b/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift
new file mode 100644
index 00000000..8497f6ad
--- /dev/null
+++ b/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift
@@ -0,0 +1,1124 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+@preconcurrency import AVFoundation
+import CoreML
+import Foundation
+import Observation
+import SwiftUI
+import Tokenizers
+import TTSKit
+import WhisperKit
+#if os(macOS)
+import AppKit
+#endif
+
+// MARK: - Data Model
+
+/// An in-memory representation of a generated audio clip.
+///
+/// All persistent fields live inside the M4A file's metadata - there is no
+/// companion JSON or database. `Generation` is reconstructed on launch by
+/// scanning the documents directory for `.m4a` files and calling
+/// `AudioOutput.loadMetadata(from:)` on each one.
+///
+/// `isFavorite` is the only mutable bit of state not in the file itself; it
+/// is stored as a `Set` in `UserDefaults` so we never need to re-encode
+/// the M4A.
+@MainActor
+struct Generation: Identifiable {
+ let id: UUID
+ let title: String
+ let text: String
+ let speaker: String
+ let language: String
+ let instruction: String
+ let modelName: String
+ let audioDuration: TimeInterval
+ let realTimeFactor: Double
+ let speedFactor: Double
+ let stepsPerSecond: Double
+ let timeToFirstBuffer: TimeInterval
+ let date: Date
+ var isFavorite: Bool
+ let audioFileName: String
+ var waveformSamples: [Float]?
+
+ init(
+ metadata: AudioMetadata,
+ audioFileName: String,
+ audioDuration: TimeInterval,
+ isFavorite: Bool
+ ) {
+ self.id = metadata.id
+ self.title = String(metadata.text.prefix(ViewModel.titleMaxLength))
+ self.text = metadata.text
+ self.speaker = metadata.speaker
+ self.language = metadata.language
+ self.instruction = metadata.instruction
+ self.modelName = metadata.modelName
+ self.audioDuration = audioDuration
+ self.realTimeFactor = metadata.realTimeFactor
+ self.speedFactor = metadata.speedFactor
+ self.stepsPerSecond = metadata.stepsPerSecond
+ self.timeToFirstBuffer = metadata.timeToFirstBuffer
+ self.date = metadata.date
+ self.isFavorite = isFavorite
+ self.audioFileName = audioFileName
+ self.waveformSamples = nil
+ }
+}
+
+// MARK: - Model State
+
+enum ModelState: Equatable {
+ case unloaded
+ case downloading
+ case prewarming
+ case loading
+ case loaded
+ case error(String)
+
+ var label: String {
+ switch self {
+ case .unloaded: return "Unloaded"
+ case .downloading: return "Downloading..."
+ case .prewarming: return "Specializing..."
+ case .loading: return "Loading..."
+ case .loaded: return "Loaded"
+ case let .error(msg): return "Error: \(msg)"
+ }
+ }
+
+ var color: Color {
+ switch self {
+ case .unloaded: return .red
+ case .downloading, .prewarming, .loading: return .yellow
+ case .loaded: return .green
+ case .error: return .red
+ }
+ }
+
+ var isBusy: Bool {
+ self == .downloading || self == .prewarming || self == .loading
+ }
+}
+
+enum GenerationState: Equatable {
+ case idle
+ case generating
+}
+
+// MARK: - Settings Storage
+
+/// Holds all persisted settings using @AppStorage.
+/// Lives inside ViewModel and is read once at init time to seed the
+/// observable stored properties. Each stored property's didSet writes
+/// the new value back here so UserDefaults stays in sync.
+@MainActor
+private class Settings {
+ /// Model selection
+ @AppStorage("selectedPreset") var selectedPresetRaw: String = TTSModelVariant.defaultForCurrentPlatform.rawValue
+ @AppStorage("selectedSpeaker") var selectedSpeakerRaw: String = Qwen3Speaker.ryan.rawValue
+ @AppStorage("selectedLanguage") var selectedLanguageRaw: String = Qwen3Language.english.rawValue
+ @AppStorage("playbackStrategyTag") var playbackStrategyTag: String = "auto"
+
+ /// Generation options
+ @AppStorage("genTemperature") var temperature: Double = .init(GenerationOptions.defaultTemperature)
+ @AppStorage("genTopK") var topK: Double = .init(GenerationOptions.defaultTopK)
+ @AppStorage("genRepetitionPenalty") var repetitionPenalty: Double = .init(GenerationOptions.defaultRepetitionPenalty)
+ @AppStorage("genMaxNewTokens") var maxNewTokens: Double = .init(GenerationOptions.defaultMaxNewTokens)
+ @AppStorage("genConcurrentWorkerCount") var concurrentWorkerCount: Double = 0
+ @AppStorage("genChunkingStrategy") var chunkingStrategyTag: String = "sentence"
+ @AppStorage("genTargetChunkSize") var targetChunkSize: Double = .init(TextChunker.defaultTargetChunkSize)
+ @AppStorage("genMinChunkSize") var minChunkSize: Double = .init(TextChunker.defaultMinChunkSize)
+
+ /// Compute units (stored as Int matching MLComputeUnits.rawValue)
+ @AppStorage("embedderComputeUnits") var embedderComputeUnitsRaw: Int = MLComputeUnits.cpuOnly.rawValue
+ @AppStorage("codeDecoderComputeUnits") var codeDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue
+ @AppStorage("multiCodeDecoderComputeUnits") var multiCodeDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue
+ @AppStorage("speechDecoderComputeUnits") var speechDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue
+}
+
+// MARK: - View Model
+
+@MainActor
+@Observable
+final class ViewModel: @unchecked Sendable {
+ // MARK: - Constants
+
+ fileprivate static let titleMaxLength = 40
+ private static let historyLimit = 20
+ /// Poll interval for playback position updates (~30 fps).
+ private static let playbackPollIntervalMs = 33
+ /// Debounce delay before running a token count after a keystroke.
+ private static let tokenCountDebounceMs = 250
+
+ // MARK: - Settings storage
+
+ private let settings = Settings()
+
+ // MARK: - Model management
+
+ var modelState: ModelState = .unloaded
+ var downloadProgress: Double = 0
+ var localModelPaths: [TTSModelVariant: String] = [:]
+
+ // MARK: - Persisted: model selection
+
+ var selectedPreset: TTSModelVariant {
+ didSet { settings.selectedPresetRaw = selectedPreset.rawValue }
+ }
+
+ // MARK: - Generation state
+
+ var generationState: GenerationState = .idle
+ var statusMessage = "Select a model to get started"
+
+ // MARK: - Persisted: input defaults
+
+ var selectedSpeaker: Qwen3Speaker {
+ didSet { settings.selectedSpeakerRaw = selectedSpeaker.rawValue }
+ }
+
+ var selectedLanguage: Qwen3Language {
+ didSet { settings.selectedLanguageRaw = selectedLanguage.rawValue }
+ }
+
+ var playbackStrategyTag: String {
+ didSet { settings.playbackStrategyTag = playbackStrategyTag }
+ }
+
+ var selectedPlaybackStrategy: PlaybackStrategy {
+ switch playbackStrategyTag {
+ case "stream": return .stream
+ case "generateFirst": return .generateFirst
+ default: return .auto
+ }
+ }
+
+ var instruction: String = ""
+
+ // MARK: - Persisted: generation options
+
+ var temperature: Double { didSet { settings.temperature = temperature } }
+ var topK: Double { didSet { settings.topK = topK } }
+ var repetitionPenalty: Double { didSet { settings.repetitionPenalty = repetitionPenalty } }
+ var maxNewTokens: Double { didSet { settings.maxNewTokens = maxNewTokens } }
+ var concurrentWorkerCount: Double { didSet { settings.concurrentWorkerCount = concurrentWorkerCount } }
+ var chunkingStrategyTag: String { didSet { settings.chunkingStrategyTag = chunkingStrategyTag } }
+
+ var chunkingStrategy: TextChunkingStrategy {
+ TextChunkingStrategy(rawValue: chunkingStrategyTag) ?? .sentence
+ }
+
+ var targetChunkSize: Double { didSet { settings.targetChunkSize = targetChunkSize } }
+ var minChunkSize: Double { didSet { settings.minChunkSize = minChunkSize } }
+
+ // MARK: - Persisted: compute units
+
+ var embedderComputeUnits: MLComputeUnits {
+ didSet { settings.embedderComputeUnitsRaw = embedderComputeUnits.rawValue }
+ }
+
+ var codeDecoderComputeUnits: MLComputeUnits {
+ didSet { settings.codeDecoderComputeUnitsRaw = codeDecoderComputeUnits.rawValue }
+ }
+
+ var multiCodeDecoderComputeUnits: MLComputeUnits {
+ didSet { settings.multiCodeDecoderComputeUnitsRaw = multiCodeDecoderComputeUnits.rawValue }
+ }
+
+ var speechDecoderComputeUnits: MLComputeUnits {
+ didSet { settings.speechDecoderComputeUnitsRaw = speechDecoderComputeUnits.rawValue }
+ }
+
+ var computeOptions: ComputeOptions {
+ ComputeOptions(
+ embedderComputeUnits: embedderComputeUnits,
+ codeDecoderComputeUnits: codeDecoderComputeUnits,
+ multiCodeDecoderComputeUnits: multiCodeDecoderComputeUnits,
+ speechDecoderComputeUnits: speechDecoderComputeUnits
+ )
+ }
+
+ // MARK: - Init
+
+ init() {
+ // Seed all persisted properties from @AppStorage backing store
+ // Resolve the persisted preset, falling back to the platform default.
+ // If a previously saved preset is no longer available on this platform (e.g. 0.6B
+ // saved on macOS, then opened on iOS), quietly switch to the platform default.
+ let savedPreset = TTSModelVariant(rawValue: settings.selectedPresetRaw)
+ selectedPreset = (savedPreset?.isAvailableOnCurrentPlatform == true)
+ ? savedPreset!
+ : .defaultForCurrentPlatform
+ selectedSpeaker = Qwen3Speaker(rawValue: settings.selectedSpeakerRaw) ?? .ryan
+ selectedLanguage = Qwen3Language(rawValue: settings.selectedLanguageRaw) ?? .english
+ playbackStrategyTag = settings.playbackStrategyTag
+ temperature = settings.temperature
+ topK = settings.topK
+ repetitionPenalty = settings.repetitionPenalty
+ maxNewTokens = settings.maxNewTokens
+ concurrentWorkerCount = settings.concurrentWorkerCount
+ chunkingStrategyTag = settings.chunkingStrategyTag
+ targetChunkSize = settings.targetChunkSize
+ minChunkSize = settings.minChunkSize
+ embedderComputeUnits = MLComputeUnits(rawValue: settings.embedderComputeUnitsRaw) ?? .cpuOnly
+ codeDecoderComputeUnits = MLComputeUnits(rawValue: settings.codeDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine
+ multiCodeDecoderComputeUnits = MLComputeUnits(rawValue: settings.multiCodeDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine
+ speechDecoderComputeUnits = MLComputeUnits(rawValue: settings.speechDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine
+ }
+
+ // MARK: - Generation output
+
+ var currentWaveform: [Float] = []
+ var currentAudioSamples: [Float] = []
+ var currentDuration: TimeInterval = 0
+ var currentRTF: Double = 0
+ var currentSpeedFactor: Double = 0
+ var currentStepsPerSecond: Double = 0
+ var currentTimeToFirstBuffer: TimeInterval = 0
+
+ // MARK: - Playback & streaming
+
+ var isPlaying = false
+ /// Current playback position in seconds (works for both streaming and replay)
+ var playbackTime: TimeInterval = 0
+ /// True while we're in a live generate-and-play session
+ var isStreaming = false
+ /// Reference to the active audio output during streaming, for querying playback position
+ private var activeAudioOutput: AudioOutput?
+
+ // MARK: - Generation progress (generateFirst mode)
+
+ /// Running count of decoder steps completed across all chunks.
+ var stepsCompleted: Int = 0
+ /// Estimated total steps for the full request (maxNewTokens × totalChunks).
+ var totalSteps: Int = 0
+ /// Total number of text chunks in the current request.
+ var chunksTotal: Int = 0
+
+ /// Seconds of audio still accumulating in the pre-buffer before the next chunk
+ /// flushes and playback resumes. Non-zero only while actively buffering mid-stream.
+ var silentBufferRemaining: TimeInterval {
+ guard isStreaming, let audioOut = activeAudioOutput else { return 0 }
+ return audioOut.silentBufferRemaining
+ }
+
+ /// Total real audio scheduled to the player so far (excludes silent sentinel buffers).
+ var scheduledAudioDuration: TimeInterval {
+ guard isStreaming, let audioOut = activeAudioOutput else { return 0 }
+ return audioOut.scheduledAudioDuration
+ }
+
+ private var playbackUpdateTask: Task?
+ private var generationTask: Task?
+
+ // MARK: - History
+
+ /// A sentinel that drives the detail view into "new generation" mode without
+ /// requiring a separate navigation Bool. Setting `selectedGenerationID` to this
+ /// value shows the detail with empty inputs while no real generation is selected.
+ static let newGenerationSentinel = UUID(uuidString: "00000000-0000-0000-0000-000000000000")!
+
+ var generations: [Generation] = []
+ var selectedGenerationID: UUID?
+
+ // MARK: - Search
+
+ var searchText = ""
+
+ // MARK: - Sheet presentation
+
+ var showGenerationSettings = false
+
+ // MARK: - Private
+
+ private var tts: TTSKit?
+ private(set) var loadedPreset: TTSModelVariant?
+ private var audioPlayer: AVAudioPlayer?
+ private var tokenCountTask: Task?
+
+ // MARK: - Computed
+
+ var selectedGeneration: Generation? {
+ guard let id = selectedGenerationID else { return nil }
+ return generations.first { $0.id == id }
+ }
+
+ var filteredGenerations: [Generation] {
+ if searchText.isEmpty { return generations }
+ let query = searchText.lowercased()
+ return generations.filter {
+ $0.title.lowercased().contains(query)
+ || $0.text.lowercased().contains(query)
+ || $0.speaker.lowercased().contains(query)
+ }
+ }
+
+ var favoriteGenerations: [Generation] {
+ filteredGenerations.filter { $0.isFavorite }
+ }
+
+ var recentGenerations: [Generation] {
+ Array(filteredGenerations.prefix(Self.historyLimit))
+ }
+
+ var canGenerate: Bool {
+ !inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
+ && generationState == .idle
+ && !modelState.isBusy
+ }
+
+ var isModelDownloaded: Bool {
+ localModelPaths[selectedPreset] != nil
+ }
+
+ var inputText = "" {
+ didSet {
+ guard inputText != oldValue else { return }
+ scheduleTokenCount()
+ }
+ }
+
+ /// Token count of `inputText` as measured by the loaded tokenizer.
+ /// Updated with a short debounce as the user types; falls back to a character
+ /// approximation in the UI when the tokenizer is not yet loaded.
+ var inputTokenCount: Int?
+
+ // MARK: - Lifecycle
+
+ func onAppear() {
+ Task { await loadGenerations() }
+ scanLocalModels()
+ }
+
+ // MARK: - Model Management
+
+ /// Hub storage path relative to Documents, matching the pattern used by WhisperAX / HubApi.
+ /// The Hub library stores downloads at: Documents/huggingface/models/{repo}/
+ private static let modelStorageBase = "huggingface/models"
+
+ /// HuggingFace model repo for the selected preset
+ var modelRepoURL: URL? {
+ let config = TTSKitConfig(model: selectedPreset)
+ return URL(string: "https://huggingface.co/\(config.modelRepo)")
+ }
+
+ /// Local path to the Hub-managed repo folder for a given config
+ private func localRepoURL(for config: TTSKitConfig) -> URL {
+ documentsDirectory
+ .appendingPathComponent(Self.modelStorageBase)
+ .appendingPathComponent(config.modelRepo)
+ }
+
+ /// Scan for previously downloaded models by checking the Hub cache in Documents
+ func scanLocalModels() {
+ for preset in TTSModelVariant.allCases {
+ let config = TTSKitConfig(model: preset)
+ let repoURL = localRepoURL(for: config)
+ let repoPath = repoURL.path
+
+ guard FileManager.default.fileExists(atPath: repoPath) else { continue }
+
+ // Verify at least one component's version directory exists
+ if hasModelFiles(at: repoPath, config: config) {
+ localModelPaths[preset] = repoPath
+ }
+ }
+
+ if isModelDownloaded, modelState == .unloaded {
+ statusMessage = "Model downloaded"
+ }
+ }
+
+ /// Check if any model component directories exist at a given repo path
+ private func hasModelFiles(at basePath: String, config: TTSKitConfig) -> Bool {
+ let baseURL = URL(fileURLWithPath: basePath)
+ return config.componentDirectories(in: baseURL).contains {
+ FileManager.default.fileExists(atPath: $0.path)
+ }
+ }
+
+ /// Download the selected model preset without loading it
+ func downloadModel() async {
+ guard !modelState.isBusy else { return }
+ modelState = .downloading
+ downloadProgress = 0
+ statusMessage = "Downloading \(selectedPreset.rawValue) model..."
+
+ do {
+ // Set a HuggingFace token via TTSKitConfig(token:) if the model repo is private.
+ // For the public argmaxinc/ttskit-coreml repo no token is required.
+ let config = TTSKitConfig(
+ model: selectedPreset,
+ verbose: true
+ )
+ let folder = try await TTSKit.download(config: config) { [weak self] progress in
+ Task { @MainActor in
+ self?.downloadProgress = progress.fractionCompleted
+ }
+ }
+ localModelPaths[selectedPreset] = folder.path
+ modelState = .unloaded
+ downloadProgress = 1.0
+ statusMessage = "Downloaded"
+ } catch {
+ modelState = .error(error.localizedDescription)
+ statusMessage = "Download failed: \(error.localizedDescription)"
+ }
+ }
+
+ /// Load the selected model preset (downloads first if needed)
+ func loadModel() async {
+ guard !modelState.isBusy else { return }
+
+ // If switching presets while loaded, unload first
+ if loadedPreset != nil, loadedPreset != selectedPreset {
+ unloadModel()
+ }
+
+ // Download if needed
+ if localModelPaths[selectedPreset] == nil {
+ await downloadModel()
+ guard localModelPaths[selectedPreset] != nil else { return }
+ }
+
+ do {
+ let ttsConfig = TTSKitConfig(
+ model: selectedPreset,
+ modelFolder: localModelPaths[selectedPreset].map { URL(fileURLWithPath: $0) },
+ computeOptions: computeOptions,
+ verbose: true
+ )
+
+ // Create TTSKit without loading - we drive load/prewarm explicitly
+ ttsConfig.load = false
+ let kit = try await TTSKit(ttsConfig)
+ tts = kit
+
+ // Prewarm: compile each CoreML model sequentially, then discard.
+ // This prevents memory exhaustion from concurrent compilation on first launch.
+ // On subsequent launches the compiled cache is already on disk, so this is fast.
+ modelState = .prewarming
+ statusMessage = "Specializing \(selectedPreset.rawValue) for your device\nThis may take a few minutes on first load"
+ do {
+ try await kit.prewarmModels()
+ } catch {
+ // Prewarm failures are non-fatal (model may already be specialized).
+ // Surface to the user so they know something unexpected happened,
+ // but continue - the full load often succeeds regardless.
+ let msg = error.localizedDescription
+ print("Prewarm warning: \(msg) - continuing to load")
+ statusMessage = "Prewarm warning: \(msg)\nContinuing..."
+ }
+
+ // Load: compiled artifacts are cached, concurrent load is safe
+ modelState = .loading
+ statusMessage = "Loading \(selectedPreset.rawValue) model..."
+ try await kit.loadModels()
+
+ loadedPreset = selectedPreset
+ modelState = .loaded
+ statusMessage = "Ready - \(selectedPreset.rawValue) loaded"
+ AccessibilityNotification.Announcement("\(selectedPreset.rawValue) model loaded and ready").post()
+ // Refresh token count now that the tokenizer is available.
+ scheduleTokenCount()
+ } catch {
+ modelState = .error(error.localizedDescription)
+ statusMessage = "Load failed: \(error.localizedDescription)"
+ AccessibilityNotification.Announcement("Model load failed: \(error.localizedDescription)").post()
+ }
+ }
+
+ /// Reload the current model with updated compute options
+ func reloadModelForComputeUnitChange() {
+ guard modelState == .loaded, let preset = loadedPreset else { return }
+ unloadModel()
+ selectedPreset = preset
+ Task { [weak self] in await self?.loadModel() }
+ }
+
+ /// Unload the current model from memory
+ func unloadModel() {
+ cancelAllTasks()
+ let oldTTS = tts
+ tts = nil
+ Task { await oldTTS?.unloadModels() }
+ loadedPreset = nil
+ modelState = .unloaded
+ statusMessage = isModelDownloaded ? "Model unloaded" : "Select a model to get started"
+ }
+
+ /// Reset all generation options to their factory defaults.
+ func resetGenerationSettings() {
+ temperature = Double(GenerationOptions.defaultTemperature)
+ topK = Double(GenerationOptions.defaultTopK)
+ repetitionPenalty = Double(GenerationOptions.defaultRepetitionPenalty)
+ maxNewTokens = Double(GenerationOptions.defaultMaxNewTokens)
+ concurrentWorkerCount = 0
+ chunkingStrategyTag = "sentence"
+ targetChunkSize = Double(TextChunker.defaultTargetChunkSize)
+ minChunkSize = Double(TextChunker.defaultMinChunkSize)
+ }
+
+ /// Delete downloaded model files for a specific variant only,
+ /// leaving other variants in the shared repo untouched.
+ func deleteModel(preset: TTSModelVariant? = nil) {
+ let target = preset ?? selectedPreset
+
+ if loadedPreset == target {
+ unloadModel()
+ }
+
+ let config = TTSKitConfig(model: target)
+ let repoURL = localRepoURL(for: config)
+ for dir in config.componentDirectories(in: repoURL) {
+ guard FileManager.default.fileExists(atPath: dir.path) else { continue }
+ try? FileManager.default.removeItem(at: dir)
+ }
+ localModelPaths.removeValue(forKey: target)
+
+ if target == selectedPreset {
+ modelState = .unloaded
+ statusMessage = "Model deleted"
+ }
+ }
+
+ /// Disk size of downloaded model files for a specific variant only.
+ func modelDiskSize(for preset: TTSModelVariant) -> String? {
+ guard localModelPaths[preset] != nil else { return nil }
+ let config = TTSKitConfig(model: preset)
+ let repoURL = localRepoURL(for: config)
+ var total: UInt64 = 0
+ for dir in config.componentDirectories(in: repoURL) {
+ if let size = directorySize(at: dir) {
+ total += size
+ }
+ }
+ guard total > 0 else { return nil }
+ return ByteCountFormatter.string(fromByteCount: Int64(total), countStyle: .file)
+ }
+
+ private func directorySize(at url: URL) -> UInt64? {
+ let fm = FileManager.default
+ guard let enumerator = fm.enumerator(at: url, includingPropertiesForKeys: [.fileSizeKey]) else {
+ return nil
+ }
+ var total: UInt64 = 0
+ for case let fileURL as URL in enumerator {
+ if let size = try? fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize {
+ total += UInt64(size)
+ }
+ }
+ return total
+ }
+
+ // MARK: - Generation
+
+ /// Start generation in a tracked task that can be cancelled.
+ /// If no model is loaded yet, automatically loads the selected (or default) model first.
+ func startGeneration() {
+ generationTask?.cancel()
+ // Don't touch selectedGenerationID here. On iPhone, NavigationSplitView collapses
+ // to a NavigationStack driven by List(selection:) - setting the selection to nil
+ // immediately pops the detail view. The selection updates naturally when the new
+ // generation is saved at the end of generate().
+ generationTask = Task { [weak self] in
+ guard let self else { return }
+ if modelState != .loaded {
+ await loadModel()
+ guard !Task.isCancelled, modelState == .loaded else { return }
+ }
+ guard !Task.isCancelled else { return }
+ await generate()
+ }
+ }
+
+ /// Cancel all background tasks (generation, playback updates, token counting).
+ /// Safe to call from any lifecycle event (view disappear, scene deactivation, etc.).
+ func cancelAllTasks() {
+ generationTask?.cancel()
+ generationTask = nil
+ tokenCountTask?.cancel()
+ tokenCountTask = nil
+ stopPlaybackUpdates()
+ if activeAudioOutput != nil {
+ let audioOut = activeAudioOutput
+ activeAudioOutput = nil
+ Task { await audioOut?.stopPlayback(waitForCompletion: false) }
+ }
+ audioPlayer?.stop()
+ audioPlayer = nil
+ isPlaying = false
+ isStreaming = false
+ if generationState == .generating {
+ generationState = .idle
+ statusMessage = "Cancelled"
+ }
+ }
+
+ /// Cancel any in-progress generation and stop audio immediately.
+ func cancelGeneration() {
+ cancelAllTasks()
+ generationState = .idle
+ statusMessage = "Cancelled"
+ }
+
+ private func generate() async {
+ guard canGenerate, let tts else { return }
+
+ generationState = .generating
+ statusMessage = "Generating speech..."
+ currentWaveform = []
+ currentAudioSamples = []
+ currentDuration = 0
+ playbackTime = 0
+ stepsCompleted = 0
+ totalSteps = 0
+ chunksTotal = 0
+
+ let strategy = selectedPlaybackStrategy
+
+ switch strategy {
+ case .generateFirst:
+ break
+ case .auto, .stream, .buffered:
+ isStreaming = true
+ activeAudioOutput = tts.audioOutput
+ startPlaybackUpdates()
+ }
+
+ do {
+ let result: SpeechResult
+
+ switch strategy {
+ case .generateFirst:
+ result = try await generateFirstGeneration(tts: tts)
+ currentAudioSamples = result.audio
+ currentDuration = result.audioDuration
+ currentWaveform = peaksPerToken(from: result.audio)
+ try await finalizeGeneration(result: result)
+ if let gen = generations.first {
+ playGeneration(gen)
+ }
+ case .auto, .stream, .buffered:
+ result = try await streamGeneration(tts: tts)
+ stopPlaybackUpdates()
+ activeAudioOutput = nil
+ isStreaming = false
+ try await finalizeGeneration(result: result)
+ }
+ } catch {
+ stopPlaybackUpdates()
+ activeAudioOutput = nil
+ isStreaming = false
+ stepsCompleted = 0
+ totalSteps = 0
+ chunksTotal = 0
+ statusMessage = "Error: \(error.localizedDescription)"
+ }
+
+ if generationState == .generating {
+ generationState = .idle
+ AccessibilityNotification.Announcement("Generation cancelled").post()
+ }
+ }
+
+ /// Build the generation options from current UI state.
+ private func buildOptions() -> GenerationOptions {
+ let workerCount = Int(concurrentWorkerCount)
+ var options = GenerationOptions(
+ temperature: Float(temperature),
+ topK: Int(topK),
+ repetitionPenalty: Float(repetitionPenalty),
+ maxNewTokens: Int(maxNewTokens),
+ concurrentWorkerCount: workerCount,
+ chunkingStrategy: chunkingStrategy,
+ targetChunkSize: Int(targetChunkSize),
+ minChunkSize: Int(minChunkSize)
+ )
+ if !instruction.isEmpty {
+ options.instruction = instruction
+ }
+ return options
+ }
+
+ /// Generate all audio up front using `tts.generate()`, tracking step-level progress.
+ private func generateFirstGeneration(tts: TTSKit) async throws -> SpeechResult {
+ let options = buildOptions()
+
+ let result = try await tts.generate(
+ text: inputText,
+ speaker: selectedSpeaker,
+ language: selectedLanguage,
+ options: options,
+ callback: { [weak self] progress in
+ let steps = progress.stepsCompleted ?? 0
+ let maxSteps = progress.totalSteps ?? 1
+ let chunks = progress.totalChunks ?? 1
+ Task { @MainActor [weak self] in
+ guard let self else { return }
+ stepsCompleted = steps
+ totalSteps = maxSteps
+ chunksTotal = chunks
+ }
+ return nil
+ }
+ )
+
+ stepsCompleted = totalSteps
+ currentRTF = result.timings.realTimeFactor
+ currentSpeedFactor = result.timings.speedFactor
+ currentStepsPerSecond = result.timings.tokensPerSecond
+ currentTimeToFirstBuffer = result.timings.timeToFirstBuffer
+ return result
+ }
+
+ /// Run `play`, streaming waveform peaks and audio samples back to the main actor.
+ private func streamGeneration(tts: TTSKit) async throws -> SpeechResult {
+ let options = buildOptions()
+ let sampleRate = Double(Qwen3TTSConstants.sampleRate)
+
+ let result = try await tts.play(
+ text: inputText,
+ speaker: selectedSpeaker,
+ language: selectedLanguage,
+ options: options,
+ playbackStrategy: selectedPlaybackStrategy,
+ callback: { [weak self] progress in
+ let samples = progress.audio
+ let peak = samples.reduce(Float(0)) { max($0, abs($1)) }
+ Task { @MainActor [weak self] in
+ guard let self else { return }
+ currentAudioSamples.append(contentsOf: samples)
+ currentDuration = Double(currentAudioSamples.count) / sampleRate
+ currentWaveform.append(peak)
+ }
+ return nil
+ }
+ )
+
+ currentAudioSamples = result.audio
+ currentDuration = result.audioDuration
+ currentRTF = result.timings.realTimeFactor
+ currentSpeedFactor = result.timings.speedFactor
+ currentStepsPerSecond = result.timings.tokensPerSecond
+ currentTimeToFirstBuffer = result.timings.timeToFirstBuffer
+ playbackTime = 0
+ return result
+ }
+
+ /// Persist the generation to disk as M4A, insert it into the history, and update UI state.
+ private func finalizeGeneration(result: SpeechResult) async throws {
+ let meta = AudioMetadata(
+ text: inputText,
+ speaker: selectedSpeaker.rawValue,
+ language: selectedLanguage.rawValue,
+ instruction: instruction,
+ modelName: "\(Qwen3TTSConstants.modelFamilyDir)_\((loadedPreset ?? selectedPreset).versionDir)",
+ realTimeFactor: result.timings.realTimeFactor,
+ speedFactor: result.timings.speedFactor,
+ stepsPerSecond: result.timings.tokensPerSecond,
+ timeToFirstBuffer: result.timings.timeToFirstBuffer
+ )
+ let savedURL = try await AudioOutput.saveAudio(
+ result.audio,
+ toFolder: documentsDirectory,
+ filename: meta.suggestedFileName,
+ sampleRate: result.sampleRate,
+ metadataProvider: meta.avMetadataItems
+ )
+
+ var gen = Generation(
+ metadata: meta,
+ audioFileName: savedURL.lastPathComponent,
+ audioDuration: result.audioDuration,
+ isFavorite: false
+ )
+ gen.waveformSamples = currentWaveform
+ generations.insert(gen, at: 0)
+ selectedGenerationID = gen.id
+
+ generationState = .idle
+ statusMessage = String(
+ format: "Done generating %.1fs of audio, RTF %.2f",
+ result.audioDuration,
+ result.timings.realTimeFactor
+ )
+ AccessibilityNotification.Announcement(
+ String(format: "Generation complete. %.1f seconds of audio.", result.audioDuration)
+ ).post()
+ }
+
+
+ // MARK: - Playback position updates
+
+ /// Starts a MainActor task that polls the audio engine's actual playback position at ~30fps.
+ /// Using a Task instead of Timer ensures it always runs on the main thread.
+ private func startPlaybackUpdates() {
+ playbackUpdateTask?.cancel()
+ playbackUpdateTask = Task { @MainActor [weak self] in
+ while !Task.isCancelled {
+ guard let self else { break }
+
+ if self.isStreaming, let audioOut = self.activeAudioOutput {
+ self.playbackTime = audioOut.currentPlaybackTime
+ } else if self.isPlaying, let player = self.audioPlayer {
+ if player.isPlaying {
+ self.playbackTime = player.currentTime
+ } else {
+ // Replay finished
+ self.isPlaying = false
+ self.playbackTime = 0
+ break
+ }
+ } else {
+ break
+ }
+
+ try? await Task.sleep(for: .milliseconds(Self.playbackPollIntervalMs))
+ }
+ }
+ }
+
+ private func stopPlaybackUpdates() {
+ playbackUpdateTask?.cancel()
+ playbackUpdateTask = nil
+ }
+
+ // MARK: - Playback (replay saved audio)
+
+ /// Populate the input fields from a past generation so the user can edit and re-generate.
+ /// Called automatically whenever a generation is selected from history.
+ func loadInputs(from generation: Generation) {
+ inputText = generation.text
+ selectedSpeaker = Qwen3Speaker(rawValue: generation.speaker) ?? .ryan
+ selectedLanguage = Qwen3Language(rawValue: generation.language) ?? .english
+ instruction = generation.instruction
+ currentRTF = generation.realTimeFactor
+ currentSpeedFactor = generation.speedFactor
+ currentStepsPerSecond = generation.stepsPerSecond
+ currentTimeToFirstBuffer = generation.timeToFirstBuffer
+ }
+
+ func playGeneration(_ generation: Generation) {
+ let url = documentsDirectory.appendingPathComponent(generation.audioFileName)
+ guard FileManager.default.fileExists(atPath: url.path) else { return }
+
+ do {
+ #if os(iOS)
+ let session = AVAudioSession.sharedInstance()
+ try session.setCategory(.playback, mode: .default, options: [])
+ try session.setActive(true)
+ #endif
+
+ audioPlayer?.stop()
+ audioPlayer = try AVAudioPlayer(contentsOf: url)
+ audioPlayer?.play()
+ isPlaying = true
+ playbackTime = 0
+
+ loadWaveform(for: generation)
+ startPlaybackUpdates()
+ } catch {
+ statusMessage = "Playback error: \(error.localizedDescription)"
+ }
+ }
+
+ func stopPlayback() {
+ audioPlayer?.stop()
+ isPlaying = false
+ playbackTime = 0
+ stopPlaybackUpdates()
+ }
+
+ /// URL for the audio file of a generation (for sharing/exporting)
+ func audioFileURL(for generation: Generation) -> URL? {
+ let url = documentsDirectory.appendingPathComponent(generation.audioFileName)
+ return FileManager.default.fileExists(atPath: url.path) ? url : nil
+ }
+
+ // MARK: - History Management
+
+ func toggleFavorite(_ id: UUID) {
+ guard let idx = generations.firstIndex(where: { $0.id == id }) else { return }
+ generations[idx].isFavorite.toggle()
+ saveFavorites()
+ }
+
+ func deleteGeneration(_ id: UUID) {
+ guard let idx = generations.firstIndex(where: { $0.id == id }) else { return }
+ let url = documentsDirectory.appendingPathComponent(generations[idx].audioFileName)
+ try? FileManager.default.removeItem(at: url)
+ generations.remove(at: idx)
+ if selectedGenerationID == id {
+ selectedGenerationID = generations.first?.id
+ }
+ }
+
+ func clearAllGenerations() {
+ for gen in generations {
+ let url = documentsDirectory.appendingPathComponent(gen.audioFileName)
+ try? FileManager.default.removeItem(at: url)
+ }
+ generations.removeAll()
+ selectedGenerationID = nil
+ UserDefaults.standard.removeObject(forKey: Self.favoritesKey)
+ }
+
+ /// Debounced token counter: waits 250ms after the last keystroke before encoding.
+ /// Uses the loaded tokenizer when available; silently skips if none is loaded yet.
+ /// When the model loads for the first time, the next edit will trigger a real count.
+ private func scheduleTokenCount() {
+ tokenCountTask?.cancel()
+ let text = inputText
+ guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else {
+ inputTokenCount = nil
+ return
+ }
+ tokenCountTask = Task { [weak self] in
+ try? await Task.sleep(for: .milliseconds(Self.tokenCountDebounceMs))
+ guard !Task.isCancelled, let self else { return }
+ guard let count = self.tts?.tokenizer?.encode(text: text).count else { return }
+ await MainActor.run { self.inputTokenCount = count }
+ }
+ }
+
+ func clearInput() {
+ inputText = ""
+ instruction = ""
+ inputTokenCount = nil
+ currentWaveform = []
+ currentAudioSamples = []
+ currentDuration = 0
+ statusMessage = modelState == .loaded ? "Ready" : "Select a model to get started"
+ }
+
+ // MARK: - Persistence
+
+ var documentsDirectory: URL {
+ FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
+ }
+
+ // MARK: Generations - derived entirely from .m4a files on disk
+
+ /// Scan the Documents directory for `.m4a` files, read embedded metadata from
+ /// each one, and rebuild the in-memory `generations` array. No JSON sidecar needed.
+ func loadGenerations() async {
+ let fm = FileManager.default
+ guard let files = try? fm.contentsOfDirectory(
+ at: documentsDirectory,
+ includingPropertiesForKeys: nil
+ ) else { return }
+
+ let m4aFiles = files
+ .filter { $0.pathExtension == "m4a" }
+ .sorted { $0.lastPathComponent > $1.lastPathComponent } // newest first by name
+
+ let favorites = loadedFavoriteIDs()
+ var loaded: [Generation] = []
+
+ for url in m4aFiles {
+ do {
+ guard let meta = try await AudioMetadata.load(from: url) else {
+ print("SpeakAX: skipping \(url.lastPathComponent) - no TTSKit metadata found")
+ continue
+ }
+ let dur = (try? await AudioOutput.duration(of: url)) ?? 0
+ var gen = Generation(
+ metadata: meta,
+ audioFileName: url.lastPathComponent,
+ audioDuration: dur,
+ isFavorite: favorites.contains(meta.id)
+ )
+ gen.waveformSamples = waveformPeaks(from: url)
+ loaded.append(gen)
+ } catch {
+ print("SpeakAX: failed to load \(url.lastPathComponent) - \(error.localizedDescription)")
+ }
+ }
+
+ generations = loaded
+ }
+
+ // MARK: Favorites - stored in UserDefaults (only mutable state not in the file)
+
+ private static let favoritesKey = "ttskit_favoriteGenerationIDs"
+
+ private func loadedFavoriteIDs() -> Set {
+ let strings = UserDefaults.standard.stringArray(forKey: Self.favoritesKey) ?? []
+ return Set(strings.compactMap { UUID(uuidString: $0) })
+ }
+
+ private func saveFavorites() {
+ let ids = generations.filter(\.isFavorite).map(\.id.uuidString)
+ UserDefaults.standard.set(ids, forKey: Self.favoritesKey)
+ }
+
+ // MARK: - Waveform
+
+ /// Read an audio file from disk and return waveform peaks at token density.
+ /// Returns `nil` if the file can't be read (missing, corrupt, etc.).
+ func waveformPeaks(from url: URL) -> [Float]? {
+ guard FileManager.default.fileExists(atPath: url.path),
+ let file = try? AVAudioFile(forReading: url) else { return nil }
+ let frameCount = AVAudioFrameCount(file.length)
+ guard let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCount),
+ (try? file.read(into: buffer)) != nil,
+ let channelData = buffer.floatChannelData else { return nil }
+ let samples = Array(UnsafeBufferPointer(start: channelData[0], count: Int(buffer.frameLength)))
+ return peaksPerToken(from: samples)
+ }
+
+ /// Resample raw audio into 1 peak per token (~80ms).
+ /// Matches the fixed bar width used by WaveformView.
+ func peaksPerToken(from audioSamples: [Float]) -> [Float] {
+ let samplesPerBar = Int(WaveformView.secondsPerBar * Double(Qwen3TTSConstants.sampleRate))
+ guard samplesPerBar > 0, !audioSamples.isEmpty else { return [] }
+ let barCount = (audioSamples.count + samplesPerBar - 1) / samplesPerBar
+ return (0.. 0 else { return }
+
+ let centerX = size.width / 2
+ let midY = size.height / 2
+ let maxBarHeight = size.height * 0.85
+
+ // How many pixels displayedTime shifts the waveform left
+ let playbackOffsetPx = CGFloat(displayedTime / Self.secondsPerBar) * barStep
+
+ for i in 0.. 0 && x < size.width else { continue }
+
+ let amplitude = CGFloat(min(abs(samples[i]), 1.0))
+ let barHeight = max(1, amplitude * maxBarHeight)
+ let rect = CGRect(
+ x: x,
+ y: midY - barHeight / 2,
+ width: barWidth,
+ height: barHeight
+ )
+
+ let barTime = Double(i) * Self.secondsPerBar
+ let isPlayed = barTime < displayedTime
+ let color = isPlayed ? accentColor : accentColor.opacity(0.3)
+ context.fill(Path(roundedRect: rect, cornerRadius: 1), with: .color(color))
+ }
+
+ // Fixed playhead line at center
+ let line = Path { p in
+ p.move(to: CGPoint(x: centerX, y: 0))
+ p.addLine(to: CGPoint(x: centerX, y: size.height))
+ }
+ context.stroke(line, with: .color(Color.accentColor), lineWidth: 1.5)
+
+ // Playhead dot
+ let dot = Path(ellipseIn: CGRect(x: centerX - 4, y: -1, width: 8, height: 8))
+ context.fill(dot, with: .color(Color.accentColor))
+ }
+ .onChange(of: playbackTime) { _, newTime in
+ if newTime == 0 {
+ // Explicit reset - new session starting
+ displayedTime = 0
+ } else {
+ // Clamp to the last generated bar: playhead can't outrun visible audio
+ let maxTime = Double(samples.count) * Self.secondsPerBar
+ displayedTime = max(displayedTime, min(newTime, maxTime))
+ }
+ }
+ .onChange(of: samples.count) { _, count in
+ // Waveform cleared for a new generation
+ if count == 0 { displayedTime = 0 }
+ }
+ .accessibilityLabel("Audio waveform")
+ .accessibilityValue(
+ playbackTime > 0
+ ? "Playback at \(Int(playbackTime)) seconds of \(Int(totalDuration))"
+ : samples.isEmpty ? "No audio" : "\(Int(totalDuration)) seconds"
+ )
+ }
+}
+
+// MARK: - Thumbnail (sidebar rows)
+
+struct WaveformThumbnail: View {
+ var samples: [Float]
+ var color: Color = .secondary
+
+ var body: some View {
+ Canvas { context, size in
+ let count = samples.count
+ guard count > 0 else { return }
+
+ let barWidth = size.width / CGFloat(count)
+ let midY = size.height / 2
+
+ for (i, sample) in samples.enumerated() {
+ let amp = CGFloat(min(abs(sample), 1.0))
+ let h = max(0.5, amp * size.height * 0.85)
+ let x = CGFloat(i) * barWidth
+ let rect = CGRect(x: x, y: midY - h / 2, width: max(0.5, barWidth - 0.3), height: h)
+ context.fill(Path(roundedRect: rect, cornerRadius: 0.3), with: .color(color))
+ }
+ }
+ .frame(width: 48, height: 24)
+ .accessibilityHidden(true)
+ }
+}
+
+// MARK: - Previews
+
+#Preview("Pre-recorded - start") {
+ let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) }
+ WaveformView(samples: samples, playbackTime: 0, totalDuration: 10.0)
+ .frame(width: 600, height: 120).padding()
+}
+
+#Preview("Pre-recorded - mid playback") {
+ let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) }
+ WaveformView(samples: samples, playbackTime: 5.0, totalDuration: 10.0)
+ .frame(width: 600, height: 120).padding()
+}
+
+#Preview("Pre-recorded - near end") {
+ let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) }
+ WaveformView(samples: samples, playbackTime: 9.0, totalDuration: 10.0)
+ .frame(width: 600, height: 120).padding()
+}
diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
index 95f8598d..dc9ac9fe 100644
--- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
+++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -1,5 +1,5 @@
{
- "originHash" : "831ad63194a5262b2549d58e383a520f9cbbc80b4a75660fbbcc56d65edfdab4",
+ "originHash" : "105884455b6729deaf95d3dd16df7029fd7929c924fcab55200562b52b64dbc2",
"pins" : [
{
"identity" : "swift-argument-parser",
@@ -10,13 +10,31 @@
"version" : "1.3.0"
}
},
+ {
+ "identity" : "swift-collections",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/apple/swift-collections.git",
+ "state" : {
+ "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
+ "version" : "1.3.0"
+ }
+ },
+ {
+ "identity" : "swift-jinja",
+ "kind" : "remoteSourceControl",
+ "location" : "https://github.com/huggingface/swift-jinja.git",
+ "state" : {
+ "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
+ "version" : "2.3.1"
+ }
+ },
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
- "revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
- "version" : "0.1.8"
+ "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
+ "version" : "1.1.6"
}
}
],
diff --git a/Makefile b/Makefile
index 5a8534a0..26e73ea6 100644
--- a/Makefile
+++ b/Makefile
@@ -8,6 +8,8 @@ PYTHON_COMMAND := python3
# Define model repository and directories
MODEL_REPO := argmaxinc/whisperkit-coreml
MODEL_REPO_DIR := ./Models/whisperkit-coreml
+TTS_MODEL_REPO := argmaxinc/ttskit-coreml
+TTS_MODEL_REPO_DIR := ./Models/ttskit-coreml
BASE_COMPILED_DIR := ./Models
GIT_HASH := $(shell git rev-parse --short HEAD)
@@ -31,20 +33,24 @@ setup:
@echo "Checking for fastlane"
@which fastlane > /dev/null || (echo "Installing fastlane..." && brew install fastlane)
@echo "fastlane is installed."
- @$(MAKE) generate-whisperax-xcconfig
+ @$(MAKE) generate-xcconfigs
@echo "Done 🚀"
-generate-whisperax-xcconfig:
- @echo "Updating DEVELOPMENT_TEAM in Examples/WhisperAX/Debug.xcconfig..."
+generate-xcconfigs:
@TEAM_ID=$$(defaults read com.apple.dt.Xcode IDEProvisioningTeams | plutil -convert json -r -o - -- - | jq -r 'to_entries[0].value | sort_by(.teamType == "Individual") | .[0].teamID' 2>/dev/null); \
if [ -z "$$TEAM_ID" ]; then \
echo "Error: No Development Team ID found. Please log into Xcode with your Apple ID and select a team."; \
else \
echo "DEVELOPMENT_TEAM=$$TEAM_ID" > Examples/WhisperAX/Debug.xcconfig; \
- echo "DEVELOPMENT_TEAM has been updated in Examples/WhisperAX/Debug.xcconfig with your Development Team ID: $$TEAM_ID"; \
+ echo "Updated Examples/WhisperAX/Debug.xcconfig with Development Team ID: $$TEAM_ID"; \
+ echo "DEVELOPMENT_TEAM=$$TEAM_ID" > Examples/TTS/SpeakAX/Debug.xcconfig; \
+ echo "Updated Examples/TTS/SpeakAX/Debug.xcconfig with Development Team ID: $$TEAM_ID"; \
fi
+generate-whisperax-xcconfig: generate-xcconfigs
+generate-speakax-xcconfig: generate-xcconfigs
+
setup-huggingface-cli:
@if huggingface-cli whoami; then \
@@ -74,6 +80,19 @@ setup-model-repo:
git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
fi
+setup-tts-model-repo:
+ @echo "Setting up TTS repository..."
+ @mkdir -p $(BASE_COMPILED_DIR)
+ @if [ -d "$(TTS_MODEL_REPO_DIR)/.git" ]; then \
+ echo "Repository exists, resetting..."; \
+ export GIT_LFS_SKIP_SMUDGE=1; \
+ cd $(TTS_MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \
+ else \
+ echo "Repository not found, initializing..."; \
+ export GIT_LFS_SKIP_SMUDGE=1; \
+ git clone https://huggingface.co/$(TTS_MODEL_REPO) $(TTS_MODEL_REPO_DIR); \
+ fi
+
# Download all models
download-models: setup-model-repo
@@ -94,6 +113,24 @@ download-model:
@cd $(MODEL_REPO_DIR) && \
git lfs pull --include="openai_whisper-$(MODEL)/*"
+download-tts-models: setup-tts-model-repo
+ @echo "Downloading all TTS models..."
+ @cd $(TTS_MODEL_REPO_DIR) && \
+ git lfs pull --include="qwen3_tts/**"
+
+# Download a specific TTS model size
+# Usage: make download-tts-model MODEL=0.6b
+# make download-tts-model MODEL=1.7b
+download-tts-model: setup-tts-model-repo
+ @if [ -z "$(MODEL)" ]; then \
+ echo "Error: MODEL not set. Usage: make download-tts-model MODEL=0.6b"; \
+ echo "Available models: 0.6b, 1.7b"; \
+ exit 1; \
+ fi
+ @echo "Downloading TTS model $(MODEL)..."
+ @cd $(TTS_MODEL_REPO_DIR) && \
+ git lfs pull --include="qwen3_tts/*/12hz-$(MODEL)-customvoice/**"
+
build:
@echo "Building WhisperKit..."
@swift build -v
@@ -103,6 +140,7 @@ build-cli:
@echo "Building WhisperKit CLI..."
@swift build -c release --product whisperkit-cli
+
test:
@echo "Running tests..."
@swift test -v
diff --git a/Package.resolved b/Package.resolved
index fbd008f0..955602b1 100644
--- a/Package.resolved
+++ b/Package.resolved
@@ -32,8 +32,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
- "revision" : "d363e83a77bafe144808a3d01556139fe67cd8bc",
- "version" : "1.1.2"
+ "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
+ "version" : "1.1.6"
}
}
],
diff --git a/Package.swift b/Package.swift
index 1b7f4ad8..be181b74 100644
--- a/Package.swift
+++ b/Package.swift
@@ -17,13 +17,17 @@ let package = Package(
name: "WhisperKit",
targets: ["WhisperKit"]
),
+ .library(
+ name: "TTSKit",
+ targets: ["TTSKit"]
+ ),
.executable(
name: "whisperkit-cli",
targets: ["WhisperKitCLI"]
)
],
dependencies: [
- .package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "1.1.2")),
+ .package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "1.1.6")),
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"),
] + (isServerEnabled() ? [
.package(url: "https://github.com/vapor/vapor.git", from: "4.115.1"),
@@ -33,13 +37,26 @@ let package = Package(
] : []),
targets: [
+ .target(
+ name: "ArgmaxCore"
+ ),
.target(
name: "WhisperKit",
dependencies: [
+ "ArgmaxCore",
.product(name: "Hub", package: "swift-transformers"),
.product(name: "Tokenizers", package: "swift-transformers"),
]
),
+ .target(
+ name: "TTSKit",
+ dependencies: [
+ "ArgmaxCore",
+ .product(name: "Tokenizers", package: "swift-transformers"),
+ .product(name: "Hub", package: "swift-transformers"),
+ ],
+ swiftSettings: [.enableExperimentalFeature("StrictConcurrency")]
+ ),
.testTarget(
name: "WhisperKitTests",
dependencies: [
@@ -47,22 +64,28 @@ let package = Package(
.product(name: "Hub", package: "swift-transformers"),
.product(name: "Tokenizers", package: "swift-transformers"),
],
- path: "Tests",
+ exclude: ["UnitTestsPlan.xctestplan"],
resources: [
- .process("WhisperKitTests/Resources"),
+ .process("Resources"),
+ ]
+ ),
+ .testTarget(
+ name: "TTSKitTests",
+ dependencies: [
+ "TTSKit"
]
),
.executableTarget(
name: "WhisperKitCLI",
dependencies: [
"WhisperKit",
+ "TTSKit",
.product(name: "ArgumentParser", package: "swift-argument-parser"),
] + (isServerEnabled() ? [
.product(name: "Vapor", package: "vapor"),
.product(name: "OpenAPIRuntime", package: "swift-openapi-runtime"),
.product(name: "OpenAPIVapor", package: "swift-openapi-vapor"),
] : []),
- path: "Sources/WhisperKitCLI",
exclude: (isServerEnabled() ? [] : ["Server"]),
swiftSettings: (isServerEnabled() ? [.define("BUILD_SERVER_CLI")] : [])
)
diff --git a/README.md b/README.md
index 0d7dde71..d93ef8b2 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,27 @@ WhisperKit is an [Argmax](https://www.takeargmax.com) framework for deploying st
- [Model Selection](#model-selection)
- [Generating Models](#generating-models)
- [Swift CLI](#swift-cli)
-- [WhisperKit Local Server](#whisperkit-local-server)
+ - [WhisperKit Local Server](#whisperkit-local-server)
+ - [Building the Server](#building-the-server)
+ - [Starting the Server](#starting-the-server)
+ - [API Endpoints](#api-endpoints)
+ - [Supported Parameters](#supported-parameters)
+ - [Client Examples](#client-examples)
+ - [Generating the API Specification](#generating-the-api-specification)
+ - [Client Generation](#client-generation)
+ - [API Limitations](#api-limitations)
+ - [Fully Supported Features](#fully-supported-features)
+- [TTSKit](#ttskit)
+ - [Quick Example](#quick-example-1)
+ - [Model Selection](#model-selection-1)
+ - [Custom Voices](#custom-voices)
+ - [Real-Time Streaming Playback](#real-time-streaming-playback)
+ - [Generation Options](#generation-options)
+ - [Style Instructions (1.7B only)](#style-instructions-17b-only)
+ - [Saving Audio](#saving-audio)
+ - [Progress Callbacks](#progress-callbacks)
+ - [Swift CLI](#swift-cli-1)
+ - [Demo App](#demo-app)
- [Contributing \& Roadmap](#contributing--roadmap)
- [License](#license)
- [Citation](#citation)
@@ -48,12 +68,12 @@ WhisperKit is an [Argmax](https://www.takeargmax.com) framework for deploying st
### Swift Package Manager
-WhisperKit can be integrated into your Swift project using the Swift Package Manager.
+WhisperKit and TTSKit are separate library products in the same Swift package. Add the package once and pick the products you need.
### Prerequisites
- macOS 14.0 or later.
-- Xcode 15.0 or later.
+- Xcode 16.0 or later.
### Xcode Steps
@@ -61,11 +81,11 @@ WhisperKit can be integrated into your Swift project using the Swift Package Man
2. Navigate to `File` > `Add Package Dependencies...`.
3. Enter the package repository URL: `https://github.com/argmaxinc/whisperkit`.
4. Choose the version range or specific version.
-5. Click `Finish` to add WhisperKit to your project.
+5. When prompted to choose library products, select **WhisperKit**, **TTSKit**, or both.
### Package.swift
-If you're using WhisperKit as part of a swift package, you can include it in your Package.swift dependencies as follows:
+If you're using WhisperKit or TTSKit as part of a swift package, you can include it in your Package.swift dependencies as follows:
```swift
dependencies: [
@@ -73,12 +93,15 @@ dependencies: [
],
```
-Then add `WhisperKit` as a dependency for your target:
+Then add the products you need as target dependencies:
```swift
.target(
name: "YourApp",
- dependencies: ["WhisperKit"]
+ dependencies: [
+ "WhisperKit", // speech-to-text
+ "TTSKit", // text-to-speech
+ ]
),
```
@@ -308,6 +331,154 @@ The local server fully supports these OpenAI API features:
- **Temperature control**: Sampling temperature for transcription randomness
- **Prompt text**: Text guidance for transcription style and context
+## TTSKit
+
+TTSKit is an on-device text-to-speech framework built on Core ML. It runs [Qwen3 TTS](https://github.com/QwenLM/Qwen3-TTS) models entirely on Apple silicon with real-time streaming playback, no server required.
+
+- macOS 15.0 or later.
+- iOS 18.0 or later.
+
+### Quick Example
+
+This example demonstrates how to generate speech from text:
+
+```swift
+import TTSKit
+
+Task {
+ let tts = try await TTSKit()
+ let result = try await tts.generate(text: "Hello from TTSKit!")
+ print("Generated \(result.audioDuration)s of audio at \(result.sampleRate)Hz")
+}
+```
+
+`TTSKit()` automatically downloads the default 0.6B model on first run, loads the tokenizer and six CoreML models concurrently, and is ready to generate.
+
+### Model Selection
+
+TTSKit ships two model sizes. You can select the model by passing a variant to `TTSKitConfig`:
+
+```swift
+// Fast, runs on all platforms (~1 GB download)
+let tts = try await TTSKit(TTSKitConfig(model: .qwen3TTS_0_6b))
+
+// Higher quality, macOS only (~2.2 GB download, supports style instructions)
+let tts = try await TTSKit(TTSKitConfig(model: .qwen3TTS_1_7b))
+```
+
+Models are hosted on [HuggingFace](https://huggingface.co/argmaxinc/ttskit-coreml) and cached locally after the first download.
+
+#### Custom Voices
+
+You can choose from 9 built-in voices and 10 languages:
+
+```swift
+let result = try await tts.generate(
+ text: "こんにちは世界",
+ speaker: .onoAnna,
+ language: .japanese
+)
+```
+
+**Voices:** `.ryan`, `.aiden`, `.onoAnna` (`"ono-anna"`), `.sohee`, `.eric`, `.dylan`, `.serena`, `.vivian`, `.uncleFu` (`"uncle-fu"`)
+
+**Languages:** `.english`, `.chinese`, `.japanese`, `.korean`, `.german`, `.french`, `.russian`, `.portuguese`, `.spanish`, `.italian`
+
+#### Real-Time Streaming Playback
+
+`play` streams audio to the device speakers frame-by-frame as it is generated:
+
+```swift
+try await tts.play(text: "This starts playing before generation finishes.")
+```
+
+You can control how much audio is buffered before playback begins. The default `.auto` strategy measures the first generation step and pre-buffers just enough to avoid underruns:
+
+```swift
+try await tts.play(
+ text: "Long passage...",
+ playbackStrategy: .auto
+)
+```
+
+Other strategies include `.stream` (immediate, no buffer), `.buffered(seconds:)` (fixed pre-buffer), and `.generateFirst` (generate all audio first, then play).
+
+### Generation Options
+
+You can customize sampling, chunking, and concurrency via `GenerationOptions`:
+
+```swift
+// Defaults recommended by Qwen
+var options = GenerationOptions()
+options.temperature = 0.9
+options.topK = 50
+options.repetitionPenalty = 1.05
+options.maxNewTokens = 245
+
+// Long text is automatically split at sentence boundaries
+options.chunkingStrategy = .sentence
+options.concurrentWorkerCount = nil // nil = all chunks run concurrently with a good default for the device
+
+let result = try await tts.generate(text: longArticle, options: options)
+```
+
+#### Style Instructions (1.7B only)
+
+The 1.7B model accepts a natural-language style instruction that controls prosody:
+
+```swift
+var options = GenerationOptions()
+options.instruction = "Speak slowly and warmly, like a storyteller."
+
+let result = try await tts.generate(
+ text: "Once upon a time...",
+ speaker: .ryan,
+ options: options
+)
+```
+
+### Saving Audio
+
+Generated audio can be saved to WAV or M4A:
+
+```swift
+let result = try await tts.generate(text: "Save me!")
+let outputDir = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
+
+// Save as .wav or .m4a (AAC)
+try await AudioOutput.saveAudio(result.audio, toFolder: outputDir, filename: "output", format: .m4a)
+```
+
+### Progress Callbacks
+
+You can receive per-step audio during generation. Return `false` from the callback to cancel early:
+
+```swift
+let result = try await tts.generate(text: "Hello!") { progress in
+ print("Audio chunk: \(progress.audio.count) samples")
+ if let stepTime = progress.stepTime {
+ print("First step took \(stepTime)s")
+ }
+ return true // return false to cancel
+}
+```
+
+### Swift CLI
+
+The TTS command is available through the same `whisperkit-cli` tool. You can generate speech and optionally play it back in real time:
+
+```bash
+swift run whisperkit-cli tts --text "Hello from the command line" --play
+swift run whisperkit-cli tts --text "Save to file" --output-path output.wav
+swift run whisperkit-cli tts --text "日本語テスト" --speaker ono-anna --language japanese
+swift run whisperkit-cli tts --text-file article.txt --model 1.7b --instruction "Read cheerfully"
+swift run whisperkit-cli tts --help
+```
+
+### Demo App
+
+The [SpeakAX](Examples/TTS/SpeakAX/) example app showcases real-time streaming, model management, waveform visualization, and generation history on macOS and iOS. See the [SpeakAX README](Examples/TTS/SpeakAX/README.md) for build instructions.
+
## Contributing & Roadmap
Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future.
diff --git a/Sources/ArgmaxCore/ConcurrencyUtilities.swift b/Sources/ArgmaxCore/ConcurrencyUtilities.swift
new file mode 100644
index 00000000..507a5e6a
--- /dev/null
+++ b/Sources/ArgmaxCore/ConcurrencyUtilities.swift
@@ -0,0 +1,112 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import Foundation
+import os.lock
+
+// MARK: - Concurrency Utilities
+
+public struct ConcurrencyUtilities {
+ private init() {}
+
+ /// Number of active processors on this device.
+ public static var activeProcessorCount: Int {
+ ProcessInfo.processInfo.activeProcessorCount
+ }
+}
+
+// MARK: - Unfair Lock
+
+/// Thin wrapper around `os_unfair_lock` that exposes a Swift-friendly
+/// `withLock` helper. This lock is non-reentrant and optimized for low
+/// contention, matching the semantics of Core Foundation's unfair lock.
+public final class UnfairLock: @unchecked Sendable {
+ var lock = os_unfair_lock()
+
+ public init() {}
+
+ public func withLock(_ body: () throws -> T) rethrows -> T {
+ os_unfair_lock_lock(&lock)
+ defer { os_unfair_lock_unlock(&lock) }
+ return try body()
+ }
+}
+
+// MARK: - Property Lock
+
+/// Serializes whole-value reads and writes with an `os_unfair_lock`.
+/// Useful for properties on types marked `@unchecked Sendable`.
+///
+/// **Known limitation — only whole-value replacement is safe:**
+///
+/// ```swift
+/// holder.ref = otherRef // safe: single locked write
+/// holder.count = newCount // safe: single locked write
+///
+/// holder.ref.count += 1 // NOT safe: gets the reference (locked),
+/// // then mutates .count with no lock held
+/// holder.count += 1 // NOT safe: get and set are two separate
+/// // lock acquisitions with a gap between them
+/// ```
+///
+/// For read-modify-write safety, callers must replace the entire value
+/// atomically or use their own external synchronisation.
+@propertyWrapper
+public struct PropertyLock: Sendable, Codable {
+ private let lock: UnfairLock
+ private var value: Value
+
+ public init(wrappedValue: Value) {
+ self.lock = UnfairLock()
+ self.value = wrappedValue
+ }
+
+ public init(from decoder: Swift.Decoder) throws {
+ self.lock = UnfairLock()
+ self.value = try Value(from: decoder)
+ }
+
+ public func encode(to encoder: Encoder) throws {
+ try lock.withLock {
+ try value.encode(to: encoder)
+ }
+ }
+
+ public var wrappedValue: Value {
+ get {
+ lock.withLock {
+ return value
+ }
+ }
+ set {
+ lock.withLock {
+ value = newValue
+ }
+ }
+ }
+}
+
+// MARK: - Early Stop Actor
+
+/// An actor that provides thread-safe early stopping functionality using UUIDs as keys.
+public actor EarlyStopActor {
+ private var shouldStop = [UUID: Bool]()
+
+ public init() {}
+
+ /// Sets the stop flag for a given UUID
+ public func set(_ value: Bool, for uuid: UUID) {
+ shouldStop[uuid] = value
+ }
+
+ /// Gets the stop flag for a given UUID
+ public func get(for uuid: UUID) -> Bool {
+ return shouldStop[uuid] ?? false
+ }
+
+ /// Removes and returns the stop flag for a given UUID
+ @discardableResult
+ public func remove(for uuid: UUID) -> Bool? {
+ return shouldStop.removeValue(forKey: uuid)
+ }
+}
diff --git a/Sources/ArgmaxCore/FileUtilities.swift b/Sources/ArgmaxCore/FileUtilities.swift
new file mode 100644
index 00000000..a5957c8b
--- /dev/null
+++ b/Sources/ArgmaxCore/FileUtilities.swift
@@ -0,0 +1,55 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+#if canImport(PDFKit)
+import PDFKit
+#endif
+
+// MARK: - File Text Extraction
+
+/// Utility for extracting plain text from files on disk.
+///
+/// Supports:
+/// - Plain text files (`.txt`, and any UTF-8 readable file)
+/// - PDF files (via PDFKit, where available)
+public enum FileUtilities {
+ /// Reads and returns the text content of the file at `url`.
+ ///
+ /// - Returns: The extracted string, or `nil` if the file cannot be read or has no
+ /// readable text content.
+ public static func readTextContent(at url: URL) -> String? {
+ switch url.pathExtension.lowercased() {
+ case "pdf":
+ return readPDF(at: url)
+ default:
+ // Try UTF-8 first, fall back to Latin-1 for legacy files
+ if let text = try? String(contentsOf: url, encoding: .utf8) {
+ return text.isEmpty ? nil : text
+ }
+ if let text = try? String(contentsOf: url, encoding: .isoLatin1) {
+ return text.isEmpty ? nil : text
+ }
+ return nil
+ }
+ }
+
+ // MARK: - PDF
+
+ #if canImport(PDFKit)
+ private static func readPDF(at url: URL) -> String? {
+ guard let document = PDFDocument(url: url) else { return nil }
+ var pages: [String] = []
+ for i in 0.. String? { nil }
+ #endif
+}
diff --git a/Sources/ArgmaxCore/FloatType.swift b/Sources/ArgmaxCore/FloatType.swift
new file mode 100644
index 00000000..4fac0d81
--- /dev/null
+++ b/Sources/ArgmaxCore/FloatType.swift
@@ -0,0 +1,18 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import Accelerate
+import CoreML
+
+// MARK: - Platform-aware Float Type
+
+#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64))
+public typealias FloatType = Float16
+#else
+public typealias FloatType = Float
+#endif
+
+#if (os(macOS) || targetEnvironment(macCatalyst)) && arch(arm64) && compiler(<6)
+extension Float16: BNNSScalar {}
+extension Float16: MLShapedArrayScalar {}
+#endif
diff --git a/Sources/ArgmaxCore/FoundationExtensions.swift b/Sources/ArgmaxCore/FoundationExtensions.swift
new file mode 100644
index 00000000..b680b2c5
--- /dev/null
+++ b/Sources/ArgmaxCore/FoundationExtensions.swift
@@ -0,0 +1,139 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+// MARK: - Float
+
+public extension Float {
+ /// Rounds to the specified number of decimal places.
+ func rounded(_ decimalPlaces: Int) -> Float {
+ let divisor = pow(10.0, Float(decimalPlaces))
+ return (self * divisor).rounded() / divisor
+ }
+}
+
+// MARK: - FileManager
+
+public extension FileManager {
+ /// Resolves an input path to an absolute path, expanding tilde and resolving
+ /// relative paths against the current working directory.
+ static func resolveAbsolutePath(_ inputPath: String) -> String {
+ let fileManager = FileManager.default
+
+ let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath
+
+ if pathWithTildeExpanded.hasPrefix("/") {
+ return pathWithTildeExpanded
+ }
+
+ if let cwd = fileManager.currentDirectoryPath as String? {
+ let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path
+ return resolvedPath
+ }
+
+ return inputPath
+ }
+}
+
+// MARK: - Array
+
+extension Array {
+ /// Splits the array into batches of the given size.
+ public func batched(into size: Int) -> [[Element]] {
+ return stride(from: 0, to: count, by: size).map {
+ Array(self[$0..()
+ return self.filter { element in
+ if seen.contains(element) {
+ return false
+ } else {
+ seen.insert(element)
+ return true
+ }
+ }
+ }
+}
+
+// MARK: - String
+
+extension String {
+ /// Returns the text up to and including the last natural boundary in the string.
+ ///
+ /// Boundaries are tested in priority order: sentence enders (. ! ? \n), clause
+ /// enders (, ; : - –), then word boundaries (space). A candidate is only accepted
+ /// when its encoded token count reaches `minTokenCount`.
+ ///
+ /// - Parameters:
+ /// - minTokenCount: Minimum number of tokens the candidate must contain.
+ /// - encode: Closure that tokenizes a string and returns its token IDs.
+ /// - Returns: The trimmed substring up to the last qualifying boundary, or `nil`.
+ public func lastNaturalBoundary(minTokenCount: Int, encode: (String) -> [Int]) -> String? {
+ let sentenceEnders: [Character] = [".", "!", "?", "\n"]
+ let clauseEnders: [Character] = [",", ";", ":", "-", "–"]
+
+ for enders in [sentenceEnders, clauseEnders] {
+ if let idx = lastIndex(where: { enders.contains($0) }) {
+ let candidate = String(self[...idx]).trimmingCharacters(in: .whitespacesAndNewlines)
+ if encode(candidate).count >= minTokenCount {
+ return candidate
+ }
+ }
+ }
+
+ if let idx = lastIndex(of: " ") {
+ let candidate = String(self[..= minTokenCount {
+ return candidate
+ }
+ }
+
+ return nil
+ }
+
+ /// Trims up to `upto` occurrences of `character` from the end of the string.
+ public func trimmingFromEnd(character: Character = " ", upto: Int) -> String {
+ var result = self
+ var trimmed = 0
+ while trimmed < upto && result.last == character {
+ result.removeLast()
+ trimmed += 1
+ }
+ return result
+ }
+}
+
+extension [String] {
+ /// Filters strings matching a glob pattern using `fnmatch`.
+ public func matching(glob: String) -> [String] {
+ filter { fnmatch(glob, $0, 0) == 0 }
+ }
+}
+
+// MARK: - ProcessInfo (macOS)
+
+#if os(macOS) || targetEnvironment(simulator)
+public extension ProcessInfo {
+ static func stringFromSysctl(named name: String) -> String {
+ var size: size_t = 0
+ sysctlbyname(name, nil, &size, nil, 0)
+ var machineModel = [CChar](repeating: 0, count: Int(size))
+ sysctlbyname(name, &machineModel, &size, nil, 0)
+ return String(cString: machineModel)
+ }
+
+ static let processor = stringFromSysctl(named: "machdep.cpu.brand_string")
+ static let cores = stringFromSysctl(named: "machdep.cpu.core_count")
+ static let threads = stringFromSysctl(named: "machdep.cpu.thread_count")
+ static let vendor = stringFromSysctl(named: "machdep.cpu.vendor")
+ static let family = stringFromSysctl(named: "machdep.cpu.family")
+ static let hwModel = stringFromSysctl(named: "hw.model")
+}
+#endif
diff --git a/Sources/ArgmaxCore/Logging.swift b/Sources/ArgmaxCore/Logging.swift
new file mode 100644
index 00000000..ad5407f4
--- /dev/null
+++ b/Sources/ArgmaxCore/Logging.swift
@@ -0,0 +1,116 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import OSLog
+
+/// Shared logger for all Argmax frameworks (WhisperKit, TTSKit, etc.).
+///
+/// Configure the log level once at startup:
+/// ```swift
+/// Logging.shared.logLevel = .debug
+/// ```
+/// or via a config object, following the WhisperKit pattern:
+/// ```swift
+/// Logging.shared.logLevel = config.verbose ? config.logLevel : .none
+/// ```
+open class Logging {
+ public static let shared = Logging()
+ public var logLevel: LogLevel = .none
+
+ public typealias LoggingCallback = (_ message: String) -> Void
+ public var loggingCallback: LoggingCallback?
+
+ private let logger = OSLog(
+ subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.argmaxcore",
+ category: "Argmax"
+ )
+
+ @frozen
+ public enum LogLevel: Int {
+ case debug = 1
+ case info = 2
+ case error = 3
+ case none = 4
+
+ func shouldLog(level: LogLevel) -> Bool {
+ return self.rawValue <= level.rawValue
+ }
+ }
+
+ private init() {}
+
+ public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) {
+ let message = items.map { "\($0)" }.joined(separator: separator)
+ if let logger = loggingCallback {
+ logger(message)
+ } else {
+ os_log("%{public}@", log: logger, type: type, message)
+ }
+ }
+
+ public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") {
+ if shared.logLevel.shouldLog(level: .debug) {
+ shared.log(items, separator: separator, terminator: terminator, type: .debug)
+ }
+ }
+
+ public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") {
+ if shared.logLevel.shouldLog(level: .info) {
+ shared.log(items, separator: separator, terminator: terminator, type: .info)
+ }
+ }
+
+ public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") {
+ if shared.logLevel.shouldLog(level: .error) {
+ shared.log(items, separator: separator, terminator: terminator, type: .error)
+ }
+ }
+}
+
+public extension Logging {
+ /// Format a timing entry as a human-readable string with per-run average and percentage.
+ ///
+ /// Output format: ` 123.45 ms / 100 runs ( 1.23 ms/run) 45.67%`
+ ///
+ /// - Parameters:
+ /// - time: Duration in seconds.
+ /// - runs: Number of calls / iterations (used for per-run average).
+ /// - fullPipelineDuration: Total pipeline duration in **milliseconds** (for percentage).
+ static func formatTimeWithPercentage(_ time: Double, _ runs: Double, _ fullPipelineDuration: Double) -> String {
+ let percentage = (time * 1000 / fullPipelineDuration) * 100
+ let runTime = runs > 0 ? time * 1000 / Double(runs) : 0
+ return String(format: "%8.2f ms / %6.0f runs (%8.2f ms/run) %5.2f%%", time * 1000, runs, runTime, percentage)
+ }
+
+ static func logCurrentMemoryUsage(_ message: String) {
+ let memoryUsage = getMemoryUsage()
+ Logging.debug("\(message) - Memory usage: \(memoryUsage) MB")
+ }
+
+ static func getMemoryUsage() -> UInt64 {
+ var info = mach_task_basic_info()
+ var count = mach_msg_type_number_t(MemoryLayout.size) / 4
+
+ let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) {
+ $0.withMemoryRebound(to: integer_t.self, capacity: 1) {
+ task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count)
+ }
+ }
+
+ guard kerr == KERN_SUCCESS else {
+ return 0
+ }
+
+ return info.resident_size / 1024 / 1024
+ }
+}
+
+@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.logCurrentMemoryUsage(_:)` instead.")
+public func logCurrentMemoryUsage(_ message: String) {
+ Logging.logCurrentMemoryUsage(message)
+}
+
+@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.getMemoryUsage()` instead.")
+public func getMemoryUsage() -> UInt64 {
+ return Logging.getMemoryUsage()
+}
diff --git a/Sources/ArgmaxCore/MLModelExtensions.swift b/Sources/ArgmaxCore/MLModelExtensions.swift
new file mode 100644
index 00000000..f9abac50
--- /dev/null
+++ b/Sources/ArgmaxCore/MLModelExtensions.swift
@@ -0,0 +1,64 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import CoreML
+
+// MARK: - Async Prediction
+
+public extension MLModel {
+ /// Async wrapper for `MLModel.prediction` that uses native async prediction
+ /// on macOS 14+ / iOS 17+ and falls back to a Task-wrapped call on older OS.
+ func asyncPrediction(
+ from input: MLFeatureProvider,
+ options: MLPredictionOptions = MLPredictionOptions()
+ ) async throws -> MLFeatureProvider {
+ if #available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) {
+ return try await prediction(from: input, options: options)
+ } else {
+ return try await Task {
+ try prediction(from: input, options: options)
+ }.value
+ }
+ }
+
+ /// Async prediction with MLState for stateful models.
+ /// MLState requires macOS 15+ where native async prediction is available.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ func asyncPrediction(
+ from input: MLFeatureProvider,
+ using state: MLState,
+ options: MLPredictionOptions = MLPredictionOptions()
+ ) async throws -> MLFeatureProvider {
+ try await prediction(from: input, using: state, options: options)
+ }
+}
+
+// MARK: - Compute Units Description
+
+public extension MLComputeUnits {
+ var description: String {
+ switch self {
+ case .cpuOnly:
+ return "cpuOnly"
+ case .cpuAndGPU:
+ return "cpuAndGPU"
+ case .all:
+ return "all"
+ case .cpuAndNeuralEngine:
+ return "cpuAndNeuralEngine"
+ @unknown default:
+ return "unknown"
+ }
+ }
+
+ /// Human-readable display name suitable for UI presentation.
+ var displayName: String {
+ switch self {
+ case .cpuOnly: return "CPU"
+ case .cpuAndGPU: return "GPU"
+ case .cpuAndNeuralEngine: return "Neural Engine"
+ case .all: return "All"
+ @unknown default: return "Unknown"
+ }
+ }
+}
diff --git a/Sources/ArgmaxCore/MLModelLoading.swift b/Sources/ArgmaxCore/MLModelLoading.swift
new file mode 100644
index 00000000..17de7412
--- /dev/null
+++ b/Sources/ArgmaxCore/MLModelLoading.swift
@@ -0,0 +1,31 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import CoreML
+import Foundation
+
+// MARK: - CoreML model loading protocol
+
+/// Shared lifecycle contract for any CoreML-backed model component.
+///
+/// Conform to this protocol to get a unified `loadModel`/`unloadModel` interface.
+/// `prewarmMode` compiles the model on-device then immediately discards it to
+/// serialize compilation and cap peak memory before a final concurrent load.
+public protocol MLModelLoading {
+ /// Load the CoreML model bundle at `url` using the specified compute units.
+ ///
+ /// When `prewarmMode` is `true` the model is compiled on-device and then
+ /// immediately discarded. This serializes compilation to cap peak memory before
+ /// the final concurrent load.
+ func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws
+
+ /// Release the loaded model weights from memory.
+ func unloadModel()
+}
+
+public extension MLModelLoading {
+ /// Convenience overload — loads with `prewarmMode: false`.
+ func loadModel(at url: URL, computeUnits: MLComputeUnits) async throws {
+ try await loadModel(at: url, computeUnits: computeUnits, prewarmMode: false)
+ }
+}
diff --git a/Sources/ArgmaxCore/MLMultiArrayExtensions.swift b/Sources/ArgmaxCore/MLMultiArrayExtensions.swift
new file mode 100644
index 00000000..5f056d80
--- /dev/null
+++ b/Sources/ArgmaxCore/MLMultiArrayExtensions.swift
@@ -0,0 +1,141 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import CoreML
+
+// MARK: - MLMultiArray Creation
+
+public extension MLMultiArray {
+ /// Creates an MLMultiArray pre-filled with an initial value.
+ /// Uses IOSurface-backed storage for float16 arrays.
+ convenience init(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) throws {
+ switch dataType {
+ case .float16:
+ guard let pixelBuffer = Self.pixelBuffer(for: shape) else {
+ throw MLMultiArrayCreationError.pixelBufferFailed
+ }
+ self.init(pixelBuffer: pixelBuffer, shape: shape)
+ default:
+ try self.init(shape: shape, dataType: dataType)
+ }
+
+ switch dataType {
+ case .double:
+ if let value = initialValue as? Double {
+ let typedPointer = dataPointer.bindMemory(to: Double.self, capacity: count)
+ typedPointer.initialize(repeating: value, count: count)
+ }
+ case .float32:
+ if let value = initialValue as? Float {
+ let typedPointer = dataPointer.bindMemory(to: Float.self, capacity: count)
+ typedPointer.initialize(repeating: value, count: count)
+ }
+ case .float16:
+ if let value = initialValue as? FloatType {
+ let typedPointer = dataPointer.bindMemory(to: FloatType.self, capacity: count)
+ typedPointer.initialize(repeating: value, count: count)
+ }
+ case .int32:
+ if let value = initialValue as? Int32 {
+ let typedPointer = dataPointer.bindMemory(to: Int32.self, capacity: count)
+ typedPointer.initialize(repeating: value, count: count)
+ }
+ #if compiler(>=6.2)
+ case .int8:
+ if #available(macOS 26.0, iOS 26.0, watchOS 26.0, visionOS 26.0, tvOS 26.0, *),
+ let value = initialValue as? Int8 {
+ let typedPointer = dataPointer.bindMemory(to: Int8.self, capacity: count)
+ typedPointer.initialize(repeating: value, count: count)
+ }
+ #endif
+ @unknown default:
+ break
+ }
+ }
+
+ /// Creates an MLMultiArray from an [Int] array.
+ /// Values are stored in the last dimension (default is dims=1).
+ static func from(_ array: [Int], dims: Int = 1) throws -> MLMultiArray {
+ var shape = Array(repeating: 1, count: dims)
+ shape[shape.count - 1] = array.count
+ let output = try MLMultiArray(shape: shape as [NSNumber], dataType: .int32)
+ let pointer = UnsafeMutablePointer(OpaquePointer(output.dataPointer))
+ for (i, item) in array.enumerated() {
+ pointer[i] = Int32(item)
+ }
+ return output
+ }
+}
+
+// MARK: - MLMultiArray Indexing & Fill
+
+public extension MLMultiArray {
+ /// Computes the linear offset from multi-dimensional indices using strides.
+ @inline(__always)
+ func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int {
+ var linearOffset = 0
+ let strideInts = strideInts ?? strides.map { $0.intValue }
+ for (dimension, stride) in zip(index, strideInts) {
+ linearOffset += dimension.intValue * stride
+ }
+ return linearOffset
+ }
+
+ /// Fills a range of indices in the last dimension with a value.
+ /// Requires shape [1, 1, n].
+ func fillLastDimension(indexes: Range, with value: FloatType) {
+ precondition(shape.count == 3 && shape[0] == 1 && shape[1] == 1, "Must have [1, 1, n] shape")
+ withUnsafeMutableBufferPointer(ofType: FloatType.self) { ptr, strides in
+ for index in indexes {
+ ptr[index * strides[2]] = value
+ }
+ }
+ }
+
+ /// Fills specific multi-dimensional indices with a value.
+ func fill(indexes: [[NSNumber]], with value: Value) {
+ let pointer = UnsafeMutablePointer(OpaquePointer(dataPointer))
+ let strideInts = strides.map { $0.intValue }
+ for index in indexes {
+ let linearOffset = linearOffset(for: index, strides: strideInts)
+ pointer[linearOffset] = value
+ }
+ }
+}
+
+// MARK: - IOSurface-backed Pixel Buffer
+
+extension MLMultiArray {
+ /// Creates a CVPixelBuffer suitable for float16 IOSurface-backed MLMultiArrays.
+ public class func pixelBuffer(for shape: [NSNumber]) -> CVPixelBuffer? {
+ guard let width = shape.last?.intValue else { return nil }
+ let height = shape[0.. [Int] {
+ await shapedArray(of: Int32.self).scalars.map { Int($0) }
+ }
+
+ func toFloatArray() async -> [Float] {
+ switch scalarType {
+ case is Float32.Type:
+ return await shapedArray(of: Float32.self).scalars.map { Float($0) }
+ case is FloatType.Type:
+ return await shapedArray(of: FloatType.self).scalars.map { Float($0) }
+ case is Float.Type:
+ return await shapedArray(of: Float.self).scalars
+ case is Int32.Type:
+ return await shapedArray(of: Int32.self).scalars.map { Float($0) }
+ default:
+ fatalError("Unsupported scalar type: \(scalarType)")
+ }
+ }
+
+ func toMLMultiArray() async -> MLMultiArray {
+ switch scalarType {
+ case is Float32.Type:
+ return MLMultiArray(await shapedArray(of: Float32.self))
+ case is FloatType.Type:
+ return MLMultiArray(await shapedArray(of: FloatType.self))
+ case is Float.Type:
+ return MLMultiArray(await shapedArray(of: Float.self))
+ case is Int32.Type:
+ return MLMultiArray(await shapedArray(of: Int32.self))
+ default:
+ fatalError("Unsupported scalar type: \(scalarType)")
+ }
+ }
+
+ // MARK: Sync (legacy — uses DispatchSemaphore, unsafe in concurrent async contexts)
+
+ @available(*, deprecated, message: "Use await toIntArray() instead.")
+ func asIntArray() -> [Int] {
+ let semaphore = DispatchSemaphore(value: 0)
+ var result: [Int] = []
+ Task(priority: .high) {
+ result = await self.toIntArray()
+ semaphore.signal()
+ }
+ semaphore.wait()
+ return result
+ }
+
+ @available(*, deprecated, message: "Use await toFloatArray() instead.")
+ func asFloatArray() -> [Float] {
+ let semaphore = DispatchSemaphore(value: 0)
+ var result: [Float] = []
+ Task(priority: .high) {
+ result = await self.toFloatArray()
+ semaphore.signal()
+ }
+ semaphore.wait()
+ return result
+ }
+
+ @available(*, deprecated, message: "Use await toMLMultiArray() instead.")
+ func asMLMultiArray() -> MLMultiArray {
+ let semaphore = DispatchSemaphore(value: 0)
+ var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0)
+ Task(priority: .high) {
+ result = await self.toMLMultiArray()
+ semaphore.signal()
+ }
+ semaphore.wait()
+ return result
+ }
+}
+#endif
diff --git a/Sources/ArgmaxCore/ModelState.swift b/Sources/ArgmaxCore/ModelState.swift
new file mode 100644
index 00000000..872c915a
--- /dev/null
+++ b/Sources/ArgmaxCore/ModelState.swift
@@ -0,0 +1,53 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+// MARK: - ModelState
+
+/// Lifecycle state of a loaded ML model pipeline.
+///
+/// Shared by both WhisperKit and TTSKit so that UI components, callbacks, and
+/// utilities can reference a single canonical type.
+///
+/// State machine:
+/// ```
+/// unloaded → downloading → downloaded → loading → loaded
+/// unloaded → prewarming → prewarmed
+/// loaded → unloading → unloaded
+/// ```
+@frozen
+public enum ModelState: CustomStringConvertible {
+ case unloading
+ case unloaded
+ case loading
+ case loaded
+ case prewarming
+ case prewarmed
+ case downloading
+ case downloaded
+
+ public var description: String {
+ switch self {
+ case .unloading: return "Unloading"
+ case .unloaded: return "Unloaded"
+ case .loading: return "Loading"
+ case .loaded: return "Loaded"
+ case .prewarming: return "Specializing"
+ case .prewarmed: return "Specialized"
+ case .downloading: return "Downloading"
+ case .downloaded: return "Downloaded"
+ }
+ }
+
+ /// Returns `true` when a loading or downloading operation is in progress.
+ public var isBusy: Bool {
+ switch self {
+ case .loading, .prewarming, .downloading, .unloading: return true
+ default: return false
+ }
+ }
+}
+
+/// Callback invoked when the pipeline's `modelState` changes.
+public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void
diff --git a/Sources/ArgmaxCore/ModelUtilities.swift b/Sources/ArgmaxCore/ModelUtilities.swift
new file mode 100644
index 00000000..c67cb792
--- /dev/null
+++ b/Sources/ArgmaxCore/ModelUtilities.swift
@@ -0,0 +1,73 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import CoreML
+import Foundation
+
+public struct ModelUtilities {
+ private init() {}
+
+ // MARK: - Model Dimension Introspection
+
+ /// Read a dimension from a model input's multiarray constraint shape.
+ public static func getModelInputDimension(_ model: MLModel?, named: String, position: Int) -> Int? {
+ guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil }
+ guard inputDescription.type == .multiArray else { return nil }
+ guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
+ let shape = shapeConstraint.shape.map { $0.intValue }
+ return shape[position]
+ }
+
+ /// Read a dimension from a model output's multiarray constraint shape.
+ public static func getModelOutputDimension(_ model: MLModel?, named: String, position: Int) -> Int? {
+ guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil }
+ guard inputDescription.type == .multiArray else { return nil }
+ guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
+ let shape = shapeConstraint.shape.map { $0.intValue }
+ return shape[position]
+ }
+
+ @available(*, deprecated, renamed: "getModelInputDimension")
+ public static func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
+ getModelInputDimension(model, named: named, position: position)
+ }
+
+ @available(*, deprecated, renamed: "getModelOutputDimension")
+ public static func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
+ getModelOutputDimension(model, named: named, position: position)
+ }
+
+ // MARK: - Model URL Detection
+
+ /// Detects the best available CoreML model URL in a folder, preferring
+ /// compiled `.mlmodelc` over `.mlpackage`.
+ public static func detectModelURL(inFolder path: URL, named modelName: String) -> URL {
+ let compiledUrl = path.appending(path: "\(modelName).mlmodelc")
+ let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel")
+
+ let compiledModelExists = FileManager.default.fileExists(atPath: compiledUrl.path)
+ let packageModelExists = FileManager.default.fileExists(atPath: packageUrl.path)
+
+ if packageModelExists && !compiledModelExists {
+ return packageUrl
+ }
+ return compiledUrl
+ }
+
+ /// Scans a folder for the first CoreML model bundle when the filename is not known in advance.
+ ///
+ /// Prefers a compiled `.mlmodelc` bundle over an `.mlpackage` source bundle.
+ public static func detectModelURL(inFolder path: URL) -> URL? {
+ guard let contents = try? FileManager.default.contentsOfDirectory(
+ at: path, includingPropertiesForKeys: nil
+ ) else { return nil }
+
+ if let compiled = contents.first(where: { $0.pathExtension == "mlmodelc" }) {
+ return compiled
+ }
+ if let package = contents.first(where: { $0.pathExtension == "mlpackage" }) {
+ return package.appending(path: "Data/com.apple.CoreML/model.mlmodel")
+ }
+ return nil
+ }
+}
diff --git a/Sources/TTSKit/Configurations.swift b/Sources/TTSKit/Configurations.swift
new file mode 100644
index 00000000..a4196233
--- /dev/null
+++ b/Sources/TTSKit/Configurations.swift
@@ -0,0 +1,39 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import CoreML
+import Foundation
+
+// MARK: - Compute Options
+
+/// Per-component CoreML compute unit configuration.
+///
+/// Used by `TTSKitConfig` to specify which hardware accelerators each model component
+/// should use. This struct is model-agnostic; any backend with multiple CoreML components
+/// that need different compute targets can adopt it.
+public struct ComputeOptions: Sendable {
+ /// Compute units for embedding lookup models (TextProjector, CodeEmbedder, MultiCodeEmbedder).
+ /// Defaults to CPU-only since these are simple table lookups with minimal compute.
+ public var embedderComputeUnits: MLComputeUnits
+
+ /// Compute units for CodeDecoder. Defaults to CPU + Neural Engine.
+ public var codeDecoderComputeUnits: MLComputeUnits
+
+ /// Compute units for MultiCodeDecoder. Defaults to CPU + Neural Engine.
+ public var multiCodeDecoderComputeUnits: MLComputeUnits
+
+ /// Compute units for SpeechDecoder. Defaults to CPU + Neural Engine.
+ public var speechDecoderComputeUnits: MLComputeUnits
+
+ public init(
+ embedderComputeUnits: MLComputeUnits = .cpuOnly,
+ codeDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine,
+ multiCodeDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine,
+ speechDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine
+ ) {
+ self.embedderComputeUnits = embedderComputeUnits
+ self.codeDecoderComputeUnits = codeDecoderComputeUnits
+ self.multiCodeDecoderComputeUnits = multiCodeDecoderComputeUnits
+ self.speechDecoderComputeUnits = speechDecoderComputeUnits
+ }
+}
diff --git a/Sources/TTSKit/Generating.swift b/Sources/TTSKit/Generating.swift
new file mode 100644
index 00000000..1a87aec6
--- /dev/null
+++ b/Sources/TTSKit/Generating.swift
@@ -0,0 +1,61 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+// MARK: - SpeechGenerating
+
+/// Protocol for a single-chunk speech synthesis task.
+///
+/// `TTSKit` creates tasks via `setupGenerateTask(...)` and calls `run(...)` on each one.
+/// Each task instance is independent - fresh KV caches, own RNG - so multiple tasks
+/// can run concurrently without data races.
+///
+/// The interface uses plain `String` for `voice` and `language` to stay model-agnostic.
+/// Each implementation maps these to its internal representation (e.g. token IDs).
+///
+/// **Reference implementation:** `Qwen3GenerateTask` in `Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift`.
+public protocol SpeechGenerating: Sendable {
+ /// Default voice identifier for this model (e.g. `"ryan"` for Qwen3 TTS).
+ ///
+ /// Returned by `TTSKit.generate()` and `play()` when `voice` is `nil`.
+ /// Each model implementation provides its own sensible default.
+ var defaultVoice: String { get }
+
+ /// Default language identifier for this model (e.g. `"english"` for Qwen3 TTS).
+ ///
+ /// Returned by `TTSKit.generate()` and `play()` when `language` is `nil`.
+ var defaultLanguage: String { get }
+
+ /// Output sample rate in Hz (e.g. 24000 for Qwen3 TTS).
+ var sampleRate: Int { get }
+
+ /// Number of PCM samples produced per decoded frame (e.g. 1920 for Qwen3 TTS).
+ var samplesPerFrame: Int { get }
+
+ /// Minimum pre-buffer duration (seconds) in `.auto` playback mode.
+ var minimumBufferDuration: TimeInterval { get }
+
+ /// Generate speech for a single text chunk, delivering per-step audio via `callback`.
+ ///
+ /// - Parameters:
+ /// - text: The text chunk to synthesize.
+ /// - voice: Voice identifier (model-specific, e.g. `Qwen3Speaker.rawValue`).
+ /// - language: Language identifier (model-specific, e.g. `Qwen3Language.rawValue`).
+ /// - options: Generation options (temperature, top-k, chunking, concurrency, etc.)
+ /// - callback: Called on every decoded step with audio samples and running timings.
+ /// `SpeechProgress.stepTime` is non-nil only on the first step, allowing
+ /// adaptive playback buffer configuration before audio starts.
+ /// Return `false` to cancel; `nil` or `true` to continue.
+ /// - prefixCache: Optional cached prefix state to skip invariant prefill tokens.
+ /// - Returns: The assembled `SpeechResult` for this text chunk.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ func run(
+ text: String,
+ voice: String,
+ language: String,
+ options: GenerationOptions,
+ callback: SpeechCallback,
+ prefixCache: TTSPromptCache?
+ ) async throws -> SpeechResult
+}
diff --git a/Sources/TTSKit/Models.swift b/Sources/TTSKit/Models.swift
new file mode 100644
index 00000000..6fc53b45
--- /dev/null
+++ b/Sources/TTSKit/Models.swift
@@ -0,0 +1,532 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+@_exported import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Generation progress
+
+/// Per-step progress information delivered to a `SpeechCallback` during generation.
+///
+/// A single callback receives both streaming audio and timing metadata on every
+/// generation step.
+///
+/// **Typical usage:**
+/// ```swift
+/// let result = try await tts.generate(text: "Hello!") { progress in
+/// player.enqueue(progress.audio) // stream audio
+/// if let t = progress.stepTime { configureBufferingStrategy(t) } // first-step setup
+/// return nil // nil or true = continue, false = cancel
+/// }
+/// ```
+public struct SpeechProgress: Sendable {
+ /// Audio samples produced by this generation step (~80ms at 24 kHz).
+ public var audio: [Float]
+
+ /// Timings accumulated so far in the current chunk.
+ public var timings: SpeechTimings
+
+ /// Wall-clock duration of this generation step (seconds).
+ ///
+ /// Non-`nil` **only on the first step** - use it to configure adaptive playback
+ /// buffers (see `PlaybackStrategy.auto`) before the first audio frame is played.
+ public var stepTime: TimeInterval?
+
+ /// Zero-based index of the chunk currently being generated (for multi-chunk requests).
+ public var chunkIndex: Int?
+
+ /// Total number of text chunks in this generation request.
+ public var totalChunks: Int?
+
+ /// Running count of generation steps completed so far across all chunks.
+ /// Increases by one for every decoder step, regardless of which chunk produced it.
+ public var stepsCompleted: Int?
+
+ /// Estimated total steps for the entire request (`maxNewTokens × totalChunks`).
+ /// Actual steps may be fewer if chunks finish early (EOS).
+ public var totalSteps: Int?
+
+ public init(
+ audio: [Float],
+ timings: SpeechTimings,
+ stepTime: TimeInterval? = nil,
+ chunkIndex: Int? = nil,
+ totalChunks: Int? = nil,
+ stepsCompleted: Int? = nil,
+ totalSteps: Int? = nil
+ ) {
+ self.audio = audio
+ self.timings = timings
+ self.stepTime = stepTime
+ self.chunkIndex = chunkIndex
+ self.totalChunks = totalChunks
+ self.stepsCompleted = stepsCompleted
+ self.totalSteps = totalSteps
+ }
+}
+
+/// A closure invoked on each generation step with decoded audio and timing info.
+///
+/// Return `false` to cancel generation early; return `nil` or `true` to continue.
+/// Matches the `TranscriptionCallback` pattern in WhisperKit.
+public typealias SpeechCallback = (@Sendable (SpeechProgress) -> Bool?)?
+
+// MARK: - SpeechModel Protocol
+
+/// The stable public contract that every TTS model family must satisfy.
+///
+/// Conform to this protocol to add a new TTS model family to TTSKit.
+///
+/// **Minimum implementation:**
+/// ```swift
+/// public class MyTTSModel: SpeechModel {
+/// public var sampleRate: Int { 24000 }
+///
+/// public func generate(text: String, voice: String, language: String,
+/// options: GenerationOptions, callback: SpeechCallback) async throws -> SpeechResult { ... }
+///
+/// public func play(text: String, voice: String, language: String,
+/// options: GenerationOptions, playbackStrategy: PlaybackStrategy,
+/// callback: SpeechCallback) async throws -> SpeechResult { ... }
+/// }
+/// ```
+public protocol SpeechModel: AnyObject, Sendable {
+ /// Output sample rate in Hz for all audio produced by this model.
+ var sampleRate: Int { get }
+
+ /// Generate speech from text and return the complete audio result.
+ ///
+ /// - Parameters:
+ /// - text: The text to synthesize.
+ /// - voice: Voice/speaker identifier. `nil` uses the model's default voice.
+ /// - language: Language identifier. `nil` uses the model's default language.
+ /// - options: Sampling and generation options. Model-specific fields are ignored by models
+ /// that do not support them.
+ /// - callback: Optional per-step callback receiving decoded audio chunks.
+ /// Return `false` to cancel; `nil` or `true` to continue.
+ /// - Returns: The assembled `SpeechResult` containing all audio and timing data.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ func generate(
+ text: String,
+ voice: String?,
+ language: String?,
+ options: GenerationOptions,
+ callback: SpeechCallback
+ ) async throws -> SpeechResult
+
+ /// Generate speech and stream it through the device audio output in real time.
+ ///
+ /// Default implementation calls `generate` and plays via `AudioOutput`. Override
+ /// if the model has its own streaming playback path.
+ func play(
+ text: String,
+ voice: String?,
+ language: String?,
+ options: GenerationOptions,
+ playbackStrategy: PlaybackStrategy,
+ callback: SpeechCallback
+ ) async throws -> SpeechResult
+}
+
+// MARK: - Playback Strategy
+
+/// Controls how `play()` buffers audio before starting playback.
+///
+/// The TTS generation loop produces one RVQ frame (~80ms of audio) per step.
+/// On fast devices, steps complete well under 80ms and audio can stream immediately.
+/// On slower devices (or in debug builds), steps may exceed 80ms, causing the
+/// playback buffer to drain and produce choppy audio.
+///
+/// `.auto` (default) measures the first generation step - which runs the full
+/// pipeline (MultiCodeDecoder + SpeechDecoder + CodeDecoder) before any audio
+/// is emitted - and uses that exact timing to compute the pre-buffer needed so
+/// playback is less likely to underrun. The buffer is re-assessed at the start of each chunk.
+public enum PlaybackStrategy: Sendable {
+ /// Automatically determine buffer duration from the first step's measured
+ /// wall-clock time. On fast devices this resolves to a small minimum (~240ms)
+ /// to absorb per-step variance. On slower devices it pre-buffers just enough
+ /// to avoid underruns for the full generation. Re-assessed per chunk.
+ case auto
+
+ /// Always stream frame-by-frame with no pre-buffer.
+ /// May produce choppy audio if the device can't generate faster than real-time.
+ case stream
+
+ /// Pre-buffer a fixed duration of audio before starting playback.
+ case buffered(seconds: Double)
+
+ /// Generate all audio for each chunk first, then play.
+ /// Highest latency but guaranteed smooth playback.
+ case generateFirst
+
+ // MARK: - Buffer calculation
+
+ /// Default minimum buffer duration (seconds) applied in `.auto` mode when no
+ /// model-specific value is available.
+ /// 80ms ≈ 1 audio frame at 24 kHz / 1920 spf - provides headroom for scheduling jitter.
+ /// The `SpeechDecoding` protocol exposes `minimumBufferDuration` so each model can
+ /// override this value; pass it to `requiredBuffer(minimumBuffer:)` at the call site.
+ public static let minimumBufferDuration: TimeInterval = 0.08
+
+ /// Audio duration produced by a single generation step (seconds).
+ ///
+ /// - Parameters:
+ /// - samplesPerFrame: PCM samples produced per decode step (model-specific).
+ /// - sampleRate: Output sample rate in Hz (model-specific).
+ /// - Returns: Audio duration in seconds for one generation step.
+ public static func audioPerStep(samplesPerFrame: Int, sampleRate: Int) -> Double {
+ Double(samplesPerFrame) / Double(sampleRate)
+ }
+
+ /// Compute the required pre-buffer duration (seconds) given a measured step time.
+ ///
+ /// Called on every generation step so the buffer duration converges toward
+ /// the minimum as the measured speed approaches or exceeds real-time.
+ /// No additional safety margin is applied; `maxNewTokens` is already an
+ /// upper bound on actual generation length, providing natural conservatism.
+ ///
+ /// - Parameters:
+ /// - stepTime: Measured wall-clock time of one generation step (seconds).
+ /// - maxNewTokens: Upper bound on remaining generation steps.
+ /// - samplesPerFrame: PCM samples produced per decode step (model-specific).
+ /// - sampleRate: Output sample rate in Hz (model-specific).
+ /// - minimumBuffer: Floor on the returned buffer (defaults to
+ /// `PlaybackStrategy.minimumBufferDuration`; pass
+ /// `speechDecoder.minimumBufferDuration` for a model-specific override).
+ /// - Returns: Buffer duration in seconds (≥ `minimumBuffer`).
+ public static func requiredBuffer(
+ stepTime: TimeInterval,
+ maxNewTokens: Int,
+ samplesPerFrame: Int,
+ sampleRate: Int,
+ minimumBuffer: TimeInterval = PlaybackStrategy.minimumBufferDuration
+ ) -> TimeInterval {
+ let perStep = audioPerStep(samplesPerFrame: samplesPerFrame, sampleRate: sampleRate)
+ let speedRatio = perStep / stepTime
+ let deficit = max(0.0, 1.0 - speedRatio)
+ let maxAudioDuration = Double(maxNewTokens) * perStep
+ let deficitBuffer = maxAudioDuration * deficit
+ return max(minimumBuffer, deficitBuffer)
+ }
+}
+
+// MARK: - Generation Options
+
+/// Options that control the speech synthesis pipeline.
+///
+/// Mirrors `DecodingOptions` in WhisperKit: all fields have sensible defaults so
+/// the zero-argument initializer works for most use cases.
+public struct GenerationOptions: Codable, Sendable {
+ // MARK: - Defaults
+
+ public static let defaultTemperature: Float = 0.9
+ public static let defaultTopK: Int = 50
+ public static let defaultRepetitionPenalty: Float = 1.05
+ public static let defaultMaxNewTokens: Int = 245
+
+ public var temperature: Float
+ public var topK: Int
+ public var repetitionPenalty: Float
+ public var maxNewTokens: Int
+
+ /// Number of concurrent workers for multi-chunk generation.
+ /// - `0`: all chunks run concurrently in one batch (default, fastest for non-streaming use cases).
+ /// - `1`: sequential - one chunk at a time; required for real-time `play` streaming.
+ /// - `N`: at most N chunks run concurrently.
+ public var concurrentWorkerCount: Int
+
+ /// How to split long text into chunks. Defaults to `.sentence`.
+ /// Set to `.none` to force a single-pass generation without sentence splitting.
+ public var chunkingStrategy: TextChunkingStrategy?
+
+ /// Target chunk size in tokens for sentence chunking.
+ /// `nil` resolves to `TextChunker.defaultTargetChunkSize` at the call site.
+ public var targetChunkSize: Int?
+
+ /// Minimum chunk size in tokens.
+ /// `nil` resolves to `TextChunker.defaultMinChunkSize` at the call site.
+ public var minChunkSize: Int?
+
+ /// Optional style instruction for controlling speech characteristics
+ /// (e.g., `"Very happy"`). Prepended as a text-only user prompt before the main
+ /// TTS segment. For Qwen3, this is only supported by the 1.7B model variant.
+ public var instruction: String?
+
+ /// Force the legacy `[FloatType]` inference path even on macOS 15+ / iOS 18+.
+ /// When `false` (default), the MLTensor path is taken on supported OS versions.
+ /// Set to `true` in tests to exercise the pre-macOS-15 code path on current hardware.
+ // TODO: Remove forking logic with package with min os version upgrade
+ public var forceLegacyEmbedPath: Bool
+
+ public init(
+ temperature: Float = GenerationOptions.defaultTemperature,
+ topK: Int = GenerationOptions.defaultTopK,
+ repetitionPenalty: Float = GenerationOptions.defaultRepetitionPenalty,
+ maxNewTokens: Int = GenerationOptions.defaultMaxNewTokens,
+ concurrentWorkerCount: Int = 0,
+ chunkingStrategy: TextChunkingStrategy? = nil,
+ targetChunkSize: Int? = nil,
+ minChunkSize: Int? = nil,
+ instruction: String? = nil,
+ forceLegacyEmbedPath: Bool = false
+ ) {
+ self.temperature = temperature
+ self.topK = topK
+ self.repetitionPenalty = repetitionPenalty
+ self.maxNewTokens = maxNewTokens
+ self.concurrentWorkerCount = concurrentWorkerCount
+ self.chunkingStrategy = chunkingStrategy
+ self.targetChunkSize = targetChunkSize
+ self.minChunkSize = minChunkSize
+ self.instruction = instruction
+ self.forceLegacyEmbedPath = forceLegacyEmbedPath
+ }
+}
+
+// MARK: - Timings
+
+/// All timing values are stored in seconds, matching `TranscriptionTimings` in WhisperKit.
+public struct SpeechTimings: Codable, Sendable {
+ // MARK: - Model loading
+
+ public var modelLoading: TimeInterval = 0
+ public var tokenizerLoadTime: TimeInterval = 0
+
+ // MARK: - Pipeline phases
+
+ public var tokenize: TimeInterval = 0
+ public var prefill: TimeInterval = 0
+ public var timeToFirstBuffer: TimeInterval = 0
+ public var fullPipeline: TimeInterval = 0
+
+ // MARK: - Generation loop
+
+ /// Total wall-clock time spent in the autoregressive decoding loop.
+ /// Mirrors `decodingLoop` in `TranscriptionTimings`.
+ public var decodingLoop: TimeInterval = 0
+
+ // MARK: - CodeDecoder
+
+ /// Sum of `model.prediction()` call durations for the main autoregressive decoder.
+ /// Mirrors `decodingPredictions` in `TranscriptionTimings`.
+ public var decodingPredictions: TimeInterval = 0
+
+ // MARK: - MultiCodeDecoder
+
+ /// Total wall-clock time for MultiCodeDecoder across all steps.
+ public var multiCodeDecoder: TimeInterval = 0
+ /// Sum of `model.prediction()` calls inside MultiCodeDecoder.
+ public var multiCodeDecoderPredictions: TimeInterval = 0
+ public var multiCodeDecoderSampling: TimeInterval = 0
+ public var multiCodeDecoderEmbedding: TimeInterval = 0
+ /// Sum of KV cache update calls inside MultiCodeDecoder.
+ /// Mirrors `decodingKvCaching` in `TranscriptionTimings`.
+ public var decodingKvCaching: TimeInterval = 0
+ public var totalMultiCodeDecoderPredictions: Double = 0
+
+ // MARK: - SpeechDecoder
+
+ /// Total wall-clock time for SpeechDecoder across all steps.
+ public var speechDecoder: TimeInterval = 0
+ /// Sum of `model.prediction()` calls inside SpeechDecoder.
+ public var speechDecoderPredictions: TimeInterval = 0
+
+ // MARK: - Non-model overhead
+
+ /// Outer CodeDecoder KV cache update time.
+ public var kvCacheUpdate: TimeInterval = 0
+ /// CodeEmbedder lookup time per step.
+ public var codeEmbed: TimeInterval = 0
+ /// MultiCodeEmbedder + sum for next step.
+ public var codecHidden: TimeInterval = 0
+ /// TextProjector + combine-embeddings per step.
+ public var textProjection: TimeInterval = 0
+ /// Outer code-0 sampling time.
+ /// Mirrors `decodingSampling` in `TranscriptionTimings`.
+ public var decodingSampling: TimeInterval = 0
+
+ // MARK: - Counters
+
+ public var prefillTokens: Double = 0
+ /// Total autoregressive steps. Mirrors `totalDecodingLoops` in `TranscriptionTimings`.
+ public var totalDecodingLoops: Double = 0
+ /// Duration of audio produced (seconds). Mirrors `inputAudioSeconds` in `TranscriptionTimings`.
+ public var inputAudioSeconds: TimeInterval = 0
+
+ public init() {}
+
+ // MARK: - Computed properties
+
+ /// Total wall-clock time in `model.prediction()` calls across all three models.
+ /// Note: with async parallelism this sum may exceed `decodingLoop` (overlap).
+ public var totalPredictions: TimeInterval {
+ decodingPredictions + multiCodeDecoderPredictions + speechDecoderPredictions
+ }
+
+ /// Intra-step parallel overlap: SpeechDecoder time that ran concurrently with
+ /// CodeDecoder within each generation step.
+ public var concurrentStepOverlap: TimeInterval {
+ max(0, totalPredictions - decodingLoop)
+ }
+
+ /// Non-inference overhead on the main path (excludes overlapped SpeechDecoder time).
+ public var totalNonPrediction: TimeInterval {
+ let mainPathPredictions = decodingPredictions + multiCodeDecoderPredictions
+ return decodingLoop - mainPathPredictions
+ }
+
+ public var tokensPerSecond: Double {
+ fullPipeline > 0 ? totalDecodingLoops / fullPipeline : 0
+ }
+
+ public var realTimeFactor: Double {
+ inputAudioSeconds > 0 ? fullPipeline / inputAudioSeconds : 0
+ }
+
+ public var speedFactor: Double {
+ inputAudioSeconds > 0 ? inputAudioSeconds / fullPipeline : 0
+ }
+
+ // MARK: - Merge
+
+ /// Accumulate all per-chunk timing fields from `other` into this timing.
+ ///
+ /// `modelLoading`, `tokenizerLoadTime`, `timeToFirstBuffer`, and `fullPipeline`
+ /// are intentionally excluded - callers set those separately on the combined struct.
+ public mutating func merge(_ other: SpeechTimings) {
+ tokenize += other.tokenize
+ prefill += other.prefill
+ prefillTokens += other.prefillTokens
+ decodingLoop += other.decodingLoop
+ decodingPredictions += other.decodingPredictions
+ multiCodeDecoder += other.multiCodeDecoder
+ multiCodeDecoderPredictions += other.multiCodeDecoderPredictions
+ multiCodeDecoderSampling += other.multiCodeDecoderSampling
+ multiCodeDecoderEmbedding += other.multiCodeDecoderEmbedding
+ decodingKvCaching += other.decodingKvCaching
+ totalMultiCodeDecoderPredictions += other.totalMultiCodeDecoderPredictions
+ speechDecoder += other.speechDecoder
+ speechDecoderPredictions += other.speechDecoderPredictions
+ kvCacheUpdate += other.kvCacheUpdate
+ codeEmbed += other.codeEmbed
+ codecHidden += other.codecHidden
+ textProjection += other.textProjection
+ decodingSampling += other.decodingSampling
+ totalDecodingLoops += other.totalDecodingLoops
+ }
+}
+
+// MARK: - Result
+
+/// The complete output of a speech synthesis request.
+///
+/// Mirrors `TranscriptionResult` in WhisperKit: holds the audio samples, timing
+/// breakdown, and sample rate. Conforms to `Codable` so results can be serialized
+/// for caching or logging.
+public struct SpeechResult: Codable, Sendable {
+ /// Raw float audio samples (mono PCM).
+ public let audio: [Float]
+
+ /// Generation performance timings.
+ public let timings: SpeechTimings
+
+ /// Sample rate of the audio in Hz.
+ public let sampleRate: Int
+
+ /// Audio duration in seconds.
+ public var audioDuration: Double {
+ Double(audio.count) / Double(sampleRate)
+ }
+
+ public init(audio: [Float], timings: SpeechTimings, sampleRate: Int) {
+ self.audio = audio
+ self.timings = timings
+ self.sampleRate = sampleRate
+ }
+
+ /// Log detailed timing breakdown to console, matching WhisperKit's format.
+ public func logTimings() {
+ let totalLoops = timings.totalDecodingLoops
+ let mcdPredRuns = timings.totalMultiCodeDecoderPredictions
+ let fullPipelineDuration = max(timings.decodingLoop, timings.fullPipeline) * 1000
+ let formatTime = { (duration: TimeInterval, count: Double) in
+ Logging.formatTimeWithPercentage(duration, count, fullPipelineDuration)
+ }
+
+ Logging.info(
+ """
+ ---- Speech Timings ----
+ Tokenize: \(formatTime(timings.tokenize, 1))
+ Prefill: \(formatTime(timings.prefill, timings.prefillTokens))
+ All Predictions: \(formatTime(timings.totalPredictions, totalLoops))
+ Non-inference (main): \(formatTime(timings.totalNonPrediction, totalLoops))
+ Concurrent Step Overlap: \(formatTime(timings.concurrentStepOverlap, totalLoops))
+ CodeDecoder: \(formatTime(timings.decodingPredictions, totalLoops))
+ MultiCodeDecoder: \(formatTime(timings.multiCodeDecoder, totalLoops))
+ - Predictions: \(formatTime(timings.multiCodeDecoderPredictions, mcdPredRuns))
+ - Sampling: \(formatTime(timings.multiCodeDecoderSampling, totalLoops))
+ - Embedding: \(formatTime(timings.multiCodeDecoderEmbedding, totalLoops))
+ - KV Caching: \(formatTime(timings.decodingKvCaching, totalLoops))
+ SpeechDecoder: \(formatTime(timings.speechDecoder, totalLoops))
+ - Predictions: \(formatTime(timings.speechDecoderPredictions, totalLoops))
+ KV Cache (outer): \(formatTime(timings.kvCacheUpdate, totalLoops))
+ Code Embed (outer): \(formatTime(timings.codeEmbed, totalLoops))
+ Codec Hidden: \(formatTime(timings.codecHidden, totalLoops))
+ Text Projection: \(formatTime(timings.textProjection, totalLoops))
+ Sampling (outer): \(formatTime(timings.decodingSampling, totalLoops))
+ Decoding Loop: \(formatTime(timings.decodingLoop, totalLoops))
+ -------------------------------
+ Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds
+ - Tokenizer: \(String(format: "%.2f", timings.tokenizerLoadTime)) seconds
+ Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds
+ Time to first buffer: \(String(format: "%.2f", timings.timeToFirstBuffer)) seconds
+ Total Steps: \(Int(totalLoops))
+ Steps per Second: \(String(format: "%.2f", timings.tokensPerSecond)) steps/s
+ Real Time Factor: \(String(format: "%.3f", timings.realTimeFactor))
+ Speed Factor: \(String(format: "%.3f", timings.speedFactor))
+ Audio Duration: \(String(format: "%.2f", audioDuration)) seconds
+ """
+ )
+ }
+}
+
+// MARK: - Decoder Result Types
+
+/// Result from MultiCodeDecoder.generateMultiCodes
+public struct MultiCodeGenerationResult {
+ public let codes: [Int32]
+ public let timings: SpeechTimings
+ /// Pre-computed embeddings of codes 1-14 with position offsets (offsetIndex 0-13).
+ /// Produced as a side effect of multi-code generation and reused for codec-hidden
+ /// computation, saving 14 embed model calls per step.
+ /// The embedding of code15 (offset 14) is not included and must be fetched separately.
+ public let offsetCodeEmbeds: [[FloatType]]
+ /// MLTensor variant of offsetCodeEmbeds - populated by the async tensor path,
+ /// avoiding the [FloatType] -> MLTensor round-trip.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public var offsetCodeEmbedTensors: [MLTensor]? { _offsetCodeEmbedTensors as? [MLTensor] }
+
+ let _offsetCodeEmbedTensors: Any?
+
+ public init(codes: [Int32], timings: SpeechTimings, offsetCodeEmbeds: [[FloatType]]) {
+ self.codes = codes
+ self.timings = timings
+ self.offsetCodeEmbeds = offsetCodeEmbeds
+ self._offsetCodeEmbedTensors = nil
+ }
+
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public init(codes: [Int32], timings: SpeechTimings, offsetCodeEmbedTensors: [MLTensor]) {
+ self.codes = codes
+ self.timings = timings
+ self.offsetCodeEmbeds = []
+ self._offsetCodeEmbedTensors = offsetCodeEmbedTensors
+ }
+}
+
+/// Result from SpeechDecoder.decodeFrameAsync
+public struct SpeechDecoderTimedResult: Sendable {
+ public let samples: [Float]
+ public let timings: SpeechTimings
+}
diff --git a/Sources/TTSKit/Protocols.swift b/Sources/TTSKit/Protocols.swift
new file mode 100644
index 00000000..c4d16884
--- /dev/null
+++ b/Sources/TTSKit/Protocols.swift
@@ -0,0 +1,135 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Text Projector Protocol
+
+/// Projects text token IDs into the shared embedding space.
+public protocol TextProjecting: MLModelLoading {
+ var model: MLModel? { get }
+ func project(tokenId: Int32) async throws -> [FloatType]
+}
+
+// MARK: - Code Embedder Protocol
+
+/// Embeds codec-0 tokens into the shared embedding space.
+public protocol CodeEmbedding: MLModelLoading {
+ var model: MLModel? { get }
+ func embed(tokenId: Int32) async throws -> [FloatType]
+}
+
+// MARK: - Multi-Code Embedder Protocol
+
+/// Embeds codec-1..15 tokens (with position offsets) into the shared embedding space.
+public protocol MultiCodeEmbedding: MLModelLoading {
+ var model: MLModel? { get }
+ func embed(tokenId: Int32) async throws -> [FloatType]
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ func embed(tokenId: Int32) async throws -> MLTensor
+}
+
+// MARK: - Code Decoder Protocol
+
+/// Autoregressive decoder that generates codec-0 tokens from combined embeddings.
+public protocol CodeDecoding: MLModelLoading {
+ var model: MLModel? { get }
+ var isStateful: Bool { get }
+ var kvCacheEmbedDim: Int { get }
+ var kvCacheMaxSequenceLength: Int { get }
+ var embedSize: Int { get }
+ func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput
+ func makeState() -> Any?
+}
+
+// MARK: - Code Decoder Output
+
+public struct CodeDecoderOutput {
+ /// Codec-0 logits - MLTensor (macOS 15+ async path) or MLMultiArray (sync path).
+ public let logits: any EmbedTensorType // [1, 1, vocabSize]
+ /// Hidden states - `[FloatType]` (legacy path) or MLTensor (async). Passed to MultiCodeDecoder.
+ public let hiddenStates: any EmbedTensorType
+ /// KV cache updates - populated by the sync path; nil when the async decoder updates the cache internally.
+ public let keyCacheUpdates: MLMultiArray?
+ public let valueCacheUpdates: MLMultiArray?
+ /// Time spent on KV cache update inside the decoder (async path only). Lets callers
+ /// subtract this from total decode time to isolate pure prediction cost.
+ public var internalCacheUpdateTime: TimeInterval = 0
+}
+
+// MARK: - Multi-Code Decoder Output
+
+public struct MultiCodeDecoderOutput {
+ /// All-head logits - MLTensor (macOS 15+ async path) or MLMultiArray (sync path).
+ public let allLogits: any EmbedTensorType // [1, 15, vocabSize]
+ /// KV cache updates - populated by the sync path; nil when the async decoder updates the cache internally.
+ public let keyCacheUpdates: MLMultiArray?
+ public let valueCacheUpdates: MLMultiArray?
+}
+
+// MARK: - Multi-Code Decoder Protocol
+
+/// Generates codec-1..15 tokens for a single RVQ frame given hidden states and codec-0 embedding.
+public protocol MultiCodeDecoding: MLModelLoading {
+ var model: MLModel? { get }
+ var isStateful: Bool { get }
+
+ // MARK: - Cache geometry (read after loadModel)
+
+ var kvCacheEmbedDim: Int { get }
+ var kvCacheMaxSequenceLength: Int { get }
+ var codecVocabSize: Int { get }
+
+ // MARK: - Decoding
+
+ func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any?) async throws -> MultiCodeDecoderOutput
+ func makeState() -> Any?
+
+ /// Generate codes 1-15 for one RVQ frame.
+ func generateMultiCodes(
+ hiddenStates: [FloatType],
+ code0Embed: [FloatType],
+ multiCodeEmbedder: any MultiCodeEmbedding,
+ sampler: any TokenSampling,
+ options: GenerationOptions
+ ) async throws -> MultiCodeGenerationResult
+}
+
+// MARK: - Speech Decoder Protocol
+
+/// Decodes RVQ code frames into audio waveform samples.
+public protocol SpeechDecoding: MLModelLoading {
+ var model: MLModel? { get }
+
+ // MARK: - Audio format (model-specific)
+
+ /// Output sample rate in Hz (e.g. 24000 for Qwen3 TTS).
+ var sampleRate: Int { get }
+ /// Number of PCM samples produced per decoded RVQ frame (e.g. 1920 for Qwen3 TTS).
+ var samplesPerFrame: Int { get }
+ /// Minimum pre-buffer duration (seconds) in `.auto` playback mode.
+ var minimumBufferDuration: TimeInterval { get }
+
+ // MARK: - Cache geometry (read after loadModel)
+
+ var kvCacheEmbedDim: Int { get }
+ var kvCacheMaxSequenceLength: Int { get }
+ var hiddenDim: Int { get }
+ var hiddenContextLen: Int { get }
+
+ // MARK: - Decoding
+
+ /// Decode a single RVQ frame (16 codes) into audio samples.
+ func decodeFrame(
+ codes: [Int32],
+ cache: SpeechDecoderCache
+ ) async throws -> [Float]
+
+ /// Async decode that returns audio samples with wall-clock timing.
+ func decodeFrameAsync(
+ codes: [Int32],
+ cache: SpeechDecoderCache
+ ) async throws -> SpeechDecoderTimedResult
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift
new file mode 100644
index 00000000..c06e75c7
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift
@@ -0,0 +1,190 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Implementation
+
+/// Autoregressive codec-0 decoder backed by a CoreML model.
+///
+/// Thread safety: mutable state (`model`, dimension properties) is set once during
+/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe.
+/// Per-generation `MLState` is created via `makeState()` and owned by each task,
+/// never stored on this shared instance.
+public class Qwen3CodeDecoder: CodeDecoding, @unchecked Sendable {
+ public var model: MLModel?
+
+ /// KV cache embedding dimension, detected from model at load time
+ public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.cdCacheDim
+ /// KV cache max sequence length, detected from model at load time
+ public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.cdMaxSeq
+ /// Input embedding dimension
+ public private(set) var embedSize: Int = Qwen3TTSConstants.embedDim
+
+ public init() {}
+
+ public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws {
+ let modelConfig = MLModelConfiguration()
+ modelConfig.computeUnits = computeUnits
+ let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig)
+
+ // In prewarm mode, compilation is complete - discard to free memory before next model compiles
+ guard !prewarmMode else { return }
+
+ self.model = loaded
+
+ // Detect dimensions from model description
+ // key_cache_updates output shape: [1, cacheDim, 1, 1]
+ if let dim = ModelUtilities.getModelOutputDimension(model, named: "key_cache_updates", position: 1) {
+ self.kvCacheEmbedDim = dim
+ }
+ // key_padding_mask input shape: [1, maxSeqLen]
+ if let seq = ModelUtilities.getModelInputDimension(model, named: "key_padding_mask", position: 1) {
+ self.kvCacheMaxSequenceLength = seq
+ }
+ // input_embeds input shape: [1, embedDim, 1, 1]
+ if let dim = ModelUtilities.getModelInputDimension(model, named: "input_embeds", position: 1) {
+ self.embedSize = dim
+ }
+ }
+
+ public var isStateful: Bool {
+ guard let model else { return false }
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ return !model.modelDescription.stateDescriptionsByName.isEmpty
+ }
+ return false
+ }
+
+ /// Create a fresh MLState for a new generation session (stateful models only).
+ /// Returns nil for non-stateful models. The caller owns the returned state.
+ public func makeState() -> Any? {
+ guard isStateful, let model else { return nil }
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ return model.makeState()
+ }
+ return nil
+ }
+
+ public func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any? = nil) async throws -> CodeDecoderOutput {
+ guard let model else { throw TTSError.generationFailed("CodeDecoder model not loaded") }
+
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let tensor = inputEmbeds as? MLTensor {
+ return try await decodeWithTensor(tensor, model: model, cache: cache, state: state)
+ }
+
+ guard let array = inputEmbeds as? MLMultiArray else {
+ throw TTSError.generationFailed("CodeDecoder: unsupported embed input type \(type(of: inputEmbeds))")
+ }
+ return try await decodeWithMultiArray(array, model: model, cache: cache, state: state)
+ }
+
+ /// MLTensor path: passes `[String: MLTensor]` directly - no FeatureProvider boxing.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ private func decodeWithTensor(_ inputEmbeds: MLTensor, model: MLModel, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput {
+ var inputs: [String: MLTensor] = [
+ "input_embeds": inputEmbeds,
+ "cache_length": cache.cacheLengthTensor,
+ "kv_cache_update_mask": cache.kvCacheUpdateMaskTensor,
+ "key_padding_mask": cache.keyPaddingMaskTensor
+ ]
+ if !isStateful, let keyCacheTensor = cache.keyCacheTensor, let valueCacheTensor = cache.valueCacheTensor {
+ inputs["key_cache"] = keyCacheTensor
+ inputs["value_cache"] = valueCacheTensor
+ }
+
+ let outputs: [String: MLTensor]
+ if let mlState = state as? MLState {
+ outputs = try await model.prediction(from: inputs, using: mlState)
+ } else {
+ outputs = try await model.prediction(from: inputs)
+ }
+
+ guard let keyTensor = outputs["key_cache_updates"],
+ let valueTensor = outputs["value_cache_updates"]
+ else {
+ throw TTSError.generationFailed("CodeDecoder: missing key/value cache update tensors")
+ }
+
+ if let mlState = state as? MLState, isStateful {
+ await KVCache.updateStateCache(
+ state: mlState,
+ keyTensor: keyTensor,
+ valueTensor: valueTensor,
+ position: Int(cache.cacheLength)
+ )
+ }
+
+ let cacheUpdateStart = CFAbsoluteTimeGetCurrent()
+ await cache.update(keyTensor: keyTensor, valueTensor: valueTensor)
+ let cacheTime = CFAbsoluteTimeGetCurrent() - cacheUpdateStart
+
+ guard let logitsTensor = outputs["logits"],
+ let hiddenStatesTensor = outputs["hidden_states"]
+ else {
+ throw TTSError.generationFailed("CodeDecoder: missing logits or hidden_states tensor")
+ }
+ return CodeDecoderOutput(
+ logits: logitsTensor,
+ hiddenStates: hiddenStatesTensor,
+ keyCacheUpdates: nil,
+ valueCacheUpdates: nil,
+ internalCacheUpdateTime: cacheTime
+ )
+ }
+
+ /// MLMultiArray path: FeatureProvider-based prediction for older OS versions.
+ private func decodeWithMultiArray(_ inputEmbeds: MLMultiArray, model: MLModel, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput {
+ var dict: [String: MLFeatureValue] = try [
+ "input_embeds": MLFeatureValue(multiArray: inputEmbeds),
+ "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()),
+ "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask),
+ "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask)
+ ]
+ if !isStateful, let keyCache = cache.keyCache, let valueCache = cache.valueCache {
+ dict["key_cache"] = MLFeatureValue(multiArray: keyCache)
+ dict["value_cache"] = MLFeatureValue(multiArray: valueCache)
+ }
+
+ let input = try MLDictionaryFeatureProvider(dictionary: dict)
+ let output: MLFeatureProvider
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState {
+ output = try await model.asyncPrediction(from: input, using: mlState)
+ } else {
+ output = try await model.asyncPrediction(from: input)
+ }
+
+ guard let keyCacheUpdates = output.featureValue(for: "key_cache_updates")?.multiArrayValue,
+ let valueCacheUpdates = output.featureValue(for: "value_cache_updates")?.multiArrayValue
+ else {
+ throw TTSError.generationFailed("CodeDecoder: missing key/value cache update arrays")
+ }
+
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState, isStateful {
+ KVCache.updateStateCache(
+ state: mlState,
+ keyCacheUpdates: keyCacheUpdates,
+ valueCacheUpdates: valueCacheUpdates,
+ position: Int(cache.cacheLength)
+ )
+ }
+
+ guard let logitsArray = output.featureValue(for: "logits")?.multiArrayValue,
+ let hiddenStatesArray = output.featureValue(for: "hidden_states")?.multiArrayValue
+ else {
+ throw TTSError.generationFailed("CodeDecoder: missing logits or hidden_states array")
+ }
+ return CodeDecoderOutput(
+ logits: logitsArray,
+ hiddenStates: EmbedUtilities.extractEmbed(from: hiddenStatesArray),
+ keyCacheUpdates: keyCacheUpdates,
+ valueCacheUpdates: valueCacheUpdates
+ )
+ }
+
+ public func unloadModel() {
+ model = nil
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift
new file mode 100644
index 00000000..78a0d90c
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift
@@ -0,0 +1,372 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Model Family
+
+/// Identifies the TTS model architecture. Used by `setupPipeline`, `loadModels`,
+/// and `setupGenerateTask` to dispatch to the correct model-specific code paths.
+@frozen
+public enum TTSModelFamily: String, Sendable {
+ case qwen3
+ // case kokoro // future
+}
+
+// MARK: - Model Variants
+
+/// Pre-configured TTS model sizes with matching variant defaults.
+///
+/// Pass a variant to `TTSKitConfig(model:)` to select the desired model size.
+/// Mirrors `ModelVariant` in WhisperKit: `@frozen`, `CustomStringConvertible`, `CaseIterable`.
+@frozen
+public enum TTSModelVariant: String, CustomStringConvertible, CaseIterable, Sendable {
+ case qwen3TTS_0_6b = "0.6b"
+ case qwen3TTS_1_7b = "1.7b"
+
+ /// The model architecture family this variant belongs to.
+ public var family: TTSModelFamily {
+ switch self {
+ case .qwen3TTS_0_6b, .qwen3TTS_1_7b: return .qwen3
+ }
+ }
+
+ public var description: String {
+ switch self {
+ case .qwen3TTS_0_6b: return "Qwen3-TTS-0.6B"
+ case .qwen3TTS_1_7b: return "Qwen3-TTS-1.7B"
+ }
+ }
+
+ /// Display name suitable for UI presentation.
+ public var displayName: String {
+ switch self {
+ case .qwen3TTS_0_6b: return "Qwen3 TTS 0.6B"
+ case .qwen3TTS_1_7b: return "Qwen3 TTS 1.7B"
+ }
+ }
+
+ /// Whether this model supports the voice instruction prompt.
+ /// Only the 1.7B variant has the capacity to follow style instructions.
+ public var supportsVoiceDirection: Bool {
+ switch self {
+ case .qwen3TTS_0_6b: return false
+ case .qwen3TTS_1_7b: return true
+ }
+ }
+
+ /// Returns `true` when this variant can run on the current platform.
+ ///
+ /// The 1.7B model requires more peak memory during CoreML compilation than
+ /// iOS/iPadOS devices can reliably provide, so it is restricted to macOS.
+ public var isAvailableOnCurrentPlatform: Bool {
+ #if os(macOS)
+ return true
+ #else
+ switch self {
+ case .qwen3TTS_0_6b: return true
+ case .qwen3TTS_1_7b: return false
+ }
+ #endif
+ }
+
+ /// The best default variant for the current platform.
+ public static var defaultForCurrentPlatform: TTSModelVariant {
+ return .qwen3TTS_0_6b
+ }
+
+ public var versionDir: String {
+ switch self {
+ case .qwen3TTS_0_6b: return "12hz-0.6b-customvoice"
+ case .qwen3TTS_1_7b: return "12hz-1.7b-customvoice"
+ }
+ }
+
+ /// Recommended code decoder variant for this model size.
+ public var codeDecoderVariant: String { Qwen3VariantDefaults.codeDecoder }
+ /// Recommended multi-code decoder variant for this model size.
+ public var multiCodeDecoderVariant: String { Qwen3VariantDefaults.multiCodeDecoder }
+ /// Code embedder variant (same across model sizes).
+ public var codeEmbedderVariant: String { Qwen3VariantDefaults.codeEmbedder }
+ /// Multi-code embedder variant (same across model sizes).
+ public var multiCodeEmbedderVariant: String { Qwen3VariantDefaults.multiCodeEmbedder }
+ /// Text projector variant (same across model sizes).
+ public var textProjectorVariant: String { Qwen3VariantDefaults.textProjector }
+ /// Recommended speech decoder variant for this model size.
+ public var speechDecoderVariant: String { Qwen3VariantDefaults.speechDecoder }
+ /// HuggingFace repo used to load the tokenizer for this model size.
+ public var tokenizerRepo: String { Qwen3TTSConstants.defaultTokenizerRepo }
+}
+
+// MARK: - Variant Defaults
+
+/// Default quantization variant strings matching the standard model repository layout.
+public enum Qwen3VariantDefaults {
+ public static let codeDecoder = "W8A16-stateful"
+ public static let multiCodeDecoder = "W8A16"
+ public static let codeEmbedder = "W16A16"
+ public static let multiCodeEmbedder = "W16A16"
+ public static let textProjector = "W8A16"
+ public static let speechDecoder = "W8A16"
+}
+
+// MARK: - TTSKit Configuration
+
+/// Configuration for initializing a `TTSKit` instance.
+///
+/// Mirrors `WhisperKitConfig`: an `open class` so you can subclass for custom
+/// configurations without modifying `TTSKit` itself.
+///
+/// **Minimal usage** (auto-downloads models and tokenizer):
+/// ```swift
+/// let tts = try await TTSKit()
+/// ```
+///
+/// **Local models** (skip download):
+/// ```swift
+/// let config = TTSKitConfig(modelFolder: URL(fileURLWithPath: "/path/to/models"))
+/// let tts = try await TTSKit(config)
+/// ```
+///
+/// **Expected local directory layout:**
+/// ```
+/// modelFolder/
+/// └── qwen3_tts/
+/// ├── code_decoder///*.mlmodelc
+/// ├── multi_code_decoder///*.mlmodelc
+/// ├── code_embedder///*.mlmodelc
+/// ├── multi_code_embedder///*.mlmodelc
+/// ├── text_projector///*.mlmodelc
+/// └── speech_decoder///*.mlmodelc
+/// ```
+open class TTSKitConfig {
+ // MARK: - Model selection
+
+ /// Model variant that determines default versionDir, component variants, and tokenizer.
+ public var model: TTSModelVariant
+
+ // MARK: - Model location
+
+ /// Local URL to the model repository directory.
+ /// If `nil`, models are downloaded from `modelRepo` on HuggingFace Hub.
+ public var modelFolder: URL?
+
+ /// Base URL for downloading and caching models.
+ /// `nil` uses the Hub library's default cache directory.
+ public var downloadBase: URL?
+
+ /// HuggingFace repo ID for auto-downloading models.
+ public var modelRepo: String
+
+ // MARK: - Tokenizer
+
+ /// HuggingFace repo ID or local folder path for the tokenizer
+ /// (resolved from `model` by default).
+ public var tokenizerFolder: URL?
+
+ // MARK: - Authentication
+
+ /// HuggingFace API token for private repos (or set the `HF_TOKEN` env var).
+ public var modelToken: String?
+
+ /// HuggingFace Hub endpoint URL.
+ ///
+ /// Override to point at a regional mirror or an on-premise Hub instance.
+ /// Mirrors `WhisperKitConfig` (via `Constants.defaultRemoteEndpoint`).
+ public var modelEndpoint: String
+
+ // MARK: - Component variants
+
+ /// Version directory shared across all components (resolved from `model` by default).
+ public var versionDir: String
+
+ /// Per-component quantization variant (resolved from `model` by default).
+ public var codeDecoderVariant: String
+ public var multiCodeDecoderVariant: String
+ public var codeEmbedderVariant: String
+ public var multiCodeEmbedderVariant: String
+ public var textProjectorVariant: String
+ public var speechDecoderVariant: String
+
+ // MARK: - Compute
+
+ /// Compute unit configuration per model component.
+ public var computeOptions: ComputeOptions
+
+ // MARK: - Logging
+
+ /// Whether to emit diagnostic logs during loading and generation.
+ public var verbose: Bool
+
+ /// Logging level when `verbose` is `true`. Defaults to `.debug`.
+ public var logLevel: Logging.LogLevel
+
+ // MARK: - Download options
+
+ /// Specific git revision (commit SHA, tag, or branch) to download from the Hub.
+ /// `nil` (default) resolves to the repo's default branch head.
+ /// Mirrors `WhisperKit.download(variant:revision:)`.
+ public var downloadRevision: String?
+
+ /// Additional glob patterns to include during model download, appended to the
+ /// patterns generated from the configured component variants.
+ public var downloadAdditionalPatterns: [String]
+
+ /// Use a background `URLSession` for model downloads.
+ /// Mirrors `WhisperKitConfig.useBackgroundDownloadSession`.
+ public var useBackgroundDownloadSession: Bool
+
+ /// Download models if not already available locally.
+ /// When `true` (default), `loadModels()` will trigger a download if `modelFolder` is nil.
+ public var download: Bool
+
+ // MARK: - Lifecycle flags
+
+ /// Enable model prewarming.
+ ///
+ /// Prewarming compiles each CoreML model sequentially then discards weights,
+ /// minimizing peak memory during compilation. Call before `loadModels()` on first
+ /// launch or after a model update. Mirrors `WhisperKitConfig.prewarm`.
+ public var prewarm: Bool?
+
+ /// Load models immediately after init.
+ /// `nil` loads when `modelFolder` is non-nil, matching WhisperKit's default.
+ public var load: Bool?
+
+ // MARK: - Generation
+
+ /// Optional seed for reproducible generation.
+ /// Each concurrent task receives a derived seed (`seed ^ taskIndex`).
+ public var seed: UInt64?
+
+ // MARK: - Component overrides
+
+ /// Set any of these to substitute a custom implementation for that model component.
+ /// `nil` means TTSKit will use the default Qwen3 TTS class for that component.
+ ///
+ /// Example:
+ /// ```swift
+ /// let config = TTSKitConfig()
+ /// config.codeDecoder = MyCodeDecoder()
+ /// let tts = try await TTSKit(config)
+ /// ```
+ public var textProjector: (any TextProjecting)?
+ public var codeEmbedder: (any CodeEmbedding)?
+ public var multiCodeEmbedder: (any MultiCodeEmbedding)?
+ public var codeDecoder: (any CodeDecoding)?
+ public var multiCodeDecoder: (any MultiCodeDecoding)?
+ public var speechDecoder: (any SpeechDecoding)?
+
+ public init(
+ model: TTSModelVariant = .qwen3TTS_0_6b,
+ modelFolder: URL? = nil,
+ downloadBase: URL? = nil,
+ modelRepo: String = Qwen3TTSConstants.defaultModelRepo,
+ tokenizerFolder: URL? = nil,
+ modelToken: String? = nil,
+ modelEndpoint: String = Qwen3TTSConstants.defaultEndpoint,
+ versionDir: String? = nil,
+ codeDecoderVariant: String? = nil,
+ multiCodeDecoderVariant: String? = nil,
+ codeEmbedderVariant: String? = nil,
+ multiCodeEmbedderVariant: String? = nil,
+ textProjectorVariant: String? = nil,
+ speechDecoderVariant: String? = nil,
+ computeOptions: ComputeOptions = ComputeOptions(),
+ verbose: Bool = true,
+ logLevel: Logging.LogLevel = .info,
+ downloadRevision: String? = nil,
+ downloadAdditionalPatterns: [String] = [],
+ useBackgroundDownloadSession: Bool = false,
+ download: Bool = true,
+ prewarm: Bool? = nil,
+ load: Bool? = nil,
+ seed: UInt64? = nil
+ ) {
+ self.model = model
+ self.modelFolder = modelFolder
+ self.downloadBase = downloadBase
+ self.modelRepo = modelRepo
+ self.tokenizerFolder = tokenizerFolder
+ self.modelToken = modelToken
+ self.modelEndpoint = modelEndpoint
+ self.versionDir = versionDir ?? model.versionDir
+ self.codeDecoderVariant = codeDecoderVariant ?? model.codeDecoderVariant
+ self.multiCodeDecoderVariant = multiCodeDecoderVariant ?? model.multiCodeDecoderVariant
+ self.codeEmbedderVariant = codeEmbedderVariant ?? model.codeEmbedderVariant
+ self.multiCodeEmbedderVariant = multiCodeEmbedderVariant ?? model.multiCodeEmbedderVariant
+ self.textProjectorVariant = textProjectorVariant ?? model.textProjectorVariant
+ self.speechDecoderVariant = speechDecoderVariant ?? model.speechDecoderVariant
+ self.computeOptions = computeOptions
+ self.verbose = verbose
+ self.logLevel = logLevel
+ self.downloadRevision = downloadRevision
+ self.downloadAdditionalPatterns = downloadAdditionalPatterns
+ self.useBackgroundDownloadSession = useBackgroundDownloadSession
+ self.download = download
+ self.prewarm = prewarm
+ self.load = load
+ self.seed = seed
+ }
+
+ // MARK: - Path resolution
+
+ /// Resolve the full path to a component's model bundle.
+ ///
+ /// Requires `modelFolder` to be set (either directly or via download).
+ /// Delegates to `ModelUtilities.detectModelURL(inFolder:)`, which prefers
+ /// a compiled `.mlmodelc` bundle and falls back to `.mlpackage`.
+ public func modelURL(component: String, variant: String) -> URL? {
+ guard let modelFolder else { return nil }
+
+ let variantDir = modelFolder.appending(path: Qwen3TTSConstants.modelFamilyDir)
+ .appending(path: component).appending(path: versionDir)
+ .appending(path: variant)
+
+ return ModelUtilities.detectModelURL(inFolder: variantDir)
+ }
+
+ /// The effective tokenizer source: local `tokenizerFolder` if set, otherwise the
+ /// model's default HuggingFace repo ID.
+ public var tokenizerSource: String {
+ tokenizerFolder?.path ?? model.tokenizerRepo
+ }
+
+ /// Component names in the model layout.
+ public static let componentNames = [
+ "text_projector", "code_embedder", "multi_code_embedder",
+ "code_decoder", "multi_code_decoder", "speech_decoder"
+ ]
+
+ /// Version-specific directory for each component inside `modelFolder`.
+ ///
+ /// e.g. `modelFolder/qwen3_tts/code_decoder/12hz-0.6b-customvoice`
+ ///
+ /// Useful for targeted deletion or disk-size calculation of a single variant.
+ public func componentDirectories(in folder: URL? = nil) -> [URL] {
+ guard let base = folder ?? modelFolder else { return [] }
+ return Self.componentNames.map { component in
+ base
+ .appending(path: Qwen3TTSConstants.modelFamilyDir)
+ .appending(path: component)
+ .appending(path: versionDir)
+ }
+ }
+
+ /// Glob patterns used to download only the files needed for the configured variants.
+ public var downloadPatterns: [String] {
+ let components: [(String, String)] = [
+ ("text_projector", textProjectorVariant),
+ ("code_embedder", codeEmbedderVariant),
+ ("multi_code_embedder", multiCodeEmbedderVariant),
+ ("code_decoder", codeDecoderVariant),
+ ("multi_code_decoder", multiCodeDecoderVariant),
+ ("speech_decoder", speechDecoderVariant)
+ ]
+ return components.map {
+ "\(Qwen3TTSConstants.modelFamilyDir)/\($0.0)/\(versionDir)/\($0.1)/**"
+ }
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift
new file mode 100644
index 00000000..9085a5e8
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift
@@ -0,0 +1,103 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Code Embedder Implementation
+
+/// Codec-0 token embedder backed by a CoreML model.
+///
+/// Thread safety: per-call input tensors - safe for concurrent use from multiple tasks.
+public class Qwen3CodeEmbedder: CodeEmbedding, @unchecked Sendable {
+ public var model: MLModel?
+
+ public init() {}
+
+ public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws {
+ let modelConfig = MLModelConfiguration()
+ modelConfig.computeUnits = computeUnits
+ let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig)
+
+ guard !prewarmMode else { return }
+
+ self.model = loaded
+ }
+
+ public func embed(tokenId: Int32) async throws -> [FloatType] {
+ guard let model else { throw TTSError.generationFailed("CodeEmbedder model not loaded") }
+ let ids = try EmbedUtilities.makeInt32Array([tokenId])
+ let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)])
+ let output = try await model.asyncPrediction(from: provider)
+ guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else {
+ throw TTSError.generationFailed("CodeEmbedder: missing input_embeds output")
+ }
+ return EmbedUtilities.extractEmbed(from: embedArray)
+ }
+
+ /// Optimised async path: passes `[String: MLTensor]` directly - no FeatureProvider boxing.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public func embed(tokenId: Int32) async throws -> MLTensor {
+ guard let model else { throw TTSError.generationFailed("CodeEmbedder model not loaded") }
+ let outputs = try await model.prediction(from: [
+ "input_ids": MLTensor(shape: [1], scalars: [tokenId])
+ ])
+ guard let embedTensor = outputs["input_embeds"] else {
+ throw TTSError.generationFailed("CodeEmbedder: missing input_embeds tensor output")
+ }
+ return embedTensor
+ }
+
+ public func unloadModel() {
+ model = nil
+ }
+}
+
+// MARK: - Multi-Code Embedder Implementation
+
+/// Codec-1..15 token embedder backed by a CoreML model.
+///
+/// Thread safety: same as Qwen3CodeEmbedder - per-call input tensors, safe for concurrent use.
+public class Qwen3MultiCodeEmbedder: MultiCodeEmbedding, @unchecked Sendable {
+ public var model: MLModel?
+
+ public init() {}
+
+ public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws {
+ let modelConfig = MLModelConfiguration()
+ modelConfig.computeUnits = computeUnits
+ let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig)
+
+ guard !prewarmMode else { return }
+
+ self.model = loaded
+ }
+
+ public func embed(tokenId: Int32) async throws -> [FloatType] {
+ guard let model else { throw TTSError.generationFailed("MultiCodeEmbedder model not loaded") }
+ let ids = try EmbedUtilities.makeInt32Array([tokenId])
+ let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)])
+ let output = try await model.asyncPrediction(from: provider)
+ guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else {
+ throw TTSError.generationFailed("MultiCodeEmbedder: missing input_embeds output")
+ }
+ return EmbedUtilities.extractEmbed(from: embedArray)
+ }
+
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public func embed(tokenId: Int32) async throws -> MLTensor {
+ guard let model else { throw TTSError.generationFailed("MultiCodeEmbedder model not loaded") }
+ let outputs = try await model.prediction(from: [
+ "input_ids": MLTensor(shape: [1], scalars: [tokenId])
+ ])
+ guard let embedTensor = outputs["input_embeds"] else {
+ throw TTSError.generationFailed("MultiCodeEmbedder: missing input_embeds tensor output")
+ }
+ return embedTensor
+ }
+
+ public func unloadModel() {
+ model = nil
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift b/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift
new file mode 100644
index 00000000..240ab322
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift
@@ -0,0 +1,838 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+import Tokenizers
+
+// MARK: - Internal phase-result types
+
+/// Output of the tokenize phase - fed into prefill and the generation loop.
+struct TokenizeResult {
+ let textTokenIds: [Int32]
+ let trailingTextTokens: [Int32]
+ let firstTextEmbed: [FloatType]
+ let variableEmbed: [FloatType]
+ let textPadEmbed: [FloatType]
+ let timings: SpeechTimings
+}
+
+/// Output of the prefill phase - fed into the generation loop.
+struct PrefillResult {
+ let cdCache: KVCache
+ let lastCdOutput: CodeDecoderOutput
+ let timings: SpeechTimings
+}
+
+/// Output of the autoregressive generation loop.
+struct GenerationLoopResult {
+ let audio: [Float]
+ let steps: Int
+ let timings: SpeechTimings
+}
+
+// MARK: - Qwen3GenerateTask
+
+/// Qwen3 TTS single-chunk generation task.
+///
+/// The core building block for Qwen3 TTS generation, analogous to `TranscribeTask`
+/// in WhisperKit. Each task creates its own KV caches, MLState, and sampler, then
+/// runs a complete prefill + autoregressive decode cycle.
+///
+/// Conforms to `SpeechGenerating` so it can be returned by
+/// `TTSKit.setupGenerateTask(...)` and consumed by `TTSKit`'s generic
+/// orchestration layer.
+///
+/// Thread safety: all stored properties are `let` (immutable after init). Each task
+/// owns its own sampler (derived seed) so concurrent tasks don't share RNG state.
+/// Model components are shared read-only references - `MLModel.prediction()` is
+/// thread-safe. The class is `@unchecked Sendable` to permit `open` subclassing.
+open class Qwen3GenerateTask: @unchecked Sendable, SpeechGenerating {
+ /// Model components - concrete Qwen3 types for correct async method dispatch.
+ /// Using `any Protocol` existentials would cause async extension methods to dispatch
+ /// to the protocol default (sync path) instead of the Qwen3-specific MLTensor path.
+ public let textProjector: Qwen3TextProjector
+ public let codeEmbedder: Qwen3CodeEmbedder
+ public let multiCodeEmbedder: Qwen3MultiCodeEmbedder
+ public let codeDecoder: Qwen3CodeDecoder
+ public let multiCodeDecoder: Qwen3MultiCodeDecoder
+ public let speechDecoder: Qwen3SpeechDecoder
+ public let sampler: any TokenSampling
+ public let tokenizer: any Tokenizer
+ public let suppressTokenIds: Set
+
+ /// Timings captured at model-load time (modelLoading + tokenizerLoading populated).
+ public let loadTimings: SpeechTimings
+
+ /// Progress object for tracking generation. `totalUnitCount` is set to
+ /// `maxNewTokens` at the start of generation; `completedUnitCount` is
+ /// updated after each decoding step.
+ public let progress: Progress
+
+ // MARK: - Initialization
+
+ public init(
+ textProjector: Qwen3TextProjector,
+ codeEmbedder: Qwen3CodeEmbedder,
+ multiCodeEmbedder: Qwen3MultiCodeEmbedder,
+ codeDecoder: Qwen3CodeDecoder,
+ multiCodeDecoder: Qwen3MultiCodeDecoder,
+ speechDecoder: Qwen3SpeechDecoder,
+ sampler: any TokenSampling,
+ tokenizer: any Tokenizer,
+ suppressTokenIds: Set,
+ loadTimings: SpeechTimings = SpeechTimings(),
+ progress: Progress? = nil
+ ) {
+ self.textProjector = textProjector
+ self.codeEmbedder = codeEmbedder
+ self.multiCodeEmbedder = multiCodeEmbedder
+ self.codeDecoder = codeDecoder
+ self.multiCodeDecoder = multiCodeDecoder
+ self.speechDecoder = speechDecoder
+ self.sampler = sampler
+ self.tokenizer = tokenizer
+ self.suppressTokenIds = suppressTokenIds
+ self.loadTimings = loadTimings
+ self.progress = progress ?? Progress()
+ }
+
+ // MARK: - SpeechGenerating defaults
+
+ /// Default voice for Qwen3 TTS. Matches `Qwen3Speaker.ryan`.
+ public var defaultVoice: String { Qwen3Speaker.ryan.rawValue }
+
+ /// Default language for Qwen3 TTS. Matches `Qwen3Language.english`.
+ public var defaultLanguage: String { Qwen3Language.english.rawValue }
+
+ // MARK: - Audio format (forwarded from speechDecoder)
+
+ public var sampleRate: Int { speechDecoder.sampleRate }
+ public var samplesPerFrame: Int { speechDecoder.samplesPerFrame }
+ public var minimumBufferDuration: TimeInterval { speechDecoder.minimumBufferDuration }
+
+ // MARK: - Run
+
+ /// Generate speech for a single text segment.
+ ///
+ /// Creates fresh KV caches, runs prefill, then autoregressive generation with
+ /// interleaved SpeechDecoder audio output. Safe to call concurrently from
+ /// multiple tasks against the same model instances.
+ ///
+ /// - Parameters:
+ /// - text: The text to synthesize.
+ /// - voice: Raw string matching `Qwen3Speaker.rawValue`; falls back to `.ryan`.
+ /// - language: Raw string matching `Qwen3Language.rawValue`; falls back to `.english`.
+ /// - options: Generation options (temperature, top-k, etc.)
+ /// - callback: Per-step callback receiving decoded audio and running timings.
+ /// `TTSProgress.stepTime` is non-nil only on the first step.
+ /// Return `false` to cancel; `nil` or `true` to continue.
+ /// - prefixCache: Optional cached prefix state to skip invariant prefill tokens.
+ /// - Returns: A `SpeechResult` containing the complete audio and timings for this chunk.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ open func run(
+ text: String,
+ voice: String,
+ language: String,
+ options: GenerationOptions,
+ callback: SpeechCallback,
+ prefixCache: TTSPromptCache? = nil
+ ) async throws -> SpeechResult {
+ let qwen3Speaker = Qwen3Speaker(rawValue: voice) ?? .ryan
+ let lang = Qwen3Language(rawValue: language) ?? .english
+
+ var timings = loadTimings
+ let pipelineStart = CFAbsoluteTimeGetCurrent()
+
+ progress.totalUnitCount = Int64(options.maxNewTokens)
+ progress.completedUnitCount = 0
+
+ // Create task-local MLState for stateful decoders (nil for non-stateful)
+ let cdState = codeDecoder.makeState()
+
+ // Phase 1: Tokenize text and build initial embeddings
+ let tokenizeResult = try await tokenizeAndBuildEmbeds(text: text)
+ timings.merge(tokenizeResult.timings)
+
+ // Phase 2: Prefill the CodeDecoder with the prompt prefix
+ let prefillResult = try await prefillCodeDecoder(
+ tokenizeResult: tokenizeResult,
+ speaker: qwen3Speaker, lang: lang,
+ options: options,
+ prefixCache: prefixCache,
+ voice: voice, language: language,
+ cdState: cdState
+ )
+ timings.merge(prefillResult.timings)
+
+ // Phase 3: Autoregressive RVQ generation with interleaved audio decode
+ let loopResult = try await runGenerationLoop(
+ tokenizeResult: tokenizeResult,
+ prefillResult: prefillResult,
+ cdState: cdState,
+ options: options,
+ pipelineStart: pipelineStart,
+ callback: callback,
+ baseTimings: timings
+ )
+ timings.merge(loopResult.timings)
+ timings.timeToFirstBuffer = loopResult.timings.timeToFirstBuffer
+
+ timings.fullPipeline = CFAbsoluteTimeGetCurrent() - pipelineStart
+ timings.inputAudioSeconds = Double(loopResult.audio.count) / Double(speechDecoder.sampleRate)
+
+ progress.completedUnitCount = progress.totalUnitCount
+
+ let genMs = timings.decodingLoop * 1000
+ let avgMs = loopResult.steps > 0 ? genMs / Double(loopResult.steps) : 0
+ let stepsPerSec = loopResult.steps > 0 ? Double(loopResult.steps) / timings.decodingLoop : 0
+ Logging.info(
+ String(
+ format: "Generation: %d frames in %.1fms (%.1fms/step, %.1f frames/s)",
+ loopResult.steps, genMs, avgMs, stepsPerSec
+ ))
+
+ return SpeechResult(audio: loopResult.audio, timings: timings, sampleRate: speechDecoder.sampleRate)
+ }
+
+ // MARK: - Phase 1: Tokenize
+
+ /// Tokenize `text` and pre-compute the initial embeddings needed for prefill and decoding.
+ private func tokenizeAndBuildEmbeds(
+ text: String
+ ) async throws -> TokenizeResult {
+ let start = CFAbsoluteTimeGetCurrent()
+
+ let textTokenIds = tokenizer.encode(text: text).map { Int32($0) }
+ guard !textTokenIds.isEmpty else { throw TTSError.emptyText }
+
+ let firstTextEmbed = try await textProjector.project(tokenId: textTokenIds[0])
+ let codecBOSEmbed = try await codeEmbedder.embed(tokenId: Qwen3TTSConstants.codecBOS)
+ let variableEmbed = EmbedUtilities.addEmbeddings(firstTextEmbed, codecBOSEmbed)
+ let textPadEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD)
+
+ var phaseTimings = SpeechTimings()
+ phaseTimings.tokenize = CFAbsoluteTimeGetCurrent() - start
+
+ return TokenizeResult(
+ textTokenIds: textTokenIds,
+ trailingTextTokens: Array(textTokenIds.dropFirst()),
+ firstTextEmbed: firstTextEmbed,
+ variableEmbed: variableEmbed,
+ textPadEmbed: textPadEmbed,
+ timings: phaseTimings
+ )
+ }
+
+ // MARK: - Phase 2: Prefill
+
+ /// Prefill the CodeDecoder KV cache with the invariant prompt prefix.
+ ///
+ /// If `prefixCache` matches the current voice/language/instruction, restores the
+ /// cached state and only decodes the variable token. Otherwise runs a full prefill.
+ private func prefillCodeDecoder(
+ tokenizeResult: TokenizeResult,
+ speaker: Qwen3Speaker,
+ lang: Qwen3Language,
+ options: GenerationOptions,
+ prefixCache: TTSPromptCache?,
+ voice: String,
+ language: String,
+ cdState: Any?
+ ) async throws -> PrefillResult {
+ let start = CFAbsoluteTimeGetCurrent()
+
+ let cdCache = try KVCache(
+ cacheDim: codeDecoder.kvCacheEmbedDim,
+ maxSeqLength: codeDecoder.kvCacheMaxSequenceLength,
+ isStateful: codeDecoder.isStateful
+ )
+
+ let usedCache = prefixCache?.matches(voice: voice, language: language, instruction: options.instruction) == true
+ var totalPrefillTokens: Int
+ var lastCdOutput: CodeDecoderOutput?
+
+ if usedCache, let prefixCache {
+ cdCache.restore(from: prefixCache.kvSnapshot)
+ if let stateData = prefixCache.stateData {
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = cdState as? MLState {
+ mlState.restore(from: stateData)
+ }
+ }
+ totalPrefillTokens = prefixCache.prefixLength + 1
+
+ // TODO: Remove forking logic with package with min os version upgrade
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath {
+ lastCdOutput = try await codeDecoder.decode(
+ inputEmbeds: tokenizeResult.variableEmbed.asMLTensor(), cache: cdCache, state: cdState
+ )
+ } else {
+ let embedArr = try EmbedUtilities.createEmbedMLArray(tokenizeResult.variableEmbed)
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState)
+ }
+ } else {
+ let embedDim = codeDecoder.embedSize
+ let combinedEmbeds = try await buildCombinedEmbeddings(
+ speaker: speaker, lang: lang,
+ instruction: options.instruction,
+ firstTextEmbed: tokenizeResult.firstTextEmbed,
+ embedDim: embedDim
+ )
+ totalPrefillTokens = combinedEmbeds.count
+
+ // TODO: Remove forking logic with package with min os version upgrade
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath {
+ for embed in combinedEmbeds {
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: embed.asMLTensor(), cache: cdCache, state: cdState)
+ }
+ } else {
+ for (embedIndex, embed) in combinedEmbeds.enumerated() {
+ if embedIndex > 0, let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates {
+ cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates)
+ }
+ let embedArr = try EmbedUtilities.createEmbedMLArray(embed)
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState)
+ }
+ }
+ }
+
+ guard let resolvedLastCdOutput = lastCdOutput else {
+ throw TTSError.generationFailed("Prefill produced no decoder output")
+ }
+
+ var phaseTimings = SpeechTimings()
+ phaseTimings.prefill = CFAbsoluteTimeGetCurrent() - start
+ phaseTimings.prefillTokens = Double(totalPrefillTokens)
+
+ let prefillMs = phaseTimings.prefill * 1000
+ let prefillTokS = phaseTimings.prefill > 0 ? Double(totalPrefillTokens) / phaseTimings.prefill : 0
+ let cacheTag = usedCache ? " (cache hit, restored \(prefixCache?.prefixLength ?? 0) tokens)" : ""
+ Logging.info(
+ String(
+ format: "Prefill: %.1fms (%d tokens, %.1f tok/s)%@",
+ prefillMs, totalPrefillTokens, prefillTokS, cacheTag
+ ))
+
+ return PrefillResult(cdCache: cdCache, lastCdOutput: resolvedLastCdOutput, timings: phaseTimings)
+ }
+
+ // MARK: - Phase 3: Generation loop
+
+ /// Run the autoregressive RVQ generation loop, delivering audio frames via `callback`.
+ ///
+ /// `baseTimings` carries the tokenize + prefill phase timings and is used to build
+ /// accurate cumulative `SpeechProgress` values for callbacks.
+ /// Returns the assembled audio, the number of steps completed, and the loop-phase timings.
+ private func runGenerationLoop(
+ tokenizeResult: TokenizeResult,
+ prefillResult: PrefillResult,
+ cdState: Any?,
+ options: GenerationOptions,
+ pipelineStart: CFAbsoluteTime,
+ callback: SpeechCallback,
+ baseTimings: SpeechTimings
+ ) async throws -> GenerationLoopResult {
+ let cdCache = prefillResult.cdCache
+ var lastCdOutput = prefillResult.lastCdOutput
+ var timings = baseTimings
+
+ let roleTokenIds = tokenizer.encode(text: "<|im_start|>assistant\n").map { Int32($0) }
+ let maxStepsByPrefill = 8 * (roleTokenIds.count + tokenizeResult.textTokenIds.count)
+
+ let sdCache = try SpeechDecoderCache(
+ cacheDim: speechDecoder.kvCacheEmbedDim,
+ maxSeqLength: speechDecoder.kvCacheMaxSequenceLength,
+ hiddenDim: speechDecoder.hiddenDim,
+ hiddenContextLen: speechDecoder.hiddenContextLen
+ )
+
+ var generatedTokens: [Int32] = []
+ var code0 = await sampler.sampleCodec0(
+ logits: lastCdOutput.logits,
+ temperature: options.temperature, topK: options.topK,
+ generatedTokens: generatedTokens,
+ repetitionPenalty: options.repetitionPenalty,
+ suppressTokenIds: suppressTokenIds
+ )
+ generatedTokens.append(code0)
+
+ var allAudio: [Float] = []
+ var stepIndex = 0
+ var firstBufferEmitted = false
+
+ // TODO: Remove forking logic with package with min os version upgrade
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath {
+ let textPadEmbedTensor: MLTensor = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD)
+
+ while code0 != Qwen3TTSConstants.codecEOS
+ && !cdCache.isFull
+ && stepIndex < options.maxNewTokens
+ && stepIndex < maxStepsByPrefill
+ {
+ try Task.checkCancellation()
+ let stepStart = CFAbsoluteTimeGetCurrent()
+
+ let codeEmbedStart = CFAbsoluteTimeGetCurrent()
+ let code0EmbedTensor: MLTensor = try await codeEmbedder.embed(tokenId: code0)
+ timings.codeEmbed += CFAbsoluteTimeGetCurrent() - codeEmbedStart
+
+ let mcdStart = CFAbsoluteTimeGetCurrent()
+ guard let hiddenStatesTensor = lastCdOutput.hiddenStates as? MLTensor else {
+ throw TTSError.generationFailed("Expected MLTensor hidden states on async path")
+ }
+ let mcdResult = try await multiCodeDecoder.generateMultiCodes(
+ hiddenStatesTensor: hiddenStatesTensor,
+ code0EmbedTensor: code0EmbedTensor,
+ multiCodeEmbedder: multiCodeEmbedder,
+ sampler: sampler, options: options
+ )
+ timings.multiCodeDecoder += CFAbsoluteTimeGetCurrent() - mcdStart
+ timings.multiCodeDecoderPredictions += mcdResult.timings.multiCodeDecoderPredictions
+ timings.multiCodeDecoderSampling += mcdResult.timings.multiCodeDecoderSampling
+ timings.multiCodeDecoderEmbedding += mcdResult.timings.multiCodeDecoderEmbedding
+ timings.decodingKvCaching += mcdResult.timings.decodingKvCaching
+ timings.totalMultiCodeDecoderPredictions += mcdResult.timings.totalMultiCodeDecoderPredictions
+
+ let rvqFrame = [code0] + mcdResult.codes
+
+ if !firstBufferEmitted {
+ // First step: give SpeechDecoder exclusive compute access for minimum TTFB.
+ // Emit the buffer immediately, then do the remaining step work (codec hidden,
+ // text projection, CodeDecoder) which only affects the *next* step's readiness.
+ let sdResult = try await speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache)
+ let stepTime = CFAbsoluteTimeGetCurrent() - stepStart
+ timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions
+ timings.speechDecoder += sdResult.timings.speechDecoderPredictions
+ allAudio.append(contentsOf: sdResult.samples)
+
+ timings.timeToFirstBuffer = CFAbsoluteTimeGetCurrent() - pipelineStart
+ firstBufferEmitted = true
+
+ if callback?(
+ SpeechProgress(
+ audio: sdResult.samples,
+ timings: timings,
+ stepTime: stepTime
+ )
+ ) == false {
+ break
+ }
+
+ let codecHiddenStart = CFAbsoluteTimeGetCurrent()
+ guard let lastMcdCode = mcdResult.codes.last else {
+ throw TTSError.generationFailed("Multi-code generation result has no codes")
+ }
+ let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14)
+ let code15EmbedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: code15OffsetId)
+ var allCodeEmbedTensors: [MLTensor] = [code0EmbedTensor]
+ if let tensorEmbeds = mcdResult.offsetCodeEmbedTensors {
+ allCodeEmbedTensors += tensorEmbeds
+ } else {
+ allCodeEmbedTensors += mcdResult.offsetCodeEmbeds.map { $0.asMLTensor() }
+ }
+ allCodeEmbedTensors.append(code15EmbedTensor)
+ let codecHiddenTensor = EmbedUtilities.sumEmbeddings(allCodeEmbedTensors)
+ timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart
+
+ let textProjStart = CFAbsoluteTimeGetCurrent()
+ let textEmbedTensor: MLTensor =
+ stepIndex < tokenizeResult.trailingTextTokens.count
+ ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex])
+ : textPadEmbedTensor
+ let combinedTensor = EmbedUtilities.addEmbeddings(codecHiddenTensor, textEmbedTensor)
+ timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart
+
+ let decodingStart = CFAbsoluteTimeGetCurrent()
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedTensor, cache: cdCache, state: cdState)
+ timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart - lastCdOutput.internalCacheUpdateTime
+ timings.kvCacheUpdate += lastCdOutput.internalCacheUpdateTime
+ } else {
+ // Subsequent steps: overlap SpeechDecoder with CodeDecoder for throughput
+ async let speechResult = speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache)
+
+ let codecHiddenStart = CFAbsoluteTimeGetCurrent()
+ guard let lastMcdCode = mcdResult.codes.last else {
+ throw TTSError.generationFailed("Multi-code generation result has no codes")
+ }
+ let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14)
+ let code15EmbedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: code15OffsetId)
+ var allCodeEmbedTensors: [MLTensor] = [code0EmbedTensor]
+ if let tensorEmbeds = mcdResult.offsetCodeEmbedTensors {
+ allCodeEmbedTensors += tensorEmbeds
+ } else {
+ allCodeEmbedTensors += mcdResult.offsetCodeEmbeds.map { $0.asMLTensor() }
+ }
+ allCodeEmbedTensors.append(code15EmbedTensor)
+ let codecHiddenTensor = EmbedUtilities.sumEmbeddings(allCodeEmbedTensors)
+ timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart
+
+ let textProjStart = CFAbsoluteTimeGetCurrent()
+ let textEmbedTensor: MLTensor =
+ stepIndex < tokenizeResult.trailingTextTokens.count
+ ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex])
+ : textPadEmbedTensor
+ let combinedTensor = EmbedUtilities.addEmbeddings(codecHiddenTensor, textEmbedTensor)
+ timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart
+
+ let decodingStart = CFAbsoluteTimeGetCurrent()
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedTensor, cache: cdCache, state: cdState)
+ timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart - lastCdOutput.internalCacheUpdateTime
+ timings.kvCacheUpdate += lastCdOutput.internalCacheUpdateTime
+
+ let sdResult = try await speechResult
+ timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions
+ timings.speechDecoder += sdResult.timings.speechDecoderPredictions
+ allAudio.append(contentsOf: sdResult.samples)
+
+ if callback?(
+ SpeechProgress(
+ audio: sdResult.samples,
+ timings: {
+ var merged = baseTimings; merged.merge(timings); return merged
+ }(), stepTime: nil)) == false
+ {
+ break
+ }
+ }
+
+ let samplingStart = CFAbsoluteTimeGetCurrent()
+ code0 = await sampler.sampleCodec0(
+ logits: lastCdOutput.logits,
+ temperature: options.temperature, topK: options.topK,
+ generatedTokens: generatedTokens,
+ repetitionPenalty: options.repetitionPenalty,
+ suppressTokenIds: suppressTokenIds
+ )
+ generatedTokens.append(code0)
+ timings.decodingSampling += CFAbsoluteTimeGetCurrent() - samplingStart
+
+ timings.decodingLoop += CFAbsoluteTimeGetCurrent() - stepStart
+ stepIndex += 1
+ progress.completedUnitCount = Int64(stepIndex)
+
+ if stepIndex == 1 || stepIndex % 10 == 0 {
+ let stepMs = (CFAbsoluteTimeGetCurrent() - stepStart) * 1000
+ Logging.debug(
+ String(
+ format: " Step %d: %.1fms (avg %.1fms/step)",
+ stepIndex, stepMs, timings.decodingLoop * 1000 / Double(stepIndex)))
+ }
+ }
+ } else {
+ // Legacy path (older OS)
+ while code0 != Qwen3TTSConstants.codecEOS && !cdCache.isFull && stepIndex < options.maxNewTokens && stepIndex < maxStepsByPrefill {
+ try Task.checkCancellation()
+ let stepStart = CFAbsoluteTimeGetCurrent()
+
+ let cacheUpdateStart = CFAbsoluteTimeGetCurrent()
+ if let keyUpdates = lastCdOutput.keyCacheUpdates, let valueUpdates = lastCdOutput.valueCacheUpdates {
+ cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates)
+ }
+ timings.kvCacheUpdate += CFAbsoluteTimeGetCurrent() - cacheUpdateStart
+
+ let codeEmbedStart = CFAbsoluteTimeGetCurrent()
+ let code0Embed = try await codeEmbedder.embed(tokenId: code0)
+ timings.codeEmbed += CFAbsoluteTimeGetCurrent() - codeEmbedStart
+
+ let mcdStart = CFAbsoluteTimeGetCurrent()
+ guard let hiddenStates = lastCdOutput.hiddenStates as? [FloatType] else {
+ throw TTSError.generationFailed("Expected [FloatType] hidden states on legacy path")
+ }
+ let mcdResult = try await multiCodeDecoder.generateMultiCodes(
+ hiddenStates: hiddenStates, code0Embed: code0Embed,
+ multiCodeEmbedder: multiCodeEmbedder, sampler: sampler, options: options
+ )
+ timings.multiCodeDecoder += CFAbsoluteTimeGetCurrent() - mcdStart
+ timings.multiCodeDecoderPredictions += mcdResult.timings.multiCodeDecoderPredictions
+ timings.multiCodeDecoderSampling += mcdResult.timings.multiCodeDecoderSampling
+ timings.multiCodeDecoderEmbedding += mcdResult.timings.multiCodeDecoderEmbedding
+ timings.decodingKvCaching += mcdResult.timings.decodingKvCaching
+ timings.totalMultiCodeDecoderPredictions += mcdResult.timings.totalMultiCodeDecoderPredictions
+
+ let rvqFrame = [code0] + mcdResult.codes
+
+ if !firstBufferEmitted {
+ let sdResult = try await speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache)
+ timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions
+ timings.speechDecoder += sdResult.timings.speechDecoderPredictions
+ allAudio.append(contentsOf: sdResult.samples)
+
+ timings.timeToFirstBuffer = CFAbsoluteTimeGetCurrent() - pipelineStart
+ firstBufferEmitted = true
+ let stepTime = CFAbsoluteTimeGetCurrent() - stepStart
+
+ if callback?(
+ SpeechProgress(
+ audio: sdResult.samples,
+ timings: {
+ var merged = baseTimings; merged.merge(timings); return merged
+ }(), stepTime: stepTime)) == false
+ {
+ break
+ }
+
+ let codecHiddenStart = CFAbsoluteTimeGetCurrent()
+ guard let lastMcdCode = mcdResult.codes.last else {
+ throw TTSError.generationFailed("Multi-code generation result has no codes")
+ }
+ let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14)
+ var allCodeEmbeds: [[FloatType]] = [code0Embed]
+ allCodeEmbeds += mcdResult.offsetCodeEmbeds
+ try await allCodeEmbeds.append(multiCodeEmbedder.embed(tokenId: code15OffsetId))
+ let codecHidden = EmbedUtilities.sumEmbeddings(allCodeEmbeds)
+ timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart
+
+ let textProjStart = CFAbsoluteTimeGetCurrent()
+ let textTokenEmbed: [FloatType] =
+ stepIndex < tokenizeResult.trailingTextTokens.count
+ ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex])
+ : tokenizeResult.textPadEmbed
+ let combinedArr = try EmbedUtilities.createEmbedMLArray(EmbedUtilities.addEmbeddings(codecHidden, textTokenEmbed))
+ timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart
+
+ let decodingStart = CFAbsoluteTimeGetCurrent()
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedArr, cache: cdCache, state: cdState)
+ timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart
+ } else {
+ async let speechResult = speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache)
+
+ let codecHiddenStart = CFAbsoluteTimeGetCurrent()
+ guard let lastMcdCode = mcdResult.codes.last else {
+ throw TTSError.generationFailed("Multi-code generation result has no codes")
+ }
+ let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14)
+ var allCodeEmbeds: [[FloatType]] = [code0Embed]
+ allCodeEmbeds += mcdResult.offsetCodeEmbeds
+ try await allCodeEmbeds.append(multiCodeEmbedder.embed(tokenId: code15OffsetId))
+ let codecHidden = EmbedUtilities.sumEmbeddings(allCodeEmbeds)
+ timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart
+
+ let textProjStart = CFAbsoluteTimeGetCurrent()
+ let textTokenEmbed: [FloatType] =
+ stepIndex < tokenizeResult.trailingTextTokens.count
+ ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex])
+ : tokenizeResult.textPadEmbed
+ let combinedArr = try EmbedUtilities.createEmbedMLArray(EmbedUtilities.addEmbeddings(codecHidden, textTokenEmbed))
+ timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart
+
+ let decodingStart = CFAbsoluteTimeGetCurrent()
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedArr, cache: cdCache, state: cdState)
+ timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart
+
+ let sdResult = try await speechResult
+ timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions
+ timings.speechDecoder += sdResult.timings.speechDecoderPredictions
+ allAudio.append(contentsOf: sdResult.samples)
+
+ if callback?(
+ SpeechProgress(
+ audio: sdResult.samples,
+ timings: {
+ var merged = baseTimings; merged.merge(timings); return merged
+ }(), stepTime: nil)) == false
+ {
+ break
+ }
+ }
+
+ let samplingStart = CFAbsoluteTimeGetCurrent()
+ code0 = await sampler.sampleCodec0(
+ logits: lastCdOutput.logits,
+ temperature: options.temperature, topK: options.topK,
+ generatedTokens: generatedTokens,
+ repetitionPenalty: options.repetitionPenalty,
+ suppressTokenIds: suppressTokenIds
+ )
+ generatedTokens.append(code0)
+ timings.decodingSampling += CFAbsoluteTimeGetCurrent() - samplingStart
+
+ timings.decodingLoop += CFAbsoluteTimeGetCurrent() - stepStart
+ stepIndex += 1
+ progress.completedUnitCount = Int64(stepIndex)
+
+ if stepIndex == 1 || stepIndex % 10 == 0 {
+ let stepMs = (CFAbsoluteTimeGetCurrent() - stepStart) * 1000
+ Logging.debug(
+ String(
+ format: " Step %d: %.1fms (avg %.1fms/step)",
+ stepIndex, stepMs, timings.decodingLoop * 1000 / Double(stepIndex)))
+ }
+ }
+ }
+
+ let stopReason: String
+ if code0 == Qwen3TTSConstants.codecEOS {
+ stopReason = "EOS token"
+ } else if cdCache.isFull {
+ stopReason = "KV cache full (\(cdCache.cacheLength)/\(cdCache.maxSeqLength))"
+ } else if stepIndex >= maxStepsByPrefill {
+ stopReason = "Audio token ratio limit (\(stepIndex)/\(maxStepsByPrefill) steps)"
+ } else {
+ stopReason = "maxNewTokens limit (\(options.maxNewTokens))"
+ }
+ Logging.info("Loop stopped: \(stopReason) after \(stepIndex) steps")
+
+ timings.totalDecodingLoops = Double(stepIndex)
+ return GenerationLoopResult(audio: allAudio, steps: stepIndex, timings: timings)
+ }
+
+ // MARK: - Embedding Helpers
+
+ /// Build the full combined embedding sequence (text track + codec track) for prefill.
+ /// The returned array includes both the invariant prefix and the variable last token.
+ func buildCombinedEmbeddings(
+ speaker: Qwen3Speaker,
+ lang: Qwen3Language,
+ instruction: String?,
+ firstTextEmbed: [FloatType],
+ embedDim: Int
+ ) async throws -> [[FloatType]] {
+ let zeroCodecEmbed = EmbedUtilities.zeroEmbed(dim: embedDim)
+
+ var instructTextEmbeds: [[FloatType]] = []
+ var instructCodecEmbeds: [[FloatType]] = []
+ if let instruction, !instruction.isEmpty {
+ let instructPrompt = "<|im_start|>user\n\(instruction)<|im_end|>\n"
+ let instructTokenIds = tokenizer.encode(text: instructPrompt).map { Int32($0) }
+ for tokenId in instructTokenIds {
+ try await instructTextEmbeds.append(textProjector.project(tokenId: tokenId))
+ instructCodecEmbeds.append(zeroCodecEmbed)
+ }
+ Logging.debug("Instruction: \(instructTokenIds.count) tokens")
+ }
+
+ let rolePrefix = "<|im_start|>assistant\n"
+ let roleTokenIds = tokenizer.encode(text: rolePrefix).map { Int32($0) }
+
+ var textTrackEmbeds: [[FloatType]] = instructTextEmbeds
+ for tokenId in roleTokenIds {
+ try await textTrackEmbeds.append(textProjector.project(tokenId: tokenId))
+ }
+ let textPadEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD)
+ let textBosEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textBOS)
+
+ let codecIds: [Int32] = [
+ Qwen3TTSConstants.codecThink,
+ Qwen3TTSConstants.codecThinkBos,
+ lang.tokenID,
+ Qwen3TTSConstants.codecThinkEos,
+ speaker.tokenID,
+ Qwen3TTSConstants.codecPAD,
+ Qwen3TTSConstants.codecBOS
+ ]
+ var codecTrackEmbeds: [[FloatType]] = []
+ for codecId in codecIds {
+ try await codecTrackEmbeds.append(codeEmbedder.embed(tokenId: codecId))
+ }
+
+ let numPads = codecIds.count - 2
+ for _ in 0.. TTSPromptCache {
+ let qwen3Speaker = Qwen3Speaker(rawValue: voice) ?? .ryan
+ let lang = Qwen3Language(rawValue: language) ?? .english
+ let embedDim = codeDecoder.embedSize
+
+ // Build invariant embeddings (everything except the last variable token).
+ // Use a dummy firstTextEmbed since we drop the last element.
+ let dummyFirstTextEmbed = EmbedUtilities.zeroEmbed(dim: embedDim)
+ let allEmbeds = try await buildCombinedEmbeddings(
+ speaker: qwen3Speaker,
+ lang: lang,
+ instruction: instruction,
+ firstTextEmbed: dummyFirstTextEmbed,
+ embedDim: embedDim
+ )
+ let invariantEmbeds = Array(allEmbeds.dropLast())
+
+ // Pre-initialize the MultiCodeDecoder ANE pipeline concurrently with the
+ // CodeDecoder prefill loop below. Cache build is the right place for this
+ // one-time cost: it absorbs the ~150ms without affecting TTFB, and the
+ // warmed pipeline persists for all subsequent generation calls.
+ let mcdWarmupTask: Task?
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ let mcd = multiCodeDecoder
+ mcdWarmupTask = Task { try? await mcd.prewarmInference() }
+ } else {
+ mcdWarmupTask = nil
+ }
+
+ let cdState = codeDecoder.makeState()
+ let cdCache = try KVCache(
+ cacheDim: codeDecoder.kvCacheEmbedDim,
+ maxSeqLength: codeDecoder.kvCacheMaxSequenceLength,
+ isStateful: codeDecoder.isStateful
+ )
+
+ var lastCdOutput: CodeDecoderOutput?
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ // Async path: decoder updates cache internally
+ for embed in invariantEmbeds {
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: embed.asMLTensor(), cache: cdCache, state: cdState)
+ }
+ } else {
+ for (embedIndex, embed) in invariantEmbeds.enumerated() {
+ if embedIndex > 0, let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates {
+ cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates)
+ }
+ let embedArr = try EmbedUtilities.createEmbedMLArray(embed)
+ lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState)
+ }
+ // Commit the last pending KV update so the snapshot is fully self-contained
+ if let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates {
+ cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates)
+ }
+ }
+
+ // Snapshot MLState for stateful models
+ var stateData: KVStateData?
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = cdState as? MLState {
+ stateData = mlState.snapshot()
+ }
+
+ // Ensure warmup is done before returning - by this point the CodeDecoder
+ // loop has run (~2700ms), so this await is a no-op in practice.
+ await mcdWarmupTask?.value
+
+ Logging.info("Built prompt cache: \(invariantEmbeds.count) invariant tokens, isStateful=\(codeDecoder.isStateful) for \(voice)/\(language)")
+
+ return TTSPromptCache(
+ voice: voice,
+ language: language,
+ instruction: instruction,
+ prefixLength: invariantEmbeds.count,
+ kvSnapshot: cdCache.snapshot(),
+ stateData: stateData
+ )
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift
new file mode 100644
index 00000000..001b09d0
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift
@@ -0,0 +1,174 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+// MARK: - Qwen3 TTS Constants
+
+/// Qwen3 TTS model-specific constants: codec token IDs, vocabulary sizes,
+/// model dimensions, cache geometry, and HuggingFace source locations.
+///
+/// These values are derived from the Qwen3 TTS architecture and must be updated
+/// if the model is retrained with a different configuration.
+///
+/// Audio-format constants (`sampleRate`, `samplesPerFrame`) are here because
+/// they match Qwen3's output format. Future model families with different audio
+/// parameters should expose them via `SpeechDecoding.sampleRate` /
+/// `SpeechDecoding.samplesPerFrame` rather than adding a second constants enum.
+public enum Qwen3TTSConstants {
+ // MARK: Codec track special tokens
+
+ public static let codecPAD: Int32 = 2148
+ public static let codecBOS: Int32 = 2149
+ public static let codecEOS: Int32 = 2150
+ public static let codecThink: Int32 = 2154
+ public static let codecThinkBos: Int32 = 2156
+ public static let codecThinkEos: Int32 = 2157
+
+ // MARK: Text track special tokens
+
+ public static let textPAD: Int32 = 151_671
+ public static let textBOS: Int32 = 151_672
+
+ // MARK: Vocabulary sizes
+
+ /// Vocabulary size for the multi-code decoder heads (codes 1-15).
+ public static let codecVocabSize: Int = 2048
+
+ // MARK: Audio format
+
+ public static let sampleRate: Int = 24000
+ public static let samplesPerFrame: Int = 1920
+
+ // MARK: Model dimensions
+
+ /// Shared embedding dimension for all projectors and embedders.
+ public static let embedDim: Int = 1024
+
+ // MARK: KV cache geometry
+
+ public static let cdCacheDim: Int = 28672
+ public static let cdMaxSeq: Int = 256
+ public static let mcdCacheDim: Int = 5120
+ public static let mcdMaxSeq: Int = 64
+ public static let sdCacheDim: Int = 8192
+ public static let sdMaxSeq: Int = 256
+ public static let sdHiddenDim: Int = 1024
+ public static let sdHiddenContextLen: Int = 16
+
+ // MARK: Default HuggingFace sources
+
+ /// Tokenizer-compatible model on HuggingFace (must contain `tokenizer.json`).
+ public static let defaultTokenizerRepo = "Qwen/Qwen3-0.6B"
+ /// HuggingFace repo hosting the pre-compiled CoreML TTS models.
+ public static let defaultModelRepo = "argmaxinc/ttskit-coreml"
+ /// Default HuggingFace Hub endpoint. Override to point at a mirror or on-premise instance.
+ public static let defaultEndpoint = "https://huggingface.co"
+ /// Intermediate subdirectory inside the model repo grouping all Qwen3 TTS components.
+ public static let modelFamilyDir = "qwen3_tts"
+ /// Default model version directory, equivalent to `TTSModelPreset.qwen3TTS_0_6B.versionDir`.
+ /// Provided as a stable exported constant; `TTSKitConfig` uses the preset's value directly.
+ public static let defaultVersionDir = "12hz-0.6b-customvoice"
+
+ // MARK: Token suppression
+
+ /// Codec-0 token IDs suppressed during sampling: [2048, 3072) except EOS (2150).
+ public static let suppressTokenIds: Set = {
+ var ids = Set()
+ for tokenId in 2048..<3072 where tokenId != Int(Qwen3TTSConstants.codecEOS) {
+ ids.insert(tokenId)
+ }
+ return ids
+ }()
+}
+
+// MARK: - Speaker
+
+/// Qwen3 TTS speaker voices with their corresponding codec token IDs.
+public enum Qwen3Speaker: String, CaseIterable, Sendable {
+ case ryan, aiden
+ case onoAnna = "ono-anna"
+ case sohee, eric, dylan, serena, vivian
+ case uncleFu = "uncle-fu"
+
+ public var tokenID: Int32 {
+ switch self {
+ case .ryan: return 3061
+ case .aiden: return 2861
+ case .onoAnna: return 2873
+ case .sohee: return 2864
+ case .eric: return 2875
+ case .dylan: return 2878
+ case .serena: return 3066
+ case .vivian: return 3065
+ case .uncleFu: return 3010
+ }
+ }
+
+ /// Human-readable display name (handles hyphenated raw values).
+ public var displayName: String {
+ switch self {
+ case .ryan: return "Ryan"
+ case .aiden: return "Aiden"
+ case .onoAnna: return "Ono Anna"
+ case .sohee: return "Sohee"
+ case .eric: return "Eric"
+ case .dylan: return "Dylan"
+ case .serena: return "Serena"
+ case .vivian: return "Vivian"
+ case .uncleFu: return "Uncle Fu"
+ }
+ }
+
+ /// Short description of the voice character and quality.
+ public var voiceDescription: String {
+ switch self {
+ case .ryan: return "Dynamic male voice with strong rhythmic drive."
+ case .aiden: return "Sunny American male voice with a clear midrange."
+ case .onoAnna: return "Playful Japanese female voice with a light, nimble timbre."
+ case .sohee: return "Warm Korean female voice with rich emotion."
+ case .eric: return "Lively Chengdu male voice with a slightly husky brightness."
+ case .dylan: return "Youthful Beijing male voice with a clear, natural timbre."
+ case .serena: return "Warm, gentle young female voice."
+ case .vivian: return "Bright, slightly edgy young female voice."
+ case .uncleFu: return "Seasoned male voice with a low, mellow timbre."
+ }
+ }
+
+ /// The speaker's native language (best quality when used with this language).
+ public var nativeLanguage: String {
+ switch self {
+ case .ryan: return "English"
+ case .aiden: return "English"
+ case .onoAnna: return "Japanese"
+ case .sohee: return "Korean"
+ case .eric: return "Chinese (Sichuan)"
+ case .dylan: return "Chinese (Beijing)"
+ case .serena: return "Chinese"
+ case .vivian: return "Chinese"
+ case .uncleFu: return "Chinese"
+ }
+ }
+}
+
+// MARK: - Language
+
+/// Qwen3 TTS supported languages with their corresponding codec token IDs.
+public enum Qwen3Language: String, CaseIterable, Sendable {
+ case english, chinese, japanese, korean, german, french, russian, portuguese, spanish, italian
+
+ public var tokenID: Int32 {
+ switch self {
+ case .english: return 2050
+ case .chinese: return 2055
+ case .japanese: return 2058
+ case .korean: return 2064
+ case .german: return 2053
+ case .french: return 2061
+ case .russian: return 2069
+ case .portuguese: return 2071
+ case .spanish: return 2054
+ case .italian: return 2070
+ }
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift
new file mode 100644
index 00000000..aa9191a1
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift
@@ -0,0 +1,465 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Supporting Types
+
+/// Update and padding masks for a single MLTensor decode step.
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+struct MLTensorMasks {
+ let updateMask: MLTensor
+ let paddingMask: MLTensor
+}
+
+/// Result of a single MLTensor prediction step, including the updated KV cache.
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+struct MLTensorStepResult {
+ let outputs: [String: MLTensor]
+ let keyCache: MLTensor
+ let valueCache: MLTensor
+ let cachePosition: Int32
+ let predictionTime: TimeInterval
+ let cacheUpdateTime: TimeInterval
+}
+
+// MARK: - Implementation
+
+/// Multi-code RVQ decoder backed by a CoreML model.
+///
+/// Thread safety: mutable state (`model`, dimension properties) is set once during
+/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe.
+/// Per-call state is created locally within `generateMultiCodes()` and never stored
+/// on this shared instance.
+public class Qwen3MultiCodeDecoder: MultiCodeDecoding, @unchecked Sendable {
+ public var model: MLModel?
+
+ /// KV cache embedding dimension, detected from model at load time
+ public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.mcdCacheDim
+ /// KV cache max sequence length, detected from model at load time
+ public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.mcdMaxSeq
+ /// Codec vocabulary size per head (codes 1-15), detected from model output
+ public private(set) var codecVocabSize: Int = Qwen3TTSConstants.codecVocabSize
+
+ public init() {}
+
+ public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws {
+ let modelConfig = MLModelConfiguration()
+ modelConfig.computeUnits = computeUnits
+ let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig)
+
+ // In prewarm mode, compilation is complete - discard to free memory before next model compiles
+ guard !prewarmMode else { return }
+
+ self.model = loaded
+
+ // Detect dimensions from model description
+ if let dim = ModelUtilities.getModelOutputDimension(model, named: "key_cache_updates", position: 1) {
+ self.kvCacheEmbedDim = dim
+ }
+ if let seq = ModelUtilities.getModelInputDimension(model, named: "key_padding_mask", position: 1) {
+ self.kvCacheMaxSequenceLength = seq
+ }
+ // all_logits output shape: [1, 15, codecVocabSize]
+ if let vocab = ModelUtilities.getModelOutputDimension(model, named: "all_logits", position: 2) {
+ self.codecVocabSize = vocab
+ }
+ // input_embeds shape: [1, embedDim, 1, 1]
+ if let embedDim = ModelUtilities.getModelInputDimension(model, named: "input_embeds", position: 1) {
+ self.inputEmbedDim = embedDim
+ }
+ }
+
+ /// Embedding dimension for `input_embeds`, detected from the model at load time.
+ public private(set) var inputEmbedDim: Int = Qwen3TTSConstants.embedDim
+
+ public var isStateful: Bool {
+ guard let model else { return false }
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ return !model.modelDescription.stateDescriptionsByName.isEmpty
+ }
+ return false
+ }
+
+ /// Create a fresh MLState for a new RVQ frame (stateful models only).
+ /// Returns nil for non-stateful models. The caller owns the returned state.
+ public func makeState() -> Any? {
+ guard isStateful, let model else { return nil }
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ return model.makeState()
+ }
+ return nil
+ }
+
+ /// Pre-initialize the ANE pipeline by running dummy predictions before the first
+ /// real generation step. Without this, the first `generateMultiCodes` call per
+ /// generation is ~7x slower than steady state due to lazy ANE pipeline setup.
+ ///
+ /// Run concurrently with CodeDecoder prefill so there is no net TTFB cost.
+ /// Replicates the exact loop pattern used in `generateMultiCodes` (4 passes
+ /// × 16 predictions each) to match what the ANE needs to pipeline.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public func prewarmInference() async throws {
+ guard let model else { return }
+ let sequenceLength = kvCacheMaxSequenceLength
+ let dummyInput = MLTensor(zeros: [1, inputEmbedDim, 1, 1], scalarType: FloatType.self)
+ for _ in 0..<4 {
+ var keyCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self)
+ var valueCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self)
+ var cachePosition: Int32 = 0
+ for _ in 0..<16 {
+ let stepResult = try await predictMLTensorStep(
+ inputEmbeds: dummyInput, model: model,
+ keyCache: keyCache, valueCache: valueCache,
+ cachePosition: cachePosition, sequenceLength: sequenceLength
+ )
+ keyCache = stepResult.keyCache
+ valueCache = stepResult.valueCache
+ cachePosition = stepResult.cachePosition
+ }
+ }
+ }
+
+ public func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any? = nil) async throws -> MultiCodeDecoderOutput {
+ guard let model else {
+ throw TTSError.generationFailed("MultiCodeDecoder model not loaded")
+ }
+ guard let array = inputEmbeds as? MLMultiArray else {
+ throw TTSError.generationFailed("MultiCodeDecoder: unsupported embed input type \(type(of: inputEmbeds))")
+ }
+
+ var dict: [String: MLFeatureValue] = try [
+ "input_embeds": MLFeatureValue(multiArray: array),
+ "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()),
+ "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask),
+ "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask)
+ ]
+
+ // Only pass external KV cache for non-stateful models
+ if !isStateful, let keyCache = cache.keyCache, let valueCache = cache.valueCache {
+ dict["key_cache"] = MLFeatureValue(multiArray: keyCache)
+ dict["value_cache"] = MLFeatureValue(multiArray: valueCache)
+ }
+
+ let input = try MLDictionaryFeatureProvider(dictionary: dict)
+
+ let output: MLFeatureProvider
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState {
+ output = try await model.asyncPrediction(from: input, using: mlState)
+ } else {
+ output = try await model.asyncPrediction(from: input)
+ }
+
+ guard let keyCacheUpdates = output.featureValue(for: "key_cache_updates")?.multiArrayValue,
+ let valueCacheUpdates = output.featureValue(for: "value_cache_updates")?.multiArrayValue
+ else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing key/value cache update arrays")
+ }
+
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState, isStateful {
+ KVCache.updateStateCache(
+ state: mlState,
+ keyCacheUpdates: keyCacheUpdates,
+ valueCacheUpdates: valueCacheUpdates,
+ position: Int(cache.cacheLength)
+ )
+ }
+
+ guard let allLogitsArray = output.featureValue(for: "all_logits")?.multiArrayValue else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits array")
+ }
+ return MultiCodeDecoderOutput(
+ allLogits: allLogitsArray,
+ keyCacheUpdates: keyCacheUpdates,
+ valueCacheUpdates: valueCacheUpdates
+ )
+ }
+
+ /// Pure MLTensor path - cache lives as tensors, updated via element-wise masking.
+ /// No MLMultiArray round-trip: prediction takes/returns MLTensor, cache updates
+ /// are lazy tensor ops, and only the logits are materialized (by the sampler).
+ /// Build the update mask and padding mask tensors for a given cache position.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ func buildMasks(position: Int, sequenceLength: Int) -> MLTensorMasks {
+ var updateData = [FloatType](repeating: 0, count: sequenceLength)
+ var paddingData = [FloatType](repeating: -10000, count: sequenceLength)
+ if position < sequenceLength { updateData[position] = 1 }
+ for index in 0...min(position, sequenceLength - 1) {
+ paddingData[index] = 0
+ }
+ return MLTensorMasks(
+ updateMask: MLTensor(shape: [1, sequenceLength], scalars: updateData),
+ paddingMask: MLTensor(shape: [1, sequenceLength], scalars: paddingData)
+ )
+ }
+
+ /// Run one MLTensor prediction step and return the outputs with an updated KV cache.
+ ///
+ /// Cache updates are performed in tensor space via element-wise masking -
+ /// no MLMultiArray round-trip occurs. The model must be compiled for
+ /// single-token `[1, embedDim, 1, 1]` input.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ func predictMLTensorStep(
+ inputEmbeds: MLTensor,
+ model: MLModel,
+ keyCache: MLTensor,
+ valueCache: MLTensor,
+ cachePosition: Int32,
+ sequenceLength: Int
+ ) async throws -> MLTensorStepResult {
+ let masks = buildMasks(position: Int(cachePosition), sequenceLength: sequenceLength)
+ let predictionStart = CFAbsoluteTimeGetCurrent()
+ let outputs = try await model.prediction(from: [
+ "input_embeds": inputEmbeds,
+ "cache_length": MLTensor(shape: [1], scalars: [cachePosition]),
+ "kv_cache_update_mask": masks.updateMask,
+ "key_padding_mask": masks.paddingMask,
+ "key_cache": keyCache,
+ "value_cache": valueCache
+ ])
+ let predictionTime = CFAbsoluteTimeGetCurrent() - predictionStart
+
+ let cacheUpdateStart = CFAbsoluteTimeGetCurrent()
+ var positionMaskData = [FloatType](repeating: 0, count: sequenceLength)
+ positionMaskData[Int(cachePosition)] = 1
+ let positionMask = MLTensor(shape: [1, 1, 1, sequenceLength], scalars: positionMaskData)
+ let invertedMask = MLTensor(repeating: FloatType(1), shape: [1, 1, 1, sequenceLength]) - positionMask
+ guard let keyCacheOutput = outputs["key_cache_updates"],
+ let valueCacheOutput = outputs["value_cache_updates"]
+ else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing key/value cache update tensors")
+ }
+ let updatedKeyCache = keyCache * invertedMask + keyCacheOutput * positionMask
+ let updatedValueCache = valueCache * invertedMask + valueCacheOutput * positionMask
+ let cacheUpdateTime = CFAbsoluteTimeGetCurrent() - cacheUpdateStart
+
+ return MLTensorStepResult(
+ outputs: outputs,
+ keyCache: updatedKeyCache,
+ valueCache: updatedValueCache,
+ cachePosition: cachePosition + 1,
+ predictionTime: predictionTime,
+ cacheUpdateTime: cacheUpdateTime
+ )
+ }
+
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public func generateMultiCodes(
+ hiddenStatesTensor: MLTensor,
+ code0EmbedTensor: MLTensor,
+ multiCodeEmbedder: any MultiCodeEmbedding,
+ sampler: any TokenSampling,
+ options: GenerationOptions
+ ) async throws -> MultiCodeGenerationResult {
+ guard let model else { throw TTSError.generationFailed("MultiCodeDecoder model not loaded") }
+
+ let sequenceLength = kvCacheMaxSequenceLength
+ var keyCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self)
+ var valueCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self)
+ var cachePosition: Int32 = 0
+ var timings = SpeechTimings()
+
+ // Prefill: hidden_states, then code_0_embed. First call's logits are discarded.
+ var stepResult = try await predictMLTensorStep(
+ inputEmbeds: hiddenStatesTensor, model: model,
+ keyCache: keyCache, valueCache: valueCache,
+ cachePosition: cachePosition, sequenceLength: sequenceLength
+ )
+ keyCache = stepResult.keyCache
+ valueCache = stepResult.valueCache
+ cachePosition = stepResult.cachePosition
+ timings.multiCodeDecoderPredictions += stepResult.predictionTime
+ timings.totalMultiCodeDecoderPredictions += 1
+ timings.decodingKvCaching += stepResult.cacheUpdateTime
+
+ stepResult = try await predictMLTensorStep(
+ inputEmbeds: code0EmbedTensor, model: model,
+ keyCache: keyCache, valueCache: valueCache,
+ cachePosition: cachePosition, sequenceLength: sequenceLength
+ )
+ keyCache = stepResult.keyCache
+ valueCache = stepResult.valueCache
+ cachePosition = stepResult.cachePosition
+ timings.multiCodeDecoderPredictions += stepResult.predictionTime
+ timings.totalMultiCodeDecoderPredictions += 1
+ timings.decodingKvCaching += stepResult.cacheUpdateTime
+
+ var stepIndex = 0
+ let samplingStart = CFAbsoluteTimeGetCurrent()
+ guard let firstStepLogits = stepResult.outputs["all_logits"] else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits tensor on step 0")
+ }
+ let code1 = await sampler.sampleMultiHead(
+ allLogits: firstStepLogits,
+ headIndex: stepIndex,
+ temperature: options.temperature,
+ topK: options.topK
+ )
+ timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - samplingStart
+ var codes: [Int32] = [code1]
+
+ var offsetCodeEmbedTensors: [MLTensor] = []
+ offsetCodeEmbedTensors.reserveCapacity(14)
+
+ for _ in 0..<14 {
+ let embeddingStart = CFAbsoluteTimeGetCurrent()
+ guard let lastCode = codes.last else {
+ throw TTSError.generationFailed("MultiCodeDecoder: codes array is empty in MLTensor loop")
+ }
+ let offsetId = lastCode + Int32(codecVocabSize * stepIndex)
+ let embedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: offsetId)
+ offsetCodeEmbedTensors.append(embedTensor)
+ timings.multiCodeDecoderEmbedding += CFAbsoluteTimeGetCurrent() - embeddingStart
+
+ stepResult = try await predictMLTensorStep(
+ inputEmbeds: embedTensor, model: model,
+ keyCache: keyCache, valueCache: valueCache,
+ cachePosition: cachePosition, sequenceLength: sequenceLength
+ )
+ keyCache = stepResult.keyCache
+ valueCache = stepResult.valueCache
+ cachePosition = stepResult.cachePosition
+ timings.multiCodeDecoderPredictions += stepResult.predictionTime
+ timings.totalMultiCodeDecoderPredictions += 1
+ timings.decodingKvCaching += stepResult.cacheUpdateTime
+
+ stepIndex += 1
+
+ let nextSamplingStart = CFAbsoluteTimeGetCurrent()
+ guard let nextStepLogits = stepResult.outputs["all_logits"] else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits tensor on step \(stepIndex)")
+ }
+ let code = await sampler.sampleMultiHead(
+ allLogits: nextStepLogits,
+ headIndex: stepIndex,
+ temperature: options.temperature,
+ topK: options.topK
+ )
+ timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - nextSamplingStart
+ codes.append(code)
+ }
+
+ return MultiCodeGenerationResult(codes: codes, timings: timings, offsetCodeEmbedTensors: offsetCodeEmbedTensors)
+ }
+
+ /// Legacy path - kept for OS compatibility (pre-macOS 15).
+ // TODO: Remove forking logic with package with min os version upgrade
+ public func generateMultiCodes(
+ hiddenStates: [FloatType],
+ code0Embed: [FloatType],
+ multiCodeEmbedder: any MultiCodeEmbedding,
+ sampler: any TokenSampling,
+ options: GenerationOptions
+ ) async throws -> MultiCodeGenerationResult {
+ var timings = SpeechTimings()
+ let mcdCache = try KVCache(
+ cacheDim: kvCacheEmbedDim,
+ maxSeqLength: kvCacheMaxSequenceLength,
+ isStateful: isStateful
+ )
+
+ let frameState = makeState()
+ let embedDim = hiddenStates.count
+ let reuseArray = try MLMultiArray(shape: [1, NSNumber(value: embedDim), 1, 1], dataType: .float16)
+
+ // Prefill step 1: hiddenStates from CodeDecoder
+ let prefillStart = CFAbsoluteTimeGetCurrent()
+ var mcdOutput = try await decodeEmbedBuffer(hiddenStates, reuseArray: reuseArray, cache: mcdCache, state: frameState)
+ timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - prefillStart
+ timings.totalMultiCodeDecoderPredictions += 1
+
+ let cacheUpdateStart = CFAbsoluteTimeGetCurrent()
+ guard let prefillKeyUpdates = mcdOutput.keyCacheUpdates,
+ let prefillValueUpdates = mcdOutput.valueCacheUpdates
+ else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing cache updates after prefill step 1")
+ }
+ mcdCache.update(keyCacheUpdates: prefillKeyUpdates, valueCacheUpdates: prefillValueUpdates)
+ timings.decodingKvCaching += CFAbsoluteTimeGetCurrent() - cacheUpdateStart
+
+ // Prefill step 2: code0 embedding
+ let prefill2Start = CFAbsoluteTimeGetCurrent()
+ mcdOutput = try await decodeEmbedBuffer(code0Embed, reuseArray: reuseArray, cache: mcdCache, state: frameState)
+ timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - prefill2Start
+ timings.totalMultiCodeDecoderPredictions += 1
+
+ var stepIndex = 0
+ let samplingStart = CFAbsoluteTimeGetCurrent()
+ let code1 = await sampler.sampleMultiHead(
+ allLogits: mcdOutput.allLogits,
+ headIndex: stepIndex,
+ temperature: options.temperature,
+ topK: options.topK
+ )
+ timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - samplingStart
+ var codes: [Int32] = [code1]
+
+ var offsetCodeEmbeds: [[FloatType]] = []
+ offsetCodeEmbeds.reserveCapacity(14)
+
+ for _ in 0..<14 {
+ let cacheStep = CFAbsoluteTimeGetCurrent()
+ guard let loopKeyUpdates = mcdOutput.keyCacheUpdates,
+ let loopValueUpdates = mcdOutput.valueCacheUpdates
+ else {
+ throw TTSError.generationFailed("MultiCodeDecoder: missing cache updates in generation loop")
+ }
+ mcdCache.update(keyCacheUpdates: loopKeyUpdates, valueCacheUpdates: loopValueUpdates)
+ timings.decodingKvCaching += CFAbsoluteTimeGetCurrent() - cacheStep
+
+ let embeddingStart = CFAbsoluteTimeGetCurrent()
+ guard let lastCode = codes.last else {
+ throw TTSError.generationFailed("MultiCodeDecoder: codes array is empty in legacy loop")
+ }
+ let offsetId = lastCode + Int32(codecVocabSize * stepIndex)
+ let codeEmbedBuf = try await multiCodeEmbedder.embed(tokenId: offsetId)
+ offsetCodeEmbeds.append(codeEmbedBuf)
+ timings.multiCodeDecoderEmbedding += CFAbsoluteTimeGetCurrent() - embeddingStart
+
+ let decodingStart = CFAbsoluteTimeGetCurrent()
+ mcdOutput = try await decodeEmbedBuffer(codeEmbedBuf, reuseArray: reuseArray, cache: mcdCache, state: frameState)
+ timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - decodingStart
+ timings.totalMultiCodeDecoderPredictions += 1
+
+ stepIndex += 1
+
+ let nextSamplingStart = CFAbsoluteTimeGetCurrent()
+ let code = await sampler.sampleMultiHead(
+ allLogits: mcdOutput.allLogits,
+ headIndex: stepIndex,
+ temperature: options.temperature,
+ topK: options.topK
+ )
+ timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - nextSamplingStart
+ codes.append(code)
+ }
+
+ return MultiCodeGenerationResult(
+ codes: codes,
+ timings: timings,
+ offsetCodeEmbeds: offsetCodeEmbeds
+ )
+ }
+
+ /// Copy `embed` into a pre-allocated reuse array and run a single decode step.
+ /// Reusing `reuseArray` avoids per-step MLMultiArray allocation.
+ func decodeEmbedBuffer(
+ _ embed: [FloatType],
+ reuseArray: MLMultiArray,
+ cache: KVCache,
+ state: Any?
+ ) async throws -> MultiCodeDecoderOutput {
+ let reusePointer = reuseArray.dataPointer.bindMemory(to: FloatType.self, capacity: embed.count)
+ embed.withUnsafeBufferPointer { src in
+ guard let baseAddress = src.baseAddress else { return }
+ reusePointer.update(from: baseAddress, count: embed.count)
+ }
+ return try await decode(inputEmbeds: reuseArray, cache: cache, state: state)
+ }
+
+ public func unloadModel() {
+ model = nil
+ }
+}
diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift
new file mode 100644
index 00000000..e9ae3789
--- /dev/null
+++ b/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift
@@ -0,0 +1,187 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Accelerate
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Implementation
+
+/// RVQ-to-audio waveform decoder backed by a CoreML model.
+///
+/// Thread safety: mutable state (`model`, dimension properties) is set once during
+/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe.
+public class Qwen3SpeechDecoder: SpeechDecoding, @unchecked Sendable {
+ public var model: MLModel?
+
+ // MARK: - Audio format
+
+ public let sampleRate: Int = Qwen3TTSConstants.sampleRate
+ public let samplesPerFrame: Int = Qwen3TTSConstants.samplesPerFrame
+ /// Minimum pre-buffer: 80ms ≈ 2 audio frames at 24 kHz / 1920 spf.
+ public let minimumBufferDuration: TimeInterval = 0.08
+
+ /// Detected from model metadata at load time
+ public private(set) var hiddenContextLen: Int = Qwen3TTSConstants.sdHiddenContextLen
+ /// KV cache embedding dimension
+ public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.sdCacheDim
+ /// KV cache max sequence length
+ public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.sdMaxSeq
+ /// Hidden state dimension
+ public private(set) var hiddenDim: Int = Qwen3TTSConstants.sdHiddenDim
+ public init() {}
+
+ public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws {
+ let modelConfig = MLModelConfiguration()
+ modelConfig.computeUnits = computeUnits
+ let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig)
+
+ // In prewarm mode, compilation is complete - discard to free memory before next model compiles
+ guard !prewarmMode else { return }
+
+ self.model = loaded
+
+ // Detect dimensions from model description
+ // hidden_context input shape: [1, hiddenDim, 1, contextLen]
+ if let dim = ModelUtilities.getModelInputDimension(model, named: "hidden_context", position: 1) {
+ self.hiddenDim = dim
+ }
+ if let ctxLen = ModelUtilities.getModelInputDimension(model, named: "hidden_context", position: 3) {
+ self.hiddenContextLen = ctxLen
+ }
+ // key_cache input shape: [1, cacheDim, 1, maxSeqLen]
+ if let dim = ModelUtilities.getModelInputDimension(model, named: "key_cache", position: 1) {
+ self.kvCacheEmbedDim = dim
+ }
+ if let seq = ModelUtilities.getModelInputDimension(model, named: "key_cache", position: 3) {
+ self.kvCacheMaxSequenceLength = seq
+ }
+ }
+
+ public func decodeFrame(
+ codes: [Int32],
+ cache: SpeechDecoderCache
+ ) async throws -> [Float] {
+ guard let model else {
+ throw TTSError.generationFailed("SpeechDecoder model not loaded")
+ }
+
+ let codesArr = try MLMultiArray(shape: [1, 16, 1], dataType: .int32)
+ let codesPtr = codesArr.dataPointer.bindMemory(to: Int32.self, capacity: 16)
+ for i in 0..<16 {
+ codesPtr[i] = codes[i]
+ }
+
+ guard let keyCache = cache.keyCache, let valueCache = cache.valueCache else {
+ throw TTSError.generationFailed("SpeechDecoder: KV cache not initialized")
+ }
+ let input = try MLDictionaryFeatureProvider(dictionary: [
+ "audio_codes": MLFeatureValue(multiArray: codesArr),
+ "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()),
+ "key_cache": MLFeatureValue(multiArray: keyCache),
+ "value_cache": MLFeatureValue(multiArray: valueCache),
+ "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask),
+ "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask),
+ "hidden_context": MLFeatureValue(multiArray: cache.hiddenContext)
+ ])
+
+ let output = try await model.asyncPrediction(from: input)
+ cache.updateWithHiddenContext(output: output)
+
+ guard let audioArr = output.featureValue(for: "audio")?.multiArrayValue else {
+ throw TTSError.generationFailed("SpeechDecoder: missing audio output array")
+ }
+ let sampleCount = audioArr.count
+ let audioPtr = audioArr.dataPointer.bindMemory(to: FloatType.self, capacity: sampleCount)
+ var samples = [Float](repeating: 0, count: sampleCount)
+ for i in 0.. SpeechDecoderTimedResult {
+ guard let model else {
+ throw TTSError.generationFailed("SpeechDecoder model not loaded")
+ }
+
+ var timings = SpeechTimings()
+
+ // TODO: Remove forking logic with package with min os version upgrade
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ guard let keyCacheTensor = cache.keyCacheTensor,
+ let valueCacheTensor = cache.valueCacheTensor
+ else {
+ throw TTSError.generationFailed("SpeechDecoder: KV cache tensors not initialized")
+ }
+ let inputs: [String: MLTensor] = [
+ "audio_codes": MLTensor(shape: [1, codes.count, 1], scalars: codes),
+ "cache_length": cache.cacheLengthTensor,
+ "key_cache": keyCacheTensor,
+ "value_cache": valueCacheTensor,
+ "kv_cache_update_mask": cache.kvCacheUpdateMaskTensor,
+ "key_padding_mask": cache.keyPaddingMaskTensor,
+ "hidden_context": cache.hiddenContextTensor
+ ]
+
+ let predictionStart = CFAbsoluteTimeGetCurrent()
+ let outputs = try await model.prediction(from: inputs)
+ timings.speechDecoderPredictions += CFAbsoluteTimeGetCurrent() - predictionStart
+
+ await cache.updateWithHiddenContext(tensorOutputs: outputs)
+
+ guard let audioTensor = outputs["audio"] else {
+ throw TTSError.generationFailed("SpeechDecoder: missing audio tensor output")
+ }
+ let samples = await audioTensor.toFloatArray()
+ return SpeechDecoderTimedResult(samples: samples, timings: timings)
+ } else {
+ let codesArr = try MLMultiArray(shape: [1, 16, 1], dataType: .int32)
+ let codesPtr = codesArr.dataPointer.bindMemory(to: Int32.self, capacity: 16)
+ for i in 0..<16 {
+ codesPtr[i] = codes[i]
+ }
+ guard let keyCacheFallback = cache.keyCache, let valueCacheFallback = cache.valueCache else {
+ throw TTSError.generationFailed("SpeechDecoder: KV cache not initialized (legacy path)")
+ }
+ let input = try MLDictionaryFeatureProvider(dictionary: [
+ "audio_codes": MLFeatureValue(multiArray: codesArr),
+ "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()),
+ "key_cache": MLFeatureValue(multiArray: keyCacheFallback),
+ "value_cache": MLFeatureValue(multiArray: valueCacheFallback),
+ "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask),
+ "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask),
+ "hidden_context": MLFeatureValue(multiArray: cache.hiddenContext)
+ ])
+
+ let predictionStart = CFAbsoluteTimeGetCurrent()
+ let output = try await model.asyncPrediction(from: input, options: MLPredictionOptions())
+ timings.speechDecoderPredictions += CFAbsoluteTimeGetCurrent() - predictionStart
+
+ cache.updateWithHiddenContext(output: output)
+ guard let audioArr = output.featureValue(for: "audio")?.multiArrayValue else {
+ throw TTSError.generationFailed("SpeechDecoder: missing audio output array (legacy path)")
+ }
+ let sampleCount = audioArr.count
+ let audioPtr = audioArr.dataPointer.bindMemory(to: FloatType.self, capacity: sampleCount)
+ var samples = [Float](repeating: 0, count: sampleCount)
+ for i in 0.. [FloatType] {
+ guard let model else { throw TTSError.generationFailed("TextProjector model not loaded") }
+ let ids = try EmbedUtilities.makeInt32Array([tokenId])
+ let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)])
+ let output = try await model.asyncPrediction(from: provider)
+ guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else {
+ throw TTSError.generationFailed("TextProjector: missing input_embeds output")
+ }
+ return EmbedUtilities.extractEmbed(from: embedArray)
+ }
+
+ /// Optimised async path: passes `[String: MLTensor]` directly - no FeatureProvider boxing.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public func project(tokenId: Int32) async throws -> MLTensor {
+ guard let model else { throw TTSError.generationFailed("TextProjector model not loaded") }
+ let outputs = try await model.prediction(from: [
+ "input_ids": MLTensor(shape: [1], scalars: [tokenId])
+ ])
+ guard let embedTensor = outputs["input_embeds"] else {
+ throw TTSError.generationFailed("TextProjector: missing input_embeds tensor output")
+ }
+ return embedTensor
+ }
+
+ public func unloadModel() {
+ model = nil
+ }
+}
diff --git a/Sources/TTSKit/TTSKit.swift b/Sources/TTSKit/TTSKit.swift
new file mode 100644
index 00000000..d053234c
--- /dev/null
+++ b/Sources/TTSKit/TTSKit.swift
@@ -0,0 +1,1147 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+import Hub
+import Tokenizers
+import os
+
+// MARK: - Callback Typealiases
+
+// MARK: - TTSKit
+
+/// Generic TTS orchestrator: text chunking, concurrent generation, crossfade, and audio playback.
+///
+/// Following the WhisperKit pattern, `TTSKit` exposes each model component as a
+/// protocol-typed `public var`. Swap any component at runtime to change behaviour:
+/// ```swift
+/// let config = TTSKitConfig(load: false)
+/// let tts = try await TTSKit(config)
+/// tts.codeDecoder = MyOptimisedCodeDecoder()
+/// try await tts.loadModels()
+/// ```
+///
+/// The default implementation uses Qwen3 TTS components (`Sources/TTSKit/Qwen3TTS/`).
+/// Components from entirely different model families can be plugged in by conforming
+/// to the same component protocols, or by implementing `SpeechModel` directly.
+///
+/// `setupGenerateTask(...)` returns an `any SpeechGenerating` - override it to use a
+/// completely different generation algorithm while keeping the chunking, concurrency,
+/// crossfade, and playback orchestration provided by `generate` and `play`.
+/// Mirrors `WhisperKit.setupTranscribeTask(...)`.
+open class TTSKit: @unchecked Sendable {
+ // MARK: - Model components (protocol-typed, swappable)
+
+ /// Text token -> embedding. Conforms to `TextProjecting`.
+ public var textProjector: any TextProjecting = Qwen3TextProjector()
+ /// Codec-0 token -> embedding. Conforms to `CodeEmbedding`.
+ public var codeEmbedder: any CodeEmbedding = Qwen3CodeEmbedder()
+ /// Multi-code token -> embedding. Conforms to `MultiCodeEmbedding`.
+ public var multiCodeEmbedder: any MultiCodeEmbedding = Qwen3MultiCodeEmbedder()
+ /// Autoregressive code-0 decoder. Conforms to `CodeDecoding`.
+ public var codeDecoder: any CodeDecoding = Qwen3CodeDecoder()
+ /// Per-frame decoder. Conforms to `MultiCodeDecoding`.
+ public var multiCodeDecoder: any MultiCodeDecoding = Qwen3MultiCodeDecoder()
+ /// RVQ codes -> audio waveform. Conforms to `SpeechDecoding`.
+ public var speechDecoder: any SpeechDecoding = Qwen3SpeechDecoder()
+ /// Tokenizer. `nil` before the first `loadModels()` call or after `unloadModels()`.
+ public var tokenizer: (any Tokenizer)?
+
+ // MARK: - Model state
+
+ /// Current lifecycle state of the loaded models.
+ /// Mirrors `WhisperKit.modelState`. Transitions:
+ /// `.unloaded` -> `.downloading` -> `.downloaded` -> `.loading` -> `.loaded`
+ /// `.unloaded` -> `.prewarming` -> `.prewarmed`
+ public private(set) var modelState: ModelState = .unloaded {
+ didSet { modelStateCallback?(oldValue, modelState) }
+ }
+
+ // MARK: - Configuration & timing
+
+ public var config: TTSKitConfig
+
+ /// Direct accessor for the resolved local model folder.
+ ///
+ /// Mirrors `WhisperKit.modelFolder`. Backed by `config.modelFolder`; set by
+ /// `setupModels()` and may also be assigned directly for offline usage.
+ public var modelFolder: URL? {
+ get { config.modelFolder }
+ set { config.modelFolder = newValue }
+ }
+
+ /// Whether to use a background `URLSession` for model downloads.
+ ///
+ /// Mirrors `WhisperKit.useBackgroundDownloadSession`. Backed by
+ /// `config.useBackgroundDownloadSession`.
+ public var useBackgroundDownloadSession: Bool {
+ get { config.useBackgroundDownloadSession }
+ set { config.useBackgroundDownloadSession = newValue }
+ }
+
+ /// Cumulative timings for the most recent pipeline run.
+ /// `modelLoading` and `tokenizerLoadTime` are populated after `loadModels()`.
+ public private(set) var currentTimings = SpeechTimings()
+
+ /// Wall-clock seconds for the most recent full model load.
+ public var modelLoadTime: TimeInterval { currentTimings.modelLoading }
+ /// Wall-clock seconds for the most recent tokenizer load.
+ public var tokenizerLoadTime: TimeInterval { currentTimings.tokenizerLoadTime }
+
+ // MARK: - Audio output
+
+ /// Audio output used by `play`.
+ /// `AudioOutput` is playback-only; WhisperKit's `AudioProcessor` is capture-only.
+ /// They serve complementary roles and do not need to be merged.
+ public let audioOutput = AudioOutput()
+
+ // MARK: - Prompt cache
+
+ /// Cached prefix state for the most recently used voice/language/instruction.
+ /// Automatically built on the first `generate` call and reused for subsequent
+ /// calls with the same parameters. Set to `nil` to force a full prefill.
+ public var promptCache: TTSPromptCache?
+
+ // MARK: - Callbacks
+
+ /// Invoked whenever `modelState` changes.
+ public var modelStateCallback: ModelStateCallback?
+
+ // MARK: - Seed
+
+ public let seed: UInt64?
+ private var taskCounter: UInt64 = 0
+
+ // MARK: - Initialization
+
+ /// Create a `TTSKit` instance from a `TTSKitConfig`.
+ ///
+ /// Uses the component overrides in `config` if set; otherwise instantiates the default
+ /// components for the selected model family. Components can also be replaced after init.
+ ///
+ /// - Parameter config: Pipeline configuration (model variant, paths, compute units,
+ /// component overrides, lifecycle flags).
+ /// - Throws: `TTSError` if the model family is unsupported or component instantiation fails.
+ public init(_ config: TTSKitConfig = TTSKitConfig()) async throws {
+ self.config = config
+ self.seed = config.seed
+
+ Logging.shared.logLevel = config.verbose ? config.logLevel : .none
+
+ setupPipeline(for: config.model, config: config)
+
+ // Resolve or download the model folder so that the load condition below
+ // can use `config.modelFolder != nil` as its auto-load sentinel.
+ try await setupModels(
+ model: config.model,
+ downloadBase: config.downloadBase,
+ modelRepo: config.modelRepo,
+ modelToken: config.modelToken,
+ modelFolder: config.modelFolder,
+ download: config.download,
+ endpoint: config.modelEndpoint
+ )
+
+ if let prewarm = config.prewarm, prewarm {
+ Logging.info("Prewarming models...")
+ try await prewarmModels()
+ }
+
+ // Load if explicitly requested, or if a local folder is now available
+ // (either provided directly or populated by setupModels above).
+ if config.load ?? (config.modelFolder != nil) {
+ Logging.info("Loading models...")
+ try await loadModels()
+ }
+ }
+
+ /// Convenience initializer that exposes all configuration fields as individual parameters.
+ ///
+ /// Mirrors `WhisperKit.init(model:modelFolder:...)`. Constructs a `TTSKitConfig` and
+ /// delegates to `init(_ config:)`.
+ public convenience init(
+ model: TTSModelVariant = .qwen3TTS_0_6b,
+ modelFolder: URL? = nil,
+ downloadBase: URL? = nil,
+ modelRepo: String = Qwen3TTSConstants.defaultModelRepo,
+ tokenizerFolder: URL? = nil,
+ modelToken: String? = nil,
+ computeOptions: ComputeOptions? = nil,
+ textProjector: (any TextProjecting)? = nil,
+ codeEmbedder: (any CodeEmbedding)? = nil,
+ multiCodeEmbedder: (any MultiCodeEmbedding)? = nil,
+ codeDecoder: (any CodeDecoding)? = nil,
+ multiCodeDecoder: (any MultiCodeDecoding)? = nil,
+ speechDecoder: (any SpeechDecoding)? = nil,
+ verbose: Bool = false,
+ logLevel: Logging.LogLevel = .debug,
+ prewarm: Bool? = nil,
+ load: Bool? = nil,
+ download: Bool = true,
+ useBackgroundDownloadSession: Bool = false,
+ seed: UInt64? = nil
+ ) async throws {
+ let config = TTSKitConfig(
+ model: model,
+ modelFolder: modelFolder,
+ downloadBase: downloadBase,
+ modelRepo: modelRepo,
+ tokenizerFolder: tokenizerFolder,
+ modelToken: modelToken,
+ computeOptions: computeOptions ?? ComputeOptions(),
+ verbose: verbose,
+ logLevel: logLevel,
+ useBackgroundDownloadSession: useBackgroundDownloadSession,
+ download: download,
+ prewarm: prewarm,
+ load: load,
+ seed: seed
+ )
+ config.textProjector = textProjector
+ config.codeEmbedder = codeEmbedder
+ config.multiCodeEmbedder = multiCodeEmbedder
+ config.codeDecoder = codeDecoder
+ config.multiCodeDecoder = multiCodeDecoder
+ config.speechDecoder = speechDecoder
+ try await self.init(config)
+ }
+
+ // MARK: - Pipeline setup
+
+ /// Configure the model-specific component properties for the active model family.
+ ///
+ /// Uses the component overrides in `config` if set; otherwise instantiates the
+ /// default components for the given variant's model family. Called from `init` and
+ /// can be called again to reconfigure the pipeline for a different variant.
+ ///
+ /// Mirrors how WhisperKit configures its encoder/decoder components based on the
+ /// selected model.
+ open func setupPipeline(for variant: TTSModelVariant, config: TTSKitConfig) {
+ switch variant.family {
+ case .qwen3:
+ self.textProjector = config.textProjector ?? Qwen3TextProjector()
+ self.codeEmbedder = config.codeEmbedder ?? Qwen3CodeEmbedder()
+ self.multiCodeEmbedder = config.multiCodeEmbedder ?? Qwen3MultiCodeEmbedder()
+ self.codeDecoder = config.codeDecoder ?? Qwen3CodeDecoder()
+ self.multiCodeDecoder = config.multiCodeDecoder ?? Qwen3MultiCodeDecoder()
+ self.speechDecoder = config.speechDecoder ?? Qwen3SpeechDecoder()
+ }
+ }
+
+ // MARK: - Model Discovery
+
+ /// Returns the recommended model variant for the current platform.
+ ///
+ /// Mirrors `WhisperKit.recommendedModels()`.
+ public static func recommendedModels() -> TTSModelVariant {
+ return TTSModelVariant.defaultForCurrentPlatform
+ }
+
+ /// Fetch all available model variants from the HuggingFace Hub.
+ ///
+ /// Mirrors `WhisperKit.fetchAvailableModels(from:matching:downloadBase:token:)`.
+ ///
+ /// - Parameters:
+ /// - repo: HuggingFace repo ID to query. Defaults to the standard Qwen3 TTS repo.
+ /// - matching: Glob patterns to filter returned variant names. Defaults to `["*"]` (all variants).
+ /// - downloadBase: Optional base URL for Hub downloads.
+ /// - token: HuggingFace API token (or set `HF_TOKEN` env var).
+ /// - endpoint: HuggingFace Hub endpoint URL.
+ /// - Returns: Display names of available model variants matching the given patterns.
+ /// - Throws: `TTSError` if the Hub request fails.
+ public static func fetchAvailableModels(
+ from repo: String = Qwen3TTSConstants.defaultModelRepo,
+ matching: [String] = ["*"],
+ downloadBase: URL? = nil,
+ token: String? = nil,
+ endpoint: String = Qwen3TTSConstants.defaultEndpoint
+ ) async throws -> [String] {
+ let hubApi = HubApi(downloadBase: downloadBase, hfToken: token, endpoint: endpoint)
+ let hubRepo = Hub.Repo(id: repo, type: .models)
+ let files = try await hubApi.getFilenames(from: hubRepo, matching: ["qwen3_tts/**"])
+ var variants: [String] = []
+ for variant in TTSModelVariant.allCases {
+ let prefix = "qwen3_tts/code_decoder/\(variant.versionDir)/"
+ if files.contains(where: { $0.hasPrefix(prefix) }) {
+ variants.append(variant.displayName)
+ }
+ }
+ let allVariants = variants.isEmpty ? TTSModelVariant.allCases.map(\.displayName) : variants
+ var filteredVariants: Set = []
+ for glob in matching {
+ filteredVariants = filteredVariants.union(allVariants.matching(glob: glob))
+ }
+ return Array(filteredVariants)
+ }
+
+ // MARK: - Download
+
+ /// Download models for a specific variant from HuggingFace Hub.
+ ///
+ /// Mirrors `WhisperKit.download(variant:downloadBase:from:token:progressCallback:)`.
+ ///
+ /// - Parameters:
+ /// - variant: The model variant to download.
+ /// - downloadBase: Base URL for the local cache. Defaults to the Hub library default.
+ /// - useBackgroundSession: Use a background `URLSession` for the download.
+ /// - repo: HuggingFace repo ID. Defaults to the standard Qwen3 TTS repo.
+ /// - token: HuggingFace API token (or set `HF_TOKEN` env var).
+ /// - endpoint: HuggingFace Hub endpoint URL.
+ /// - revision: Specific git revision (commit SHA, tag, or branch) to download.
+ /// - additionalPatterns: Extra glob patterns to include alongside the default component patterns.
+ /// - progressCallback: Optional closure receiving download progress updates.
+ /// - Returns: Local URL of the downloaded model folder.
+ /// - Throws: `TTSError` if the Hub download fails.
+ open class func download(
+ variant: TTSModelVariant = .defaultForCurrentPlatform,
+ downloadBase: URL? = nil,
+ useBackgroundSession: Bool = false,
+ from repo: String = Qwen3TTSConstants.defaultModelRepo,
+ token: String? = nil,
+ endpoint: String = Qwen3TTSConstants.defaultEndpoint,
+ revision: String? = nil,
+ additionalPatterns: [String] = [],
+ progressCallback: (@Sendable (Progress) -> Void)? = nil
+ ) async throws -> URL {
+ let config = TTSKitConfig(
+ model: variant,
+ downloadBase: downloadBase,
+ modelRepo: repo,
+ modelToken: token,
+ modelEndpoint: endpoint,
+ downloadRevision: revision,
+ downloadAdditionalPatterns: additionalPatterns,
+ useBackgroundDownloadSession: useBackgroundSession
+ )
+ return try await download(config: config, progressCallback: progressCallback)
+ }
+
+ /// Download models using a full `TTSKitConfig`.
+ ///
+ /// Downloads only the files matching the configured component variants.
+ /// Files are cached locally by the Hub library.
+ ///
+ /// - Parameters:
+ /// - config: Pipeline configuration containing `modelRepo`, `modelToken`,
+ /// `downloadRevision`, `downloadAdditionalPatterns`, and variant settings.
+ /// - progressCallback: Optional closure receiving download progress updates.
+ /// - Returns: Local URL of the downloaded model folder.
+ /// - Throws: `TTSError` if the Hub download fails.
+ open class func download(
+ config: TTSKitConfig = TTSKitConfig(),
+ progressCallback: (@Sendable (Progress) -> Void)? = nil
+ ) async throws -> URL {
+ let hubApi = HubApi(downloadBase: config.downloadBase, hfToken: config.modelToken, endpoint: config.modelEndpoint)
+ let hubRepo = Hub.Repo(id: config.modelRepo, type: .models)
+ let patterns = config.downloadPatterns + config.downloadAdditionalPatterns
+
+ do {
+ return try await hubApi.snapshot(
+ from: hubRepo,
+ revision: config.downloadRevision ?? "main",
+ matching: patterns
+ ) { progress in
+ progressCallback?(progress)
+ }
+ } catch {
+ throw TTSError.generationFailed(
+ "Failed to download models from \(config.modelRepo). Check that the repo exists and you have access. Error: \(error.localizedDescription)"
+ )
+ }
+ }
+
+ // MARK: - Model lifecycle
+
+ /// Resolve the local model folder, downloading from HuggingFace Hub if needed.
+ ///
+ /// Mirrors `WhisperKit.setupModels(model:downloadBase:modelRepo:...)`. Populates
+ /// `config.modelFolder` so that `loadModels()` can be called immediately after.
+ /// Separated from `loadModels()` so callers can call setup once and load separately,
+ /// or override just the resolution logic without touching the load path.
+ ///
+ /// - Parameters:
+ /// - model: Model variant to download. `nil` uses `config.model`.
+ /// - downloadBase: Base URL for Hub cache. `nil` uses the Hub library default.
+ /// - modelRepo: HuggingFace repo ID. `nil` uses `config.modelRepo`.
+ /// - modelToken: HuggingFace API token. `nil` uses `config.modelToken`.
+ /// - modelFolder: Explicit local folder URL. When non-nil the download is skipped.
+ /// - download: When `true` and `modelFolder` is nil, download from the resolved repo.
+ /// - endpoint: HuggingFace Hub endpoint URL. Defaults to `Qwen3TTSConstants.defaultEndpoint`.
+ /// - Throws: `TTSError` if the download fails or the model folder cannot be resolved.
+ open func setupModels(
+ model: TTSModelVariant? = nil,
+ downloadBase: URL? = nil,
+ modelRepo: String? = nil,
+ modelToken: String? = nil,
+ modelFolder: URL? = nil,
+ download: Bool,
+ endpoint: String = Qwen3TTSConstants.defaultEndpoint
+ ) async throws {
+ if let folder = modelFolder {
+ config.modelFolder = folder
+ } else if download {
+ let resolvedModel = model ?? config.model
+ let resolvedRepo = modelRepo ?? config.modelRepo
+ let resolvedToken = modelToken ?? config.modelToken
+
+ let downloadConfig = TTSKitConfig(
+ model: resolvedModel,
+ downloadBase: downloadBase ?? config.downloadBase,
+ modelRepo: resolvedRepo,
+ modelToken: resolvedToken,
+ modelEndpoint: endpoint,
+ downloadRevision: config.downloadRevision,
+ downloadAdditionalPatterns: config.downloadAdditionalPatterns,
+ useBackgroundDownloadSession: config.useBackgroundDownloadSession
+ )
+
+ modelState = .downloading
+ Logging.info("Downloading models from \(resolvedRepo)...")
+ do {
+ let folder = try await TTSKit.download(config: downloadConfig) { progress in
+ let percent = Int(progress.fractionCompleted * 100)
+ Logging.debug(" Download: \(percent)%")
+ }
+ config.modelFolder = folder
+ modelState = .downloaded
+ Logging.info("Models cached at \(folder.path)")
+ } catch {
+ modelState = .unloaded
+ throw TTSError.modelNotFound(
+ "Model download failed. Check the repo name and try again. Error: \(error)"
+ )
+ }
+ }
+ }
+
+ /// Prewarm all CoreML models by compiling them sequentially, then discarding weights.
+ ///
+ /// Serializes CoreML compilation to cap peak memory. Call before `loadModels()` on
+ /// first launch or after a model update. Mirrors `WhisperKit.prewarmModels()`.
+ open func prewarmModels() async throws {
+ try await loadModels(prewarmMode: true)
+ }
+
+ /// Load all models and the tokenizer.
+ ///
+ /// Expects `config.modelFolder` to be set (call `setupModels` first if needed).
+ /// Mirrors `WhisperKit.loadModels(prewarmMode:)`.
+ ///
+ /// - Parameter prewarmMode: When `true`, compile models one at a time and discard weights
+ /// to limit peak memory (prewarm). When `false` (default), load all concurrently.
+ /// - Throws: `TTSError` if model compilation or tokenizer loading fails.
+ open func loadModels(prewarmMode: Bool = false) async throws {
+ modelState = prewarmMode ? .prewarming : .loading
+
+ let embedUnits = config.computeOptions.embedderComputeUnits
+ let cdUnits = config.computeOptions.codeDecoderComputeUnits
+ let mcdUnits = config.computeOptions.multiCodeDecoderComputeUnits
+ let sdUnits = config.computeOptions.speechDecoderComputeUnits
+
+ guard let modelFolder = config.modelFolder,
+ FileManager.default.fileExists(atPath: modelFolder.path)
+ else {
+ modelState = .unloaded
+ throw TTSError.modelNotFound(config.modelFolder?.path ?? "")
+ }
+
+ // Resolve all six component URLs. A nil result means the .mlmodelc bundle is
+ // missing from disk - surface this immediately rather than crashing later.
+ func requireURL(_ component: String, _ variant: String) throws -> URL {
+ guard let url = config.modelURL(component: component, variant: variant) else {
+ throw TTSError.invalidConfiguration(
+ "No .mlmodelc found at \(component)/\(config.versionDir)/\(variant) inside \(modelFolder.path)."
+ )
+ }
+ return url
+ }
+ let tpURL = try requireURL("text_projector", config.textProjectorVariant)
+ let ceURL = try requireURL("code_embedder", config.codeEmbedderVariant)
+ let mceURL = try requireURL("multi_code_embedder", config.multiCodeEmbedderVariant)
+ let cdURL = try requireURL("code_decoder", config.codeDecoderVariant)
+ let mcdURL = try requireURL("multi_code_decoder", config.multiCodeDecoderVariant)
+ let sdURL = try requireURL("speech_decoder", config.speechDecoderVariant)
+
+ // Load tokenizer (skipped in prewarm - only CoreML compilation needed).
+ if !prewarmMode {
+ try await loadTokenizerIfNeeded()
+ }
+
+ // Load the six CoreML models.
+ // Prewarm: sequential to serialize compilation -> lower peak memory.
+ // Normal: concurrent since compiled artifacts are already cached.
+ let modelLoadStart = CFAbsoluteTimeGetCurrent()
+
+ if prewarmMode {
+ Logging.info("Prewarming 6 CoreML models sequentially (serializing compilation)...")
+ try await textProjector.loadModel(at: tpURL, computeUnits: embedUnits, prewarmMode: true)
+ try await codeEmbedder.loadModel(at: ceURL, computeUnits: embedUnits, prewarmMode: true)
+ try await multiCodeEmbedder.loadModel(at: mceURL, computeUnits: embedUnits, prewarmMode: true)
+ try await codeDecoder.loadModel(at: cdURL, computeUnits: cdUnits, prewarmMode: true)
+ try await multiCodeDecoder.loadModel(at: mcdURL, computeUnits: mcdUnits, prewarmMode: true)
+ try await speechDecoder.loadModel(at: sdURL, computeUnits: sdUnits, prewarmMode: true)
+ Logging.info(String(format: "Prewarm complete in %.2fs", CFAbsoluteTimeGetCurrent() - modelLoadStart))
+ modelState = .prewarmed
+ } else {
+ Logging.info("Loading 6 CoreML models concurrently...")
+ Logging.debug(" TextProjector: \(tpURL.lastPathComponent) compute: \(embedUnits.description)")
+ Logging.debug(" CodeEmbedder: \(ceURL.lastPathComponent) compute: \(embedUnits.description)")
+ Logging.debug(" MultiCodeEmbedder: \(mceURL.lastPathComponent) compute: \(embedUnits.description)")
+ Logging.debug(" CodeDecoder: \(cdURL.lastPathComponent) (\(config.codeDecoderVariant), compute: \(cdUnits.description))")
+ Logging.debug(" MultiCodeDecoder: \(mcdURL.lastPathComponent) (\(config.multiCodeDecoderVariant), compute: \(mcdUnits.description))")
+ Logging.debug(" SpeechDecoder: \(sdURL.lastPathComponent) (\(config.speechDecoderVariant), compute: \(sdUnits.description))")
+
+ async let loadTP: Void = textProjector.loadModel(at: tpURL, computeUnits: embedUnits)
+ async let loadCE: Void = codeEmbedder.loadModel(at: ceURL, computeUnits: embedUnits)
+ async let loadMCE: Void = multiCodeEmbedder.loadModel(at: mceURL, computeUnits: embedUnits)
+ async let loadCD: Void = codeDecoder.loadModel(at: cdURL, computeUnits: cdUnits)
+ async let loadMCD: Void = multiCodeDecoder.loadModel(at: mcdURL, computeUnits: mcdUnits)
+ async let loadSD: Void = speechDecoder.loadModel(at: sdURL, computeUnits: sdUnits)
+ _ = try await (loadTP, loadCE, loadMCE, loadCD, loadMCD, loadSD)
+
+ currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart
+
+ // Sync audio output sample rate to the loaded speech decoder.
+ audioOutput.configure(sampleRate: speechDecoder.sampleRate)
+
+ Logging.info(String(format: "Total model load: %.2fs", modelLoadTime))
+ modelState = .loaded
+ }
+ }
+
+ /// Load the tokenizer only if it has not been loaded yet.
+ ///
+ /// Mirrors `WhisperKit.loadTokenizerIfNeeded()`. Skips loading when `tokenizer` is
+ /// already set, avoiding redundant network calls or file-system work on repeated
+ /// `loadModels()` calls.
+ open func loadTokenizerIfNeeded() async throws {
+ guard tokenizer == nil else {
+ Logging.debug("Tokenizer already loaded, skipping")
+ return
+ }
+ self.tokenizer = try await loadTokenizer()
+ }
+
+ /// Load the tokenizer from `config.tokenizerSource`.
+ ///
+ /// Checks for a local `tokenizer.json` file first; falls back to downloading from
+ /// the Hugging Face Hub if no local file is found. Updates `currentTimings.tokenizerLoadTime`.
+ ///
+ /// Override this method to plug in a custom tokenizer loading strategy (e.g. fully
+ /// offline from a bundled path) without touching the rest of `loadModels()`.
+ open func loadTokenizer() async throws -> any Tokenizer {
+ let start = CFAbsoluteTimeGetCurrent()
+ Logging.info("Loading tokenizer from \(config.tokenizerSource)...")
+ let tokenizerURL = URL(fileURLWithPath: config.tokenizerSource)
+ let tokenizer: any Tokenizer
+ if FileManager.default.fileExists(atPath: tokenizerURL.appending(path: "tokenizer.json").path) {
+ tokenizer = try await AutoTokenizer.from(modelFolder: tokenizerURL)
+ } else {
+ tokenizer = try await AutoTokenizer.from(pretrained: config.tokenizerSource)
+ }
+ currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - start
+ Logging.info(String(format: "Tokenizer loaded in %.2fs", tokenizerLoadTime))
+ return tokenizer
+ }
+
+ /// Release all model weights and the tokenizer from memory.
+ ///
+ /// Mirrors `WhisperKit.unloadModels()`. Transitions through `.unloading` before
+ /// reaching `.unloaded` so observers can distinguish the in-progress state.
+ open func unloadModels() async {
+ modelState = .unloading
+ textProjector.unloadModel()
+ codeEmbedder.unloadModel()
+ multiCodeEmbedder.unloadModel()
+ codeDecoder.unloadModel()
+ multiCodeDecoder.unloadModel()
+ speechDecoder.unloadModel()
+ tokenizer = nil
+ modelState = .unloaded
+ Logging.info("Unloaded all models")
+ }
+
+ /// Reset all accumulated timing statistics.
+ ///
+ /// Mirrors `WhisperKit.clearState()`. Call between generation runs when you want
+ /// fresh per-run timing data without triggering a full reload.
+ open func clearState() {
+ currentTimings = SpeechTimings()
+ }
+
+ deinit {
+ Task { [audioOutput] in
+ await audioOutput.stopPlayback(waitForCompletion: false)
+ }
+ }
+
+ /// Register a custom log sink for all `Logging` output from TTSKit.
+ ///
+ /// Mirrors `WhisperKit.loggingCallback(_:)`. Pass `nil` to restore the default
+ /// print-based logger.
+ open func loggingCallback(_ callback: Logging.LoggingCallback?) {
+ Logging.shared.loggingCallback = callback
+ }
+
+ // MARK: - Prompt cache management
+
+ /// Build a prompt cache for the given voice/language/instruction combination.
+ ///
+ /// Pre-computes the invariant prefix embeddings and prefills them through the
+ /// CodeDecoder, returning a reusable cache that eliminates ~90% of prefill cost
+ /// on subsequent `generate` calls.
+ ///
+ /// The cache is stored on `self.promptCache` for automatic reuse. Delegates to
+ /// `Qwen3GenerateTask.buildPromptCache` on the task returned by
+ /// `setupGenerateTask(...)`, so Qwen3 models get prompt caching automatically.
+ ///
+ /// - Parameters:
+ /// - voice: Voice/speaker identifier. `nil` uses the model's `defaultVoice`.
+ /// - language: Language identifier. `nil` uses the model's `defaultLanguage`.
+ /// - instruction: Optional style instruction prepended to the TTS prompt.
+ /// - Returns: The built `TTSPromptCache` that can be passed to subsequent `generate` calls.
+ /// - Throws: `TTSError` if the model is not loaded or prompt caching is unsupported.
+ @discardableResult
+ open func buildPromptCache(
+ voice: String? = nil,
+ language: String? = nil,
+ instruction: String? = nil
+ ) async throws -> TTSPromptCache {
+ let task = try createTask()
+ let resolvedVoice = voice ?? task.defaultVoice
+ let resolvedLanguage = language ?? task.defaultLanguage
+ guard let qwen3Task = task as? Qwen3GenerateTask else {
+ throw TTSError.generationFailed("Prompt caching is not supported by this model family.")
+ }
+ let cache = try await qwen3Task.buildPromptCache(voice: resolvedVoice, language: resolvedLanguage, instruction: instruction)
+ self.promptCache = cache
+ return cache
+ }
+
+ /// Save the current prompt cache to disk under the model's embeddings directory.
+ ///
+ /// The file is saved at `/embeddings/_.promptcache`.
+ public func savePromptCache() throws {
+ guard let cache = promptCache else { return }
+ guard let url = promptCacheURL(for: cache) else {
+ throw TTSError.generationFailed("Cannot determine prompt cache path (modelFolder not set)")
+ }
+ try cache.save(to: url)
+ Logging.info("Saved prompt cache to \(url.path)")
+ }
+
+ /// Load a prompt cache from disk if one exists for the given parameters.
+ ///
+ /// Returns `nil` if no cached file exists. Also stores the loaded cache
+ /// on `self.promptCache` for automatic reuse.
+ ///
+ /// - Parameters:
+ /// - voice: Voice/speaker identifier.
+ /// - language: Language identifier.
+ /// - instruction: Optional style instruction.
+ /// - Returns: The loaded cache, or `nil` if not found.
+ @discardableResult
+ public func loadPromptCache(
+ voice: String,
+ language: String,
+ instruction: String? = nil
+ ) -> TTSPromptCache? {
+ let probe = TTSPromptCache(
+ voice: voice, language: language, instruction: instruction,
+ prefixLength: 0,
+ kvSnapshot: KVCacheSnapshot(
+ isStateful: false, cacheDim: 0, maxSeqLength: 0, cacheLength: 0,
+ keyCacheData: Data(), valueCacheData: Data(),
+ updateMaskData: Data(), paddingMaskData: Data()
+ ),
+ stateData: nil
+ )
+ guard let url = promptCacheURL(for: probe),
+ FileManager.default.fileExists(atPath: url.path)
+ else { return nil }
+ do {
+ let cache = try TTSPromptCache.load(from: url)
+ self.promptCache = cache
+ Logging.info("Loaded prompt cache from \(url.path)")
+ return cache
+ } catch {
+ Logging.error("Failed to load prompt cache: \(error)")
+ return nil
+ }
+ }
+
+ private func promptCacheURL(for cache: TTSPromptCache) -> URL? {
+ guard let modelFolder = config.modelFolder else { return nil }
+ return
+ modelFolder
+ .appendingPathComponent("embeddings")
+ .appendingPathComponent(cache.cacheFileName)
+ }
+
+ // MARK: - Task factory
+
+ /// Setup the generate task used for speech synthesis.
+ /// Subclasses may override to provide custom behavior.
+ ///
+ /// Mirrors `WhisperKit.setupTranscribeTask(...)`. Model-agnostic params are passed
+ /// explicitly; model-specific components are accessed from `self` (configured by
+ /// `setupPipeline`).
+ open func setupGenerateTask(
+ currentTimings: SpeechTimings,
+ progress: Progress,
+ tokenizer: any Tokenizer,
+ sampler: any TokenSampling
+ ) throws -> any SpeechGenerating {
+ switch config.model.family {
+ case .qwen3:
+ guard let qwen3TextProjector = textProjector as? Qwen3TextProjector,
+ let qwen3CodeEmbedder = codeEmbedder as? Qwen3CodeEmbedder,
+ let qwen3MultiCodeEmbedder = multiCodeEmbedder as? Qwen3MultiCodeEmbedder,
+ let qwen3CodeDecoder = codeDecoder as? Qwen3CodeDecoder,
+ let qwen3MultiCodeDecoder = multiCodeDecoder as? Qwen3MultiCodeDecoder,
+ let qwen3SpeechDecoder = speechDecoder as? Qwen3SpeechDecoder
+ else {
+ throw TTSError.generationFailed("Qwen3 model family requires Qwen3-specific model components")
+ }
+ return Qwen3GenerateTask(
+ textProjector: qwen3TextProjector,
+ codeEmbedder: qwen3CodeEmbedder,
+ multiCodeEmbedder: qwen3MultiCodeEmbedder,
+ codeDecoder: qwen3CodeDecoder,
+ multiCodeDecoder: qwen3MultiCodeDecoder,
+ speechDecoder: qwen3SpeechDecoder,
+ sampler: sampler,
+ tokenizer: tokenizer,
+ suppressTokenIds: Qwen3TTSConstants.suppressTokenIds,
+ loadTimings: currentTimings,
+ progress: progress
+ )
+ }
+ }
+
+ /// Create a fresh generation task with the guard/seed/counter boilerplate.
+ ///
+ /// Each call returns an independent task with its own sampler seed and per-task
+ /// buffers. Delegates to `setupGenerateTask(...)` for the actual construction.
+ open func createTask(progress: Progress? = nil) throws -> any SpeechGenerating {
+ guard let tokenizer else {
+ throw TTSError.tokenizerUnavailable("Tokenizer is not loaded. Call loadModels() before generating speech.")
+ }
+ let derivedSeed: UInt64? = seed.map { $0 ^ taskCounter }
+ taskCounter += 1
+ return try setupGenerateTask(
+ currentTimings: currentTimings,
+ progress: progress ?? Progress(),
+ tokenizer: tokenizer,
+ sampler: GreedyTokenSampler(seed: derivedSeed)
+ )
+ }
+
+ // MARK: - Speech generation
+
+ /// Synthesize speech from text and return the complete audio result.
+ ///
+ /// Mirrors `WhisperKit.transcribe(audioPath:decodeOptions:callback:)`.
+ /// Handles text chunking, optional prompt caching, and concurrent multi-chunk generation.
+ ///
+ /// - Parameters:
+ /// - text: The text to synthesize.
+ /// - voice: Voice/speaker identifier. Format is model-specific (e.g. `"ryan"` for Qwen3 TTS).
+ /// - language: Language identifier. Format is model-specific (e.g. `"english"` for Qwen3 TTS).
+ /// - options: Sampling and generation options.
+ /// - callback: Optional per-step callback receiving decoded audio chunks.
+ /// Return `false` to cancel; `nil` or `true` to continue.
+ /// - Returns: A `SpeechResult` containing the raw audio samples and timing breakdown.
+ /// - Throws: `TTSError` if text is empty, models are not loaded, or generation fails.
+ open func generate(
+ text: String,
+ voice: String? = nil,
+ language: String? = nil,
+ options: GenerationOptions = GenerationOptions(),
+ callback: SpeechCallback = nil
+ ) async throws -> SpeechResult {
+ // Auto-load models if they have not been loaded yet, mirroring WhisperKit's
+ // runTranscribeTask which calls loadModels() when modelState != .loaded.
+ if modelState != .loaded {
+ try await loadModels()
+ }
+
+ try Task.checkCancellation()
+
+ // Create the primary task to resolve model-specific defaults for voice/language.
+ // This task is also reused for the single-chunk fast path to avoid a second allocation.
+ let primaryTask = try createTask()
+ let resolvedVoice = voice ?? primaryTask.defaultVoice
+ let resolvedLanguage = language ?? primaryTask.defaultLanguage
+
+ // Build prompt cache ahead of time if none exists or current doesn't match.
+ let cache: TTSPromptCache?
+ if let existing = promptCache, existing.matches(voice: resolvedVoice, language: resolvedLanguage, instruction: options.instruction) {
+ cache = existing
+ } else if tokenizer != nil {
+ cache = try await buildPromptCache(voice: resolvedVoice, language: resolvedLanguage, instruction: options.instruction)
+ } else {
+ cache = nil
+ }
+
+ let effectiveStrategy = options.chunkingStrategy ?? .sentence
+ let textChunks: [String]
+ if effectiveStrategy == .none || tokenizer == nil {
+ textChunks = [text]
+ } else {
+ guard let tokenizer else {
+ throw TTSError.tokenizerUnavailable("Tokenizer is not loaded. Call loadModels() before generating speech.")
+ }
+ let chunker = TextChunker(
+ targetChunkSize: options.targetChunkSize ?? TextChunker.defaultTargetChunkSize,
+ minChunkSize: options.minChunkSize ?? TextChunker.defaultMinChunkSize,
+ tokenizer: tokenizer
+ )
+ let chunks = chunker.chunk(text)
+ textChunks = chunks.isEmpty ? [text] : chunks
+ }
+
+ // Single-chunk fast path: reuse primaryTask (already allocated above).
+ if textChunks.count == 1 {
+ return try await primaryTask.run(
+ text: textChunks[0],
+ voice: resolvedVoice,
+ language: resolvedLanguage,
+ options: options,
+ callback: callback,
+ prefixCache: cache
+ )
+ }
+
+ let workerDesc = options.concurrentWorkerCount == 0 ? "max" : "\(options.concurrentWorkerCount)"
+ Logging.info("Chunked TTS: \(textChunks.count) chunks, concurrency=\(workerDesc)")
+ for (i, chunk) in textChunks.enumerated() {
+ let truncated = chunk.count > 60 ? "\(chunk.prefix(60))..." : chunk
+ Logging.debug(" Chunk \(i): \"\(truncated)\" (\(chunk.count) chars)")
+ }
+
+ let pipelineStart = CFAbsoluteTimeGetCurrent()
+ var combinedTimings = SpeechTimings()
+ combinedTimings.modelLoading = currentTimings.modelLoading
+ combinedTimings.tokenizerLoadTime = currentTimings.tokenizerLoadTime
+
+ let crossfadeSamples = primaryTask.sampleRate / 10 // 100ms crossfade
+ var chunkAudioArrays = [[Float]](repeating: [], count: textChunks.count)
+
+ let totalChunks = textChunks.count
+
+ let maxSteps = totalChunks * options.maxNewTokens
+
+ if options.concurrentWorkerCount == 1 {
+ var stepsSoFar = 0
+ for (i, chunkText) in textChunks.enumerated() {
+ Logging.debug(String(format: " Generating chunk %d/%d...", i + 1, totalChunks))
+ let chunkStepBase = stepsSoFar
+ let wrappedCallback: SpeechCallback = callback.map { cb in
+ { @Sendable progress in
+ var p = progress
+ p.chunkIndex = i
+ p.totalChunks = totalChunks
+ p.stepsCompleted = chunkStepBase + Int(progress.timings.totalDecodingLoops)
+ p.totalSteps = maxSteps
+ return cb(p)
+ }
+ }
+ let chunkResult = try await (createTask()).run(
+ text: chunkText, voice: resolvedVoice, language: resolvedLanguage,
+ options: options, callback: wrappedCallback, prefixCache: cache
+ )
+ stepsSoFar += options.maxNewTokens
+ chunkAudioArrays[i] = chunkResult.audio
+ combinedTimings.merge(chunkResult.timings)
+ if i == 0 { combinedTimings.timeToFirstBuffer = chunkResult.timings.timeToFirstBuffer }
+ Logging.debug(
+ String(
+ format: " Chunk %d done: %.2fs audio (%d steps)",
+ i + 1, chunkResult.audioDuration, Int(chunkResult.timings.totalDecodingLoops))
+ )
+ }
+ } else {
+ let indexedChunks = textChunks.enumerated().map { (index: $0.offset, text: $0.element) }
+
+ let effectiveWorkers = options.concurrentWorkerCount == 0 ? indexedChunks.count : options.concurrentWorkerCount
+
+ let batchedChunks: [[(index: Int, text: String)]]
+ batchedChunks = stride(from: 0, to: indexedChunks.count, by: effectiveWorkers).map {
+ Array(indexedChunks[$0.. Int in
+ state += 1
+ return state
+ }
+ let stepProgress = SpeechProgress(
+ audio: [], timings: progress.timings,
+ totalChunks: chunkCount,
+ stepsCompleted: steps, totalSteps: maxSteps
+ )
+ return unwrappedCallback(stepProgress)
+ }
+ }
+
+ let maxNewTokens = options.maxNewTokens
+ let batchResults: [(index: Int, result: SpeechResult)] = try await withThrowingTaskGroup(
+ of: (index: Int, result: SpeechResult).self
+ ) { group in
+ for item in taskItems {
+ group.addTask {
+ Logging.debug(String(format: " Starting chunk %d/%d...", item.index + 1, chunkCount))
+ let chunkResult = try await item.task.run(
+ text: item.text, voice: resolvedVoice, language: resolvedLanguage,
+ options: options, callback: workerCallback, prefixCache: cache
+ )
+ // Snap progress forward to the full budget for this chunk
+ let actualSteps = Int(chunkResult.timings.totalDecodingLoops)
+ let remaining = maxNewTokens - actualSteps
+ if remaining > 0 {
+ stepCounter.withLock { $0 += remaining }
+ }
+ Logging.debug(
+ String(
+ format: " Chunk %d done: %.2fs audio (%d steps)",
+ item.index + 1, chunkResult.audioDuration, actualSteps))
+ return (index: item.index, result: chunkResult)
+ }
+ }
+ var results = [(index: Int, result: SpeechResult)]()
+ for try await result in group {
+ results.append(result)
+ }
+ return results
+ }
+
+ for entry in batchResults {
+ chunkAudioArrays[entry.index] = entry.result.audio
+ combinedTimings.merge(entry.result.timings)
+ if entry.index == 0 { combinedTimings.timeToFirstBuffer = entry.result.timings.timeToFirstBuffer }
+ }
+ }
+
+ // Deliver audio in order via callback after concurrent batch completes.
+ if let callback {
+ for (i, chunkAudio) in chunkAudioArrays.enumerated() {
+ let progress = SpeechProgress(
+ audio: chunkAudio, timings: combinedTimings,
+ stepTime: i == 0 ? 0 : nil,
+ chunkIndex: i, totalChunks: totalChunks
+ )
+ if callback(progress) == false { break }
+ }
+ }
+ }
+
+ // Crossfade consecutive chunks and assemble final audio.
+ let allAudio = AudioOutput.crossfade(chunkAudioArrays, fadeLength: crossfadeSamples)
+
+ combinedTimings.fullPipeline = CFAbsoluteTimeGetCurrent() - pipelineStart
+ let sampleRate = primaryTask.sampleRate
+ combinedTimings.inputAudioSeconds = Double(allAudio.count) / Double(sampleRate)
+
+ let steps = Int(combinedTimings.totalDecodingLoops)
+ let avgMs = steps > 0 ? combinedTimings.decodingLoop * 1000 / Double(steps) : 0
+ Logging.info(
+ String(
+ format: "Chunked TTS: %d chunks, %d steps, %.1fms avg/step, %.2fs audio",
+ textChunks.count, steps, avgMs, Double(allAudio.count) / Double(sampleRate)
+ ))
+
+ return SpeechResult(audio: allAudio, timings: combinedTimings, sampleRate: sampleRate)
+ }
+
+ // MARK: - Play Speech
+
+ /// Generate speech and stream it through the audio output in real time.
+ ///
+ /// Generates speech and plays it back.
+ ///
+ /// For streaming strategies (auto, stream, buffered) chunking is forced to
+ /// sequential (`concurrentWorkerCount = 1`) so frames can be enqueued in
+ /// order. `generateFirst` respects the caller's concurrency setting so the
+ /// full file can be generated with parallel workers before playback begins.
+ ///
+ /// - Parameters:
+ /// - text: The text to synthesize.
+ /// - voice: Voice/speaker identifier.
+ /// - language: Language identifier.
+ /// - options: Sampling and generation options.
+ /// - playbackStrategy: Controls how audio is buffered before playback begins.
+ /// - callback: Optional per-step callback.
+ /// - Returns: A `SpeechResult` with the complete audio and timing breakdown.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ open func play(
+ text: String,
+ voice: String? = nil,
+ language: String? = nil,
+ options: GenerationOptions = GenerationOptions(),
+ playbackStrategy: PlaybackStrategy = .auto,
+ callback: SpeechCallback = nil
+ ) async throws -> SpeechResult {
+ var playOptions = options
+
+ let audioOut = audioOutput
+ let maxTokens = playOptions.maxNewTokens
+
+ // Pre-resolve audio format from the task so the playback closure doesn't
+ // reach into model-specific components (keeps TTSKit model-agnostic).
+ let formatTask = try createTask()
+ let samplesPerFrame = formatTask.samplesPerFrame
+ let sampleRate = formatTask.sampleRate
+ let minBuffer = formatTask.minimumBufferDuration
+
+ if case .generateFirst = playbackStrategy {
+ let result = try await generate(
+ text: text, voice: voice, language: language,
+ options: playOptions, callback: callback
+ )
+ try audioOut.startPlayback()
+ audioOut.setBufferDuration(0)
+ audioOut.enqueueAudioChunk(result.audio)
+ await audioOut.stopPlayback(waitForCompletion: true)
+ return result
+ }
+
+ // Streaming requires sequential generation to preserve chunk order.
+ playOptions.concurrentWorkerCount = 1
+
+ try audioOut.startPlayback(deferEngineStart: true)
+ switch playbackStrategy {
+ case .stream: audioOut.setBufferDuration(0)
+ case let .buffered(secs): audioOut.setBufferDuration(secs)
+ case .auto: break
+ case .generateFirst: break
+ }
+
+ let result = try await generate(
+ text: text, voice: voice, language: language,
+ options: playOptions,
+ callback: { progress in
+ if let stepTime = progress.stepTime, case .auto = playbackStrategy {
+ let buffer = PlaybackStrategy.requiredBuffer(
+ stepTime: stepTime,
+ maxNewTokens: maxTokens,
+ samplesPerFrame: samplesPerFrame,
+ sampleRate: sampleRate,
+ minimumBuffer: minBuffer
+ )
+ audioOut.setBufferDuration(buffer)
+ let speedRatio = PlaybackStrategy.audioPerStep(samplesPerFrame: samplesPerFrame, sampleRate: sampleRate) / stepTime
+ Logging.info(
+ String(
+ format: "Playback: step %.1fms (%.2fx real-time) -> buffer %.2fs",
+ stepTime * 1000, speedRatio, buffer))
+ }
+ audioOut.enqueueAudioChunk(progress.audio)
+ return callback?(progress)
+ }
+ )
+
+ await audioOut.stopPlayback(waitForCompletion: true)
+ return result
+ }
+
+ // MARK: - Qwen3-typed convenience API
+
+ /// Build a prompt cache using typed Qwen3 speaker and language enums.
+ ///
+ /// - Parameters:
+ /// - speaker: The `Qwen3Speaker` to pre-warm the cache for.
+ /// - language: The `Qwen3Language` to pre-warm the cache for.
+ /// - instruction: Optional style instruction (1.7B only).
+ /// - Returns: A `TTSPromptCache` for the given parameters.
+ /// - Throws: `TTSError` on generation failure.
+ @discardableResult
+ open func buildPromptCache(
+ speaker: Qwen3Speaker,
+ language: Qwen3Language,
+ instruction: String? = nil
+ ) async throws -> TTSPromptCache {
+ try await buildPromptCache(
+ voice: speaker.rawValue,
+ language: language.rawValue,
+ instruction: instruction
+ )
+ }
+
+ /// Generate speech from text using typed Qwen3 speaker and language enums.
+ ///
+ /// - Parameters:
+ /// - text: Input text to synthesise.
+ /// - speaker: The `Qwen3Speaker` voice to use.
+ /// - language: The `Qwen3Language` to synthesise in.
+ /// - options: Generation options controlling sampling, chunking, and concurrency.
+ /// - callback: Per-step callback receiving decoded audio chunks. Return `false` to cancel.
+ /// - Returns: The assembled `SpeechResult`.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ open func generate(
+ text: String,
+ speaker: Qwen3Speaker,
+ language: Qwen3Language = .english,
+ options: GenerationOptions = GenerationOptions(),
+ callback: SpeechCallback = nil
+ ) async throws -> SpeechResult {
+ try await generate(
+ text: text,
+ voice: speaker.rawValue,
+ language: language.rawValue,
+ options: options,
+ callback: callback
+ )
+ }
+
+ /// Generate speech and stream playback using typed Qwen3 speaker and language enums.
+ ///
+ /// - Parameters:
+ /// - text: Input text to synthesise.
+ /// - speaker: The `Qwen3Speaker` voice to use.
+ /// - language: The `Qwen3Language` to synthesise in.
+ /// - options: Generation options controlling sampling, chunking, and concurrency.
+ /// - playbackStrategy: Controls how much audio is buffered before playback begins.
+ /// - callback: Per-step callback receiving decoded audio chunks. Return `false` to cancel.
+ /// - Returns: The assembled `SpeechResult`.
+ /// - Throws: `TTSError` on generation failure or task cancellation.
+ open func play(
+ text: String,
+ speaker: Qwen3Speaker,
+ language: Qwen3Language = .english,
+ options: GenerationOptions = GenerationOptions(),
+ playbackStrategy: PlaybackStrategy = .auto,
+ callback: SpeechCallback = nil
+ ) async throws -> SpeechResult {
+ try await play(
+ text: text,
+ voice: speaker.rawValue,
+ language: language.rawValue,
+ options: options,
+ playbackStrategy: playbackStrategy,
+ callback: callback
+ )
+ }
+}
+
+// MARK: - SpeechModel conformance
+
+extension TTSKit: SpeechModel {
+ /// The output sample rate of the currently loaded speech decoder.
+ public var sampleRate: Int { speechDecoder.sampleRate }
+}
diff --git a/Sources/TTSKit/Utilities/AudioOutput.swift b/Sources/TTSKit/Utilities/AudioOutput.swift
new file mode 100644
index 00000000..4370c4cf
--- /dev/null
+++ b/Sources/TTSKit/Utilities/AudioOutput.swift
@@ -0,0 +1,734 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import AVFoundation
+import Accelerate
+import ArgmaxCore
+import Foundation
+
+// MARK: - Audio Output
+
+/// Handles audio export to file and real-time streaming playback via AVAudioEngine.
+///
+/// Supports adaptive pre-buffering and edge-fading to prevent audible clicks:
+///
+/// **Pre-buffering:** When a buffer duration is configured via `setBufferDuration(_:)`,
+/// incoming audio frames accumulate until the threshold is reached, then flush to the
+/// player all at once. This prevents underruns on slower devices.
+///
+/// **Edge-fading:** Fades are applied only at actual audio discontinuities:
+/// - Fade-in on the first frame of a session, chunk, or after a detected underrun.
+/// - Fade-out on the last frame of a session, chunk, or before a detected underrun.
+/// Interior frames of contiguous playback are untouched.
+///
+/// Underrun detection uses wall-clock timing: if the current time exceeds the
+/// expected playback end of all scheduled audio, the player has drained and the
+/// next frame needs fade-in (and the previous tail gets fade-out).
+///
+/// **Buffer lifecycle:**
+/// 1. `startPlayback()` - resets all state; frames accumulate until configured.
+/// 2. `setBufferDuration(_:)` - configures threshold (call after start).
+/// 3. `enqueueAudioChunk(_:)` - pushes frames through the buffer/tail pipeline.
+/// 4. `stopPlayback()` - commits the tail with fade-out, waits, tears down.
+///
+/// Thread safety: used only from `TTSKit.play()` which forces sequential mode
+/// (concurrency=1). `TTSKit` ensures serialized access.
+/// Subclassing is intentionally not supported. Use the `TTSKit` component override
+/// mechanism (`TTSKitConfig`) to swap in an alternative audio backend.
+public class AudioOutput: @unchecked Sendable {
+ private var audioEngine: AVAudioEngine?
+ private var playerNode: AVAudioPlayerNode?
+ private var engineStartDeferred: Bool = false
+
+ /// Pre-buffer threshold in seconds. `nil` means not yet configured - frames
+ /// accumulate in `pendingFrames` until `setBufferDuration` is called.
+ /// `0` means stream immediately. `> 0` means buffer that duration first.
+ private var bufferDuration: TimeInterval?
+
+ /// Accumulated frames waiting to be flushed to the player.
+ private var pendingFrames: [[Float]] = []
+
+ /// Total duration (seconds) of audio in `pendingFrames`.
+ private var pendingDuration: TimeInterval = 0
+
+ /// Whether the initial buffer threshold has been met and frames flow directly
+ /// to the player without accumulating.
+ private var bufferThresholdMet: Bool = false
+
+ /// The most recently received frame, held back so we can apply fade-out
+ /// when we learn it's the last frame before a gap or end of session.
+ private var tailFrame: [Float]?
+
+ /// Whether the next frame scheduled to the player needs a fade-in
+ /// (start of session, start of chunk, or after an underrun).
+ private var needsFadeIn: Bool = true
+
+ /// Wall-clock time at which all currently scheduled audio is expected to
+ /// finish playing. Used to detect underruns: if `now > expectedPlaybackEnd`,
+ /// the player has drained and the next frame follows a gap.
+ private var expectedPlaybackEnd: CFAbsoluteTime = 0
+
+ /// Player time (seconds) when the first buffer was scheduled. The player
+ /// node's clock starts at `.play()` but no audio plays until buffers are
+ /// queued, so this offset must be subtracted from `playerTime.sampleTime`
+ /// to get the true audio-out position.
+ private var playbackTimeOffset: TimeInterval = 0
+
+ /// Cumulative duration (seconds) of real audio that has been scheduled via
+ /// `scheduleWithFades`. The silent sentinel buffer used for drain detection
+ /// is not included. Used to clamp `currentPlaybackTime` so the reported
+ /// position never advances into silence gaps between chunks or past the end
+ /// of generated audio.
+ public private(set) var scheduledAudioDuration: TimeInterval = 0
+
+ /// Number of samples for the fade-in/fade-out ramp.
+ /// 256 samples at 24kHz ≈ 10.7ms - imperceptible on contiguous audio
+ /// but smoothly eliminates clicks at discontinuities.
+ public static let fadeLengthSamples: Int = 256
+
+ /// Output sample rate in Hz. Defaults to 24000 (Qwen3 TTS).
+ /// Updated by `TTSKit.loadModels()` to match the loaded speech decoder's actual sample rate.
+ public private(set) var sampleRate: Int
+
+ /// The audio format used for playback and export (derived from `sampleRate`).
+ public private(set) var audioFormat: AVAudioFormat
+
+ public init(sampleRate: Int = 24000) {
+ self.sampleRate = sampleRate
+ guard let format = AVAudioFormat(commonFormat: .pcmFormatFloat32, sampleRate: Double(sampleRate), channels: 1, interleaved: false) else {
+ preconditionFailure("AVAudioFormat init failed for PCM Float32 \(sampleRate)Hz mono - this is an invariant violation")
+ }
+ self.audioFormat = format
+ }
+
+ /// Update the sample rate to match the loaded speech decoder.
+ /// Must be called before `startPlayback()`.
+ public func configure(sampleRate newRate: Int) {
+ guard newRate != sampleRate else { return }
+ sampleRate = newRate
+ guard let format = AVAudioFormat(commonFormat: .pcmFormatFloat32, sampleRate: Double(newRate), channels: 1, interleaved: false) else {
+ preconditionFailure("AVAudioFormat init failed for PCM Float32 \(newRate)Hz mono - this is an invariant violation")
+ }
+ audioFormat = format
+ }
+
+ /// Current playback position in seconds, based on the audio engine's render timeline.
+ /// Returns 0 if the player is not active, no audio has been scheduled yet, or
+ /// the player hasn't started rendering.
+ ///
+ /// Clamped to `scheduledAudioDuration` so the position never advances into
+ /// silence gaps between chunks, the silent drain sentinel, or the hardware
+ /// pipeline tail after the last real audio frame.
+ /// Returns how many seconds of audio still need to accumulate in the pre-buffer before
+ /// the next chunk flushes and playback resumes. Non-zero only while in buffering
+ /// mode (`bufferThresholdMet == false` and a positive `bufferDuration` is set).
+ public var silentBufferRemaining: TimeInterval {
+ guard !bufferThresholdMet, let duration = bufferDuration, duration > 0 else { return 0 }
+ return max(0, duration - pendingDuration)
+ }
+
+ public var currentPlaybackTime: TimeInterval {
+ // Guard against the engine being torn down (stopPlayback nullifies these).
+ // Checking audioEngine?.isRunning prevents accessing a detached node.
+ guard expectedPlaybackEnd > 0,
+ audioEngine?.isRunning == true,
+ let player = playerNode,
+ let nodeTime = player.lastRenderTime,
+ nodeTime.isSampleTimeValid,
+ let playerTime = player.playerTime(forNodeTime: nodeTime)
+ else { return 0 }
+ let rawTime = max(0, Double(playerTime.sampleTime) / playerTime.sampleRate - playbackTimeOffset)
+ return min(rawTime, scheduledAudioDuration)
+ }
+
+ // MARK: - Buffer Configuration
+
+ /// Configure the pre-buffer duration. Call after `startPlayback()`.
+ ///
+ /// - If `seconds == 0`: immediately flushes any pending frames and switches
+ /// to direct streaming (no buffering).
+ /// - If `seconds > 0`: sets the threshold. If enough audio has already
+ /// accumulated, flushes immediately.
+ /// - Can be called multiple times (e.g., per-chunk reassessment). Any held
+ /// tail frame from the previous chunk is committed with fade-out first.
+ ///
+ /// - Parameter seconds: Duration of audio to accumulate before flushing.
+ /// Pass 0 for immediate streaming (fast devices).
+ public func setBufferDuration(_ seconds: TimeInterval) {
+ let duration = max(0, seconds)
+ bufferDuration = duration
+
+ // Commit any held tail as the last frame of the previous chunk
+ if let tail = tailFrame {
+ scheduleWithFades(tail, fadeIn: needsFadeIn, fadeOut: true)
+ needsFadeIn = true // next chunk starts fresh with fade-in
+ tailFrame = nil
+ }
+
+ if duration == 0 {
+ bufferThresholdMet = true
+ if !pendingFrames.isEmpty {
+ flushPendingFrames()
+ }
+ } else {
+ // Re-enter buffering mode so the new chunk accumulates frames
+ // before flushing - lets the model catch up between chunks.
+ bufferThresholdMet = false
+ if pendingDuration >= duration {
+ flushPendingFrames()
+ }
+ }
+ }
+
+ // MARK: - File Export
+
+ /// Supported audio export formats.
+ public enum AudioFileFormat: String, Sendable {
+ case m4a
+ case wav
+
+ public var fileExtension: String { rawValue }
+
+ /// Resolve the effective format for the current platform.
+ /// On watchOS, M4A is not supported so falls back to WAV with a warning.
+ public static func resolve(_ preferred: AudioFileFormat = .m4a) -> AudioFileFormat {
+ #if os(watchOS)
+ if preferred == .m4a {
+ Logging.info("[Warning] M4A export is not available on watchOS, falling back to WAV")
+ return .wav
+ }
+ #endif
+ return preferred
+ }
+ }
+
+ /// Save audio samples to a file.
+ ///
+ /// For M4A with metadata: writes PCM -> AAC to a temp file, then uses
+ /// `AVAssetExportSession` passthrough to remux with embedded metadata atoms
+ /// (no re-encode). For WAV or metadata-free M4A: writes directly.
+ /// On watchOS, `.m4a` automatically falls back to `.wav`.
+ ///
+ /// Any extension already present in `filename` is stripped before writing.
+ /// The output format is resolved in this order: explicit `format` parameter,
+ /// then extension found in `filename` if it matches a supported format,
+ /// then `.m4a` as the default.
+ ///
+ /// - Parameters:
+ /// - samples: Mono Float32 PCM samples.
+ /// - folder: Destination directory. Created if it doesn't exist.
+ /// - filename: File name, with or without extension.
+ /// - sampleRate: Sample rate in Hz.
+ /// - format: Output format. Inferred from `filename` extension when `nil`.
+ /// - metadataProvider: Optional metadata callback for items to embed into the file container for m4a formats.
+ /// - Returns: The URL of the written file.
+ /// - Throws: `TTSError` if audio encoding or export fails.
+ @discardableResult
+ public static func saveAudio(
+ _ samples: [Float],
+ toFolder folder: URL,
+ filename: String,
+ sampleRate: Int = 24000,
+ format: AudioFileFormat? = nil,
+ metadataProvider: (@Sendable () throws -> [AVMetadataItem])? = nil
+ ) async throws -> URL {
+ guard !samples.isEmpty else {
+ throw TTSError.audioOutputFailed("No audio samples to export")
+ }
+
+ let filenameURL = URL(fileURLWithPath: filename)
+ let baseName = filenameURL.deletingPathExtension().lastPathComponent
+ let inferredFormat = AudioFileFormat(rawValue: filenameURL.pathExtension.lowercased())
+ let resolvedFormat = AudioFileFormat.resolve(format ?? inferredFormat ?? .m4a)
+ let outputURL = folder
+ .appendingPathComponent(baseName)
+ .appendingPathExtension(resolvedFormat.fileExtension)
+
+ if !FileManager.default.fileExists(atPath: folder.path) {
+ try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
+ }
+ try? FileManager.default.removeItem(at: outputURL)
+
+ let pcmBuffer = try createPCMBuffer(from: samples, sampleRate: sampleRate)
+
+ switch resolvedFormat {
+ case .m4a:
+ #if os(watchOS)
+ throw TTSError.audioOutputFailed("M4A should have been resolved to WAV on watchOS")
+ #else
+ if let metadataProvider {
+ let metadata = try metadataProvider()
+ try await writeM4AWithMetadata(pcmBuffer, to: outputURL, sampleRate: sampleRate, metadata: metadata)
+ } else {
+ try writeM4A(pcmBuffer, to: outputURL, sampleRate: sampleRate)
+ }
+ #endif
+ case .wav:
+ try writeWAV(pcmBuffer, to: outputURL, sampleRate: sampleRate)
+ }
+
+ return outputURL
+ }
+
+ /// Return the playback duration of an audio file in seconds.
+ public static func duration(of url: URL) async throws -> TimeInterval {
+ let asset = AVURLAsset(url: url)
+ let cmDuration = try await asset.load(.duration)
+ return CMTimeGetSeconds(cmDuration)
+ }
+
+ // MARK: - Crossfade assembly
+
+ /// Assemble multiple audio chunks into one array with equal-power crossfades at each boundary.
+ ///
+ /// Uses `cos(t*pi/2)` fade-out and `sin(t*pi/2)` fade-in so that energy is
+ /// preserved through the overlap region. Fade curves are pre-computed once
+ /// via Accelerate (`vDSP_vramp` + `vvcosf`/`vvsinf`) and reused at every
+ /// chunk boundary; the per-boundary mix uses `vDSP_vmul` + `vDSP_vma`.
+ ///
+ /// - Parameters:
+ /// - chunks: Ordered audio chunks to concatenate.
+ /// - fadeLength: Number of overlap samples for each crossfade.
+ /// - Returns: Single concatenated audio array with crossfades applied at chunk boundaries.
+ public static func crossfade(_ chunks: [[Float]], fadeLength: Int) -> [Float] {
+ guard !chunks.isEmpty else { return [] }
+ guard chunks.count > 1 else { return chunks[0] }
+
+ let (fadeOut, fadeIn) = equalPowerCurves(length: fadeLength)
+
+ var result = [Float]()
+ result.reserveCapacity(chunks.reduce(0) { $0 + $1.count })
+ result.append(contentsOf: chunks[0])
+
+ for i in 1.. ([Float], [Float]) {
+ guard length > 0 else { return ([], []) }
+
+ var ramp = [Float](repeating: 0, count: length)
+ var start: Float = 0
+ var step = Float.pi / 2.0 / Float(max(length - 1, 1))
+ vDSP_vramp(&start, &step, &ramp, 1, vDSP_Length(length))
+
+ var fadeOut = [Float](repeating: 0, count: length)
+ var fadeIn = [Float](repeating: 0, count: length)
+ var n = Int32(length)
+ vvcosf(&fadeOut, ramp, &n)
+ vvsinf(&fadeIn, ramp, &n)
+
+ return (fadeOut, fadeIn)
+ }
+
+ // MARK: - Internal helpers
+
+ private static func createPCMBuffer(from samples: [Float], sampleRate: Int) throws -> AVAudioPCMBuffer {
+ guard
+ let pcmFormat = AVAudioFormat(
+ commonFormat: .pcmFormatFloat32,
+ sampleRate: Double(sampleRate),
+ channels: 1,
+ interleaved: false
+ )
+ else {
+ throw TTSError.audioOutputFailed("Failed to create PCM format for sampleRate \(sampleRate)")
+ }
+
+ guard let buffer = AVAudioPCMBuffer(pcmFormat: pcmFormat, frameCapacity: AVAudioFrameCount(samples.count)) else {
+ throw TTSError.audioOutputFailed("Failed to create PCM buffer")
+ }
+ buffer.frameLength = AVAudioFrameCount(samples.count)
+
+ guard let channelData = buffer.floatChannelData else {
+ throw TTSError.audioOutputFailed("Failed to access buffer channel data")
+ }
+ samples.withUnsafeBufferPointer { src in
+ guard let srcBase = src.baseAddress else { return }
+ channelData[0].update(from: srcBase, count: samples.count)
+ }
+
+ return buffer
+ }
+
+ #if !os(watchOS)
+ /// Write PCM samples as AAC in an M4A container.
+ ///
+ /// `AVAudioFile` accepts PCM input via `write(from:)` and internally
+ /// encodes to AAC when the file settings specify a compressed format,
+ /// so no temp file or explicit `AVAudioConverter` is needed.
+ private static func writeM4A(
+ _ pcmBuffer: AVAudioPCMBuffer,
+ to url: URL,
+ sampleRate: Int
+ ) throws {
+ let aacSettings: [String: Any] = [
+ AVFormatIDKey: kAudioFormatMPEG4AAC,
+ AVSampleRateKey: sampleRate,
+ AVNumberOfChannelsKey: 1,
+ AVEncoderBitRateKey: 64000,
+ AVEncoderAudioQualityKey: AVAudioQuality.medium.rawValue
+ ]
+
+ let file = try AVAudioFile(
+ forWriting: url,
+ settings: aacSettings,
+ commonFormat: .pcmFormatFloat32,
+ interleaved: false
+ )
+ try file.write(from: pcmBuffer)
+ }
+
+ /// Write AAC/M4A with embedded metadata atoms.
+ ///
+ /// `AVAudioFile` cannot embed metadata, so this does a two-step write:
+ /// encode PCM -> AAC into a temp file, then use `AVAssetExportSession`
+ /// passthrough to remux the bitstream into the final file with metadata
+ /// atoms attached. No audio re-encoding occurs.
+ private static func writeM4AWithMetadata(
+ _ pcmBuffer: AVAudioPCMBuffer,
+ to url: URL,
+ sampleRate: Int,
+ metadata: [AVMetadataItem]
+ ) async throws {
+ let tempURL = url.deletingLastPathComponent()
+ .appendingPathComponent(UUID().uuidString)
+ .appendingPathExtension("m4a")
+ defer { try? FileManager.default.removeItem(at: tempURL) }
+
+ try writeM4A(pcmBuffer, to: tempURL, sampleRate: sampleRate)
+
+ let asset = AVURLAsset(url: tempURL)
+ guard let exportSession = AVAssetExportSession(asset: asset, presetName: AVAssetExportPresetPassthrough) else {
+ throw TTSError.audioOutputFailed("AVAssetExportSession could not be created for metadata export")
+ }
+ exportSession.outputURL = url
+ exportSession.outputFileType = .m4a
+ exportSession.metadata = metadata
+
+ await exportSession.export()
+
+ if let exportError = exportSession.error {
+ throw exportError
+ }
+ guard exportSession.status == .completed else {
+ throw TTSError.audioOutputFailed(
+ "AVAssetExportSession finished with unexpected status \(exportSession.status.rawValue)")
+ }
+ }
+ #endif
+
+ private static func writeWAV(
+ _ pcmBuffer: AVAudioPCMBuffer,
+ to url: URL,
+ sampleRate: Int
+ ) throws {
+ let settings: [String: Any] = [
+ AVFormatIDKey: kAudioFormatLinearPCM,
+ AVSampleRateKey: sampleRate,
+ AVNumberOfChannelsKey: 1,
+ AVLinearPCMBitDepthKey: 32,
+ AVLinearPCMIsBigEndianKey: false,
+ AVLinearPCMIsFloatKey: true
+ ]
+
+ let file = try AVAudioFile(
+ forWriting: url,
+ settings: settings,
+ commonFormat: .pcmFormatFloat32,
+ interleaved: false
+ )
+ try file.write(from: pcmBuffer)
+ }
+
+ // MARK: - Streaming Playback
+
+ /// Start the audio engine for streaming playback.
+ ///
+ /// Resets all buffering, fade, and timing state. After calling this,
+ /// configure the buffer threshold via `setBufferDuration(_:)`.
+ ///
+ /// - Parameter deferEngineStart: When `true`, the audio engine is created and
+ /// connected but not started. The engine will start automatically on the first
+ /// `enqueueAudioChunk` call. This avoids the render thread contending with
+ /// model predictions during the critical time-to-first-buffer path.
+ /// - Throws: `TTSError` if the audio engine fails to start.
+ public func startPlayback(deferEngineStart: Bool = false) throws {
+ pendingFrames.removeAll()
+ pendingDuration = 0
+ bufferThresholdMet = false
+ bufferDuration = nil
+ tailFrame = nil
+ needsFadeIn = true
+ expectedPlaybackEnd = 0
+ playbackTimeOffset = 0
+ scheduledAudioDuration = 0
+ engineStartDeferred = false
+
+ // On iOS, AVAudioEngine requires an active audio session with a playback
+ // category. Without this, engine.start() may silently fail or route to
+ // the wrong output (e.g., airpods instead of main speaker).
+ #if os(iOS)
+ let session = AVAudioSession.sharedInstance()
+ try session.setCategory(.playback, mode: .default, options: [])
+ try session.setActive(true)
+ #endif
+
+ let engine = AVAudioEngine()
+ let player = AVAudioPlayerNode()
+ let format = audioFormat
+
+ engine.attach(player)
+ engine.connect(player, to: engine.mainMixerNode, format: format)
+
+ if deferEngineStart {
+ engineStartDeferred = true
+ } else {
+ try engine.start()
+ player.play()
+ }
+
+ self.audioEngine = engine
+ self.playerNode = player
+ }
+
+ /// Enqueue a chunk of audio samples for playback.
+ ///
+ /// In streaming mode, detects underruns via wall-clock timing: if the player
+ /// has drained since the last buffer, the held tail is committed with fade-out
+ /// (it was the last frame before the gap) and the incoming frame is marked for
+ /// fade-in. On contiguous playback, no fades are applied to interior frames.
+ public func enqueueAudioChunk(_ samples: [Float]) {
+ guard let engine = audioEngine, let player = playerNode else { return }
+
+ if engineStartDeferred {
+ engineStartDeferred = false
+ do {
+ try engine.start()
+ player.play()
+ } catch {
+ Logging.error("AudioOutput: deferred engine start failed: \(error)")
+ return
+ }
+ }
+
+ if bufferThresholdMet {
+ // Detect underrun: has the player drained since the last schedule?
+ let playerDrained = CFAbsoluteTimeGetCurrent() > expectedPlaybackEnd
+
+ if let tail = tailFrame {
+ // If player drained, this tail was the last frame before silence -
+ // apply fade-out to smooth the audio -> silence transition, and
+ // mark the next frame for fade-in (silence -> audio transition).
+ let fadeOut = playerDrained
+ let fadeIn = needsFadeIn || playerDrained
+ scheduleWithFades(tail, fadeIn: fadeIn, fadeOut: fadeOut)
+ needsFadeIn = playerDrained
+ } else if playerDrained {
+ needsFadeIn = true
+ }
+ tailFrame = samples
+ } else {
+ pendingFrames.append(samples)
+ pendingDuration += Double(samples.count) / Double(sampleRate)
+
+ if let threshold = bufferDuration, pendingDuration >= threshold {
+ flushPendingFrames()
+ }
+ }
+ }
+
+ /// Flush all accumulated frames to the player.
+ ///
+ /// All frames except the last are scheduled immediately (they're contiguous).
+ /// The last frame becomes the tail, held back until the next frame arrives
+ /// or `stopPlayback` commits it with fade-out.
+ private func flushPendingFrames() {
+ guard !pendingFrames.isEmpty else {
+ bufferThresholdMet = true
+ return
+ }
+
+ // For subsequent chunks (scheduledAudioDuration > 0), the player may have
+ // been idle during the inter-chunk gap while the model generated the next
+ // buffer. The raw player clock kept advancing through that silence, so
+ // rawTime = playbackTimeOffset + scheduledAudioDuration + gap.
+ // Absorb the gap into playbackTimeOffset so currentPlaybackTime resumes
+ // from the end of the previous chunk rather than jumping ahead by gap.
+ if scheduledAudioDuration > 0,
+ let player = playerNode,
+ let nodeTime = player.lastRenderTime,
+ nodeTime.isSampleTimeValid,
+ let playerTime = player.playerTime(forNodeTime: nodeTime)
+ {
+ let currentRawTime = Double(playerTime.sampleTime) / playerTime.sampleRate
+ let expectedRawTime = playbackTimeOffset + scheduledAudioDuration
+ let gap = currentRawTime - expectedRawTime
+ if gap > 0.01 {
+ playbackTimeOffset += gap
+ }
+ }
+
+ for i in 0.. 0 {
+ let data = channelData[0]
+ let invFade = 1.0 / Float(fadeLen)
+
+ if fadeIn {
+ for i in 0..) in
+ let format = audioFormat
+ if let silentBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: 1) {
+ silentBuffer.frameLength = 1
+ silentBuffer.floatChannelData?[0][0] = 0
+ player.scheduleBuffer(silentBuffer) {
+ continuation.resume()
+ }
+ } else {
+ continuation.resume()
+ }
+ }
+
+ // Wait ~80ms after the sentinel so the hardware pipeline drains (~1-2 render cycles).
+ // Prevents tail clip and CoreAudio "out of order"/overload errors across devices.
+ try? await Task.sleep(for: .milliseconds(80))
+ }
+
+ let engine = audioEngine
+ playerNode = nil
+ audioEngine = nil
+
+ // Stop the engine only
+ engine?.stop()
+
+ pendingFrames.removeAll()
+ pendingDuration = 0
+ bufferThresholdMet = false
+ bufferDuration = nil
+ tailFrame = nil
+ needsFadeIn = true
+ expectedPlaybackEnd = 0
+ playbackTimeOffset = 0
+ }
+}
diff --git a/Sources/TTSKit/Utilities/EmbedTypes.swift b/Sources/TTSKit/Utilities/EmbedTypes.swift
new file mode 100644
index 00000000..a63991b1
--- /dev/null
+++ b/Sources/TTSKit/Utilities/EmbedTypes.swift
@@ -0,0 +1,185 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Accelerate
+import ArgmaxCore
+import CoreML
+
+// MARK: - TTS Embed Type Protocols
+
+// These protocols allow embed tensors to be represented as either MLMultiArray (all platforms)
+// or MLTensor (macOS 15+ / iOS 18+) without changing the calling convention.
+
+/// Marker protocol for a raw TTS embedding tensor emitted by a CoreML model.
+public protocol EmbedTensorType {}
+
+/// Marker protocol for a TTS embedding value that can be used as CoreML model input.
+public protocol EmbedInputType {}
+
+extension MLMultiArray: EmbedTensorType {}
+extension MLMultiArray: EmbedInputType {}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+extension MLTensor: EmbedTensorType {}
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+extension MLTensor: EmbedInputType {}
+
+extension Array: EmbedTensorType where Element == FloatType {}
+
+// MARK: - Embedding Utilities
+
+/// Static helpers for creating, combining, and extracting TTS embedding vectors.
+///
+/// Mirrors the `ModelUtilities` pattern in ArgmaxCore: all operations are pure
+/// static functions with no shared state.
+public enum EmbedUtilities {
+ // MARK: - Float16 / FloatType helpers
+
+ /// Element-wise addition of two equal-length embedding vectors via vDSP.
+ public static func addEmbeddings(_ a: [FloatType], _ b: [FloatType]) -> [FloatType] {
+ precondition(a.count == b.count)
+ var result = [Float](repeating: 0, count: a.count)
+ a.withUnsafeBufferPointer { aPtr in
+ b.withUnsafeBufferPointer { bPtr in
+ var aFloat = [Float](repeating: 0, count: a.count)
+ var bFloat = [Float](repeating: 0, count: a.count)
+ convertToFloat(aPtr, to: &aFloat)
+ convertToFloat(bPtr, to: &bFloat)
+ vDSP_vadd(aFloat, 1, bFloat, 1, &result, 1, vDSP_Length(a.count))
+ }
+ }
+ return result.map { FloatType($0) }
+ }
+
+ /// Accumulates embeddings into a Float32 buffer to avoid repeated type conversions,
+ /// reusing a single intermediate buffer rather than allocating one per embed.
+ public static func sumEmbeddings(_ embeddings: [[FloatType]]) -> [FloatType] {
+ guard let first = embeddings.first else { return [] }
+ let count = first.count
+ var accum = [Float](repeating: 0, count: count)
+ var floatBuf = [Float](repeating: 0, count: count)
+ for embed in embeddings {
+ embed.withUnsafeBufferPointer { ptr in
+ convertToFloat(ptr, to: &floatBuf)
+ vDSP_vadd(accum, 1, floatBuf, 1, &accum, 1, vDSP_Length(count))
+ }
+ }
+ var result = [FloatType](repeating: 0, count: count)
+ result.withUnsafeMutableBufferPointer { dst in
+ accum.withUnsafeBufferPointer { src in
+ // vDSP Float16<->Float conversion is available on iOS 14+, macOS 11+, visionOS 1+
+ // but not on watchOS. Fall back to scalar on watchOS and x86_64.
+ #if arch(arm64) && !os(watchOS)
+ vDSP.convertElements(of: src, to: &dst)
+ #else
+ for i in 0.. [FloatType] {
+ let dim: Int
+ if arr.shape.count == 4 {
+ dim = arr.shape[1].intValue
+ } else {
+ dim = arr.count
+ }
+ let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: arr.count)
+ var result = [FloatType](repeating: 0, count: dim)
+
+ if arr.shape.count == 4 {
+ let stride1 = arr.strides[1].intValue
+ if stride1 == 1 {
+ // Contiguous layout ([1, D, 1, 1] with unit stride) - direct buffer copy
+ result.withUnsafeMutableBufferPointer { dst in
+ guard let dstBase = dst.baseAddress else { return }
+ dstBase.update(from: ptr, count: dim)
+ }
+ } else {
+ for d in 0.. MLMultiArray {
+ let dim = values.count
+ let arr = try MLMultiArray(shape: [1, NSNumber(value: dim), 1, 1], dataType: .float16)
+ let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: dim)
+ values.withUnsafeBufferPointer { buf in
+ guard let bufBase = buf.baseAddress else { return }
+ ptr.update(from: bufBase, count: dim)
+ }
+ return arr
+ }
+
+ /// Create a zero-filled embedding vector.
+ /// - Parameter dim: Embedding dimension (match the actual model's embed size).
+ /// - Returns: A `[FloatType]` array of length `dim` filled with zeros.
+ public static func zeroEmbed(dim: Int = 1024) -> [FloatType] {
+ [FloatType](repeating: FloatType(0), count: dim)
+ }
+
+ /// Pack an `[Int32]` token-id array into a flat CoreML `MLMultiArray`.
+ public static func makeInt32Array(_ values: [Int32]) throws -> MLMultiArray {
+ let arr = try MLMultiArray(shape: [NSNumber(value: values.count)], dataType: .int32)
+ let ptr = arr.dataPointer.bindMemory(to: Int32.self, capacity: values.count)
+ for (index, value) in values.enumerated() {
+ ptr[index] = value
+ }
+ return arr
+ }
+
+ // MARK: - MLTensor helpers
+
+ /// Element-wise addition of two MLTensor embeddings. No data copy - deferred until materialised.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ @inline(__always)
+ public static func addEmbeddings(_ a: MLTensor, _ b: MLTensor) -> MLTensor { a + b }
+
+ /// Sum a list of MLTensor embeddings element-wise.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ public static func sumEmbeddings(_ tensors: [MLTensor]) -> MLTensor {
+ tensors.dropFirst().reduce(tensors[0], +)
+ }
+
+ // MARK: - Internal conversion helper
+
+ /// Platform-safe conversion from FloatType buffer to Float array.
+ /// On arm64 iOS/macOS/visionOS (Float16), uses vDSP for vectorized conversion.
+ @inline(__always)
+ static func convertToFloat(_ source: UnsafeBufferPointer, to dest: inout [Float]) {
+ #if arch(arm64) && !os(watchOS)
+ vDSP.convertElements(of: source, to: &dest)
+ #else
+ for i in 0.. convenience
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+public extension Array where Element == FloatType {
+ /// Wrap this embedding vector as a `[1, count, 1, 1]` Float16 MLTensor (one memcpy).
+ func asMLTensor() -> MLTensor {
+ MLTensor(MLShapedArray(scalars: self, shape: [1, count, 1, 1]))
+ }
+}
diff --git a/Sources/TTSKit/Utilities/KVCache.swift b/Sources/TTSKit/Utilities/KVCache.swift
new file mode 100644
index 00000000..2e5aa5ed
--- /dev/null
+++ b/Sources/TTSKit/Utilities/KVCache.swift
@@ -0,0 +1,306 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - KV Cache
+
+/// KV cache for autoregressive decoder models.
+///
+/// Manages external cache arrays (`keyCache`/`valueCache`) for non-stateful models,
+/// or position tracking and attention masks only for stateful models whose weights
+/// are managed internally by CoreML via `MLState`.
+///
+/// Thread safety: each `Qwen3GenerateTask` creates its own cache instance.
+/// Caches are never shared across concurrent tasks.
+public class KVCache: @unchecked Sendable {
+ public var cacheLength: Int32 = 0
+ public let maxSeqLength: Int
+ public let cacheDim: Int
+ public let isStateful: Bool
+
+ /// External key cache -- nil for stateful models
+ public let keyCache: MLMultiArray?
+ /// External value cache -- nil for stateful models
+ public let valueCache: MLMultiArray?
+ public let kvCacheUpdateMask: MLMultiArray
+ public let keyPaddingMask: MLMultiArray
+
+ public init(cacheDim: Int, maxSeqLength: Int, isStateful: Bool = false) throws {
+ self.cacheDim = cacheDim
+ self.maxSeqLength = maxSeqLength
+ self.isStateful = isStateful
+
+ if isStateful {
+ // Stateful models manage KV cache internally via MLState
+ keyCache = nil
+ valueCache = nil
+ } else {
+ keyCache = try MLMultiArray(
+ shape: [1, NSNumber(value: cacheDim), 1, NSNumber(value: maxSeqLength)],
+ dataType: .float16
+ )
+ valueCache = try MLMultiArray(
+ shape: [1, NSNumber(value: cacheDim), 1, NSNumber(value: maxSeqLength)],
+ dataType: .float16
+ )
+ }
+
+ kvCacheUpdateMask = try MLMultiArray(
+ shape: [1, NSNumber(value: maxSeqLength)],
+ dataType: .float16
+ )
+ keyPaddingMask = try MLMultiArray(
+ shape: [1, NSNumber(value: maxSeqLength)],
+ dataType: .float16
+ )
+
+ reset()
+ }
+
+ public func reset() {
+ cacheLength = 0
+
+ // Zero-fill external KV caches (stateful models don't have them)
+ if let keyCache, let valueCache {
+ memset(keyCache.dataPointer, 0, cacheDim * maxSeqLength * MemoryLayout.size)
+ memset(valueCache.dataPointer, 0, cacheDim * maxSeqLength * MemoryLayout.size)
+ }
+
+ // Initialize masks: first position active, rest masked
+ let updatePtr = kvCacheUpdateMask.dataPointer.bindMemory(to: FloatType.self, capacity: maxSeqLength)
+ let paddingPtr = keyPaddingMask.dataPointer.bindMemory(to: FloatType.self, capacity: maxSeqLength)
+ for i in 0..= maxSeqLength - 1 }
+
+ /// How many free positions remain before the cache is full
+ public var freePositions: Int { maxSeqLength - 1 - Int(cacheLength) }
+
+ public func makeCacheLengthArray() throws -> MLMultiArray {
+ let arr = try MLMultiArray(shape: [1], dataType: .int32)
+ arr[0] = NSNumber(value: cacheLength)
+ return arr
+ }
+}
+
+// MARK: - MLTensor Access
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+public extension KVCache {
+ /// Cache position as a `[1]` Int32 tensor.
+ var cacheLengthTensor: MLTensor { MLTensor(shape: [1], scalars: [Int32(cacheLength)]) }
+ /// Update-mask as a `[1, maxSeqLength]` Float16 tensor.
+ var kvCacheUpdateMaskTensor: MLTensor { MLTensor(MLShapedArray(kvCacheUpdateMask)) }
+ /// Padding-mask as a `[1, maxSeqLength]` Float16 tensor.
+ var keyPaddingMaskTensor: MLTensor { MLTensor(MLShapedArray(keyPaddingMask)) }
+ /// External key-cache tensor - `nil` for stateful models.
+ var keyCacheTensor: MLTensor? { keyCache.map { MLTensor(MLShapedArray($0)) } }
+ /// External value-cache tensor - `nil` for stateful models.
+ var valueCacheTensor: MLTensor? { valueCache.map { MLTensor(MLShapedArray($0)) } }
+
+ /// Async update from MLTensor outputs - materializes without blocking the cooperative pool.
+ func update(keyTensor: MLTensor, valueTensor: MLTensor) async {
+ let keyArr = await keyTensor.toMLMultiArray()
+ let valArr = await valueTensor.toMLMultiArray()
+ update(keyCacheUpdates: keyArr, valueCacheUpdates: valArr)
+ }
+}
+
+// MARK: - Speech Decoder Cache
+
+/// Extended KV cache for SpeechDecoder with a rolling hidden context buffer.
+///
+/// The hidden context window length varies by quantization variant
+/// (e.g. 4 for W8A16, 16 for W16A16) and is read from the model at load time.
+/// Task-local, never shared across concurrent tasks.
+public class SpeechDecoderCache: KVCache, @unchecked Sendable {
+ public let hiddenContext: MLMultiArray // [1, hiddenDim, 1, contextLen]
+ public let hiddenDim: Int
+ public let hiddenContextLen: Int
+
+ public init(
+ cacheDim: Int = Qwen3TTSConstants.sdCacheDim,
+ maxSeqLength: Int = Qwen3TTSConstants.sdMaxSeq,
+ hiddenDim: Int = Qwen3TTSConstants.sdHiddenDim,
+ hiddenContextLen: Int = Qwen3TTSConstants.sdHiddenContextLen
+ ) throws {
+ self.hiddenDim = hiddenDim
+ self.hiddenContextLen = hiddenContextLen
+ hiddenContext = try MLMultiArray(
+ shape: [1, NSNumber(value: hiddenDim), 1, NSNumber(value: hiddenContextLen)],
+ dataType: .float16
+ )
+ memset(hiddenContext.dataPointer, 0, hiddenDim * hiddenContextLen * MemoryLayout.size)
+ // Speech decoder is currently non-stateful
+ try super.init(cacheDim: cacheDim, maxSeqLength: maxSeqLength, isStateful: false)
+ }
+
+ /// Update KV cache and roll hidden context left, appending new hidden state
+ public func updateWithHiddenContext(output: MLFeatureProvider) {
+ guard let keyCU = output.featureValue(for: "key_cache_updates")?.multiArrayValue,
+ let valCU = output.featureValue(for: "value_cache_updates")?.multiArrayValue
+ else {
+ return
+ }
+ super.update(keyCacheUpdates: keyCU, valueCacheUpdates: valCU)
+
+ // Roll hidden context left and append new state
+ let hidDim = hiddenDim
+ let contextLen = hiddenContextLen
+ let hiddenContextPtr = hiddenContext.dataPointer.bindMemory(to: FloatType.self, capacity: hidDim * contextLen)
+ guard let updateArr = output.featureValue(for: "hidden_context_update")?.multiArrayValue else { return }
+ let hiddenUpdatePtr = updateArr.dataPointer.bindMemory(to: FloatType.self, capacity: updateArr.count)
+ let hiddenUpdateStride = updateArr.strides[1].intValue
+
+ for dim in 0..(hiddenContext)) }
+
+ /// Update KV cache and rolling hidden context from `[String: MLTensor]` prediction outputs.
+ /// Materializes tensors asynchronously to avoid blocking the cooperative thread pool.
+ func updateWithHiddenContext(tensorOutputs: [String: MLTensor]) async {
+ guard let keyUpdateTensor = tensorOutputs["key_cache_updates"],
+ let valueUpdateTensor = tensorOutputs["value_cache_updates"]
+ else {
+ return
+ }
+ await super.update(keyTensor: keyUpdateTensor, valueTensor: valueUpdateTensor)
+
+ let hidDim = hiddenDim
+ let contextLen = hiddenContextLen
+ let ctxPtr = hiddenContext.dataPointer.bindMemory(to: FloatType.self, capacity: hidDim * contextLen)
+ guard let hiddenUpdateTensor = tensorOutputs["hidden_context_update"] else { return }
+ let updateArr = await hiddenUpdateTensor.toMLMultiArray()
+ let updatePtr = updateArr.dataPointer.bindMemory(to: FloatType.self, capacity: updateArr.count)
+ let updateStride = updateArr.strides[1].intValue
+ for dim in 0...size
+
+ let keyUpdatePtr = keyCacheUpdates.dataPointer.bindMemory(to: FloatType.self, capacity: keyCacheUpdates.count)
+ let keyUpdateStride = keyCacheUpdates.strides[1].intValue
+ let valueUpdatePtr = valueCacheUpdates.dataPointer.bindMemory(to: FloatType.self, capacity: valueCacheUpdates.count)
+ let valueUpdateStride = valueCacheUpdates.strides[1].intValue
+
+ state.withMultiArray(for: "self_attn_key_cache") { keyStateCache in
+ let embedDim = keyStateCache.shape[1].intValue
+ keyStateCache.withUnsafeMutableBytes { cachePtr, cacheStrides in
+ guard let baseAddress = cachePtr.baseAddress else { return }
+ for dim in 0.. Bool {
+ self.voice == voice && self.language == language && self.instruction == instruction
+ }
+}
+
+// MARK: - KV Cache Snapshot
+
+/// Serializable snapshot of a `KVCache` state (masks, position counter, and
+/// optionally key/value cache data for non-stateful models).
+public struct KVCacheSnapshot: Sendable {
+ public let isStateful: Bool
+ public let cacheDim: Int
+ public let maxSeqLength: Int
+ public let cacheLength: Int32
+
+ /// Raw bytes from keyCache/valueCache (non-stateful models).
+ /// Empty `Data()` for stateful models where KV data lives in MLState.
+ public let keyCacheData: Data
+ public let valueCacheData: Data
+
+ public let updateMaskData: Data
+ public let paddingMaskData: Data
+}
+
+/// Serializable snapshot of MLState KV buffers (stateful models only).
+public struct KVStateData: Sendable {
+ public let keyData: Data
+ public let valueData: Data
+}
+
+// MARK: - KVCache Snapshot/Restore
+
+public extension KVCache {
+ /// Create a serializable snapshot of the current cache state.
+ func snapshot() -> KVCacheSnapshot {
+ let maskBytes = maxSeqLength * MemoryLayout.size
+
+ return KVCacheSnapshot(
+ isStateful: isStateful,
+ cacheDim: cacheDim,
+ maxSeqLength: maxSeqLength,
+ cacheLength: cacheLength,
+ keyCacheData: keyCache.map { Data(bytes: $0.dataPointer, count: cacheDim * maxSeqLength * MemoryLayout.size) } ?? Data(),
+ valueCacheData: valueCache.map { Data(bytes: $0.dataPointer, count: cacheDim * maxSeqLength * MemoryLayout.size) } ?? Data(),
+ updateMaskData: Data(bytes: kvCacheUpdateMask.dataPointer, count: maskBytes),
+ paddingMaskData: Data(bytes: keyPaddingMask.dataPointer, count: maskBytes)
+ )
+ }
+
+ /// Restore cache state from a snapshot. The snapshot must have matching geometry.
+ func restore(from snapshot: KVCacheSnapshot) {
+ precondition(
+ snapshot.cacheDim == cacheDim && snapshot.maxSeqLength == maxSeqLength,
+ "Cache geometry mismatch: expected (\(cacheDim), \(maxSeqLength)), got (\(snapshot.cacheDim), \(snapshot.maxSeqLength))"
+ )
+
+ cacheLength = snapshot.cacheLength
+
+ let maskBytes = maxSeqLength * MemoryLayout.size
+
+ if let keyCache, let valueCache, !snapshot.keyCacheData.isEmpty {
+ let kvBytes = cacheDim * maxSeqLength * MemoryLayout.size
+ snapshot.keyCacheData.copyBytes(to: keyCache.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.keyCacheData.count, kvBytes))
+ snapshot.valueCacheData.copyBytes(to: valueCache.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.valueCacheData.count, kvBytes))
+ }
+
+ snapshot.updateMaskData.copyBytes(to: kvCacheUpdateMask.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.updateMaskData.count, maskBytes))
+ snapshot.paddingMaskData.copyBytes(to: keyPaddingMask.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.paddingMaskData.count, maskBytes))
+ }
+}
+
+// MARK: - MLState Snapshot/Restore
+
+@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+public extension MLState {
+ /// Capture the current `self_attn_key_cache` and `self_attn_value_cache` buffers as raw `Data`.
+ func snapshot() -> KVStateData {
+ var keyData = Data()
+ withMultiArray(for: "self_attn_key_cache") { arr in
+ arr.withUnsafeMutableBytes { buf, _ in
+ guard let baseAddress = buf.baseAddress else { return }
+ keyData = Data(bytes: baseAddress, count: buf.count)
+ }
+ }
+ var valueData = Data()
+ withMultiArray(for: "self_attn_value_cache") { arr in
+ arr.withUnsafeMutableBytes { buf, _ in
+ guard let baseAddress = buf.baseAddress else { return }
+ valueData = Data(bytes: baseAddress, count: buf.count)
+ }
+ }
+ return KVStateData(keyData: keyData, valueData: valueData)
+ }
+
+ /// Overwrite the `self_attn_key_cache` and `self_attn_value_cache` buffers from a snapshot.
+ func restore(from data: KVStateData) {
+ withMultiArray(for: "self_attn_key_cache") { arr in
+ arr.withUnsafeMutableBytes { buf, _ in
+ guard let bufBase = buf.baseAddress else { return }
+ data.keyData.withUnsafeBytes { src in
+ guard let srcBase = src.baseAddress else { return }
+ bufBase.copyMemory(from: srcBase, byteCount: min(src.count, buf.count))
+ }
+ }
+ }
+ withMultiArray(for: "self_attn_value_cache") { arr in
+ arr.withUnsafeMutableBytes { buf, _ in
+ guard let bufBase = buf.baseAddress else { return }
+ data.valueData.withUnsafeBytes { src in
+ guard let srcBase = src.baseAddress else { return }
+ bufBase.copyMemory(from: srcBase, byteCount: min(src.count, buf.count))
+ }
+ }
+ }
+ }
+}
+
+// MARK: - Disk Persistence
+
+public extension TTSPromptCache {
+ /// Save this prompt cache to disk as a property list.
+ ///
+ /// The file can be reloaded with `TTSPromptCache.load(from:)` as long as
+ /// the model variant and cache geometry haven't changed.
+ func save(to url: URL) throws {
+ let container = CacheContainer(
+ voice: voice,
+ language: language,
+ instruction: instruction,
+ prefixLength: prefixLength,
+ isStateful: kvSnapshot.isStateful,
+ cacheDim: kvSnapshot.cacheDim,
+ maxSeqLength: kvSnapshot.maxSeqLength,
+ cacheLength: kvSnapshot.cacheLength,
+ keyCacheData: kvSnapshot.keyCacheData,
+ valueCacheData: kvSnapshot.valueCacheData,
+ updateMaskData: kvSnapshot.updateMaskData,
+ paddingMaskData: kvSnapshot.paddingMaskData,
+ stateKeyData: stateData?.keyData,
+ stateValueData: stateData?.valueData
+ )
+ let data = try PropertyListEncoder().encode(container)
+ try FileManager.default.createDirectory(at: url.deletingLastPathComponent(), withIntermediateDirectories: true)
+ try data.write(to: url)
+ }
+
+ /// Load a previously saved prompt cache from disk.
+ static func load(from url: URL) throws -> TTSPromptCache {
+ let data = try Data(contentsOf: url)
+ let container = try PropertyListDecoder().decode(CacheContainer.self, from: data)
+ let snapshot = KVCacheSnapshot(
+ isStateful: container.isStateful,
+ cacheDim: container.cacheDim,
+ maxSeqLength: container.maxSeqLength,
+ cacheLength: container.cacheLength,
+ keyCacheData: container.keyCacheData,
+ valueCacheData: container.valueCacheData,
+ updateMaskData: container.updateMaskData,
+ paddingMaskData: container.paddingMaskData
+ )
+ let stateData: KVStateData?
+ if let keyData = container.stateKeyData, let valueData = container.stateValueData {
+ stateData = KVStateData(keyData: keyData, valueData: valueData)
+ } else {
+ stateData = nil
+ }
+ return TTSPromptCache(
+ voice: container.voice,
+ language: container.language,
+ instruction: container.instruction,
+ prefixLength: container.prefixLength,
+ kvSnapshot: snapshot,
+ stateData: stateData
+ )
+ }
+
+ /// File name for this cache based on voice/language/instruction.
+ var cacheFileName: String {
+ var name = "\(voice)_\(language)"
+ if let instruction, !instruction.isEmpty {
+ let hash = instruction.utf8.reduce(into: UInt64(5381)) { $0 = $0 &* 33 &+ UInt64($1) }
+ name += "_\(String(hash, radix: 16))"
+ }
+ return name + ".promptcache"
+ }
+}
+
+/// Codable container for plist serialization (Data fields are stored as binary, not base64).
+private struct CacheContainer: Codable {
+ let voice: String
+ let language: String
+ let instruction: String?
+ let prefixLength: Int
+ let isStateful: Bool
+ let cacheDim: Int
+ let maxSeqLength: Int
+ let cacheLength: Int32
+ let keyCacheData: Data
+ let valueCacheData: Data
+ let updateMaskData: Data
+ let paddingMaskData: Data
+ let stateKeyData: Data?
+ let stateValueData: Data?
+}
diff --git a/Sources/TTSKit/Utilities/Sampling.swift b/Sources/TTSKit/Utilities/Sampling.swift
new file mode 100644
index 00000000..69ebfa97
--- /dev/null
+++ b/Sources/TTSKit/Utilities/Sampling.swift
@@ -0,0 +1,384 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Accelerate
+import ArgmaxCore
+import CoreML
+import Foundation
+
+// MARK: - Token Sampling
+
+/// Protocol for TTS token sampling strategies.
+public protocol TokenSampling {
+ /// Sample a single token from codec-0 logits.
+ /// `logits` is an MLTensor (macOS 15+ async path) or MLMultiArray (sync path).
+ func sampleCodec0(
+ logits: any EmbedTensorType,
+ temperature: Float,
+ topK: Int,
+ generatedTokens: [Int32],
+ repetitionPenalty: Float,
+ suppressTokenIds: Set
+ ) async -> Int32
+
+ /// Sample a single token from one head of multi-code logits.
+ /// `allLogits` is an MLTensor (macOS 15+ async path) or MLMultiArray (sync path).
+ func sampleMultiHead(
+ allLogits: any EmbedTensorType,
+ headIndex: Int,
+ temperature: Float,
+ topK: Int
+ ) async -> Int32
+}
+
+// MARK: - Greedy / Top-k Sampler
+
+/// Greedy / top-k / temperature token sampler with a seedable RNG.
+///
+/// Thread safety: each `Qwen3GenerateTask` owns its own `GreedyTokenSampler`
+/// instance (created per-task in `TTSKit.createTask()` with a derived seed).
+/// The `var rng` is never accessed concurrently because it's single-owner.
+public class GreedyTokenSampler: TokenSampling, @unchecked Sendable {
+ private var rng: any RandomNumberGenerator
+
+ /// Create a sampler with an optional seed for reproducibility.
+ /// - Parameter seed: If provided, uses a deterministic RNG. If nil, uses system RNG.
+ public init(seed: UInt64? = nil) {
+ if let seed {
+ self.rng = SeededRandomNumberGenerator(seed: seed)
+ } else {
+ self.rng = SystemRandomNumberGenerator()
+ }
+ }
+
+ public func sampleCodec0(
+ logits: any EmbedTensorType,
+ temperature: Float,
+ topK: Int,
+ generatedTokens: [Int32],
+ repetitionPenalty: Float,
+ suppressTokenIds: Set
+ ) async -> Int32 {
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
+ let tensor: MLTensor
+ let vocabSize: Int
+ if let logitsTensor = logits as? MLTensor {
+ vocabSize = logitsTensor.shape.last ?? 0
+ tensor = logitsTensor
+ } else {
+ guard let logitsArray = logits as? MLMultiArray else {
+ return Int32(Qwen3TTSConstants.codecBOS)
+ }
+ vocabSize = logitsArray.shape.last?.intValue ?? 0
+ tensor = MLTensor(MLShapedArray(logitsArray))
+ }
+ return await sampleCodec0WithMLTensor(
+ logitsTensor: tensor,
+ vocabSize: vocabSize,
+ temperature: temperature,
+ topK: topK,
+ generatedTokens: generatedTokens,
+ repetitionPenalty: repetitionPenalty,
+ suppressTokenIds: suppressTokenIds
+ )
+ }
+ guard let logitsArray = logits as? MLMultiArray else {
+ return Int32(Qwen3TTSConstants.codecBOS)
+ }
+ return sampleCodec0WithVDSP(
+ logits: logitsArray,
+ temperature: temperature,
+ topK: topK,
+ generatedTokens: generatedTokens,
+ repetitionPenalty: repetitionPenalty,
+ suppressTokenIds: suppressTokenIds
+ )
+ }
+
+ public func sampleMultiHead(
+ allLogits: any EmbedTensorType,
+ headIndex: Int,
+ temperature: Float,
+ topK: Int
+ ) async -> Int32 {
+ if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *),
+ let tensor = allLogits as? MLTensor
+ {
+ // Extract a single head lazily via gathering - the full [1, 15, vocabSize] tensor
+ // stays on device; only the top-k results (~400 B) are downloaded to CPU.
+ return await sampleMultiHeadFromTensor(tensor, headIndex: headIndex, temperature: temperature, topK: topK)
+ }
+
+ // MLMultiArray path: pointer arithmetic for zero-copy head extraction
+ guard let allLogitsArray = allLogits as? MLMultiArray else {
+ return Int32(Qwen3TTSConstants.codecBOS)
+ }
+ let vocabSize = allLogitsArray.shape[2].intValue
+ let ptr = allLogitsArray.dataPointer.bindMemory(to: FloatType.self, capacity: allLogitsArray.count)
+ let stride1 = allLogitsArray.strides[1].intValue
+ let stride2 = allLogitsArray.strides[2].intValue
+ var logitsF = [Float](repeating: 0, count: vocabSize)
+ let base = headIndex * stride1
+
+ if stride2 == 1 {
+ let src = UnsafeBufferPointer(start: ptr.advanced(by: base), count: vocabSize)
+ EmbedUtilities.convertToFloat(src, to: &logitsF)
+ } else {
+ for i in 0..
+ ) -> Int32 {
+ let vocabSize = logits.shape.last?.intValue ?? 0
+ var logitsF = extractFloat32Logits(logits, count: vocabSize)
+ for id in suppressTokenIds where id < vocabSize {
+ logitsF[id] = -.infinity
+ }
+ if repetitionPenalty != 1.0 && !generatedTokens.isEmpty {
+ for tokenId in Set(generatedTokens) {
+ let tokenIndex = Int(tokenId)
+ guard tokenIndex < vocabSize else { continue }
+ logitsF[tokenIndex] = logitsF[tokenIndex] > 0 ? logitsF[tokenIndex] / repetitionPenalty : logitsF[tokenIndex] * repetitionPenalty
+ }
+ }
+ return sampleFromLogits(logitsF, temperature: temperature, topK: topK)
+ }
+
+ /// MLTensor-based codec-0 sampler (macOS 15+).
+ /// `logitsTensor` arrives directly from the model output - no MLMultiArray conversion needed.
+ /// Uses `MLTensor.topK()` - O(n) partial selection - instead of vDSP_vsort's O(n log n) full sort.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ private func sampleCodec0WithMLTensor(
+ logitsTensor: MLTensor,
+ vocabSize: Int,
+ temperature: Float,
+ topK: Int,
+ generatedTokens: [Int32],
+ repetitionPenalty: Float,
+ suppressTokenIds: Set
+ ) async -> Int32 {
+ let needsCPUPass = !suppressTokenIds.isEmpty || (repetitionPenalty != 1.0 && !generatedTokens.isEmpty)
+
+ let processedTensor: MLTensor
+ if needsCPUPass {
+ // Materialise to Float32, apply scalar modifications, re-wrap as [1, vocabSize] tensor.
+ var logitsF = await logitsTensor.reshaped(to: [1, vocabSize]).cast(to: Float.self).toFloatArray()
+ for id in suppressTokenIds where id < vocabSize {
+ logitsF[id] = -.infinity
+ }
+ if repetitionPenalty != 1.0 && !generatedTokens.isEmpty {
+ for tokenId in Set(generatedTokens) {
+ let tokenIndex = Int(tokenId)
+ guard tokenIndex < vocabSize else { continue }
+ logitsF[tokenIndex] = logitsF[tokenIndex] > 0 ? logitsF[tokenIndex] / repetitionPenalty : logitsF[tokenIndex] * repetitionPenalty
+ }
+ }
+ processedTensor = MLTensor(shape: [1, vocabSize], scalars: logitsF, scalarType: Float.self)
+ } else {
+ // Fully lazy path: cast + reshape stay on device until argmax/topK materializes them.
+ // [1, vocabSize] shape is required - argmax on a 1D tensor yields a 0D scalar.
+ processedTensor = logitsTensor.reshaped(to: [1, vocabSize]).cast(to: Float.self)
+ }
+
+ if temperature == 0 {
+ return Int32(await processedTensor.argmax(alongAxis: -1).toIntArray()[0])
+ }
+
+ let probs = (processedTensor / temperature).softmax(alongAxis: -1)
+ return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK)
+ }
+
+ /// Extract a single head from the multi-code logit tensor and sample from it (macOS 15+).
+ /// The full [1, numHeads, vocabSize] tensor stays on device; `gathering` selects the head
+ /// lazily so that only the top-k results (~400 B) need to be downloaded to CPU.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ private func sampleMultiHeadFromTensor(
+ _ allLogits: MLTensor,
+ headIndex: Int,
+ temperature: Float,
+ topK: Int
+ ) async -> Int32 {
+ let vocabSize = allLogits.shape[2]
+ // gathering along axis 1: [1, numHeads, vocabSize] -> [1, 1, vocabSize]
+ let headLogits = allLogits.gathering(
+ atIndices: MLTensor(shape: [1], scalars: [Int32(headIndex)], scalarType: Int32.self),
+ alongAxis: 1
+ )
+ if temperature == 0 {
+ return Int32(await headLogits.cast(to: Float.self).argmax(alongAxis: -1).toIntArray()[0])
+ }
+ let probs = (headLogits.cast(to: Float.self) / temperature).softmax(alongAxis: -1)
+ return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK)
+ }
+
+ /// Shared topK multinomial sampler over an already-softmaxed probability MLTensor.
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ private func sampleFromProbs(_ probs: MLTensor, vocabSize: Int, topK: Int) async -> Int32 {
+ if topK > 0 && topK < vocabSize {
+ // Partial selection: O(n) vs O(n log n) full sort
+ let (topKProbs, topKIndices) = probs.topK(topK)
+ let probsArray = await topKProbs.toFloatArray()
+ let idxArray = await topKIndices.toIntArray()
+ let probSum = probsArray.reduce(0, +)
+ let randomValue = Float.random(in: 0..= randomValue { return Int32(idxArray[i]) }
+ }
+ return idxArray.last.map(Int32.init) ?? Int32(vocabSize - 1)
+ } else {
+ let probsArray = await probs.toFloatArray()
+ let randomValue = Float.random(in: 0..<1, using: &rng)
+ var cumulativeSum: Float = 0
+ for (i, probability) in probsArray.enumerated() {
+ cumulativeSum += probability
+ if cumulativeSum >= randomValue { return Int32(i) }
+ }
+ return Int32(vocabSize - 1)
+ }
+ }
+
+ /// MLTensor sampling from a pre-extracted Float32 logits array (macOS 15+).
+ @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
+ private func sampleFromLogitsWithMLTensor(_ logits: [Float], temperature: Float, topK: Int) async -> Int32 {
+ let vocabSize = logits.count
+
+ if temperature == 0 {
+ var maxValue: Float = 0
+ var maxIndex: vDSP_Length = 0
+ vDSP_maxvi(logits, 1, &maxValue, &maxIndex, vDSP_Length(vocabSize))
+ return Int32(maxIndex)
+ }
+
+ // [1, vocabSize] keeps topK/softmax results at least 2D so toFloatArray()/toIntArray() are safe
+ let logitsTensor = MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self)
+ let probs = (logitsTensor / temperature).softmax(alongAxis: -1)
+ return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK)
+ }
+
+ // MARK: - Private helpers
+
+ private func extractFloat32Logits(_ arr: MLMultiArray, count: Int) -> [Float] {
+ let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: arr.count)
+ let lastStride = arr.strides.last?.intValue ?? 1
+ var result = [Float](repeating: 0, count: count)
+ if lastStride == 1 {
+ let src = UnsafeBufferPointer(start: ptr, count: count)
+ EmbedUtilities.convertToFloat(src, to: &result)
+ } else {
+ for i in 0.. Int32 {
+ var mutableLogits = logits
+ let vocabSize = mutableLogits.count
+
+ if temperature == 0 {
+ var maxValue: Float = 0
+ var maxIndex: vDSP_Length = 0
+ vDSP_maxvi(mutableLogits, 1, &maxValue, &maxIndex, vDSP_Length(vocabSize))
+ return Int32(maxIndex)
+ }
+
+ var temperatureScalar = temperature
+ vDSP_vsdiv(mutableLogits, 1, &temperatureScalar, &mutableLogits, 1, vDSP_Length(vocabSize))
+
+ if topK > 0 && topK < vocabSize {
+ var sorted = mutableLogits
+ vDSP_vsort(&sorted, vDSP_Length(vocabSize), -1)
+ let threshold = sorted[topK - 1]
+ for i in 0.. 0 {
+ vDSP_vsdiv(mutableLogits, 1, &sum, &mutableLogits, 1, vDSP_Length(vocabSize))
+ }
+
+ let randomValue = Float.random(in: 0..<1, using: &rng)
+ var cumulativeSum: Float = 0
+ for i in 0..= randomValue { return Int32(i) }
+ }
+ return Int32(vocabSize - 1)
+ }
+}
+
+// MARK: - Seedable RNG
+
+/// A seedable random number generator using xoshiro256 algorithm.
+/// Produces deterministic sequences for a given seed.
+public struct SeededRandomNumberGenerator: RandomNumberGenerator {
+ private var state: (UInt64, UInt64, UInt64, UInt64)
+
+ public init(seed: UInt64) {
+ var z = seed &+ 0x9E37_79B9_7F4A_7C15
+ z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
+ z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
+ let s0 = z ^ (z >> 31)
+
+ z = (seed &+ 2 &* 0x9E37_79B9_7F4A_7C15)
+ z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
+ z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
+ let s1 = z ^ (z >> 31)
+
+ z = (seed &+ 3 &* 0x9E37_79B9_7F4A_7C15)
+ z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
+ z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
+ let s2 = z ^ (z >> 31)
+
+ z = (seed &+ 4 &* 0x9E37_79B9_7F4A_7C15)
+ z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9
+ z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB
+ let s3 = z ^ (z >> 31)
+
+ state = (s0, s1, s2, s3)
+ }
+
+ public mutating func next() -> UInt64 {
+ let result = rotl(state.1 &* 5, 7) &* 9
+ let shifted = state.1 << 17
+ state.2 ^= state.0
+ state.3 ^= state.1
+ state.1 ^= state.2
+ state.0 ^= state.3
+ state.2 ^= shifted
+ state.3 = rotl(state.3, 45)
+ return result
+ }
+
+ private func rotl(_ x: UInt64, _ k: Int) -> UInt64 {
+ (x << k) | (x >> (64 - k))
+ }
+}
diff --git a/Sources/TTSKit/Utilities/TTSError.swift b/Sources/TTSKit/Utilities/TTSError.swift
new file mode 100644
index 00000000..19a0391e
--- /dev/null
+++ b/Sources/TTSKit/Utilities/TTSError.swift
@@ -0,0 +1,27 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+
+@frozen
+public enum TTSError: Error, LocalizedError, Equatable {
+ case emptyText
+ case modelNotFound(String)
+ case generationFailed(String)
+ case tokenizerUnavailable(String)
+ case audioOutputFailed(String)
+ /// A required component URL could not be resolved from the config.
+ /// Thrown at `loadModels()` time so callers fail early with a clear message.
+ case invalidConfiguration(String)
+
+ public var errorDescription: String? {
+ switch self {
+ case .emptyText: return "Input text is empty"
+ case let .modelNotFound(path): return "Model directory not found: \(path)"
+ case let .generationFailed(msg): return "Generation failed: \(msg)"
+ case let .tokenizerUnavailable(msg): return "Tokenizer unavailable: \(msg)"
+ case let .audioOutputFailed(msg): return "Audio output failed: \(msg)"
+ case let .invalidConfiguration(msg): return "Invalid TTSKit configuration: \(msg)"
+ }
+ }
+}
diff --git a/Sources/TTSKit/Utilities/TextChunker.swift b/Sources/TTSKit/Utilities/TextChunker.swift
new file mode 100644
index 00000000..b5249340
--- /dev/null
+++ b/Sources/TTSKit/Utilities/TextChunker.swift
@@ -0,0 +1,115 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2026 Argmax, Inc. All rights reserved.
+
+import Foundation
+import Tokenizers
+
+/// Strategy for splitting long text into chunks.
+@frozen
+public enum TextChunkingStrategy: String, Codable, CaseIterable, Sendable {
+ /// No chunking, generate the full text in a single pass.
+ case none
+ /// Split at sentence boundaries using TextChunker.
+ case sentence
+}
+
+/// Splits long text into sentence-bounded chunks suitable for chunked TTS generation.
+///
+/// Uses natural sentence boundaries (. ! ? and newlines) to avoid splitting mid-sentence,
+/// with a fallback to clause boundaries (, ; :) and then word boundaries for very long sentences.
+///
+/// Algorithm: tokenize the full text, decode windows of `targetChunkSize` tokens,
+/// find the last natural boundary in each decoded window, then re-tokenize the accepted
+/// chunk to advance by the exact token count - no character estimation needed.
+///
+/// `defaultTargetChunkSize` / `defaultMinChunkSize` are the single source of truth for
+/// these defaults - all other call sites (TTSGenerationOptions, SpeakAX) reference them.
+public struct TextChunker {
+ /// Default target chunk size (tokens).
+ public static let defaultTargetChunkSize: Int = 42
+
+ /// Default minimum chunk size (tokens).
+ public static let defaultMinChunkSize: Int = 10
+
+ /// Target chunk size in tokens. Actual chunks may be shorter (sentence boundary)
+ /// or slightly longer (no boundary found within the window).
+ public let targetChunkSize: Int
+
+ /// Minimum chunk size in tokens. Short trailing segments are merged into
+ /// the previous chunk to avoid tiny segments with poor prosody.
+ public let minChunkSize: Int
+
+ private let encodeText: (String) -> [Int]
+ private let decodeTokens: ([Int]) -> String
+
+ public init(
+ targetChunkSize: Int = TextChunker.defaultTargetChunkSize,
+ minChunkSize: Int = TextChunker.defaultMinChunkSize,
+ tokenizer: any Tokenizer
+ ) {
+ self.targetChunkSize = targetChunkSize
+ self.minChunkSize = minChunkSize
+ self.encodeText = { tokenizer.encode(text: $0) }
+ self.decodeTokens = { tokenizer.decode(tokens: $0) }
+ }
+
+ /// Internal init for testing - accepts raw encode/decode functions so tests
+ /// don't need a heavyweight `Tokenizer` conformance.
+ init(
+ targetChunkSize: Int = TextChunker.defaultTargetChunkSize,
+ minChunkSize: Int = TextChunker.defaultMinChunkSize,
+ encode: @escaping (String) -> [Int],
+ decode: @escaping ([Int]) -> String
+ ) {
+ self.targetChunkSize = targetChunkSize
+ self.minChunkSize = minChunkSize
+ self.encodeText = encode
+ self.decodeTokens = decode
+ }
+
+ /// Split text into chunks at natural boundaries, respecting the token budget.
+ /// Returns an array of non-empty text chunks.
+ public func chunk(_ text: String) -> [String] {
+ let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
+ guard !trimmed.isEmpty else { return [] }
+
+ var tokens = encodeText(trimmed)
+ guard tokens.count > targetChunkSize else { return [trimmed] }
+
+ var chunks: [String] = []
+
+ while !tokens.isEmpty {
+ if tokens.count <= targetChunkSize {
+ let tail = decodeTokens(tokens)
+ .trimmingCharacters(in: .whitespacesAndNewlines)
+ if !tail.isEmpty {
+ // Merge tiny trailing segment with previous chunk to avoid poor prosody
+ if encodeText(tail).count < minChunkSize, let last = chunks.last {
+ chunks[chunks.count - 1] = last + " " + tail
+ } else {
+ chunks.append(tail)
+ }
+ }
+ break
+ }
+
+ // Decode a window of targetChunkSize tokens, then find the best boundary within it
+ let window = Array(tokens.prefix(targetChunkSize))
+ let windowText = decodeTokens(window)
+ .trimmingCharacters(in: .whitespacesAndNewlines)
+
+ let accepted = windowText.lastNaturalBoundary(minTokenCount: minChunkSize, encode: encodeText) ?? windowText
+
+ if !accepted.isEmpty {
+ chunks.append(accepted)
+ }
+
+ // Re-tokenize the accepted text to advance by its exact token count,
+ // avoiding drift from imperfect encode/decode round-trips
+ let consumed = encodeText(accepted).count
+ tokens.removeFirst(min(max(consumed, 1), tokens.count))
+ }
+
+ return chunks
+ }
+}
diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift
index 79e286ff..e005bab8 100644
--- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift
+++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift
@@ -26,7 +26,7 @@ public extension AudioChunking {
let updatedSegment = TranscriptionUtilities.updateSegmentTimings(segment: segment, seekTime: seekTime)
updatedSegments.append(updatedSegment)
}
- var updatedResult = result
+ let updatedResult = result
updatedResult.seekTime = seekTime
updatedResult.segments = updatedSegments
updatedTranscriptionResults.append(updatedResult)
diff --git a/Sources/WhisperKit/Core/Configurations.swift b/Sources/WhisperKit/Core/Configurations.swift
index a52cf28d..fa59ca9f 100644
--- a/Sources/WhisperKit/Core/Configurations.swift
+++ b/Sources/WhisperKit/Core/Configurations.swift
@@ -13,7 +13,8 @@ open class WhisperKitConfig {
public var modelRepo: String?
/// Token for downloading models from repo (if required)
public var modelToken: String?
-
+ /// HuggingFace Hub compatible endpoint URL
+ public var modelEndpoint: String?
/// Folder to store models
public var modelFolder: String?
/// Folder to store tokenizers
@@ -53,11 +54,11 @@ open class WhisperKitConfig {
/// model gets loaded sequentially and unloaded immediately to trigger specialization if necessary.
///
/// **Trade-offs**
- /// - **Pro** — The peak memory usage during compilation is reduced because
+ /// - **Pro** - The peak memory usage during compilation is reduced because
/// only one model is kept in memory at any given point. Otherwise, the
/// peak memory will bloat to all model weights combined plus the peak
/// compilation memory (higher than model weights).
- /// - **Con** — The load time will be multiplied by 2 (usually <1s when cache is hit)
+ /// - **Con** - The load time will be multiplied by 2 (usually <1s when cache is hit)
/// because of the load-unload-load pattern when the specialized model file cache is
/// actually hit and prewarm does not trigger specialization
///
@@ -75,6 +76,7 @@ open class WhisperKitConfig {
downloadBase: URL? = nil,
modelRepo: String? = nil,
modelToken: String? = nil,
+ modelEndpoint: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
@@ -97,6 +99,7 @@ open class WhisperKitConfig {
self.downloadBase = downloadBase
self.modelRepo = modelRepo
self.modelToken = modelToken
+ self.modelEndpoint = modelEndpoint
self.modelFolder = modelFolder
self.tokenizerFolder = tokenizerFolder
self.computeOptions = computeOptions
diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift
index 7712979f..1eee6beb 100644
--- a/Sources/WhisperKit/Core/Models.swift
+++ b/Sources/WhisperKit/Core/Models.swift
@@ -7,17 +7,7 @@ import CoreML
import Hub
import NaturalLanguage
import Tokenizers
-
-#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64))
-public typealias FloatType = Float16
-#else
-public typealias FloatType = Float
-#endif
-
-#if (os(macOS) || targetEnvironment(macCatalyst)) && arch(arm64) && compiler(<6)
-extension Float16: BNNSScalar {}
-extension Float16: MLShapedArrayScalar {}
-#endif
+@_exported import ArgmaxCore
// MARK: - CoreML
@@ -101,38 +91,7 @@ public enum ModelVariant: CustomStringConvertible, CaseIterable {
}
}
-@frozen
-public enum ModelState: CustomStringConvertible {
- case unloading
- case unloaded
- case loading
- case loaded
- case prewarming
- case prewarmed
- case downloading
- case downloaded
-
- public var description: String {
- switch self {
- case .unloading:
- return "Unloading"
- case .unloaded:
- return "Unloaded"
- case .loading:
- return "Loading"
- case .loaded:
- return "Loaded"
- case .prewarming:
- return "Specializing"
- case .prewarmed:
- return "Specialized"
- case .downloading:
- return "Downloading"
- case .downloaded:
- return "Downloaded"
- }
- }
-}
+// ModelState is defined in ArgmaxCore/ModelState.swift and re-exported here.
public struct ModelComputeOptions: Sendable {
public var melCompute: MLComputeUnits
@@ -502,10 +461,11 @@ public struct DecodingResult {
}
/// Reference-type container for transcription output.
-/// The stored properties stay thread-safe because each one uses
-/// `TranscriptionPropertyLock`, so reads/writes hop through a private `NSLock`
-/// before the value is accessed, making this shared `@unchecked Sendable` class
-/// safe to hand across concurrent contexts.
+///
+/// Each property is protected by its own `TranscriptionPropertyLock`, which
+/// serializes whole-value reads and writes. Atomic whole-value replacement is
+/// thread-safe; read-modify-write operations (e.g. `result.segments.append(...)`)
+/// are not - callers must use external synchronisation.
open class TranscriptionResult: Codable, @unchecked Sendable {
@TranscriptionPropertyLock public var text: String
@TranscriptionPropertyLock public var segments: [TranscriptionSegment]
@@ -732,11 +692,6 @@ public struct TranscriptionProgress: Sendable {
public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void
/// A callback that reports changes in the model's state.
-/// - Parameters:
-/// - oldState: The previous state of the model, if any
-/// - newState: The current state of the model
-public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void
-
/// A callback that reports changes in the transcription process.
/// - Parameter state: The current `TranscriptionState` of the transcription process
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void
diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift
index 4d717b6e..65815366 100644
--- a/Sources/WhisperKit/Core/TextDecoder.swift
+++ b/Sources/WhisperKit/Core/TextDecoder.swift
@@ -2,6 +2,7 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.
import Accelerate
+import ArgmaxCore
import CoreML
import Tokenizers
diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift
index 5349ee99..0cc30004 100644
--- a/Sources/WhisperKit/Core/WhisperKit.swift
+++ b/Sources/WhisperKit/Core/WhisperKit.swift
@@ -77,7 +77,8 @@ open class WhisperKit {
modelRepo: config.modelRepo,
modelToken: config.modelToken,
modelFolder: config.modelFolder,
- download: config.download
+ download: config.download,
+ endpoint: config.modelEndpoint ?? Constants.defaultRemoteEndpoint
)
if let prewarm = config.prewarm, prewarm {
@@ -173,6 +174,7 @@ open class WhisperKit {
remoteConfigName: remoteConfigName,
endpoint: endpoint
)
+
return ModelUtilities.modelSupport(for: deviceName, from: config)
}
diff --git a/Sources/WhisperKit/Utilities/Concurrency.swift b/Sources/WhisperKit/Utilities/Concurrency.swift
index 92b82f74..57bdcafe 100644
--- a/Sources/WhisperKit/Utilities/Concurrency.swift
+++ b/Sources/WhisperKit/Utilities/Concurrency.swift
@@ -1,89 +1,8 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.
-import Foundation
-import os.lock
+import ArgmaxCore
-/// An actor that provides thread-safe early stopping functionality using UUIDs as keys
-public actor EarlyStopActor {
- private var shouldStop = [UUID: Bool]()
-
- public init() {}
-
- /// Sets the stop flag for a given UUID
- /// - Parameters:
- /// - value: The boolean value to set
- /// - uuid: The UUID key
- public func set(_ value: Bool, for uuid: UUID) {
- shouldStop[uuid] = value
- }
-
- /// Gets the stop flag for a given UUID
- /// - Parameter uuid: The UUID key
- /// - Returns: The current stop flag value, or false if not set
- public func get(for uuid: UUID) -> Bool {
- return shouldStop[uuid] ?? false
- }
-
- /// Removes and returns the stop flag for a given UUID
- /// - Parameter uuid: The UUID key
- /// - Returns: The removed stop flag value, if it existed
- public func remove(for uuid: UUID) -> Bool? {
- return shouldStop.removeValue(forKey: uuid)
- }
-}
-
-/// Serializes access to a value with an `os_unfair_lock` so mutation stays
-/// thread-safe. The wrapper is used by `TranscriptionResult`, which is marked
-/// `@unchecked Sendable`; guarding each property with this lock helps keep the
-/// result instance safe when shared across concurrent contexts.
-@propertyWrapper
-public struct TranscriptionPropertyLock: Sendable, Codable {
- private let lock: UnfairLock
- private var value: Value
-
- public init(wrappedValue: Value) {
- self.lock = UnfairLock()
- self.value = wrappedValue
- }
- public init(from decoder: Swift.Decoder) throws {
- self.lock = UnfairLock()
- self.value = try Value(from: decoder)
- }
-
- public func encode(to encoder: Encoder) throws {
- try lock.withLock {
- try value.encode(to: encoder)
- }
-
- }
-
- public var wrappedValue: Value {
- get {
- lock.withLock {
- return value
- }
- }
- set {
- lock.withLock {
- value = newValue
- }
- }
- }
-}
-
-/// Thin wrapper around `os_unfair_lock` that exposes a Swift-friendly
-/// `withLock` helper. This lock is non-reentrant and optimized for low
-/// contention, matching the semantics of Core Foundation’s unfair lock.
-@usableFromInline
-final class UnfairLock: @unchecked Sendable {
- @usableFromInline
- var lock = os_unfair_lock()
-
- @inlinable
- func withLock(_ body: () throws -> T) rethrows -> T {
- os_unfair_lock_lock(&lock)
- defer { os_unfair_lock_unlock(&lock) }
- return try body()
- }
-}
+// Backward compatibility: TranscriptionPropertyLock is now PropertyLock in ArgmaxCore.
+// Existing consumers can continue using the old name.
+public typealias TranscriptionPropertyLock = PropertyLock
diff --git a/Sources/WhisperKit/Utilities/Extensions+Internal.swift b/Sources/WhisperKit/Utilities/Extensions+Internal.swift
index b99af3bf..0f4432c1 100644
--- a/Sources/WhisperKit/Utilities/Extensions+Internal.swift
+++ b/Sources/WhisperKit/Utilities/Extensions+Internal.swift
@@ -4,32 +4,6 @@
import AVFoundation
import CoreML
-extension MLMultiArray {
- /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
- static func from(_ array: [Int], dims: Int = 1) throws -> MLMultiArray {
- var shape = Array(repeating: 1, count: dims)
- shape[shape.count - 1] = array.count
- /// Examples:
- /// dims=1 : [arr.count]
- /// dims=2 : [1, arr.count]
- ///
- let output = try MLMultiArray(shape: shape as [NSNumber], dataType: .int32)
- let pointer = UnsafeMutablePointer(OpaquePointer(output.dataPointer))
- for (i, item) in array.enumerated() {
- pointer[i] = Int32(item)
- }
- return output
- }
-}
-
-extension Array {
- func batched(into size: Int) -> [[Element]] {
- return stride(from: 0, to: count, by: size).map {
- Array(self[$0.. {
/// Convenience method to convert the `Result` object into an array of optional arrays of `TranscriptionResult`.
/// - Returns: An array of optional arrays containing `TranscriptionResult`.
@@ -38,41 +12,6 @@ extension Array where Element == Result<[TranscriptionResult], Swift.Error> {
}
}
-extension Array where Element: Hashable {
- /// Returns an array with duplicates removed, preserving the original order.
- var orderedSet: [Element] {
- var seen = Set()
- return self.filter { element in
- if seen.contains(element) {
- return false
- } else {
- seen.insert(element)
- return true
- }
- }
- }
-}
-
-extension String {
- /// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Tokenizers/Decoder.swift#L267-L275
- func trimmingFromEnd(character: Character = " ", upto: Int) -> String {
- var result = self
- var trimmed = 0
- while trimmed < upto && result.last == character {
- result.removeLast()
- trimmed += 1
- }
- return result
- }
-}
-
-extension [String] {
- /// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Hub/HubApi.swift#L983-L987
- func matching(glob: String) -> [String] {
- filter { fnmatch(glob, $0, 0) == 0 }
- }
-}
-
extension AVAudioPCMBuffer {
/// Converts the buffer to a float array
func asFloatArray() throws -> [Float] {
diff --git a/Sources/WhisperKit/Utilities/Extensions+Public.swift b/Sources/WhisperKit/Utilities/Extensions+Public.swift
index 0cca123e..c109c3be 100644
--- a/Sources/WhisperKit/Utilities/Extensions+Public.swift
+++ b/Sources/WhisperKit/Utilities/Extensions+Public.swift
@@ -20,13 +20,6 @@ public extension WhisperKit {
}
}
-public extension Float {
- func rounded(_ decimalPlaces: Int) -> Float {
- let divisor = pow(10.0, Float(decimalPlaces))
- return (self * divisor).rounded() / divisor
- }
-}
-
public extension String {
var normalized: String {
// Convert to lowercase
@@ -52,245 +45,6 @@ public extension String {
}
}
-// MARK: CoreML
-
-public extension MLMultiArray {
- convenience init(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) throws {
- switch dataType {
- case .float16:
- // IOSurface-backed arrays are implicitly float16. They can
- // reduce buffer copies for some OS:compute unit combinations.
- guard let pixelBuffer = Self.pixelBuffer(for: shape) else {
- throw WhisperError.initializationError("MLMultiArray: Failed to initialize PixelBuffer")
- }
- self.init(pixelBuffer: pixelBuffer, shape: shape)
- default:
- try self.init(shape: shape, dataType: dataType)
- }
-
- switch dataType {
- case .double:
- if let value = initialValue as? Double {
- let typedPointer = dataPointer.bindMemory(to: Double.self, capacity: count)
- typedPointer.initialize(repeating: value, count: count)
- }
- case .float32:
- if let value = initialValue as? Float {
- let typedPointer = dataPointer.bindMemory(to: Float.self, capacity: count)
- typedPointer.initialize(repeating: value, count: count)
- }
- case .float16:
- if let value = initialValue as? FloatType {
- let typedPointer = dataPointer.bindMemory(to: FloatType.self, capacity: count)
- typedPointer.initialize(repeating: value, count: count)
- }
- case .int32:
- if let value = initialValue as? Int32 {
- let typedPointer = dataPointer.bindMemory(to: Int32.self, capacity: count)
- typedPointer.initialize(repeating: value, count: count)
- }
- @unknown default:
- fatalError("Unsupported data type")
- }
- }
-
- /// Calculate the linear offset by summing the products of each dimension's index with the dimension's stride.
- /// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript)
- /// - Parameters:
- /// - index: The index of the element
- /// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes.
- @inline(__always)
- func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int {
- var linearOffset = 0
- let strideInts = strideInts ?? strides.map { $0.intValue }
- for (dimension, stride) in zip(index, strideInts) {
- linearOffset += dimension.intValue * stride
- }
- return linearOffset
- }
-
- func fillLastDimension(indexes: Range, with value: FloatType) {
- precondition(shape.count == 3 && shape[0] == 1 && shape[1] == 1, "Must have [1, 1, n] shape")
- withUnsafeMutableBufferPointer(ofType: FloatType.self) { ptr, strides in
- for index in indexes {
- ptr[index * strides[2]] = value
- }
- }
- }
-
- func fill(indexes: [[NSNumber]], with value: Value) {
- let pointer = UnsafeMutablePointer(OpaquePointer(dataPointer))
- let strideInts = strides.map { $0.intValue }
- for index in indexes {
- let linearOffset = linearOffset(for: index, strides: strideInts)
- pointer[linearOffset] = value
- }
- }
-
- private class func pixelBuffer(for shape: [NSNumber]) -> CVPixelBuffer? {
- guard let width = shape.last?.intValue else { return nil }
- let height = shape[0.. [Int] {
- let semaphore = DispatchSemaphore(value: 0)
- var result: [Int] = []
-
- Task(priority: .high) {
- result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) }
- semaphore.signal()
- }
-
- semaphore.wait()
- return result
- }
-
- func asFloatArray() -> [Float] {
- let semaphore = DispatchSemaphore(value: 0)
- let tensorType = self.scalarType
-
- var result: [Float] = []
-
- Task(priority: .high) {
- switch tensorType {
- case is Float32.Type:
- result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) }
- case is FloatType.Type:
- result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) }
- case is Float.Type:
- result = await self.shapedArray(of: Float.self).scalars.map { Float($0) }
- case is Int32.Type:
- result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) }
- default:
- fatalError("Unsupported data type")
- }
- semaphore.signal()
- }
-
- semaphore.wait()
- return result
- }
-
- func asMLMultiArray() -> MLMultiArray {
- let semaphore = DispatchSemaphore(value: 0)
- let tensorType = self.scalarType
-
- var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0)
-
- Task(priority: .high) {
- switch tensorType {
- case is Float32.Type:
- result = MLMultiArray(await self.shapedArray(of: Float32.self))
- case is FloatType.Type:
- result = MLMultiArray(await self.shapedArray(of: FloatType.self))
- case is Float.Type:
- result = MLMultiArray(await self.shapedArray(of: Float.self))
- case is Int32.Type:
- result = MLMultiArray(await self.shapedArray(of: Int32.self))
- default:
- fatalError("Unsupported data type")
- }
- semaphore.signal()
- }
-
- semaphore.wait()
- return result
- }
-}
-#endif
-
-public extension MLModel {
- func asyncPrediction(
- from input: MLFeatureProvider,
- options: MLPredictionOptions
- ) async throws -> MLFeatureProvider {
- if #available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) {
- return try await prediction(from: input, options: options)
- } else {
- return try await Task {
- try prediction(from: input, options: options)
- }.value
- }
- }
-}
-
-public extension MLComputeUnits {
- var description: String {
- switch self {
- case .cpuOnly:
- return "cpuOnly"
- case .cpuAndGPU:
- return "cpuAndGPU"
- case .all:
- return "all"
- case .cpuAndNeuralEngine:
- return "cpuAndNeuralEngine"
- @unknown default:
- return "unknown"
- }
- }
-}
-
-#if os(macOS) || targetEnvironment(simulator)
-// From: https://stackoverflow.com/a/71726663
-public extension ProcessInfo {
- static func stringFromSysctl(named name: String) -> String {
- var size: size_t = 0
- sysctlbyname(name, nil, &size, nil, 0)
- var machineModel = [CChar](repeating: 0, count: Int(size))
- sysctlbyname(name, &machineModel, &size, nil, 0)
- return String(cString: machineModel)
- }
-
- static let processor = stringFromSysctl(named: "machdep.cpu.brand_string")
- static let cores = stringFromSysctl(named: "machdep.cpu.core_count")
- static let threads = stringFromSysctl(named: "machdep.cpu.thread_count")
- static let vendor = stringFromSysctl(named: "machdep.cpu.vendor")
- static let family = stringFromSysctl(named: "machdep.cpu.family")
- static let hwModel = stringFromSysctl(named: "hw.model")
-}
-#endif
-
-// MARK: FileManager
-
-public extension FileManager {
- static func resolveAbsolutePath(_ inputPath: String) -> String {
- let fileManager = FileManager.default
-
- // Expanding tilde if present
- let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath
-
- // If the path is already absolute, return it
- if pathWithTildeExpanded.hasPrefix("/") {
- return pathWithTildeExpanded
- }
-
- // Resolving relative path based on the current working directory
- if let cwd = fileManager.currentDirectoryPath as String? {
- let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path
- return resolvedPath
- }
-
- return inputPath
- }
-}
-
@available(*, deprecated, message: "Subject to removal in a future version. Use `FileManager.resolveAbsolutePath(_:)` instead.")
public func resolveAbsolutePath(_ inputPath: String) -> String {
return FileManager.resolveAbsolutePath(inputPath)
diff --git a/Sources/WhisperKit/Utilities/Logging.swift b/Sources/WhisperKit/Utilities/Logging.swift
index bea4baeb..5819f278 100644
--- a/Sources/WhisperKit/Utilities/Logging.swift
+++ b/Sources/WhisperKit/Utilities/Logging.swift
@@ -1,92 +1,10 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.
+import ArgmaxCore
import OSLog
-open class Logging {
- public static let shared = Logging()
- public var logLevel: LogLevel = .none
-
- public typealias LoggingCallback = (_ message: String) -> Void
- public var loggingCallback: LoggingCallback?
-
- private let logger = OSLog(subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.whisperkit", category: "WhisperKit")
-
- @frozen
- public enum LogLevel: Int {
- case debug = 1
- case info = 2
- case error = 3
- case none = 4
-
- func shouldLog(level: LogLevel) -> Bool {
- return self.rawValue <= level.rawValue
- }
- }
-
- private init() {}
-
- public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) {
- let message = items.map { "\($0)" }.joined(separator: separator)
- if let logger = loggingCallback {
- logger(message)
- } else {
- os_log("%{public}@", log: logger, type: type, message)
- }
- }
-
- public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") {
- if shared.logLevel.shouldLog(level: .debug) {
- shared.log(items, separator: separator, terminator: terminator, type: .debug)
- }
- }
-
- public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") {
- if shared.logLevel.shouldLog(level: .info) {
- shared.log(items, separator: separator, terminator: terminator, type: .info)
- }
- }
-
- public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") {
- if shared.logLevel.shouldLog(level: .error) {
- shared.log(items, separator: separator, terminator: terminator, type: .error)
- }
- }
-}
-
-public extension Logging {
- static func logCurrentMemoryUsage(_ message: String) {
- let memoryUsage = getMemoryUsage()
- Logging.debug("\(message) - Memory usage: \(memoryUsage) MB")
- }
-
- static func getMemoryUsage() -> UInt64 {
- var info = mach_task_basic_info()
- var count = mach_msg_type_number_t(MemoryLayout.size) / 4
-
- let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) {
- $0.withMemoryRebound(to: integer_t.self, capacity: 1) {
- task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count)
- }
- }
-
- guard kerr == KERN_SUCCESS else {
- return 0 // If the call fails, return 0
- }
-
- return info.resident_size / 1024 / 1024 // Convert to MB
- }
-}
-
-@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.logCurrentMemoryUsage(_:)` instead.")
-public func logCurrentMemoryUsage(_ message: String) {
- Logging.logCurrentMemoryUsage(message)
-}
-
-@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.getMemoryUsage()` instead.")
-public func getMemoryUsage() -> UInt64 {
- return Logging.getMemoryUsage()
-}
+// MARK: - WhisperKit signpost categories
extension Logging {
enum AudioEncoding {
@@ -133,11 +51,5 @@ extension Logging {
return String(format: "%.2f", timestamp)
}
- static func formatTimeWithPercentage(_ time: Double, _ runs: Double, _ fullPipelineDuration: Double) -> String {
- let percentage = (time * 1000 / fullPipelineDuration) * 100 // Convert to percentage
- let runTime = runs > 0 ? time * 1000 / Double(runs) : 0
- let formattedString = String(format: "%8.2f ms / %6.0f runs (%8.2f ms/run) %5.2f%%", time * 1000, runs, runTime, percentage)
- return formattedString
- }
+// formatTimeWithPercentage is defined in ArgmaxCore/Logging.swift and re-exported here.
}
-
diff --git a/Sources/WhisperKit/Utilities/ModelUtilities.swift b/Sources/WhisperKit/Utilities/ModelUtilities.swift
index 5af4d950..be5296cd 100644
--- a/Sources/WhisperKit/Utilities/ModelUtilities.swift
+++ b/Sources/WhisperKit/Utilities/ModelUtilities.swift
@@ -1,15 +1,14 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.
+import ArgmaxCore
import CoreML
import Hub
import Tokenizers
-public struct ModelUtilities {
+extension ModelUtilities {
- private init() {}
-
- // MARK: Public
+ // MARK: - WhisperKit Model Support
public static func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> ModelSupport {
let config = config ?? Constants.fallbackModelSupportConfig
@@ -85,22 +84,6 @@ public struct ModelUtilities {
at: hubTokenizerFolder
)
}
-
- public static func detectModelURL(inFolder path: URL, named modelName: String) -> URL {
- let compiledUrl = path.appending(path: "\(modelName).mlmodelc")
- let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel")
-
- let compiledModelExists: Bool = FileManager.default.fileExists(atPath: compiledUrl.path)
- let packageModelExists: Bool = FileManager.default.fileExists(atPath: packageUrl.path)
-
- // Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc
- var modelURL = compiledUrl
- if packageModelExists && !compiledModelExists {
- modelURL = packageUrl
- }
-
- return modelURL
- }
/// Formats and sorts model file names based on model variants
///
@@ -198,40 +181,6 @@ public struct ModelUtilities {
return modelVariant
}
- static func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
- guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil }
- guard inputDescription.type == .multiArray else { return nil }
- guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
- let shape = shapeConstraint.shape.map { $0.intValue }
- return shape[position]
- }
-
- static func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
- guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil }
- guard inputDescription.type == .multiArray else { return nil }
- guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
- let shape = shapeConstraint.shape.map { $0.intValue }
- return shape[position]
- }
-
- func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
- guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil }
- guard inputDescription.type == .multiArray else { return nil }
- guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
- let shape = shapeConstraint.shape.map { $0.intValue }
- return shape[position]
- }
-
- func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? {
- guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil }
- guard inputDescription.type == .multiArray else { return nil }
- guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
- let shape = shapeConstraint.shape.map { $0.intValue }
- return shape[position]
- }
-
- // MARK: Private
-
internal static func tokenizerNameForVariant(_ variant: ModelVariant) -> String {
var tokenizerName: String
switch variant {
diff --git a/Sources/WhisperKitCLI/TTSCLI.swift b/Sources/WhisperKitCLI/TTSCLI.swift
new file mode 100644
index 00000000..7594328a
--- /dev/null
+++ b/Sources/WhisperKitCLI/TTSCLI.swift
@@ -0,0 +1,279 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import ArgumentParser
+import CoreML
+import Foundation
+import TTSKit
+import WhisperKit
+
+// MARK: - CLI-only conformances for ArgumentParser
+
+extension Qwen3Speaker: ExpressibleByArgument {}
+extension Qwen3Language: ExpressibleByArgument {}
+extension TTSModelVariant: ExpressibleByArgument {}
+
+// MARK: - CLI Command
+
+struct TTSCLI: AsyncParsableCommand {
+ static let configuration = CommandConfiguration(
+ commandName: "tts",
+ abstract: "Generate speech from text using Qwen3-TTS"
+ )
+
+ // MARK: - Required (one of text or text-file)
+
+ @Option(name: .long, help: "Text to synthesize")
+ var text: String?
+
+ @Option(name: .long, help: "Read text from a file (supports .txt and .md)")
+ var textFile: String?
+
+ // MARK: - Common options
+
+ @Option(name: .long, help: "Speaker voice (aiden, ryan, ono-anna, sohee, eric, dylan, serena, vivian, uncle-fu)")
+ var speaker: Qwen3Speaker = .aiden
+
+ @Option(name: .long, help: "Language (english, chinese, japanese, korean)")
+ var language: Qwen3Language = .english
+
+ @Option(name: .long, help: "Output audio file path")
+ var outputPath: String = "output"
+
+ @Option(name: .long, help: "Output audio format (m4a or wav)")
+ var outputFormat: String = "m4a"
+
+ @Flag(name: .long, help: "Play audio through speakers in real time")
+ var play: Bool = false
+
+ @Flag(name: .long, help: "Enable verbose output")
+ var verbose: Bool = false
+
+ // MARK: - Generation options
+
+ @Option(name: .long, help: "Sampling temperature (0.0 for greedy)")
+ var temperature: Float = 0.9
+
+ @Option(name: .long, help: "Top-k sampling (0 to disable)")
+ var topK: Int = 50
+
+ @Option(name: .long, help: "Max RVQ frames to generate")
+ var maxNewTokens: Int = 245
+
+ @Option(name: .long, help: "Concurrent chunk workers (0=max, 1=sequential, N=batch size). Defaults to 1 with --play, 0 otherwise.")
+ var concurrentWorkerCount: Int?
+
+ @Option(name: .long, help: "Target chunk size in characters for sentence splitting")
+ var targetChunkSize: Int = TextChunker.defaultTargetChunkSize
+
+ @Option(name: .long, help: "Minimum chunk size in characters (short tails merge into previous chunk)")
+ var minChunkSize: Int = TextChunker.defaultMinChunkSize
+
+ @Option(name: .long, help: "Style instruction (e.g., \"Speak slowly and softly\"). Only supported by the 1.7B model.")
+ var instruction: String?
+
+ @Option(name: .long, help: "Random seed for reproducible output")
+ var seed: UInt64?
+
+ // MARK: - Model selection
+
+ @Option(name: .long, help: "Model preset (0.6b, 1.7b). Auto-configures version dir and variant defaults.")
+ var model: TTSModelVariant = .qwen3TTS_0_6b
+
+ // MARK: - Advanced options (auto-configured by preset, can be overridden)
+
+ @Option(name: .long, help: "Local model directory (skips download if provided)")
+ var modelsPath: String?
+
+ @Option(name: .long, help: "HuggingFace repo for model download")
+ var modelRepo: String = Qwen3TTSConstants.defaultModelRepo
+
+ @Option(name: .long, help: "Model version directory (overrides --model preset)")
+ var versionDir: String?
+
+ @Option(name: .long, help: "HuggingFace tokenizer repo or local path")
+ var tokenizer: String?
+
+ @Option(name: .long, help: "HuggingFace API token (for private repos, or set HF_TOKEN env var)")
+ var token: String?
+
+ @Option(name: .long, help: "HuggingFace Hub compatible endpoint URL")
+ var endpoint: String?
+
+ @Option(name: .long, help: "CodeDecoder variant (overrides --model preset)")
+ var codeDecoderVariant: String?
+
+ @Option(name: .long, help: "MultiCodeDecoder variant (overrides --model preset)")
+ var multiCodeDecoderVariant: String?
+
+ @Option(name: .long, help: "SpeechDecoder variant (overrides --model preset)")
+ var speechDecoderVariant: String?
+
+ // MARK: - Compute unit options
+
+ @Option(name: .long, help: "Compute units for embedders (TextProjector, CodeEmbedder, MultiCodeEmbedder) {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
+ var embedderComputeUnits: ComputeUnits = .cpuOnly
+
+ @Option(name: .long, help: "Compute units for CodeDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
+ var codeDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
+
+ @Option(name: .long, help: "Compute units for MultiCodeDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
+ var multiCodeDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
+
+ @Option(name: .long, help: "Compute units for SpeechDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
+ var speechDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
+
+ func run() async throws {
+ if verbose {
+ Logging.shared.loggingCallback = {
+ print("[TTSKit] \($0)")
+ }
+ }
+ // Resolve text from --text or --text-file
+ let inputText: String
+ if let textFile {
+ let resolvedPath = FileManager.resolveAbsolutePath(textFile)
+ inputText = try String(contentsOfFile: resolvedPath, encoding: .utf8)
+ .trimmingCharacters(in: .whitespacesAndNewlines)
+ } else if let text {
+ inputText = text
+ } else {
+ throw ValidationError("Either --text or --text-file must be provided")
+ }
+
+ guard !inputText.isEmpty else {
+ throw ValidationError("Input text is empty")
+ }
+
+ // Resolve local models path if provided
+ let resolvedModelFolder: URL? = modelsPath.map {
+ URL(fileURLWithPath: FileManager.resolveAbsolutePath($0))
+ }
+ // Resolve tokenizer: pass as URL — loadTokenizer falls back to HF Hub if the path doesn't exist
+ let resolvedTokenizerFolder: URL? = tokenizer.map {
+ URL(fileURLWithPath: $0)
+ }
+
+ // Build config from preset — explicit CLI overrides replace preset defaults.
+ let config = TTSKitConfig(
+ model: model,
+ modelFolder: resolvedModelFolder,
+ modelRepo: modelRepo,
+ tokenizerFolder: resolvedTokenizerFolder,
+ modelToken: token,
+ modelEndpoint: endpoint ?? Qwen3TTSConstants.defaultEndpoint,
+ versionDir: versionDir,
+ codeDecoderVariant: codeDecoderVariant,
+ multiCodeDecoderVariant: multiCodeDecoderVariant,
+ speechDecoderVariant: speechDecoderVariant,
+ computeOptions: ComputeOptions(
+ embedderComputeUnits: embedderComputeUnits.asMLComputeUnits,
+ codeDecoderComputeUnits: codeDecoderComputeUnits.asMLComputeUnits,
+ multiCodeDecoderComputeUnits: multiCodeDecoderComputeUnits.asMLComputeUnits,
+ speechDecoderComputeUnits: speechDecoderComputeUnits.asMLComputeUnits
+ ),
+ verbose: verbose
+ )
+
+ // Default: --play uses sequential (1), file output uses unlimited (0).
+ let effectiveWorkerCount = concurrentWorkerCount ?? (play ? 1 : 0)
+
+ // Always use a seed for reproducibility -- generate one if not provided
+ let effectiveSeed = seed ?? UInt64.random(in: 0...UInt64(UInt32.max))
+
+ // Warn if instruction is used with a model that doesn't support it
+ var effectiveInstruction = instruction
+ if let instruction = effectiveInstruction, !instruction.isEmpty, model == .qwen3TTS_0_6b {
+ print("Warning: --instruction is only supported by the 1.7B model variant. Ignoring instruction for \(model.rawValue).")
+ effectiveInstruction = nil
+ }
+
+ if verbose {
+ print("Qwen3-TTS Pipeline")
+ if textFile != nil {
+ print(" Text file: \(textFile!)")
+ }
+ print(" Text: \"\(inputText.prefix(80))\(inputText.count > 80 ? "..." : "")\"")
+ print(" Speaker: \(speaker.rawValue)")
+ print(" Language: \(language.rawValue)")
+ print(" Model: \(model.rawValue)")
+ if let inst = effectiveInstruction {
+ print(" Instruction: \"\(inst)\"")
+ }
+ if let folder = resolvedModelFolder {
+ print(" Models: \(folder.path)")
+ } else {
+ print(" Models: \(config.modelRepo) (auto-download)")
+ }
+ print(" Version: \(config.versionDir)")
+ print(" CodeDecoder: \(config.codeDecoderVariant)")
+ print(" MultiCodeDecoder: \(config.multiCodeDecoderVariant)")
+ print(" SpeechDecoder: \(config.speechDecoderVariant)")
+ print(" Embedder compute: \(embedderComputeUnits.rawValue)")
+ print(" CodeDecoder compute: \(codeDecoderComputeUnits.rawValue)")
+ print(" MultiCodeDecoder compute: \(multiCodeDecoderComputeUnits.rawValue)")
+ print(" SpeechDecoder compute: \(speechDecoderComputeUnits.rawValue)")
+ print(" Output: \(outputPath).\(outputFormat.lowercased())")
+ print(" Temperature: \(temperature)")
+ print(" Top-k: \(topK)")
+ print(" Play: \(play)")
+ let workerDesc = effectiveWorkerCount == 0 ? "max" : "\(effectiveWorkerCount)"
+ print(" Concurrency: \(workerDesc) (chunking: sentence)")
+ print(" Seed: \(effectiveSeed)")
+ }
+
+ // Initialize pipeline (downloads if needed, loads tokenizer + 6 models concurrently)
+ config.seed = effectiveSeed
+ let tts = try await TTSKit(config)
+
+ let options = GenerationOptions(
+ temperature: temperature,
+ topK: topK,
+ repetitionPenalty: 1.05,
+ maxNewTokens: maxNewTokens,
+ concurrentWorkerCount: effectiveWorkerCount,
+ targetChunkSize: targetChunkSize,
+ minChunkSize: minChunkSize,
+ instruction: effectiveInstruction
+ )
+
+ let result: SpeechResult
+ if play {
+ result = try await tts.play(
+ text: inputText,
+ speaker: speaker,
+ language: language,
+ options: options
+ )
+ } else {
+ result = try await tts.generate(
+ text: inputText,
+ speaker: speaker,
+ language: language,
+ options: options
+ )
+ }
+
+ let format = AudioOutput.AudioFileFormat(rawValue: outputFormat.lowercased()) ?? .m4a
+ let outputURL = URL(fileURLWithPath: outputPath)
+ let outputFolder = outputURL.deletingLastPathComponent().path == "."
+ ? URL(fileURLWithPath: FileManager.default.currentDirectoryPath)
+ : outputURL.deletingLastPathComponent()
+
+ let savedURL = try await AudioOutput.saveAudio(
+ result.audio,
+ toFolder: outputFolder,
+ filename: outputURL.lastPathComponent,
+ sampleRate: result.sampleRate,
+ format: format
+ )
+
+ result.logTimings()
+
+ if verbose {
+ print(String(format: "Generated %.2fs of audio -> %@", result.audioDuration, savedURL.path))
+ } else {
+ print(savedURL.path)
+ }
+ }
+}
diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift
index 0ce09a4f..3701d725 100644
--- a/Sources/WhisperKitCLI/TranscribeCLI.swift
+++ b/Sources/WhisperKitCLI/TranscribeCLI.swift
@@ -5,6 +5,7 @@ import ArgumentParser
import CoreML
import Foundation
import WhisperKit
+import TTSKit
struct TranscribeCLI: AsyncParsableCommand {
static let configuration = CommandConfiguration(
diff --git a/Sources/WhisperKitCLI/TranscribeCLIArguments.swift b/Sources/WhisperKitCLI/TranscribeCLIArguments.swift
index eb5d1d47..0f54f1be 100644
--- a/Sources/WhisperKitCLI/TranscribeCLIArguments.swift
+++ b/Sources/WhisperKitCLI/TranscribeCLIArguments.swift
@@ -25,6 +25,9 @@ struct TranscribeCLIArguments: ParsableArguments {
@Option(help: "Path to save the downloaded tokenizer files")
var downloadTokenizerPath: String?
+ @Option(name: .long, help: "HuggingFace Hub compatible endpoint URL")
+ var endpoint: String?
+
@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
diff --git a/Sources/WhisperKitCLI/TranscribeCLIUtils.swift b/Sources/WhisperKitCLI/TranscribeCLIUtils.swift
index 98dbfe41..586429b0 100644
--- a/Sources/WhisperKitCLI/TranscribeCLIUtils.swift
+++ b/Sources/WhisperKitCLI/TranscribeCLIUtils.swift
@@ -31,6 +31,7 @@ internal class TranscribeCLIUtils {
return WhisperKitConfig(
model: modelName,
downloadBase: downloadModelFolder,
+ modelEndpoint: arguments.endpoint,
modelFolder: arguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
diff --git a/Sources/WhisperKitCLI/WhisperKitCLI.swift b/Sources/WhisperKitCLI/WhisperKitCLI.swift
index 3ceec64b..5ed6cc50 100644
--- a/Sources/WhisperKitCLI/WhisperKitCLI.swift
+++ b/Sources/WhisperKitCLI/WhisperKitCLI.swift
@@ -8,9 +8,9 @@ let VERSION: String = "development"
var subcommands: [ParsableCommand.Type] {
#if BUILD_SERVER_CLI
- [TranscribeCLI.self, ServeCLI.self]
+ [TranscribeCLI.self, TTSCLI.self, ServeCLI.self]
#else
- [TranscribeCLI.self]
+ [TranscribeCLI.self, TTSCLI.self]
#endif
}
diff --git a/Tests/TTSKitTests/TTSKitIntegrationTests.swift b/Tests/TTSKitTests/TTSKitIntegrationTests.swift
new file mode 100644
index 00000000..48f7b23a
--- /dev/null
+++ b/Tests/TTSKitTests/TTSKitIntegrationTests.swift
@@ -0,0 +1,495 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import AVFoundation
+import CoreML
+import Foundation
+@testable import TTSKit
+import XCTest
+
+// MARK: - Helpers
+
+extension XCTestCase {
+ /// Create a `TTSKit` instance for integration tests.
+ ///
+ /// Downloads models from the Hub if not already cached on disk, matching the
+ /// same pattern as `tinyModelPath()` in WhisperKit's test suite. The test
+ /// will fail (not skip) if the download fails.
+ func makeCachedTTS(
+ model: TTSModelVariant = .qwen3TTS_0_6b,
+ seed: UInt64 = 42
+ ) async throws -> TTSKit {
+ let config = TTSKitConfig(model: model, verbose: true, logLevel: .debug, seed: seed)
+ return try await TTSKit(config)
+ }
+}
+
+// MARK: - Integration Tests
+
+final class TTSKitIntegrationTests: XCTestCase {
+
+ // MARK: - Basic generation
+
+ /// Short sentence -> non-empty audio at 24kHz.
+ func testBasicShortGeneration() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let result = try await tts.generate(
+ text: "Hello, this is a basic smoke test of the WhisperKit TTS pipeline.",
+ speaker: .ryan,
+ language: .english,
+ options: GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245)
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0, "Audio samples should be non-empty")
+ XCTAssertGreaterThan(result.audioDuration, 1.0, "Expect at least 1s of speech")
+ XCTAssertLessThan(result.audioDuration, 30.0, "Short sentence should stay under 30s")
+ XCTAssertEqual(result.sampleRate, 24000, "Sample rate should be 24kHz")
+ }
+
+ // MARK: - Long text with sentence chunking
+
+ /// 200+ word multi-paragraph text - chunked, all audio produced, no truncation crash.
+ func testLongTextSequentialChunks() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let longText = """
+ The history of artificial intelligence begins in antiquity, with myths, stories, \
+ and rumors of artificial beings endowed with intelligence by master craftsmen. \
+ The seeds of modern AI were planted by classical philosophers who attempted to \
+ describe the process of human thinking as the mechanical manipulation of symbols. \
+ In 1950, Alan Turing published a landmark paper in which he speculated about the \
+ possibility of creating machines that think. Fast forward to the 2020s, and large \
+ language models have transformed every domain from science to storytelling. \
+ Text-to-speech systems now produce voices so natural that listeners struggle to \
+ distinguish them from real human speakers.
+ """
+
+ let result = try await tts.generate(
+ text: longText,
+ speaker: .aiden,
+ language: .english,
+ options: GenerationOptions(concurrentWorkerCount: 1)
+ )
+
+ XCTAssertGreaterThan(result.audioDuration, 20.0, "Long text should produce substantial audio")
+ XCTAssertGreaterThan(result.timings.totalDecodingLoops, 100, "Should require many decode steps")
+ }
+
+ /// Same long text with unlimited concurrent workers
+ func testLongTextUnlimitedConcurrentWorkers() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let longText = """
+ First paragraph about science. It covers many interesting topics in depth. \
+ Second paragraph switches to history. Many events happened over the centuries. \
+ Third paragraph explores music theory. Harmony and rhythm define the art form.
+ """
+
+ let result = try await tts.generate(
+ text: longText,
+ speaker: .aiden,
+ language: .english,
+ options: GenerationOptions(concurrentWorkerCount: 0)
+ )
+
+ XCTAssertGreaterThan(result.audioDuration, 5.0)
+ }
+
+ // MARK: - Concurrent workers
+
+ /// workers=2 on a 3-chunk text - verifies the batching loop handles partial batches.
+ func testConcurrentWorkersBatching() async throws {
+ let tts = try await makeCachedTTS(seed: 7)
+ let threeChunkText = """
+ First sentence for chunk one. Second sentence for chunk two. Third sentence for chunk three.
+ """
+ let result = try await tts.generate(
+ text: threeChunkText,
+ speaker: .serena,
+ language: .english,
+ options: GenerationOptions(
+ concurrentWorkerCount: 2,
+ targetChunkSize: 10,
+ minChunkSize: 2
+ )
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ XCTAssertGreaterThan(result.audioDuration, 2.0)
+ }
+
+ // MARK: - Reproducibility
+
+ /// Two runs with the same seed and temperature=0 should produce byte-identical audio.
+ func testDeterministicOutputWithFixedSeed() async throws {
+ let text = "Reproducibility test with a fixed seed."
+ let options = GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 100)
+
+ let tts1 = try await makeCachedTTS(seed: 42)
+ let result1 = try await tts1.generate(
+ text: text, speaker: .ryan, language: .english, options: options
+ )
+
+ let tts2 = try await makeCachedTTS(seed: 42)
+ let result2 = try await tts2.generate(
+ text: text, speaker: .ryan, language: .english, options: options
+ )
+
+ XCTAssertEqual(result1.audio.count, result2.audio.count, "Step counts must match")
+ // Compare raw float bytes for strict reproducibility
+ let data1 = Data(bytes: result1.audio, count: result1.audio.count * MemoryLayout.size)
+ let data2 = Data(bytes: result2.audio, count: result2.audio.count * MemoryLayout.size)
+ XCTAssertEqual(data1, data2, "Audio samples must be byte-identical with the same seed and greedy decoding")
+ }
+
+ // MARK: - Speaker voices
+
+ /// All 9 Qwen3Speaker voices should produce non-empty audio without error.
+ func testAllSpeakerVoices() async throws {
+ let tts = try await makeCachedTTS(seed: 1)
+ let text = "Testing this voice."
+ let options = GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 60)
+
+ for speaker in Qwen3Speaker.allCases {
+ let result = try await tts.generate(
+ text: text, speaker: speaker, language: .english, options: options
+ )
+ XCTAssertGreaterThan(result.audio.count, 0, "Speaker \(speaker.rawValue) produced no audio")
+ }
+ }
+
+ // MARK: - Language support
+
+ /// Korean text with a Korean-native speaker should produce non-empty audio.
+ func testKoreanLanguage() async throws {
+ let tts = try await makeCachedTTS(seed: 7)
+ let result = try await tts.generate(
+ text: "안녕하세요. 저는 한국어를 말할 수 있습니다.",
+ speaker: .sohee,
+ language: .korean,
+ options: GenerationOptions(temperature: 0.9, maxNewTokens: 120)
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ XCTAssertGreaterThan(result.audioDuration, 0.5)
+ }
+
+ // MARK: - Edge cases
+
+ /// Single-word input should not crash and should produce short audio.
+ func testSingleWordInput() async throws {
+ let tts = try await makeCachedTTS(seed: 1)
+ let result = try await tts.generate(
+ text: "Hello.",
+ speaker: .ryan,
+ language: .english,
+ options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 60)
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ XCTAssertLessThan(result.audioDuration, 5.0, "Single word should be short")
+ }
+
+ /// Text with numbers, punctuation, currency, and percentages should not crash.
+ func testNumbersAndPunctuation() async throws {
+ let tts = try await makeCachedTTS(seed: 3)
+ let result = try await tts.generate(
+ text: "On January 1st, 2025, the price was $4.99 - a 12% discount from $5.67.",
+ speaker: .dylan,
+ language: .english,
+ options: GenerationOptions(maxNewTokens: 245)
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ }
+
+ /// Text with emoji and mixed scripts (Latin + accented) should not crash.
+ func testUnicodeAndEmoji() async throws {
+ let tts = try await makeCachedTTS(seed: 5)
+ let result = try await tts.generate(
+ text: "Today's meeting is at 3pm 🎉. The café serves café au lait. Résumé updated.",
+ speaker: .ryan,
+ language: .english,
+ options: GenerationOptions(maxNewTokens: 200)
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ }
+
+ // MARK: - WAV export
+
+ /// saveAudio round-trip: write WAV then verify the file exists and is readable.
+ func testSaveAudioRoundTrip() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let result = try await tts.generate(
+ text: "Audio export round trip test.",
+ speaker: .ryan,
+ language: .english,
+ options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 80)
+ )
+
+ let tmpFolder = FileManager.default.temporaryDirectory
+ let filename = "tts_roundtrip_\(UUID().uuidString)"
+ let savedURL = try await AudioOutput.saveAudio(
+ result.audio,
+ toFolder: tmpFolder,
+ filename: filename,
+ sampleRate: result.sampleRate,
+ format: .wav
+ )
+ defer { try? FileManager.default.removeItem(at: savedURL) }
+
+ XCTAssertTrue(FileManager.default.fileExists(atPath: savedURL.path), "WAV file should exist after saveAudio")
+
+ let duration = try await AudioOutput.duration(of: savedURL)
+ XCTAssertEqual(duration, result.audioDuration, accuracy: 0.1, "WAV duration should match result.audioDuration")
+ }
+
+ // MARK: - SpeechModel protocol
+
+ /// TTSKit should satisfy the SpeechModel protocol; generate via the protocol entry point.
+ func testTTSModelProtocolConformance() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let model: any SpeechModel = tts
+
+ XCTAssertEqual(model.sampleRate, 24000)
+
+ let result = try await model.generate(
+ text: "Protocol conformance check.",
+ voice: Qwen3Speaker.ryan.rawValue,
+ language: Qwen3Language.english.rawValue,
+ options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 80),
+ callback: nil
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ XCTAssertEqual(result.sampleRate, 24000)
+ }
+
+ // MARK: - Performance
+
+ /// Verify timings are populated and the generation loop completed within a reasonable ceiling.
+ func testTimingsArePopulated() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let result = try await tts.generate(
+ text: "Performance check: this sentence measures how fast the model runs on device.",
+ speaker: .ryan,
+ language: .english,
+ options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 200)
+ )
+
+ XCTAssertGreaterThan(result.timings.totalDecodingLoops, 0, "Should have completed at least one decode step")
+ XCTAssertGreaterThan(result.timings.fullPipeline, 0, "Full pipeline duration should be non-zero")
+ XCTAssertGreaterThan(result.audioDuration, 0, "Audio duration should be non-zero")
+ }
+
+ // MARK: - Prompt caching
+
+ /// Build a prompt cache and verify it stores the expected metadata.
+ func testBuildPromptCache() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english)
+
+ XCTAssertEqual(cache.voice, "ryan")
+ XCTAssertEqual(cache.language, "english")
+ XCTAssertNil(cache.instruction)
+ XCTAssertGreaterThan(cache.prefixLength, 0, "Cache should contain at least one invariant token")
+ XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: nil))
+ XCTAssertFalse(cache.matches(voice: "aiden", language: "english", instruction: nil))
+ }
+
+ /// Cached prefill should be faster than uncached since it only processes the variable token.
+ func testPromptCacheSpeedup() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let text = "Prompt cache speed test with enough words to generate meaningful audio output."
+ let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245)
+
+ // Run without cache — full prefill of all tokens
+ let uncachedResult = try await tts.createTask().run(
+ text: text, voice: "ryan", language: "english",
+ options: options, callback: nil, prefixCache: nil
+ )
+ let uncachedPrefill = uncachedResult.timings.prefill
+
+ // Build cache, then run with it — only the variable token is prefilled
+ let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english)
+ let cachedResult = try await tts.createTask().run(
+ text: text, voice: "ryan", language: "english",
+ options: options, callback: nil, prefixCache: cache
+ )
+ let cachedPrefill = cachedResult.timings.prefill
+
+ XCTAssertGreaterThan(uncachedResult.audioDuration, 1.0,
+ "Uncached run should produce at least 1s of audio")
+ XCTAssertGreaterThan(cachedResult.audioDuration, 1.0,
+ "Cached run should produce at least 1s of audio")
+ XCTAssertLessThan(cachedPrefill, uncachedPrefill,
+ "Cached prefill (\(cachedPrefill * 1000)ms) should be faster than uncached (\(uncachedPrefill * 1000)ms)")
+
+ // Verify auto-build on generate works
+ tts.promptCache = nil
+ _ = try await tts.generate(
+ text: text, speaker: .ryan, language: .english, options: options
+ )
+ XCTAssertNotNil(tts.promptCache, "Cache should be auto-built after generate")
+ }
+
+ /// Two consecutive cached runs should produce valid audio of similar duration.
+ func testPromptCacheDeterminism() async throws {
+ let text = "Cache determinism test with a longer sentence so the model generates real speech."
+ let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245)
+
+ let tts = try await makeCachedTTS(seed: 42)
+ let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english)
+
+ let result1 = try await tts.createTask().run(
+ text: text, voice: "ryan", language: "english",
+ options: options, callback: nil, prefixCache: cache
+ )
+ let result2 = try await tts.createTask().run(
+ text: text, voice: "ryan", language: "english",
+ options: options, callback: nil, prefixCache: cache
+ )
+
+ XCTAssertGreaterThan(result1.audioDuration, 1.0, "First cached run should produce at least 1s")
+ XCTAssertGreaterThan(result2.audioDuration, 1.0, "Second cached run should produce at least 1s")
+ }
+
+ /// Cache auto-invalidates when voice changes.
+ func testPromptCacheInvalidationOnVoiceChange() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245)
+
+ let result1 = try await tts.generate(
+ text: "First voice generates meaningful speech output.", speaker: .ryan, language: .english, options: options
+ )
+ XCTAssertEqual(tts.promptCache?.voice, "ryan")
+ XCTAssertGreaterThan(result1.audio.count, 0, "First voice should produce audio")
+
+ let result2 = try await tts.generate(
+ text: "Second voice also generates meaningful speech output.", speaker: .aiden, language: .english, options: options
+ )
+ XCTAssertEqual(tts.promptCache?.voice, "aiden",
+ "Cache should auto-rebuild when voice changes")
+ XCTAssertGreaterThan(result2.audio.count, 0, "Second voice should produce audio")
+ }
+
+ /// Prompt cache save/load round-trip produces identical generation output.
+ func testPromptCacheDiskPersistence() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english)
+
+ let tmpURL = FileManager.default.temporaryDirectory
+ .appendingPathComponent("tts_cache_test_\(UUID().uuidString).promptcache")
+ defer { try? FileManager.default.removeItem(at: tmpURL) }
+
+ try cache.save(to: tmpURL)
+ XCTAssertTrue(FileManager.default.fileExists(atPath: tmpURL.path), "Cache file should exist on disk")
+
+ let loaded = try TTSPromptCache.load(from: tmpURL)
+ XCTAssertEqual(loaded.voice, cache.voice)
+ XCTAssertEqual(loaded.language, cache.language)
+ XCTAssertEqual(loaded.instruction, cache.instruction)
+ XCTAssertEqual(loaded.prefixLength, cache.prefixLength)
+
+ // Use the loaded cache for generation — should produce valid audio
+ tts.promptCache = loaded
+ let result = try await tts.generate(
+ text: "Disk cache persistence test with enough text for real audio.",
+ speaker: .ryan, language: .english,
+ options: GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245)
+ )
+ XCTAssertGreaterThan(result.audioDuration, 1.0, "Loaded cache should produce at least 1s of audio")
+ }
+
+ /// Chunked generation should benefit from prompt caching across all chunks.
+ func testPromptCacheWithChunkedGeneration() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ try await tts.buildPromptCache(speaker: .ryan, language: .english)
+
+ let multiChunkText = """
+ First sentence is fairly long to make a chunk. \
+ Second sentence adds more content for another chunk. \
+ Third sentence provides even more text for splitting.
+ """
+
+ let result = try await tts.generate(
+ text: multiChunkText,
+ speaker: .ryan, language: .english,
+ options: GenerationOptions(
+ concurrentWorkerCount: 1,
+ targetChunkSize: 15,
+ minChunkSize: 5
+ )
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0)
+ XCTAssertGreaterThan(result.audioDuration, 2.0)
+ }
+
+ // MARK: - Dual inference path
+
+ /// MLTensor path (default, macOS 15+) produces valid audio.
+ func testMLTensorPathGeneration() async throws {
+ guard #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) else {
+ throw XCTSkip("MLTensor path requires macOS 15+ / iOS 18+")
+ }
+ let tts = try await makeCachedTTS(seed: 42)
+ var opts = GenerationOptions(maxNewTokens: 100)
+ opts.forceLegacyEmbedPath = false
+
+ let result = try await tts.generate(
+ text: "Testing the MLTensor inference path.",
+ speaker: .ryan, language: .english,
+ options: opts
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0, "MLTensor path should produce audio")
+ XCTAssertGreaterThan(result.audioDuration, 0.5, "Should produce at least 0.5s of speech")
+ }
+
+ /// Legacy [FloatType] path (forced via forceLegacyEmbedPath) produces valid audio.
+ func testLegacyEmbedPathGeneration() async throws {
+ let tts = try await makeCachedTTS(seed: 42)
+ var opts = GenerationOptions(maxNewTokens: 100)
+ opts.forceLegacyEmbedPath = true
+
+ let result = try await tts.generate(
+ text: "Testing the legacy embed inference path.",
+ speaker: .ryan, language: .english,
+ options: opts
+ )
+
+ XCTAssertGreaterThan(result.audio.count, 0, "Legacy path should produce audio")
+ XCTAssertGreaterThan(result.audioDuration, 0.5, "Should produce at least 0.5s of speech")
+ }
+
+ /// Both paths with the same seed should produce audio of similar duration.
+ /// Floating-point differences between paths mean samples may not be bit-identical.
+ func testBothPathsProduceSimilarAudioDuration() async throws {
+ guard #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) else {
+ throw XCTSkip("Dual-path comparison requires macOS 15+ / iOS 18+")
+ }
+ let testText = "Comparing inference paths for audio duration consistency."
+ let maxNewTokens = 80
+
+ let ttsMLTensor = try await makeCachedTTS(seed: 42)
+ var mlTensorOpts = GenerationOptions(maxNewTokens: maxNewTokens)
+ mlTensorOpts.forceLegacyEmbedPath = false
+ let mlTensorResult = try await ttsMLTensor.generate(
+ text: testText, speaker: .ryan, language: .english, options: mlTensorOpts
+ )
+
+ let ttsLegacy = try await makeCachedTTS(seed: 42)
+ var legacyOpts = GenerationOptions(maxNewTokens: maxNewTokens)
+ legacyOpts.forceLegacyEmbedPath = true
+ let legacyResult = try await ttsLegacy.generate(
+ text: testText, speaker: .ryan, language: .english, options: legacyOpts
+ )
+
+ XCTAssertGreaterThan(mlTensorResult.audioDuration, 0)
+ XCTAssertGreaterThan(legacyResult.audioDuration, 0)
+ // Durations should be within 20% of each other (same number of tokens -> same frames)
+ let durationRatio = mlTensorResult.audioDuration / legacyResult.audioDuration
+ XCTAssertGreaterThan(durationRatio, 0.8, "Paths should produce similar audio duration")
+ XCTAssertLessThan(durationRatio, 1.2, "Paths should produce similar audio duration")
+ }
+}
diff --git a/Tests/TTSKitTests/TTSKitUnitTests.swift b/Tests/TTSKitTests/TTSKitUnitTests.swift
new file mode 100644
index 00000000..49970e12
--- /dev/null
+++ b/Tests/TTSKitTests/TTSKitUnitTests.swift
@@ -0,0 +1,942 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright © 2024 Argmax, Inc. All rights reserved.
+
+import ArgmaxCore
+import CoreML
+@testable import TTSKit
+import XCTest
+
+final class TTSKitUnitTests: XCTestCase {
+
+ // MARK: - Configuration
+
+ func testTTSKitConfigDefaults() {
+ let config = TTSKitConfig()
+ XCTAssertNil(config.modelFolder)
+ XCTAssertEqual(config.modelRepo, Qwen3TTSConstants.defaultModelRepo)
+ XCTAssertEqual(config.versionDir, Qwen3TTSConstants.defaultVersionDir)
+ XCTAssertEqual(config.tokenizerSource, Qwen3TTSConstants.defaultTokenizerRepo)
+ XCTAssertEqual(config.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder)
+ XCTAssertEqual(config.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder)
+ XCTAssertEqual(config.speechDecoderVariant, Qwen3VariantDefaults.speechDecoder)
+ XCTAssertTrue(config.verbose)
+ }
+
+ func testTTSComputeOptionsDefaults() {
+ let opts = ComputeOptions()
+ XCTAssertEqual(opts.embedderComputeUnits, .cpuOnly)
+ XCTAssertEqual(opts.codeDecoderComputeUnits, .cpuAndNeuralEngine)
+ XCTAssertEqual(opts.multiCodeDecoderComputeUnits, .cpuAndNeuralEngine)
+ XCTAssertEqual(opts.speechDecoderComputeUnits, .cpuAndNeuralEngine)
+ }
+
+ func testModelPresetResolvesInConfig() {
+ let small = TTSKitConfig(model: .qwen3TTS_0_6b)
+ XCTAssertEqual(small.versionDir, TTSModelVariant.qwen3TTS_0_6b.versionDir)
+ XCTAssertEqual(small.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder)
+
+ let large = TTSKitConfig(model: .qwen3TTS_1_7b)
+ XCTAssertEqual(large.versionDir, TTSModelVariant.qwen3TTS_1_7b.versionDir)
+ XCTAssertEqual(large.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder)
+ }
+
+ func testGenerationOptionsDefaults() {
+ let opts = GenerationOptions()
+ XCTAssertEqual(opts.temperature, 0.9)
+ XCTAssertEqual(opts.topK, 50)
+ XCTAssertEqual(opts.repetitionPenalty, 1.05)
+ XCTAssertEqual(opts.maxNewTokens, 245)
+ XCTAssertNil(opts.chunkingStrategy)
+ XCTAssertNil(opts.instruction)
+ XCTAssertNotNil(opts.concurrentWorkerCount)
+ }
+
+ func testDownloadPatterns() {
+ let config = TTSKitConfig()
+ let patterns = config.downloadPatterns
+ XCTAssertEqual(patterns.count, 6)
+ XCTAssertTrue(patterns.allSatisfy { $0.hasSuffix("/**") })
+ XCTAssertTrue(patterns.contains { $0.contains("code_decoder/") })
+ XCTAssertTrue(patterns.contains { $0.contains("speech_decoder/") })
+ XCTAssertTrue(patterns.contains { $0.contains("text_projector/") })
+ }
+
+ // MARK: - Speaker & Language Enums
+
+ func testSpeakerTokenIds() {
+ // Each speaker should have a unique token ID
+ let ids = Qwen3Speaker.allCases.map { $0.tokenID }
+ XCTAssertEqual(Set(ids).count, Qwen3Speaker.allCases.count, "Speaker token IDs should be unique")
+ }
+
+ func testLanguageTokenIds() {
+ let ids = Qwen3Language.allCases.map { $0.tokenID }
+ XCTAssertEqual(Set(ids).count, Qwen3Language.allCases.count, "Language token IDs should be unique")
+ XCTAssertEqual(Qwen3Language.english.tokenID, 2050)
+ }
+
+ // MARK: - Text Chunker
+
+ /// Char-level tokenizer helper: 1 unicode scalar = 1 token, perfectly round-trips.
+ /// Sizes are exact, making test assertions easy to reason about.
+ private func makeChunker(
+ targetChunkSize: Int = TextChunker.defaultTargetChunkSize,
+ minChunkSize: Int = TextChunker.defaultMinChunkSize
+ ) -> TextChunker {
+ TextChunker(
+ targetChunkSize: targetChunkSize,
+ minChunkSize: minChunkSize,
+ encode: { $0.unicodeScalars.map { Int($0.value) } },
+ decode: { String($0.compactMap { Unicode.Scalar($0) }.map { Character($0) }) }
+ )
+ }
+
+ func testChunkerShortText() {
+ let chunker = makeChunker(targetChunkSize: 50, minChunkSize: 5)
+ XCTAssertEqual(chunker.chunk("Hello world."), ["Hello world."])
+ }
+
+ func testChunkerEmptyText() {
+ let chunker = makeChunker()
+ XCTAssertTrue(chunker.chunk("").isEmpty)
+ XCTAssertTrue(chunker.chunk(" ").isEmpty)
+ }
+
+ func testChunkerSentenceSplitting() {
+ // With char-level tokenizer (1 char = 1 token), targetChunkSize: 30 gives a 30-char window.
+ // "This is the first sentence." = 27 chars, so each window contains one sentence boundary.
+ let chunker = makeChunker(targetChunkSize: 30, minChunkSize: 5)
+ let text = "This is the first sentence. This is the second sentence. And here is a third one."
+ let chunks = chunker.chunk(text)
+ XCTAssertGreaterThan(chunks.count, 1, "Should split into multiple chunks")
+ let recombined = chunks.joined(separator: " ")
+ XCTAssertTrue(recombined.contains("first sentence"))
+ XCTAssertTrue(recombined.contains("third one"))
+ }
+
+ func testChunkerMergesTinyTrailing() {
+ // "A reasonably long sentence that is here." = exactly 40 chars = 40 tokens (char-level).
+ // targetChunkSize: 40 captures the full sentence in one window at the "." boundary.
+ // The trailing "X." = 2 tokens is below minChunkSize: 5, so it merges with the previous chunk.
+ let chunker = makeChunker(targetChunkSize: 40, minChunkSize: 5)
+ let text = "A reasonably long sentence that is here. X."
+ let chunks = chunker.chunk(text)
+ XCTAssertEqual(chunks.count, 1, "Tiny trailing chunk should merge with previous")
+ }
+
+ // MARK: - Embedding Math
+
+ func testZeroEmbed() {
+ let embed = EmbedUtilities.zeroEmbed(dim: 16)
+ XCTAssertEqual(embed.count, 16)
+ XCTAssertTrue(embed.allSatisfy { $0 == 0 })
+ }
+
+ func testAddEmbeddings() {
+ let a: [FloatType] = [FloatType(1.0), FloatType(2.0), FloatType(3.0)]
+ let b: [FloatType] = [FloatType(4.0), FloatType(5.0), FloatType(6.0)]
+ let result = EmbedUtilities.addEmbeddings(a, b)
+ XCTAssertEqual(result.count, 3)
+ XCTAssertEqual(Float(result[0]), 5.0, accuracy: 0.01)
+ XCTAssertEqual(Float(result[1]), 7.0, accuracy: 0.01)
+ XCTAssertEqual(Float(result[2]), 9.0, accuracy: 0.01)
+ }
+
+ func testSumEmbeddings() {
+ let embeds: [[FloatType]] = [
+ [FloatType(1.0), FloatType(2.0)],
+ [FloatType(3.0), FloatType(4.0)],
+ [FloatType(5.0), FloatType(6.0)],
+ ]
+ let result = EmbedUtilities.sumEmbeddings(embeds)
+ XCTAssertEqual(result.count, 2)
+ XCTAssertEqual(Float(result[0]), 9.0, accuracy: 0.01)
+ XCTAssertEqual(Float(result[1]), 12.0, accuracy: 0.01)
+ }
+
+ func testSumEmbeddingsEmpty() {
+ let result = EmbedUtilities.sumEmbeddings([])
+ XCTAssertTrue(result.isEmpty)
+ }
+
+ func testCreateAndExtractEmbed() throws {
+ let original: [FloatType] = [FloatType(1.5), FloatType(2.5), FloatType(3.5), FloatType(4.5)]
+ let arr = try EmbedUtilities.createEmbedMLArray(original)
+ XCTAssertEqual(arr.shape, [1, 4, 1, 1] as [NSNumber])
+
+ let extracted = EmbedUtilities.extractEmbed(from: arr)
+ XCTAssertEqual(extracted.count, 4)
+ for i in 0..<4 {
+ XCTAssertEqual(Float(extracted[i]), Float(original[i]), accuracy: 0.01)
+ }
+ }
+
+ // MARK: - KV Cache
+
+ func testKVCacheInit() throws {
+ let cache = try KVCache(cacheDim: 128, maxSeqLength: 32)
+ XCTAssertEqual(cache.cacheLength, 0)
+ XCTAssertEqual(cache.maxSeqLength, 32)
+ XCTAssertEqual(cache.cacheDim, 128)
+ XCTAssertFalse(cache.isStateful)
+ XCTAssertFalse(cache.isFull)
+ XCTAssertEqual(cache.freePositions, 31) // maxSeqLength - 1
+ XCTAssertNotNil(cache.keyCache)
+ XCTAssertNotNil(cache.valueCache)
+ }
+
+ func testKVCacheStatefulInit() throws {
+ let cache = try KVCache(cacheDim: 128, maxSeqLength: 32, isStateful: true)
+ XCTAssertTrue(cache.isStateful)
+ XCTAssertNil(cache.keyCache)
+ XCTAssertNil(cache.valueCache)
+ }
+
+ func testKVCacheUpdateAdvancesPosition() throws {
+ let cache = try KVCache(cacheDim: 4, maxSeqLength: 8)
+ XCTAssertEqual(cache.cacheLength, 0)
+
+ cache.update()
+ XCTAssertEqual(cache.cacheLength, 1)
+
+ cache.update()
+ XCTAssertEqual(cache.cacheLength, 2)
+ XCTAssertEqual(cache.freePositions, 5) // 8 - 1 - 2
+ }
+
+ func testKVCacheIsFull() throws {
+ let cache = try KVCache(cacheDim: 4, maxSeqLength: 4)
+ // maxSeqLength=4, isFull when cacheLength >= 3 (maxSeqLength - 1)
+ cache.update()
+ cache.update()
+ XCTAssertFalse(cache.isFull)
+ cache.update()
+ XCTAssertTrue(cache.isFull)
+ }
+
+ func testKVCacheReset() throws {
+ let cache = try KVCache(cacheDim: 4, maxSeqLength: 8)
+ cache.update()
+ cache.update()
+ XCTAssertEqual(cache.cacheLength, 2)
+
+ cache.reset()
+ XCTAssertEqual(cache.cacheLength, 0)
+ }
+
+ func testSpeechDecoderCacheInit() throws {
+ let cache = try SpeechDecoderCache(
+ cacheDim: 64, maxSeqLength: 16, hiddenDim: 32, hiddenContextLen: 4
+ )
+ XCTAssertEqual(cache.hiddenDim, 32)
+ XCTAssertEqual(cache.hiddenContextLen, 4)
+ XCTAssertEqual(cache.hiddenContext.shape, [1, 32, 1, 4] as [NSNumber])
+ }
+
+ // MARK: - Sampler
+
+ func testGreedySamplerDeterministic() async throws {
+ // Two samplers with the same seed should produce the same result
+ let sampler1 = GreedyTokenSampler(seed: 42)
+ let sampler2 = GreedyTokenSampler(seed: 42)
+
+ let vocabSize = 32
+ let logits = try MLMultiArray(shape: [1, 1, NSNumber(value: vocabSize)], dataType: .float16)
+ let ptr = logits.dataPointer.bindMemory(to: FloatType.self, capacity: vocabSize)
+ for i in 0..= 2048 && Qwen3TTSConstants.codecEOS < 3072)
+ }
+
+ // MARK: - Unloaded Model Errors
+
+ func testCodeDecoderWithoutModel() {
+ let decoder = Qwen3CodeDecoder()
+ XCTAssertNil(decoder.model)
+ XCTAssertFalse(decoder.isStateful)
+ XCTAssertNil(decoder.makeState())
+ }
+
+ func testMultiCodeDecoderWithoutModel() {
+ let decoder = Qwen3MultiCodeDecoder()
+ XCTAssertNil(decoder.model)
+ XCTAssertFalse(decoder.isStateful)
+ XCTAssertNil(decoder.makeState())
+ }
+
+ func testSpeechDecoderWithoutModel() {
+ let decoder = Qwen3SpeechDecoder()
+ XCTAssertNil(decoder.model)
+ }
+
+ // MARK: - Chunking Strategy
+
+ func testChunkingStrategyEnum() {
+ XCTAssertEqual(TextChunkingStrategy.allCases.count, 2)
+ XCTAssertEqual(TextChunkingStrategy.none.rawValue, "none")
+ XCTAssertEqual(TextChunkingStrategy.sentence.rawValue, "sentence")
+ }
+
+ // MARK: - SeededRNG Determinism
+
+ func testSeededRNGDeterminism() {
+ var rng1 = SeededRandomNumberGenerator(seed: 99)
+ var rng2 = SeededRandomNumberGenerator(seed: 99)
+ for _ in 0..<100 {
+ XCTAssertEqual(rng1.next(), rng2.next())
+ }
+ }
+
+ func testSeededRNGDifferentSeeds() {
+ var rng1 = SeededRandomNumberGenerator(seed: 1)
+ var rng2 = SeededRandomNumberGenerator(seed: 2)
+ // Very unlikely for all 10 values to match with different seeds
+ var allMatch = true
+ for _ in 0..<10 {
+ if rng1.next() != rng2.next() {
+ allMatch = false
+ break
+ }
+ }
+ XCTAssertFalse(allMatch, "Different seeds should produce different sequences")
+ }
+
+ // MARK: - SpeechProgress
+
+ func testSpeechProgressInit() {
+ let samples: [Float] = [0.1, 0.2, 0.3]
+ let timings = SpeechTimings()
+ let progress = SpeechProgress(audio: samples, timings: timings, stepTime: 0.08)
+ XCTAssertEqual(progress.audio, samples)
+ XCTAssertEqual(progress.timings.fullPipeline, 0)
+ XCTAssertEqual(progress.stepTime ?? 0, 0.08, accuracy: 0.0001)
+ }
+
+ func testSpeechProgressStepTimeNilByDefault() {
+ let progress = SpeechProgress(audio: [], timings: SpeechTimings())
+ XCTAssertNil(progress.stepTime)
+ }
+
+ func testSpeechProgressFirstStepSemantics() {
+ // stepTime non-nil signals first step; subsequent steps have nil
+ let first = SpeechProgress(audio: [0.1], timings: SpeechTimings(), stepTime: 0.05)
+ let subsequent = SpeechProgress(audio: [0.2], timings: SpeechTimings(), stepTime: nil)
+ XCTAssertNotNil(first.stepTime)
+ XCTAssertNil(subsequent.stepTime)
+ }
+
+ // MARK: - PlaybackStrategy
+
+ func testAudioPerStep() {
+ let spf = Qwen3TTSConstants.samplesPerFrame
+ let sr = Qwen3TTSConstants.sampleRate
+ let expected = Double(spf) / Double(sr)
+ XCTAssertEqual(PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr), expected, accuracy: 0.0001)
+ // ~80ms per frame at 24kHz / 1920 samples
+ XCTAssertEqual(PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr), 0.08, accuracy: 0.001)
+ }
+
+ func testRequiredBufferFastDevice() {
+ // Step completes in half the frame duration -> device is 2x real-time
+ // deficit = max(0, 1 - 2.0) = 0, so result = minimumBufferDuration
+ let spf = Qwen3TTSConstants.samplesPerFrame
+ let sr = Qwen3TTSConstants.sampleRate
+ let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr) / 2.0
+ let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 100, samplesPerFrame: spf, sampleRate: sr)
+ XCTAssertEqual(buffer, PlaybackStrategy.minimumBufferDuration, accuracy: 0.001)
+ }
+
+ func testRequiredBufferSlowDevice() {
+ // Step takes 2x the frame duration -> device is at 0.5x real-time
+ // speedRatio = 0.5, deficit = 0.5, maxAudio = 100 * 0.08 = 8s
+ // deficitBuffer = 8 * 0.5 = 4s > minimumBufferDuration
+ let spf = Qwen3TTSConstants.samplesPerFrame
+ let sr = Qwen3TTSConstants.sampleRate
+ let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr) * 2.0
+ let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 100, samplesPerFrame: spf, sampleRate: sr)
+ XCTAssertGreaterThan(buffer, PlaybackStrategy.minimumBufferDuration)
+ // Exact: 100 * 0.08 * 0.5 = 4.0s
+ XCTAssertEqual(buffer, 4.0, accuracy: 0.01)
+ }
+
+ func testRequiredBufferAtExactRealTime() {
+ // Step equals frame duration -> speedRatio = 1, deficit = 0 -> minimum clamp applies
+ let spf = Qwen3TTSConstants.samplesPerFrame
+ let sr = Qwen3TTSConstants.sampleRate
+ let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr)
+ let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 50, samplesPerFrame: spf, sampleRate: sr)
+ XCTAssertEqual(buffer, PlaybackStrategy.minimumBufferDuration, accuracy: 0.001)
+ }
+
+ func testRequiredBufferMinimumNeverExceeded() {
+ // Even a very fast device should never return less than the minimum
+ let buffer = PlaybackStrategy.requiredBuffer(
+ stepTime: 0.001, maxNewTokens: 200,
+ samplesPerFrame: Qwen3TTSConstants.samplesPerFrame, sampleRate: Qwen3TTSConstants.sampleRate
+ )
+ XCTAssertGreaterThanOrEqual(buffer, PlaybackStrategy.minimumBufferDuration)
+ }
+
+ // MARK: - TTSModelVariant
+
+ func testModelPresetDisplayNames() {
+ XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.displayName, "Qwen3 TTS 0.6B")
+ XCTAssertEqual(TTSModelVariant.qwen3TTS_1_7b.displayName, "Qwen3 TTS 1.7B")
+ }
+
+ func testModelPresetSupportsVoiceDirection() {
+ XCTAssertFalse(TTSModelVariant.qwen3TTS_0_6b.supportsVoiceDirection)
+ XCTAssertTrue(TTSModelVariant.qwen3TTS_1_7b.supportsVoiceDirection)
+ }
+
+ func testModelPresetVersionDirsDiffer() {
+ XCTAssertNotEqual(
+ TTSModelVariant.qwen3TTS_0_6b.versionDir,
+ TTSModelVariant.qwen3TTS_1_7b.versionDir
+ )
+ }
+
+ func testModelPresetAvailabilityOnMacOS() {
+ // On macOS, all presets should be available
+ #if os(macOS)
+ for preset in TTSModelVariant.allCases {
+ XCTAssertTrue(preset.isAvailableOnCurrentPlatform, "\(preset) should be available on macOS")
+ }
+ #else
+ XCTAssertTrue(TTSModelVariant.qwen3TTS_0_6b.isAvailableOnCurrentPlatform)
+ XCTAssertFalse(TTSModelVariant.qwen3TTS_1_7b.isAvailableOnCurrentPlatform)
+ #endif
+ }
+
+ func testModelPresetVariantDefaultsConsistent() {
+ // Both presets share the same variant strings (all quantization is size-independent)
+ XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder)
+ XCTAssertEqual(TTSModelVariant.qwen3TTS_1_7b.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder)
+ XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.speechDecoderVariant, Qwen3VariantDefaults.speechDecoder)
+ }
+
+ // MARK: - TTSKitConfig Component Overrides
+
+ func testTTSKitConfigComponentOverridesNilByDefault() {
+ let config = TTSKitConfig()
+ XCTAssertNil(config.textProjector)
+ XCTAssertNil(config.codeEmbedder)
+ XCTAssertNil(config.multiCodeEmbedder)
+ XCTAssertNil(config.codeDecoder)
+ XCTAssertNil(config.multiCodeDecoder)
+ XCTAssertNil(config.speechDecoder)
+ }
+
+ // MARK: - GenerationOptions Additional Defaults
+
+ func testGenerationOptionsChunkingDefaults() {
+ let opts = GenerationOptions()
+ // nil defers to TextChunker.defaultTargetChunkSize / defaultMinChunkSize at call site
+ XCTAssertNil(opts.targetChunkSize)
+ XCTAssertNil(opts.minChunkSize)
+ // Verify the canonical defaults live in TextChunker
+ XCTAssertEqual(TextChunker.defaultTargetChunkSize, 42)
+ XCTAssertEqual(TextChunker.defaultMinChunkSize, 10)
+ }
+
+ // MARK: - Speaker & Language Round-trips
+
+ func testSpeakerRawValueRoundTrip() {
+ for speaker in Qwen3Speaker.allCases {
+ let roundTripped = Qwen3Speaker(rawValue: speaker.rawValue)
+ XCTAssertEqual(roundTripped, speaker, "rawValue round-trip failed for \(speaker)")
+ }
+ }
+
+ func testLanguageRawValueRoundTrip() {
+ for lang in Qwen3Language.allCases {
+ let roundTripped = Qwen3Language(rawValue: lang.rawValue)
+ XCTAssertEqual(roundTripped, lang, "rawValue round-trip failed for \(lang)")
+ }
+ }
+
+ func testUnrecognisedSpeakerFallsBack() {
+ XCTAssertNil(Qwen3Speaker(rawValue: "nonexistent_speaker"))
+ }
+
+ func testUnrecognisedLanguageFallsBack() {
+ XCTAssertNil(Qwen3Language(rawValue: "klingon"))
+ }
+
+ // MARK: - SpeechTimings.merge
+
+ func testMergeTimingsAccumulates() {
+ var combined = SpeechTimings()
+ combined.decodingLoop = 1.0
+ combined.totalDecodingLoops = 10
+ combined.decodingPredictions = 0.5
+
+ var chunk = SpeechTimings()
+ chunk.decodingLoop = 2.0
+ chunk.totalDecodingLoops = 20
+ chunk.decodingPredictions = 0.3
+ chunk.multiCodeDecoderPredictions = 0.1
+ chunk.speechDecoderPredictions = 0.2
+
+ combined.merge(chunk)
+
+ XCTAssertEqual(combined.decodingLoop, 3.0, accuracy: 0.001)
+ XCTAssertEqual(combined.totalDecodingLoops, 30, accuracy: 0.001)
+ XCTAssertEqual(combined.decodingPredictions, 0.8, accuracy: 0.001)
+ XCTAssertEqual(combined.multiCodeDecoderPredictions, 0.1, accuracy: 0.001)
+ XCTAssertEqual(combined.speechDecoderPredictions, 0.2, accuracy: 0.001)
+ }
+
+ func testMergeTimingsIdentity() async throws {
+ var combined = SpeechTimings()
+ combined.decodingLoop = 5.0
+ let empty = SpeechTimings()
+ combined.merge(empty)
+ XCTAssertEqual(combined.decodingLoop, 5.0, accuracy: 0.001)
+ }
+
+ // MARK: - TextChunker Edge Cases
+
+ func testChunkerPreservesAllText() {
+ // Each "Sentence number N is here." ≈ 26-27 chars; targetChunkSize: 50 spans ~2 sentences.
+ let chunker = makeChunker(targetChunkSize: 50, minChunkSize: 5)
+ let sentences = (1...10).map { "Sentence number \($0) is here." }
+ let text = sentences.joined(separator: " ")
+ let chunks = chunker.chunk(text)
+ let rejoined = chunks.joined(separator: " ")
+ for sentence in sentences {
+ XCTAssertTrue(rejoined.contains(sentence.dropLast(1)), // drop period; joins may vary
+ "Missing content: \(sentence)")
+ }
+ }
+
+ func testChunkerWordBoundaryFallback() {
+ // No punctuation in text - chunker must fall back to word-boundary splits.
+ // With char-level tokenizer, targetChunkSize: 15 gives 15-char windows.
+ // Note: multi-word phrase checks are intentionally avoided here because the
+ // re-encode advance can leave a stray char when the token stream has a leading
+ // space (e.g. "very long" → 9 tokens, but the stream starts with " very lon").
+ // This is a char-level mock artifact; BPE tokenizers fold whitespace into the
+ // next word token, so no drift occurs in production.
+ let chunker = makeChunker(targetChunkSize: 15, minChunkSize: 3)
+ let text = "This is a very long text with no punctuation inside it at all here"
+ let chunks = chunker.chunk(text)
+ XCTAssertGreaterThan(chunks.count, 1, "Long text without punctuation should split at word boundaries")
+ let rejoined = chunks.joined(separator: " ")
+ for word in ["long", "text", "punctuation", "here"] {
+ XCTAssertTrue(rejoined.contains(word), "Missing word in output: \(word)")
+ }
+ }
+
+ // MARK: - Prompt Cache
+
+ func testPromptCacheMatching() {
+ let cache = TTSPromptCache(
+ voice: "ryan", language: "english", instruction: nil,
+ prefixLength: 9,
+ kvSnapshot: KVCacheSnapshot(
+ isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0,
+ keyCacheData: Data(), valueCacheData: Data(),
+ updateMaskData: Data(), paddingMaskData: Data()
+ ),
+ stateData: nil
+ )
+
+ XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: nil))
+ XCTAssertFalse(cache.matches(voice: "aiden", language: "english", instruction: nil))
+ XCTAssertFalse(cache.matches(voice: "ryan", language: "korean", instruction: nil))
+ XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: "Speak softly"))
+ }
+
+ func testPromptCacheMatchingWithInstruction() {
+ let cache = TTSPromptCache(
+ voice: "ryan", language: "english", instruction: "Speak softly",
+ prefixLength: 20,
+ kvSnapshot: KVCacheSnapshot(
+ isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0,
+ keyCacheData: Data(), valueCacheData: Data(),
+ updateMaskData: Data(), paddingMaskData: Data()
+ ),
+ stateData: nil
+ )
+
+ XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: "Speak softly"))
+ XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: nil))
+ XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: "Speak loudly"))
+ }
+
+ func testPromptCacheFileName() {
+ let cache1 = TTSPromptCache(
+ voice: "ryan", language: "english", instruction: nil,
+ prefixLength: 9,
+ kvSnapshot: KVCacheSnapshot(
+ isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0,
+ keyCacheData: Data(), valueCacheData: Data(),
+ updateMaskData: Data(), paddingMaskData: Data()
+ ),
+ stateData: nil
+ )
+ XCTAssertEqual(cache1.cacheFileName, "ryan_english.promptcache")
+
+ let cache2 = TTSPromptCache(
+ voice: "aiden", language: "korean", instruction: "Speak slowly",
+ prefixLength: 20,
+ kvSnapshot: KVCacheSnapshot(
+ isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0,
+ keyCacheData: Data(), valueCacheData: Data(),
+ updateMaskData: Data(), paddingMaskData: Data()
+ ),
+ stateData: nil
+ )
+ XCTAssertTrue(cache2.cacheFileName.hasPrefix("aiden_korean_"))
+ XCTAssertTrue(cache2.cacheFileName.hasSuffix(".promptcache"))
+ }
+
+ func testKVCacheSnapshotRoundTrip() throws {
+ let dim = 4
+ let seq = 8
+ let cache = try KVCache(cacheDim: dim, maxSeqLength: seq, isStateful: false)
+
+ // Simulate 3 prefill steps by advancing position and writing dummy data
+ for step in 0..<3 {
+ if let keyCache = cache.keyCache {
+ let ptr = keyCache.dataPointer.bindMemory(to: FloatType.self, capacity: dim * seq)
+ for d in 0..