diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme index da13128a..80b47f47 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/whisperkit-Package.xcscheme @@ -48,6 +48,20 @@ ReferencedContainer = "container:"> + + + + + + + + diff --git a/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved new file mode 100644 index 00000000..083d195a --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -0,0 +1,42 @@ +{ + "originHash" : "105884455b6729deaf95d3dd16df7029fd7929c924fcab55200562b52b64dbc2", + "pins" : [ + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "c5d11a805e765f52ba34ec7284bd4fcd6ba68615", + "version" : "1.7.0" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-jinja.git", + "state" : { + "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", + "version" : "2.3.1" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers.git", + "state" : { + "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", + "version" : "1.1.6" + } + } + ], + "version" : 3 +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift b/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift new file mode 100644 index 00000000..de35b715 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/AudioMetadata.swift @@ -0,0 +1,147 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import AVFoundation +import Foundation + +/// Metadata for a single TTS generation, embedded inside its `.m4a` file. +/// +/// The full struct is serialized as JSON and stored in the iTunes `©cmt` atom so +/// the complete generation context is recoverable from the file alone. +/// Human-readable fields (title, artist, album) are also embedded +struct AudioMetadata: Codable, Sendable { + static let metadataTitleMaxLength = 80 + let id: UUID + let text: String + let speaker: String + let language: String + let instruction: String + let modelName: String + let realTimeFactor: Double + let speedFactor: Double + let stepsPerSecond: Double + let timeToFirstBuffer: TimeInterval + let date: Date + + init( + id: UUID = UUID(), + text: String, + speaker: String, + language: String, + instruction: String, + modelName: String, + realTimeFactor: Double, + speedFactor: Double, + stepsPerSecond: Double, + timeToFirstBuffer: TimeInterval, + date: Date = Date() + ) { + self.id = id + self.text = text + self.speaker = speaker + self.language = language + self.instruction = instruction + self.modelName = modelName + self.realTimeFactor = realTimeFactor + self.speedFactor = speedFactor + self.stepsPerSecond = stepsPerSecond + self.timeToFirstBuffer = timeToFirstBuffer + self.date = date + } + + // MARK: - Filename + + private static let filenameDateFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "yyyyMMdd'T'HHmmss'Z'" + f.locale = Locale(identifier: "en_US_POSIX") + f.timeZone = TimeZone(identifier: "UTC") + return f + }() + + var suggestedFileName: String { + let slug = modelName + .lowercased() + .replacingOccurrences(of: " ", with: "-") + return "\(speaker)_\(slug)_\(Self.filenameDateFormatter.string(from: date))" + } + + // MARK: - AVFoundation metadata + + /// Build the `[AVMetadataItem]` array for `AVAssetExportSession`. + /// + /// All fields use `commonIdentifier` variants so they work across every + /// container format AVFoundation supports (M4A, MOV, MP3, AIFF...). + /// + /// Fields verified to survive `AVAssetExportSession` M4A export (confirmed via ffprobe): + /// - `©nam` title - first 80 chars of the generated text + /// - `©ART` artist - speaker name + /// - `©alb` album - model (e.g. "qwen3_tts_12hz-1.7b-customvoice") + /// - `©lyr` lyrics - full untruncated text + /// - `©too` encoder - "TTSKit v{appVersion}" (shows as `encoder` in ffprobe) + /// - `©cmt` comment - JSON blob; the app reads this back to reconstruct the generation + /// + /// `.commonIdentifierCreator` and `.commonIdentifierDescription` were tried but do not + /// survive M4A export - iTunes-specific atoms are used for those two slots instead. + func avMetadataItems() throws -> [AVMetadataItem] { + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + let json = try String(data: encoder.encode(self), encoding: .utf8) ?? "" + + let appVersion = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "1.0" + let maxLen = Self.metadataTitleMaxLength + let title = text.count > maxLen ? String(text.prefix(maxLen - 3)) + "..." : text + let langTag = Self.bcp47Tag(for: language) + + func item(_ identifier: AVMetadataIdentifier, _ value: String, lang: String = "und") -> AVMetadataItem { + let i = AVMutableMetadataItem() + i.identifier = identifier + i.value = value as NSString + i.extendedLanguageTag = lang + return i + } + + return [ + item(.commonIdentifierTitle, title, lang: langTag), + item(.commonIdentifierArtist, speaker.capitalized), + item(.commonIdentifierAlbumName, modelName), + item(.iTunesMetadataLyrics, text, lang: langTag), + item(.iTunesMetadataEncodingTool, "TTSKit v\(appVersion)"), + item(.iTunesMetadataUserComment, json), + ] + } + + private static func bcp47Tag(for language: String) -> String { + switch language.lowercased() { + case "english": return "en" + case "chinese": return "zh" + case "japanese": return "ja" + case "korean": return "ko" + case "german": return "de" + case "french": return "fr" + case "russian": return "ru" + case "portuguese": return "pt" + case "spanish": return "es" + case "italian": return "it" + default: return "und" + } + } + + // MARK: - Loading from file + + /// Reconstruct metadata from the `©cmt` atom of an `.m4a` file. + /// Returns `nil` if the file has no TTSKit metadata. + static func load(from url: URL) async throws -> AudioMetadata? { + let asset = AVURLAsset(url: url) + let items = try await asset.load(.metadata) + guard let commentItem = items.first(where: { + $0.identifier == .iTunesMetadataUserComment + }), + let json = try await commentItem.load(.stringValue), + let data = json.data(using: .utf8) else { return nil } + + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + return try? decoder.decode(AudioMetadata.self, from: data) + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift b/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift new file mode 100644 index 00000000..457f1fed --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/ComputeUnitsView.swift @@ -0,0 +1,94 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import CoreML +import SwiftUI +import TTSKit + +/// Collapsible sidebar section for configuring per-component ML compute units. +/// Matches the pattern from WhisperAX's computeUnitsView. +struct ComputeUnitsView: View { + @Environment(ViewModel.self) private var viewModel + @State private var isExpanded = false + + var body: some View { + @Bindable var vm = viewModel + + DisclosureGroup(isExpanded: $isExpanded) { + VStack(spacing: 8) { + computeRow( + label: "Embedders", + units: vm.embedderComputeUnits, + onChange: { vm.embedderComputeUnits = $0; reloadIfNeeded() } + ) + computeRow( + label: "Code Decoder", + units: vm.codeDecoderComputeUnits, + onChange: { vm.codeDecoderComputeUnits = $0; reloadIfNeeded() } + ) + computeRow( + label: "Multi-Code Decoder", + units: vm.multiCodeDecoderComputeUnits, + onChange: { vm.multiCodeDecoderComputeUnits = $0; reloadIfNeeded() } + ) + computeRow( + label: "Speech Decoder", + units: vm.speechDecoderComputeUnits, + onChange: { vm.speechDecoderComputeUnits = $0; reloadIfNeeded() } + ) + } + .padding(.top, 4) + } label: { + Button { + isExpanded.toggle() + } label: { + Text("Compute Units") + .font(.headline) + } + .buttonStyle(.plain) + } + .disabled(viewModel.modelState.isBusy) + .padding(.horizontal, 12) + .padding(.vertical, 6) + } + + @ViewBuilder + private func computeRow( + label: String, + units: MLComputeUnits, + onChange: @escaping (MLComputeUnits) -> Void + ) -> some View { + HStack(spacing: 8) { + Text(label) + .font(.caption) + .foregroundStyle(.secondary) + .lineLimit(1) + .minimumScaleFactor(0.85) + .layoutPriority(1) + + Spacer(minLength: 4) + + Picker("", selection: Binding( + get: { units }, + set: { onChange($0) } + )) { + Text("CPU").tag(MLComputeUnits.cpuOnly) + Text("GPU").tag(MLComputeUnits.cpuAndGPU) + #if os(iOS) + // Abbreviated on iOS to prevent overflow + Text("NE").tag(MLComputeUnits.cpuAndNeuralEngine) + #else + Text("Neural Engine").tag(MLComputeUnits.cpuAndNeuralEngine) + #endif + } + .labelsHidden() + .pickerStyle(.menu) + .fixedSize() + } + } + + private func reloadIfNeeded() { + guard viewModel.modelState == .loaded else { return } + viewModel.reloadModelForComputeUnitChange() + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift b/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift new file mode 100644 index 00000000..7c5b6d10 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/ContentView.swift @@ -0,0 +1,63 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import SwiftUI + +struct ContentView: View { + @Environment(ViewModel.self) private var viewModel + var body: some View { + NavigationSplitView { + SidebarView() + } detail: { + DetailView() + } + .navigationSplitViewStyle(.balanced) + #if os(macOS) + // SwiftUI frame constraints don't prevent NavigationSplitView from entering + // overlay mode. Setting minSize directly on NSWindow is a reliable fix. + .background(WindowMinSizeEnforcer(minSize: CGSize(width: 840, height: 560))) + #endif + } +} + +// MARK: - macOS window minimum size enforcer + +#if os(macOS) +/// Enforces `NSWindow.minSize` at the AppKit level. +/// +/// SwiftUI's `.frame(minWidth:)` and `windowResizability` both fail to prevent +/// `NavigationSplitView` from shrinking into overlay mode. Going through AppKit +/// directly is a reliable way to clamp the resize handle. +/// +/// A custom `NSView` subclass is used instead of `DispatchQueue.main.async` because +/// the view's `window` property is `nil` during `makeNSView`. +private struct WindowMinSizeEnforcer: NSViewRepresentable { + let minSize: CGSize + + func makeNSView(context: Context) -> MinSizeView { + let view = MinSizeView() + view.minSize = minSize + return view + } + + func updateNSView(_ nsView: MinSizeView, context: Context) { + nsView.minSize = minSize + } +} + +private final class MinSizeView: NSView { + var minSize: CGSize = .zero { + didSet { window?.minSize = minSize } + } + + override func viewDidMoveToWindow() { + super.viewDidMoveToWindow() + window?.minSize = minSize + } +} +#endif + +#Preview { + ContentView() + .environment(ViewModel()) +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift b/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift new file mode 100644 index 00000000..1aa680e3 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/DetailView.swift @@ -0,0 +1,584 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import SwiftUI +import TTSKit +import UniformTypeIdentifiers +#if canImport(UIKit) +import UIKit +#elseif canImport(AppKit) +import AppKit +#endif + +struct DetailView: View { + @Environment(ViewModel.self) private var viewModel + + @State private var isImportingFile = false + @State private var fileImportError: String? = nil + + var body: some View { + @Bindable var vm = viewModel + + VStack(spacing: 0) { + // Waveform + playback controls (~top half) + waveformSection + .frame(minHeight: 140) + + Divider() + + // Text input + controls (bottom half) + VStack(spacing: 16) { + textInputSection + metricsRow + controlsBar + } + .padding(20) + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + .navigationTitle("TTSKit Example") + #if os(iOS) + .navigationBarTitleDisplayMode(.inline) + #endif + .toolbar { + toolbarContent + } + .sheet(isPresented: $vm.showGenerationSettings) { + GenerationSettingsView() + .environment(viewModel) + .presentationDetents([.medium, .large]) + .presentationBackgroundInteraction(.enabled) + .presentationContentInteraction(.scrolls) + } + .onAppear { + viewModel.onAppear() + } + .onChange(of: viewModel.selectedGenerationID) { + guard !viewModel.isStreaming else { return } + guard let gen = viewModel.selectedGeneration else { return } + viewModel.loadWaveform(for: gen) + // Auto-populate input fields so the user can play with the settings and regenerate + viewModel.loadInputs(from: gen) + } + } + + // MARK: - Waveform Section + + @ViewBuilder + private var waveformSection: some View { + VStack(spacing: 12) { + Spacer(minLength: 8) + + // All three states live in the layout simultaneously so the height + // never changes - only opacity transitions between them. + ZStack { + // Idle empty placeholder + VStack(spacing: 8) { + Image(systemName: "waveform") + .font(.system(size: 40)) + .foregroundStyle(.quaternary) + Text("Waveform will appear here") + .font(.subheadline) + .foregroundStyle(.tertiary) + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + .opacity(viewModel.currentWaveform.isEmpty && viewModel.generationState != .generating ? 1 : 0) + + // Generating progress + VStack(spacing: 8) { + if viewModel.totalSteps > 0 { + ProgressView( + value: Double(viewModel.stepsCompleted), + total: Double(viewModel.totalSteps) + ) + .progressViewStyle(.linear) + .frame(maxWidth: 200) + + let pct = Int(Double(viewModel.stepsCompleted) / Double(viewModel.totalSteps) * 100) + let detail = viewModel.chunksTotal > 1 ? " (\(viewModel.chunksTotal) chunks)" : "" + Text("Generating... \(pct)%\(detail)") + .font(.subheadline) + .foregroundStyle(.tertiary) + .contentTransition(.numericText()) + } else { + ProgressView() + .controlSize(.small) + Text(viewModel.statusMessage) + .font(.subheadline) + .foregroundStyle(.tertiary) + } + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + .opacity(viewModel.currentWaveform.isEmpty && viewModel.generationState == .generating ? 1 : 0) + + // Live / completed waveform + WaveformView( + samples: viewModel.currentWaveform, + playbackTime: viewModel.playbackTime, + totalDuration: viewModel.currentDuration + ) + .padding(.horizontal, 0) + .opacity(viewModel.currentWaveform.isEmpty ? 0 : 1) + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + + // Time labels + playback controls + HStack { + HStack(spacing: 4) { + Text(formatTime(viewModel.playbackTime)) + .font(.caption.monospacedDigit()) + .foregroundStyle(.secondary) + + let bufferRemaining = viewModel.silentBufferRemaining + if bufferRemaining > 0.4 { + Text("(\(formatTime(bufferRemaining)))") + .font(.caption.monospacedDigit()) + .foregroundStyle(.tertiary) + } + } + + Spacer() + + playbackControls + + Spacer() + + Text(formatTime(viewModel.currentDuration)) + .font(.caption.monospacedDigit()) + .foregroundStyle(.secondary) + } + .padding(.horizontal, 24) + + Spacer(minLength: 8) + } + .background(.quaternary.opacity(0.3)) + } + + private var playbackControls: some View { + let hasAudio = viewModel.selectedGeneration?.audioFileName != nil + let icon = viewModel.isPlaying ? "stop.fill" : "play.fill" + return Button { + if viewModel.isPlaying { + viewModel.stopPlayback() + } else if let gen = viewModel.selectedGeneration { + viewModel.playGeneration(gen) + } + } label: { + Image(systemName: icon) + .font(.title2) + } + .buttonStyle(.plain) + .keyboardShortcut(.space, modifiers: []) + .disabled(!hasAudio || viewModel.generationState == .generating) + .opacity(hasAudio ? 1 : 0) + .accessibilityLabel(viewModel.isPlaying ? "Stop playback" : "Play audio") + .accessibilityHint(viewModel.isPlaying ? "Stops the current audio" : "Plays the selected generation") + } + + // MARK: - Text Input + + @ViewBuilder + private var textInputSection: some View { + @Bindable var vm = viewModel + + VStack(alignment: .leading, spacing: 8) { + let promptForInput = viewModel.modelState == .loaded + && viewModel.inputText.isEmpty + && viewModel.generationState == .idle + + ZStack(alignment: .topLeading) { + TextEditor(text: $vm.inputText) + .font(.body) + .scrollContentBackground(.hidden) + .padding(.horizontal, 8) + .padding(.vertical, 8) + .background(.quaternary.opacity(0.5), in: RoundedRectangle(cornerRadius: 8)) + .overlay { + RoundedRectangle(cornerRadius: 8) + .stroke(Color.accentColor, lineWidth: promptForInput ? 1.5 : 0) + .animation(.easeInOut(duration: 0.25), value: promptForInput) + } + .frame(minHeight: 80, maxHeight: 160) +// .overlay(alignment: .topTrailing) { +// Button { +// isImportingFile = true +// } label: { +// Image(systemName: "square.and.arrow.down") +// .font(.caption) +// .padding(6) +// .background(.thinMaterial, in: RoundedRectangle(cornerRadius: 6)) +// } +// .buttonStyle(.plain) +// .padding(6) +// .accessibilityLabel("Import text file") +// .accessibilityHint("Opens a file picker to import a .txt or .pdf file into the text editor") +// } + + if vm.inputText.isEmpty { + Text("Enter text to speak...") + .foregroundStyle(.tertiary) + .padding(.horizontal, 12) + .padding(.vertical, 8) + .allowsHitTesting(false) + } + } + .fileImporter( + isPresented: $isImportingFile, + allowedContentTypes: [.plainText, .pdf], + allowsMultipleSelection: false + ) { result in + switch result { + case .success(let urls): + guard let url = urls.first else { return } + // Security-scoped access is required for files picked outside the sandbox + let accessing = url.startAccessingSecurityScopedResource() + defer { if accessing { url.stopAccessingSecurityScopedResource() } } + if let text = FileUtilities.readTextContent(at: url) { + vm.inputText = text + } else { + fileImportError = "Could not read text from \"\(url.lastPathComponent)\"." + } + case .failure(let error): + fileImportError = error.localizedDescription + } + } + .alert("Import Failed", isPresented: .init( + get: { fileImportError != nil }, + set: { if !$0 { fileImportError = nil } } + )) { + Button("OK", role: .cancel) { fileImportError = nil } + } message: { + Text(fileImportError ?? "") + } + + HStack(spacing: 8) { + Picker("Voice", selection: $vm.selectedSpeaker) { + ForEach(Qwen3Speaker.allCases, id: \.self) { speaker in + Text("\(speaker.displayName) · \(speaker.nativeLanguage)").tag(speaker) + } + } + .fixedSize() + .accessibilityLabel("Voice") + .accessibilityHint("Choose the speaker voice for synthesis") + + Picker("Language", selection: $vm.selectedLanguage) { + ForEach(Qwen3Language.allCases, id: \.self) { lang in + Text(lang.rawValue.capitalized).tag(lang) + } + } + .fixedSize() + .accessibilityLabel("Language") + .accessibilityHint("Choose the output language for the voice") + + Picker("Playback", selection: $vm.playbackStrategyTag) { + Text("Auto").tag("auto") + Text("Stream").tag("stream") + Text("Generate First").tag("generateFirst") + } + .fixedSize() + .accessibilityLabel("Playback strategy") + .accessibilityHint("Auto buffers adaptively; Stream plays frame-by-frame; Generate First waits for the full audio") + + Spacer(minLength: 0) + + Text(vm.inputTokenCount.map { "\($0) tok" } ?? "~\(vm.inputText.unicodeScalars.count / 5) tok") + .font(.caption) + .foregroundStyle(.tertiary) + .lineLimit(1) + } + + // Selected speaker description + Text(vm.selectedSpeaker.voiceDescription) + .font(.caption) + .foregroundStyle(.secondary) + .frame(maxWidth: .infinity, alignment: .leading) + .transition(.opacity) + .animation(.easeInOut(duration: 0.2), value: vm.selectedSpeaker) + + // Instruction / style prompt (1.7B only) + let instructionSupported = vm.selectedPreset.supportsVoiceDirection + if instructionSupported { + HStack(spacing: 6) { + Image(systemName: "theatermasks") + .font(.caption) + .foregroundStyle(.secondary) + TextField( + "Style instruction (e.g. cheerful and energetic)...", + text: $vm.instruction + ) + .font(.callout) + .textFieldStyle(.plain) + } + .padding(.horizontal, 10) + .padding(.vertical, 7) + .background(.quaternary.opacity(0.3), in: RoundedRectangle(cornerRadius: 8)) + .transition(.opacity.combined(with: .move(edge: .top))) + .animation(.easeInOut(duration: 0.2), value: instructionSupported) + } + } + } + + // MARK: - Metrics Row + + private var metricsRow: some View { + let hasMetrics = viewModel.currentRTF > 0 + return HStack(spacing: 0) { + metricItem( + value: hasMetrics ? String(format: "%.1f×", viewModel.currentSpeedFactor) : "-", + label: "Speed Factor" + ) + Divider().frame(height: 24) + metricItem( + value: hasMetrics ? String(format: "%.0f", viewModel.currentStepsPerSecond) : "-", + label: "steps/s" + ) + Divider().frame(height: 24) + metricItem( + value: hasMetrics ? String(format: "%.2fs", viewModel.currentTimeToFirstBuffer) : "-", + label: "First Buffer" + ) + } + .frame(maxWidth: .infinity) + .padding(.vertical, 6) + .background(.quaternary.opacity(0.3), in: RoundedRectangle(cornerRadius: 8)) + .opacity(hasMetrics ? 1 : 0.4) + .animation(.easeInOut(duration: 0.2), value: hasMetrics) + } + + private func metricItem(value: String, label: String) -> some View { + VStack(spacing: 2) { + Text(value) + .font(.system(.body, design: .monospaced)) + .lineLimit(1) + Text(label) + .font(.caption2) + .foregroundStyle(.secondary) + } + .frame(maxWidth: .infinity) + } + + // MARK: - Controls Bar + + @ViewBuilder + private var controlsBar: some View { + #if os(iOS) + // On iOS: generate button spans full width; secondary controls on the row below. + // This prevents the button label from being clipped on narrow screens. + VStack(spacing: 10) { + Button { + if viewModel.generationState == .idle { + viewModel.startGeneration() + } else { + viewModel.cancelGeneration() + } + } label: { + Label(generateButtonTitle, systemImage: generateButtonIcon) + .frame(maxWidth: .infinity) + } + .glassButton(prominent: true) + .tint(viewModel.generationState == .idle ? .accentColor : .red) + .controlSize(.large) + .disabled(!viewModel.canGenerate && viewModel.generationState == .idle) + .accessibilityLabel(viewModel.generationState == .idle ? generateButtonTitle : "Cancel generation") + .accessibilityHint(viewModel.generationState == .idle + ? "Synthesizes the entered text using the loaded model" + : "Stops the current generation immediately") + + HStack { + Button { + viewModel.clearInput() + } label: { + Label("Clear", systemImage: "xmark.circle") + } + .glassButton() + .controlSize(.regular) + .disabled(viewModel.inputText.isEmpty) + .opacity(viewModel.generationState == .idle ? 1 : 0) + .accessibilityLabel("Clear input") + .accessibilityHint("Clears the text input and resets the waveform") + + Spacer() + + Text(viewModel.statusMessage) + .font(.caption) + .foregroundStyle(.secondary) + .lineLimit(2) + .multilineTextAlignment(.trailing) + } + } + #else + HStack(spacing: 12) { + // Generate / Cancel button + Button { + if viewModel.generationState == .idle { + viewModel.startGeneration() + } else { + viewModel.cancelGeneration() + } + } label: { + Label( + generateButtonTitle, + systemImage: generateButtonIcon + ) + .frame(maxWidth: 200) + } + .glassButton(prominent: true) + .tint(viewModel.generationState == .idle ? .accentColor : .red) + .controlSize(.large) + .disabled(!viewModel.canGenerate && viewModel.generationState == .idle) + .accessibilityLabel(viewModel.generationState == .idle ? generateButtonTitle : "Cancel generation") + .accessibilityHint(viewModel.generationState == .idle + ? "Synthesizes the entered text using the loaded model" + : "Stops the current generation immediately") + + Button { + viewModel.clearInput() + } label: { + Label("Clear", systemImage: "xmark.circle") + } + .glassButton() + .controlSize(.large) + .disabled(viewModel.inputText.isEmpty) + .opacity(viewModel.generationState == .idle ? 1 : 0) + .accessibilityLabel("Clear input") + .accessibilityHint("Clears the text input and resets the waveform") + + Spacer() + + // Status + Text(viewModel.statusMessage) + .font(.caption) + .foregroundStyle(.secondary) + .lineLimit(1) + .truncationMode(.tail) + } + #endif + } + + // MARK: - Toolbar + + @ToolbarContentBuilder + private var toolbarContent: some ToolbarContent { + #if os(macOS) + ToolbarItem { + Button(action: newGeneration) { + Label("New Generation", systemImage: "plus") + } + .keyboardShortcut("n") + } + #endif + + ToolbarItem(placement: .primaryAction) { + Button { + viewModel.showGenerationSettings = true + } label: { + Label("Generation Settings", systemImage: "slider.horizontal.3") + } + } + + // Share / export audio file + if let gen = viewModel.selectedGeneration, + let url = viewModel.audioFileURL(for: gen) + { + ToolbarItem { + AudioCopyButton(url: url) + } + } + } + + private func newGeneration() { + viewModel.clearInput() + viewModel.currentWaveform = [] + viewModel.selectedGenerationID = ViewModel.newGenerationSentinel + } + + // MARK: - Helpers + + private var generateButtonTitle: String { + switch viewModel.generationState { + case .generating: return "Cancel" + case .idle: + switch viewModel.modelState { + case .downloading: return "Downloading..." + case .prewarming: return "Specializing..." + case .loading: return "Loading..." + default: return viewModel.selectedGeneration != nil ? "Regenerate" : "Generate" + } + } + } + + private var generateButtonIcon: String { + switch viewModel.generationState { + case .generating: return "stop.fill" + case .idle: return "play.fill" + } + } + + private func formatTime(_ seconds: TimeInterval) -> String { + let m = Int(seconds) / 60 + let s = Int(seconds) % 60 + return String(format: "%d:%02d", m, s) + } +} + +// MARK: - Audio file transferable + +// MARK: - Copy to clipboard button + +/// Copies the audio file to the system clipboard so it can be pasted into Mail, Finder, etc. +/// Uses `NSPasteboard` on macOS and `UIPasteboard` on iOS - both accept a file URL directly, +/// letting the receiving app decide how to handle the file. +struct AudioCopyButton: View { + let url: URL + @State private var copied = false + + var body: some View { + Button { + copyToClipboard() + copied = true + DispatchQueue.main.asyncAfter(deadline: .now() + 1.5) { copied = false } + } label: { + Label(copied ? "Copied!" : "Copy Audio", systemImage: copied ? "checkmark" : "doc.on.doc") + } + .accessibilityLabel(copied ? "Copied to clipboard" : "Copy audio file") + .accessibilityHint("Copies the audio file to the clipboard so you can paste it into other apps") + } + + private func copyToClipboard() { + #if os(macOS) + NSPasteboard.general.clearContents() + NSPasteboard.general.writeObjects([url as NSURL]) + #else + guard let data = try? Data(contentsOf: url) else { return } + let uti = UTType(filenameExtension: url.pathExtension) ?? .audio + UIPasteboard.general.setData(data, forPasteboardType: uti.identifier) + #endif + } +} + +// MARK: - Glass Button Style + +/// Applies `.glass` on iOS/macOS 26+ and falls back to `.borderedProminent` / +/// `.bordered` on older OS versions. +struct GlassButtonModifier: ViewModifier { + var prominent: Bool = false + + func body(content: Content) -> some View { + if #available(iOS 26, macOS 26, watchOS 26, visionOS 26, *) { + content + .buttonStyle(.glass) + } else if prominent { + content + .buttonStyle(.borderedProminent) + } else { + content + .buttonStyle(.bordered) + } + } +} + +extension View { + func glassButton(prominent: Bool = false) -> some View { + modifier(GlassButtonModifier(prominent: prominent)) + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift b/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift new file mode 100644 index 00000000..3c73c5b8 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/GenerationSettingsView.swift @@ -0,0 +1,236 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import SwiftUI +import TTSKit + +/// Sheet for configuring GenerationOptions. +struct GenerationSettingsView: View { + @Environment(ViewModel.self) private var viewModel + @Environment(\.dismiss) private var dismiss + + var body: some View { + #if os(iOS) + NavigationStack { + settingsForm + } + #else + VStack(spacing: 0) { + // Explicit header with title and close button - toolbar items + // don't render reliably inside macOS sheet NavigationStacks. + HStack { + Text("Generation Options") + .font(.title2) + .fontWeight(.semibold) + Spacer() + Button { + dismiss() + } label: { + Image(systemName: "xmark.circle.fill") + .font(.title2) + .symbolRenderingMode(.hierarchical) + .foregroundStyle(.secondary) + } + .buttonStyle(.plain) + } + .padding(.horizontal) + .padding(.vertical, 12) + + Divider() + + settingsForm + .frame(minWidth: 480, minHeight: 500) + } + #endif + } + + private var settingsForm: some View { + @Bindable var vm = viewModel + + return Form { + // MARK: Sampling + + Section("Sampling") { + sliderRow( + label: "Temperature", + info: "Controls the randomness of token selection. Higher values produce more varied speech; lower values are more deterministic.", + value: $vm.temperature, + range: 0...1, + step: 0.05, + display: String(format: "%.2f", vm.temperature) + ) + + sliderRow( + label: "Top-K", + info: "Limits token selection to the top-K most probable tokens at each step. Lower values make output more focused.", + value: $vm.topK, + range: 1...100, + step: 1, + display: Int(vm.topK).formatted() + ) + + sliderRow( + label: "Repetition Penalty", + info: "Penalises repeating the same tokens. Values above 1.0 discourage repetition; 1.0 disables the penalty.", + value: $vm.repetitionPenalty, + range: 1.0...1.5, + step: 0.01, + display: String(format: "%.2f", vm.repetitionPenalty) + ) + + sliderRow( + label: "Max New Tokens", + info: "Upper bound on tokens generated per chunk. Longer values allow longer output but increase max generation time.", + value: $vm.maxNewTokens, + range: 50...500, + step: 10, + display: Int(vm.maxNewTokens).formatted() + ) + } + + // MARK: Chunking + + Section("Chunking") { + // Strategy row: label + info on the left, segmented picker on the right. + // Using LabeledContent avoids macOS Form's two-column layout pushing + // a plain Spacer-separated HStack to extreme edges. + LabeledContent { + Picker("", selection: $vm.chunkingStrategyTag) { + Text("None").tag("none") + Text("Sentence").tag("sentence") + } + .pickerStyle(.segmented) + .frame(width: 160) + } label: { + HStack(spacing: 4) { + Text("Strategy") + InfoButton("Controls how long text is split before generation. Sentence-based chunking produces more natural prosody at chunk boundaries.") + } + } + + sliderRow( + label: "Target Chunk Size", + info: "Maximum tokens per chunk when using sentence chunking. Chunks break at the nearest sentence boundary at or before this token count. Default (50).", + value: $vm.targetChunkSize, + range: 10...100, + step: 1, + display: "\(Int(vm.targetChunkSize)) tok", + displayWidth: 60 + ) + .disabled(vm.chunkingStrategyTag == "none") + + sliderRow( + label: "Min Chunk Size", + info: "Minimum tokens per chunk. Short trailing segments are merged into the previous chunk to avoid tiny segments with poor prosody.", + value: $vm.minChunkSize, + range: 1...30, + step: 1, + display: "\(Int(vm.minChunkSize)) tok", + displayWidth: 60 + ) + .disabled(vm.chunkingStrategyTag == "none") + } + + // MARK: Concurrency + + Section("Concurrency") { + sliderRow( + label: "Concurrent Workers", + info: "How many chunks to generate in parallel. 0 = all chunks at once (fastest). 1 = sequential (best for streaming playback).", + value: $vm.concurrentWorkerCount, + range: 0...16, + step: 1, + display: vm.concurrentWorkerCount == 0 ? "Max" : Int(vm.concurrentWorkerCount).formatted() + ) + } + + // MARK: Reset + + Section { + Button(role: .destructive) { + viewModel.resetGenerationSettings() + } label: { + HStack { + Spacer() + Text("Reset to Defaults") + Spacer() + } + } + } + } + // Use grouped style on both platforms so macOS renders as a scrollable list + // rather than its default two-column layout, which misaligns our custom rows. + .formStyle(.grouped) + #if os(iOS) + .navigationTitle("Generation Options") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .primaryAction) { + Button { + dismiss() + } label: { + Label("Done", systemImage: "xmark.circle.fill") + .foregroundStyle(.primary) + } + } + } + #endif + } + + // MARK: - Helpers + + /// Reusable label + info + value display + slider row. + private func sliderRow( + label: String, + info: String, + value: Binding, + range: ClosedRange, + step: Double, + display: String, + displayWidth: CGFloat = 44 + ) -> some View { + VStack(alignment: .leading, spacing: 4) { + HStack { + Text(label) + InfoButton(info) + Spacer() + Text(display) + .monospacedDigit() + .foregroundStyle(.secondary) + .frame(width: displayWidth, alignment: .trailing) + } + Slider(value: value, in: range, step: step) + } + .padding(.vertical, 2) + } +} + +// MARK: - Info Button + +/// Small info button that shows a popover with explanatory text. +/// Matches the InfoButton pattern from WhisperAX. +struct InfoButton: View { + let text: String + @State private var isShowing = false + + init(_ text: String) { + self.text = text + } + + var body: some View { + Button { + isShowing = true + } label: { + Image(systemName: "info.circle") + .foregroundStyle(.blue) + } + .buttonStyle(.borderless) + .popover(isPresented: $isShowing) { + Text(text) + .multilineTextAlignment(.leading) + .fixedSize(horizontal: false, vertical: true) + .padding() + .frame(width: 260) + } + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift b/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift new file mode 100644 index 00000000..21f7fe0c --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/ModelManagementView.swift @@ -0,0 +1,190 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import SwiftUI +import TTSKit + +/// Sidebar section for model selection, download, load/unload, and deletion. +/// Mirrors the model management pattern from WhisperAX / Argmax Playground. +struct ModelManagementView: View { + @Environment(ViewModel.self) private var viewModel + + var body: some View { + @Bindable var vm = viewModel + + VStack(alignment: .leading, spacing: 8) { + // Status row: dot + state label + accessory buttons + HStack(spacing: 6) { + Image(systemName: "circle.fill") + .font(.system(size: 8)) + .foregroundStyle(vm.modelState.color) + .symbolEffect(.variableColor, isActive: vm.modelState.isBusy) + + Text(vm.modelState.label) + .font(.caption) + .foregroundStyle(.secondary) + + Spacer() + + #if os(macOS) + // On macOS the picker fits in the status row alongside the accessory buttons + modelPicker + #endif + + accessoryButtons + } + + #if os(iOS) + // On iOS the picker gets its own row so it isn't squeezed by the accessory buttons + modelPicker + .frame(maxWidth: .infinity, alignment: .leading) + #endif + + // Size estimate + Text(viewModel.modelDiskSize(for: vm.selectedPreset) ?? vm.selectedPreset.sizeEstimate) + .font(.caption2) + .foregroundStyle(.tertiary) + + // Action slot: button OR progress - same height, no layout shift + if vm.modelState == .loaded && vm.loadedPreset == vm.selectedPreset { + Button { viewModel.unloadModel() } label: { + Label("Unload Model", systemImage: "eject") + .frame(maxWidth: .infinity) + .frame(height: 28) + } + .glassButton() + .controlSize(.small) + .accessibilityLabel("Unload model") + .accessibilityHint("Releases the model from memory. You can reload it later.") + } else if vm.modelState == .downloading { + VStack(spacing: 2) { + HStack { + ProgressView(value: vm.downloadProgress) + .progressViewStyle(.linear) + Text(String(format: "%.1f%%", vm.downloadProgress * 100)) + .font(.caption) + .foregroundStyle(.secondary) + .monospacedDigit() + } + } + .frame(height: 28) + } else if vm.modelState == .prewarming || vm.modelState == .loading { + VStack(alignment: .leading, spacing: 2) { + ProgressView() + .progressViewStyle(.linear) + Text(vm.modelState == .prewarming + ? "Specializing for your device..." + : "Loading...") + .font(.caption2) + .foregroundStyle(.secondary) + } + .frame(height: 28) + } else { + Button { Task { await viewModel.loadModel() } } label: { + Label( + vm.isModelDownloaded ? "Load Model" : "Download & Load", + systemImage: vm.isModelDownloaded ? "arrow.up.circle" : "arrow.down.circle" + ) + .frame(maxWidth: .infinity) + .frame(height: 28) + } + .glassButton(prominent: true) + .tint(.accentColor) + .controlSize(.small) + .accessibilityLabel(vm.isModelDownloaded ? "Load model" : "Download and load model") + .accessibilityHint(vm.isModelDownloaded + ? "Loads the downloaded model into memory" + : "Downloads the model (~\(vm.selectedPreset.sizeEstimate)) and loads it") + } + } + .padding(.horizontal, 12) + .padding(.vertical, 8) + } + + // MARK: - Model Picker + + private var modelPicker: some View { + @Bindable var vm = viewModel + return Picker("Model", selection: $vm.selectedPreset) { + ForEach(TTSModelVariant.allCases, id: \.self) { preset in + HStack { + let downloaded = vm.localModelPaths[preset] != nil + Image(systemName: downloaded ? "checkmark.circle" : "arrow.down.circle.dotted") + if preset.isAvailableOnCurrentPlatform { + Text(preset.displayName) + } else { + Text("\(preset.displayName) - Mac only") + .foregroundStyle(.tertiary) + } + } + .tag(preset) + } + } + .labelsHidden() + .pickerStyle(.menu) + .accessibilityLabel("Select model") + .accessibilityHint("Choose the TTS model to use for generation") + .onChange(of: vm.selectedPreset) { + if !vm.selectedPreset.isAvailableOnCurrentPlatform { + vm.selectedPreset = .defaultForCurrentPlatform + vm.statusMessage = "\(vm.selectedPreset.displayName) requires macOS" + return + } + if vm.loadedPreset == vm.selectedPreset { return } + if vm.localModelPaths[vm.selectedPreset] != nil { + vm.statusMessage = vm.modelState == .loaded + ? "Different model loaded" + : "Downloaded" + } else { + vm.statusMessage = "Not downloaded" + } + } + } + + // MARK: - Accessory Buttons (trash, reveal, HF link) + + @ViewBuilder + private var accessoryButtons: some View { + let vm = viewModel + + // Delete + Button { + viewModel.deleteModel() + } label: { + Image(systemName: "trash") + } + .buttonStyle(.borderless) + .disabled(!vm.isModelDownloaded || vm.modelState.isBusy) + .help("Delete downloaded model files") + .accessibilityLabel("Delete model") + .accessibilityHint("Permanently deletes the downloaded model files from disk") + + #if os(macOS) + // Reveal in Finder + if let path = vm.localModelPaths[vm.selectedPreset] { + Button { + NSWorkspace.shared.selectFile(nil, inFileViewerRootedAtPath: path) + } label: { + Image(systemName: "folder") + } + .buttonStyle(.borderless) + .help("Reveal in Finder") + } + #endif + + // Open on HuggingFace + Button { + if let url = viewModel.modelRepoURL { + #if os(macOS) + NSWorkspace.shared.open(url) + #else + UIApplication.shared.open(url) + #endif + } + } label: { + Image(systemName: "link.circle") + } + .buttonStyle(.borderless) + .help("Open model on HuggingFace") + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png new file mode 100644 index 00000000..55f3dedf Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/Assets/1024.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json new file mode 100644 index 00000000..621c8067 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/AppIcon.icon/icon.json @@ -0,0 +1,37 @@ +{ + "fill-specializations" : [ + { + "value" : "automatic" + }, + { + "appearance" : "dark", + "value" : "automatic" + } + ], + "groups" : [ + { + "layers" : [ + { + "glass" : false, + "hidden" : false, + "image-name" : "1024.png", + "name" : "1024" + } + ], + "shadow" : { + "kind" : "neutral", + "opacity" : 0.5 + }, + "translucency" : { + "enabled" : true, + "value" : 0.5 + } + } + ], + "supported-platforms" : { + "circles" : [ + "watchOS" + ], + "squares" : "shared" + } +} \ No newline at end of file diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png new file mode 100644 index 00000000..95cb96dd Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/100.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png new file mode 100644 index 00000000..e81683a2 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/102.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png new file mode 100644 index 00000000..55f3dedf Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png new file mode 100644 index 00000000..55f3dedf Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024 2.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png new file mode 100644 index 00000000..a084825d Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/1024.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png new file mode 100644 index 00000000..275e890f Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/108.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png new file mode 100644 index 00000000..0cc74726 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/114.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png new file mode 100644 index 00000000..4ddd0681 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png new file mode 100644 index 00000000..4ddd0681 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/120.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png new file mode 100644 index 00000000..f4e54b88 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png new file mode 100644 index 00000000..7a629226 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/128.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png new file mode 100644 index 00000000..1952cada Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/136.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png new file mode 100644 index 00000000..4b3b3e42 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/152.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png new file mode 100644 index 00000000..bc8da580 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/16.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png new file mode 100644 index 00000000..b3ffac1b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/167.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png new file mode 100644 index 00000000..c4639557 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/172.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png new file mode 100644 index 00000000..81467c9b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/180.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png new file mode 100644 index 00000000..5813d1a9 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/192.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png new file mode 100644 index 00000000..469fde22 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/196.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png new file mode 100644 index 00000000..a5a87744 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/216.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png new file mode 100644 index 00000000..24d36305 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/234.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png new file mode 100644 index 00000000..4c18e6de Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/256.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png new file mode 100644 index 00000000..023c9c32 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/258.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png new file mode 100644 index 00000000..a39f420b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/32.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png new file mode 100644 index 00000000..dad76cac Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/40.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png new file mode 100644 index 00000000..46be1514 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/44.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png new file mode 100644 index 00000000..ade9ec3b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/48.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png new file mode 100644 index 00000000..54ccc013 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/512.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png new file mode 100644 index 00000000..3b67b790 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/55.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png new file mode 100644 index 00000000..5fa76358 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png new file mode 100644 index 00000000..5fa76358 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/58.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png new file mode 100644 index 00000000..9fc760d4 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png new file mode 100644 index 00000000..a27515cf Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/60.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png new file mode 100644 index 00000000..7cff860d Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png new file mode 100644 index 00000000..e23c66d7 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/64.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png new file mode 100644 index 00000000..a4e383f3 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/66.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png new file mode 100644 index 00000000..1a1259ee Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/76.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png new file mode 100644 index 00000000..210fd08b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png new file mode 100644 index 00000000..210fd08b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/80.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png new file mode 100644 index 00000000..e3744a97 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87 1.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png new file mode 100644 index 00000000..e3744a97 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/87.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png new file mode 100644 index 00000000..262c010d Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/88.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png new file mode 100644 index 00000000..8952531b Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/92.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 00000000..d2c94f1c --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,318 @@ +{ + "images" : [ + { + "filename" : "40.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "20x20" + }, + { + "filename" : "60.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "20x20" + }, + { + "filename" : "58 1.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "29x29" + }, + { + "filename" : "87 1.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "29x29" + }, + { + "filename" : "76.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "38x38" + }, + { + "filename" : "114.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "38x38" + }, + { + "filename" : "80 1.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "40x40" + }, + { + "filename" : "120.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "40x40" + }, + { + "filename" : "120 1.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "60x60" + }, + { + "filename" : "180.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "60x60" + }, + { + "filename" : "128 1.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "64x64" + }, + { + "filename" : "192.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "3x", + "size" : "64x64" + }, + { + "filename" : "136.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "68x68" + }, + { + "filename" : "152.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "76x76" + }, + { + "filename" : "167.png", + "idiom" : "universal", + "platform" : "ios", + "scale" : "2x", + "size" : "83.5x83.5" + }, + { + "filename" : "1024 1.png", + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + }, + { + "filename" : "16.png", + "idiom" : "mac", + "scale" : "1x", + "size" : "16x16" + }, + { + "filename" : "32.png", + "idiom" : "mac", + "scale" : "2x", + "size" : "16x16" + }, + { + "filename" : "32.png", + "idiom" : "mac", + "scale" : "1x", + "size" : "32x32" + }, + { + "filename" : "64.png", + "idiom" : "mac", + "scale" : "2x", + "size" : "32x32" + }, + { + "filename" : "128.png", + "idiom" : "mac", + "scale" : "1x", + "size" : "128x128" + }, + { + "filename" : "256.png", + "idiom" : "mac", + "scale" : "2x", + "size" : "128x128" + }, + { + "filename" : "256.png", + "idiom" : "mac", + "scale" : "1x", + "size" : "256x256" + }, + { + "filename" : "512.png", + "idiom" : "mac", + "scale" : "2x", + "size" : "256x256" + }, + { + "filename" : "512.png", + "idiom" : "mac", + "scale" : "1x", + "size" : "512x512" + }, + { + "filename" : "1024.png", + "idiom" : "mac", + "scale" : "2x", + "size" : "512x512" + }, + { + "filename" : "44.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "22x22" + }, + { + "filename" : "48.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "24x24" + }, + { + "filename" : "55.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "27.5x27.5" + }, + { + "filename" : "58.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "29x29" + }, + { + "filename" : "60 1.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "30x30" + }, + { + "filename" : "64 1.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "32x32" + }, + { + "filename" : "66.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "33x33" + }, + { + "filename" : "80.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "40x40" + }, + { + "filename" : "87.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "43.5x43.5" + }, + { + "filename" : "88.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "44x44" + }, + { + "filename" : "92.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "46x46" + }, + { + "filename" : "100.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "50x50" + }, + { + "filename" : "102.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "51x51" + }, + { + "filename" : "108.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "54x54" + }, + { + "filename" : "172.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "86x86" + }, + { + "filename" : "196.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "98x98" + }, + { + "filename" : "216.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "108x108" + }, + { + "filename" : "234.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "117x117" + }, + { + "filename" : "258.png", + "idiom" : "universal", + "platform" : "watchos", + "scale" : "2x", + "size" : "129x129" + }, + { + "filename" : "1024 2.png", + "idiom" : "universal", + "platform" : "watchos", + "size" : "1024x1024" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json similarity index 51% rename from Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json rename to Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json index eb878970..73c00596 100644 --- a/Examples/WhisperAX/WhisperAXWatchApp/Assets.xcassets/AccentColor.colorset/Contents.json +++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/Contents.json @@ -1,9 +1,4 @@ { - "colors" : [ - { - "idiom" : "universal" - } - ], "info" : { "author" : "xcode", "version" : 1 diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png new file mode 100644 index 00000000..023c9c32 Binary files /dev/null and b/Examples/TTS/SpeakAX/SpeakAX/Resources/Assets.xcassets/argmaxLogo.png differ diff --git a/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist b/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist new file mode 100644 index 00000000..0c67376e --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/Resources/Info.plist @@ -0,0 +1,5 @@ + + + + + diff --git a/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift b/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift new file mode 100644 index 00000000..179ebd59 --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/SidebarView.swift @@ -0,0 +1,225 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import SwiftUI + +struct SidebarView: View { + @Environment(ViewModel.self) private var viewModel + + var body: some View { + @Bindable var vm = viewModel + + VStack(spacing: 0) { + // Model management at the top of the sidebar + ModelManagementView() + + Divider() + + // Compute units configuration + ComputeUnitsView() + .disabled(vm.modelState.isBusy) + + Divider() + + // Generation history list + List(selection: $vm.selectedGenerationID) { + if !vm.favoriteGenerations.isEmpty { + Section("Favorites") { + ForEach(vm.favoriteGenerations) { gen in + GenerationRow(generation: gen) + .tag(gen.id) + .contextMenu { rowContextMenu(for: gen) } + #if os(iOS) + .swipeActions(edge: .trailing) { rowSwipeActions(for: gen) } + #endif + } + } + } + + Section("Recents") { + ForEach(vm.generations) { gen in + GenerationRow(generation: gen) + .tag(gen.id) + .contextMenu { rowContextMenu(for: gen) } + #if os(iOS) + .swipeActions(edge: .trailing) { rowSwipeActions(for: gen) } + #endif + } + } + } + .overlay { + if vm.generations.isEmpty { + ContentUnavailableView( + "No Generations Yet", + systemImage: "waveform", + description: Text("Generated speech will appear here.") + ) + } + } + + Divider() + + // App & device info footer + appInfoFooter + } + .navigationTitle("TTSKit Example") + #if os(macOS) + .navigationSplitViewColumnWidth(min: 260, ideal: 280, max: 400) + #elseif os(iOS) + .toolbar { + ToolbarItemGroup(placement: .bottomBar) { + Spacer() + Button(action: newGeneration) { + Image(systemName: "plus") + .font(.title2) + .bold() + } + } + } + #endif + } + + // MARK: - New Generation + + private func newGeneration() { + viewModel.clearInput() + viewModel.currentWaveform = [] + // Use a sentinel UUID to show the detail in "new generation" mode. + // This drives navigation uniformly on all platforms via List(selection:). + viewModel.selectedGenerationID = ViewModel.newGenerationSentinel + } + + // MARK: - App Info Footer + + private var appInfoFooter: some View { + VStack(alignment: .leading, spacing: 4) { + let version = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String ?? "Unknown" + let build = Bundle.main.infoDictionary?["CFBundleVersion"] as? String ?? "Unknown" + Text("App Version: \(version) (\(build))") + #if os(iOS) + Text("Device: \(UIDevice.current.model)") + Text("OS: \(UIDevice.current.systemVersion)") + #elseif os(macOS) + Text("OS: \(ProcessInfo.processInfo.operatingSystemVersionString)") + #endif + } + .font(.system(.caption2, design: .monospaced)) + .foregroundStyle(.tertiary) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(.horizontal, 12) + .padding(.vertical, 8) + } + + // MARK: - Context Menu / Swipe Actions + + @ViewBuilder + private func rowContextMenu(for generation: Generation) -> some View { + Button { + viewModel.playGeneration(generation) + } label: { + Label("Play", systemImage: "play.fill") + } + + Button { + viewModel.toggleFavorite(generation.id) + } label: { + Label( + generation.isFavorite ? "Unfavorite" : "Favorite", + systemImage: generation.isFavorite ? "star.slash" : "star" + ) + } + + if let url = viewModel.audioFileURL(for: generation) { + AudioCopyButton(url: url) + + #if os(macOS) + Button { + NSWorkspace.shared.activateFileViewerSelecting([url]) + } label: { + Label("Show in Finder", systemImage: "folder") + } + #endif + } + + Divider() + + Button(role: .destructive) { + viewModel.deleteGeneration(generation.id) + } label: { + Label("Delete", systemImage: "trash") + } + } + + #if os(iOS) + @ViewBuilder + private func rowSwipeActions(for generation: Generation) -> some View { + Button(role: .destructive) { + viewModel.deleteGeneration(generation.id) + } label: { + Label("Delete", systemImage: "trash") + } + } + #endif +} + +// MARK: - Row + +struct GenerationRow: View { + @Environment(ViewModel.self) private var viewModel + let generation: Generation + @State private var isHovered = false + + var body: some View { + HStack(spacing: 10) { + if let samples = generation.waveformSamples, !samples.isEmpty { + WaveformThumbnail(samples: samples) + } else { + Image(systemName: "waveform") + .frame(width: 48, height: 24) + .foregroundStyle(.tertiary) + } + + VStack(alignment: .leading, spacing: 2) { + Text(generation.title) + .font(.body) + .lineLimit(1) + .truncationMode(.tail) + + HStack(spacing: 6) { + Text(generation.date, style: .date) + Text("·") + Text(String(format: "%.1fs", generation.audioDuration)) + Text("·") + Text(generation.speaker.capitalized) + } + .font(.caption) + .foregroundStyle(.secondary) + } + + Spacer() + + if generation.isFavorite { + Image(systemName: "star.fill") + .font(.caption) + .foregroundStyle(.yellow) + } + + #if os(macOS) + Button { + viewModel.deleteGeneration(generation.id) + } label: { + Image(systemName: "trash") + .font(.caption) + .foregroundStyle(.secondary) + } + .buttonStyle(.plain) + .opacity(isHovered ? 1 : 0) + .animation(.easeInOut(duration: 0.15), value: isHovered) + #endif + } + .padding(.vertical, 2) + #if os(macOS) + .onHover { isHovered = $0 } + #endif + } +} diff --git a/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift b/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift new file mode 100644 index 00000000..bc5e52ed --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/SpeakAXApp.swift @@ -0,0 +1,87 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import SwiftUI + +@main +struct SpeakAXApp: App { + @State private var viewModel = ViewModel() + + var body: some Scene { + WindowGroup { + ContentView() + .environment(viewModel) + #if os(iOS) + .onAppear { installKeyboardDismissGesture() } + #endif + .onReceive(NotificationCenter.default.publisher( + for: { + #if os(iOS) + UIApplication.willResignActiveNotification + #else + NSApplication.willTerminateNotification + #endif + }() + )) { _ in + viewModel.cancelAllTasks() + } + } + .defaultSize(width: 1100, height: 700) + } +} + +// MARK: - iOS keyboard dismiss on tap outside + +#if os(iOS) +import UIKit + +/// Installs a window-level tap gesture recognizer that dismisses the keyboard whenever +/// the user taps outside a UITextView or UITextField. Uses `cancelsTouchesInView = false` +/// and a delegate check so text selection and all other controls continue to work normally. +private func installKeyboardDismissGesture() { + guard let window = UIApplication.shared + .connectedScenes + .compactMap({ $0 as? UIWindowScene }) + .flatMap({ $0.windows }) + .first(where: { $0.isKeyWindow }) else { return } + + // Only install once + if window.gestureRecognizers?.contains(where: { $0 is KeyboardDismissTapRecognizer }) == true { + return + } + + window.addGestureRecognizer(KeyboardDismissTapRecognizer()) +} + +private final class KeyboardDismissTapRecognizer: UITapGestureRecognizer, UIGestureRecognizerDelegate { + init() { + super.init(target: nil, action: nil) + cancelsTouchesInView = false + delegate = self + addTarget(self, action: #selector(handleTap)) + } + + @objc private func handleTap() { + UIApplication.shared.sendAction( + #selector(UIResponder.resignFirstResponder), + to: nil, from: nil, for: nil + ) + } + + /// Only fire when the touch lands outside a text input view. + func gestureRecognizer( + _ gestureRecognizer: UIGestureRecognizer, + shouldReceive touch: UITouch + ) -> Bool { + !(touch.view is UITextView || touch.view is UITextField) + } + + /// Always allow simultaneous recognition with every other gesture (buttons, scrolls, etc.) + func gestureRecognizer( + _ gestureRecognizer: UIGestureRecognizer, + shouldRecognizeSimultaneouslyWith other: UIGestureRecognizer + ) -> Bool { + true + } +} +#endif diff --git a/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift b/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift new file mode 100644 index 00000000..8497f6ad --- /dev/null +++ b/Examples/TTS/SpeakAX/SpeakAX/ViewModel.swift @@ -0,0 +1,1124 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +@preconcurrency import AVFoundation +import CoreML +import Foundation +import Observation +import SwiftUI +import Tokenizers +import TTSKit +import WhisperKit +#if os(macOS) +import AppKit +#endif + +// MARK: - Data Model + +/// An in-memory representation of a generated audio clip. +/// +/// All persistent fields live inside the M4A file's metadata - there is no +/// companion JSON or database. `Generation` is reconstructed on launch by +/// scanning the documents directory for `.m4a` files and calling +/// `AudioOutput.loadMetadata(from:)` on each one. +/// +/// `isFavorite` is the only mutable bit of state not in the file itself; it +/// is stored as a `Set` in `UserDefaults` so we never need to re-encode +/// the M4A. +@MainActor +struct Generation: Identifiable { + let id: UUID + let title: String + let text: String + let speaker: String + let language: String + let instruction: String + let modelName: String + let audioDuration: TimeInterval + let realTimeFactor: Double + let speedFactor: Double + let stepsPerSecond: Double + let timeToFirstBuffer: TimeInterval + let date: Date + var isFavorite: Bool + let audioFileName: String + var waveformSamples: [Float]? + + init( + metadata: AudioMetadata, + audioFileName: String, + audioDuration: TimeInterval, + isFavorite: Bool + ) { + self.id = metadata.id + self.title = String(metadata.text.prefix(ViewModel.titleMaxLength)) + self.text = metadata.text + self.speaker = metadata.speaker + self.language = metadata.language + self.instruction = metadata.instruction + self.modelName = metadata.modelName + self.audioDuration = audioDuration + self.realTimeFactor = metadata.realTimeFactor + self.speedFactor = metadata.speedFactor + self.stepsPerSecond = metadata.stepsPerSecond + self.timeToFirstBuffer = metadata.timeToFirstBuffer + self.date = metadata.date + self.isFavorite = isFavorite + self.audioFileName = audioFileName + self.waveformSamples = nil + } +} + +// MARK: - Model State + +enum ModelState: Equatable { + case unloaded + case downloading + case prewarming + case loading + case loaded + case error(String) + + var label: String { + switch self { + case .unloaded: return "Unloaded" + case .downloading: return "Downloading..." + case .prewarming: return "Specializing..." + case .loading: return "Loading..." + case .loaded: return "Loaded" + case let .error(msg): return "Error: \(msg)" + } + } + + var color: Color { + switch self { + case .unloaded: return .red + case .downloading, .prewarming, .loading: return .yellow + case .loaded: return .green + case .error: return .red + } + } + + var isBusy: Bool { + self == .downloading || self == .prewarming || self == .loading + } +} + +enum GenerationState: Equatable { + case idle + case generating +} + +// MARK: - Settings Storage + +/// Holds all persisted settings using @AppStorage. +/// Lives inside ViewModel and is read once at init time to seed the +/// observable stored properties. Each stored property's didSet writes +/// the new value back here so UserDefaults stays in sync. +@MainActor +private class Settings { + /// Model selection + @AppStorage("selectedPreset") var selectedPresetRaw: String = TTSModelVariant.defaultForCurrentPlatform.rawValue + @AppStorage("selectedSpeaker") var selectedSpeakerRaw: String = Qwen3Speaker.ryan.rawValue + @AppStorage("selectedLanguage") var selectedLanguageRaw: String = Qwen3Language.english.rawValue + @AppStorage("playbackStrategyTag") var playbackStrategyTag: String = "auto" + + /// Generation options + @AppStorage("genTemperature") var temperature: Double = .init(GenerationOptions.defaultTemperature) + @AppStorage("genTopK") var topK: Double = .init(GenerationOptions.defaultTopK) + @AppStorage("genRepetitionPenalty") var repetitionPenalty: Double = .init(GenerationOptions.defaultRepetitionPenalty) + @AppStorage("genMaxNewTokens") var maxNewTokens: Double = .init(GenerationOptions.defaultMaxNewTokens) + @AppStorage("genConcurrentWorkerCount") var concurrentWorkerCount: Double = 0 + @AppStorage("genChunkingStrategy") var chunkingStrategyTag: String = "sentence" + @AppStorage("genTargetChunkSize") var targetChunkSize: Double = .init(TextChunker.defaultTargetChunkSize) + @AppStorage("genMinChunkSize") var minChunkSize: Double = .init(TextChunker.defaultMinChunkSize) + + /// Compute units (stored as Int matching MLComputeUnits.rawValue) + @AppStorage("embedderComputeUnits") var embedderComputeUnitsRaw: Int = MLComputeUnits.cpuOnly.rawValue + @AppStorage("codeDecoderComputeUnits") var codeDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue + @AppStorage("multiCodeDecoderComputeUnits") var multiCodeDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue + @AppStorage("speechDecoderComputeUnits") var speechDecoderComputeUnitsRaw: Int = MLComputeUnits.cpuAndNeuralEngine.rawValue +} + +// MARK: - View Model + +@MainActor +@Observable +final class ViewModel: @unchecked Sendable { + // MARK: - Constants + + fileprivate static let titleMaxLength = 40 + private static let historyLimit = 20 + /// Poll interval for playback position updates (~30 fps). + private static let playbackPollIntervalMs = 33 + /// Debounce delay before running a token count after a keystroke. + private static let tokenCountDebounceMs = 250 + + // MARK: - Settings storage + + private let settings = Settings() + + // MARK: - Model management + + var modelState: ModelState = .unloaded + var downloadProgress: Double = 0 + var localModelPaths: [TTSModelVariant: String] = [:] + + // MARK: - Persisted: model selection + + var selectedPreset: TTSModelVariant { + didSet { settings.selectedPresetRaw = selectedPreset.rawValue } + } + + // MARK: - Generation state + + var generationState: GenerationState = .idle + var statusMessage = "Select a model to get started" + + // MARK: - Persisted: input defaults + + var selectedSpeaker: Qwen3Speaker { + didSet { settings.selectedSpeakerRaw = selectedSpeaker.rawValue } + } + + var selectedLanguage: Qwen3Language { + didSet { settings.selectedLanguageRaw = selectedLanguage.rawValue } + } + + var playbackStrategyTag: String { + didSet { settings.playbackStrategyTag = playbackStrategyTag } + } + + var selectedPlaybackStrategy: PlaybackStrategy { + switch playbackStrategyTag { + case "stream": return .stream + case "generateFirst": return .generateFirst + default: return .auto + } + } + + var instruction: String = "" + + // MARK: - Persisted: generation options + + var temperature: Double { didSet { settings.temperature = temperature } } + var topK: Double { didSet { settings.topK = topK } } + var repetitionPenalty: Double { didSet { settings.repetitionPenalty = repetitionPenalty } } + var maxNewTokens: Double { didSet { settings.maxNewTokens = maxNewTokens } } + var concurrentWorkerCount: Double { didSet { settings.concurrentWorkerCount = concurrentWorkerCount } } + var chunkingStrategyTag: String { didSet { settings.chunkingStrategyTag = chunkingStrategyTag } } + + var chunkingStrategy: TextChunkingStrategy { + TextChunkingStrategy(rawValue: chunkingStrategyTag) ?? .sentence + } + + var targetChunkSize: Double { didSet { settings.targetChunkSize = targetChunkSize } } + var minChunkSize: Double { didSet { settings.minChunkSize = minChunkSize } } + + // MARK: - Persisted: compute units + + var embedderComputeUnits: MLComputeUnits { + didSet { settings.embedderComputeUnitsRaw = embedderComputeUnits.rawValue } + } + + var codeDecoderComputeUnits: MLComputeUnits { + didSet { settings.codeDecoderComputeUnitsRaw = codeDecoderComputeUnits.rawValue } + } + + var multiCodeDecoderComputeUnits: MLComputeUnits { + didSet { settings.multiCodeDecoderComputeUnitsRaw = multiCodeDecoderComputeUnits.rawValue } + } + + var speechDecoderComputeUnits: MLComputeUnits { + didSet { settings.speechDecoderComputeUnitsRaw = speechDecoderComputeUnits.rawValue } + } + + var computeOptions: ComputeOptions { + ComputeOptions( + embedderComputeUnits: embedderComputeUnits, + codeDecoderComputeUnits: codeDecoderComputeUnits, + multiCodeDecoderComputeUnits: multiCodeDecoderComputeUnits, + speechDecoderComputeUnits: speechDecoderComputeUnits + ) + } + + // MARK: - Init + + init() { + // Seed all persisted properties from @AppStorage backing store + // Resolve the persisted preset, falling back to the platform default. + // If a previously saved preset is no longer available on this platform (e.g. 0.6B + // saved on macOS, then opened on iOS), quietly switch to the platform default. + let savedPreset = TTSModelVariant(rawValue: settings.selectedPresetRaw) + selectedPreset = (savedPreset?.isAvailableOnCurrentPlatform == true) + ? savedPreset! + : .defaultForCurrentPlatform + selectedSpeaker = Qwen3Speaker(rawValue: settings.selectedSpeakerRaw) ?? .ryan + selectedLanguage = Qwen3Language(rawValue: settings.selectedLanguageRaw) ?? .english + playbackStrategyTag = settings.playbackStrategyTag + temperature = settings.temperature + topK = settings.topK + repetitionPenalty = settings.repetitionPenalty + maxNewTokens = settings.maxNewTokens + concurrentWorkerCount = settings.concurrentWorkerCount + chunkingStrategyTag = settings.chunkingStrategyTag + targetChunkSize = settings.targetChunkSize + minChunkSize = settings.minChunkSize + embedderComputeUnits = MLComputeUnits(rawValue: settings.embedderComputeUnitsRaw) ?? .cpuOnly + codeDecoderComputeUnits = MLComputeUnits(rawValue: settings.codeDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine + multiCodeDecoderComputeUnits = MLComputeUnits(rawValue: settings.multiCodeDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine + speechDecoderComputeUnits = MLComputeUnits(rawValue: settings.speechDecoderComputeUnitsRaw) ?? .cpuAndNeuralEngine + } + + // MARK: - Generation output + + var currentWaveform: [Float] = [] + var currentAudioSamples: [Float] = [] + var currentDuration: TimeInterval = 0 + var currentRTF: Double = 0 + var currentSpeedFactor: Double = 0 + var currentStepsPerSecond: Double = 0 + var currentTimeToFirstBuffer: TimeInterval = 0 + + // MARK: - Playback & streaming + + var isPlaying = false + /// Current playback position in seconds (works for both streaming and replay) + var playbackTime: TimeInterval = 0 + /// True while we're in a live generate-and-play session + var isStreaming = false + /// Reference to the active audio output during streaming, for querying playback position + private var activeAudioOutput: AudioOutput? + + // MARK: - Generation progress (generateFirst mode) + + /// Running count of decoder steps completed across all chunks. + var stepsCompleted: Int = 0 + /// Estimated total steps for the full request (maxNewTokens × totalChunks). + var totalSteps: Int = 0 + /// Total number of text chunks in the current request. + var chunksTotal: Int = 0 + + /// Seconds of audio still accumulating in the pre-buffer before the next chunk + /// flushes and playback resumes. Non-zero only while actively buffering mid-stream. + var silentBufferRemaining: TimeInterval { + guard isStreaming, let audioOut = activeAudioOutput else { return 0 } + return audioOut.silentBufferRemaining + } + + /// Total real audio scheduled to the player so far (excludes silent sentinel buffers). + var scheduledAudioDuration: TimeInterval { + guard isStreaming, let audioOut = activeAudioOutput else { return 0 } + return audioOut.scheduledAudioDuration + } + + private var playbackUpdateTask: Task? + private var generationTask: Task? + + // MARK: - History + + /// A sentinel that drives the detail view into "new generation" mode without + /// requiring a separate navigation Bool. Setting `selectedGenerationID` to this + /// value shows the detail with empty inputs while no real generation is selected. + static let newGenerationSentinel = UUID(uuidString: "00000000-0000-0000-0000-000000000000")! + + var generations: [Generation] = [] + var selectedGenerationID: UUID? + + // MARK: - Search + + var searchText = "" + + // MARK: - Sheet presentation + + var showGenerationSettings = false + + // MARK: - Private + + private var tts: TTSKit? + private(set) var loadedPreset: TTSModelVariant? + private var audioPlayer: AVAudioPlayer? + private var tokenCountTask: Task? + + // MARK: - Computed + + var selectedGeneration: Generation? { + guard let id = selectedGenerationID else { return nil } + return generations.first { $0.id == id } + } + + var filteredGenerations: [Generation] { + if searchText.isEmpty { return generations } + let query = searchText.lowercased() + return generations.filter { + $0.title.lowercased().contains(query) + || $0.text.lowercased().contains(query) + || $0.speaker.lowercased().contains(query) + } + } + + var favoriteGenerations: [Generation] { + filteredGenerations.filter { $0.isFavorite } + } + + var recentGenerations: [Generation] { + Array(filteredGenerations.prefix(Self.historyLimit)) + } + + var canGenerate: Bool { + !inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + && generationState == .idle + && !modelState.isBusy + } + + var isModelDownloaded: Bool { + localModelPaths[selectedPreset] != nil + } + + var inputText = "" { + didSet { + guard inputText != oldValue else { return } + scheduleTokenCount() + } + } + + /// Token count of `inputText` as measured by the loaded tokenizer. + /// Updated with a short debounce as the user types; falls back to a character + /// approximation in the UI when the tokenizer is not yet loaded. + var inputTokenCount: Int? + + // MARK: - Lifecycle + + func onAppear() { + Task { await loadGenerations() } + scanLocalModels() + } + + // MARK: - Model Management + + /// Hub storage path relative to Documents, matching the pattern used by WhisperAX / HubApi. + /// The Hub library stores downloads at: Documents/huggingface/models/{repo}/ + private static let modelStorageBase = "huggingface/models" + + /// HuggingFace model repo for the selected preset + var modelRepoURL: URL? { + let config = TTSKitConfig(model: selectedPreset) + return URL(string: "https://huggingface.co/\(config.modelRepo)") + } + + /// Local path to the Hub-managed repo folder for a given config + private func localRepoURL(for config: TTSKitConfig) -> URL { + documentsDirectory + .appendingPathComponent(Self.modelStorageBase) + .appendingPathComponent(config.modelRepo) + } + + /// Scan for previously downloaded models by checking the Hub cache in Documents + func scanLocalModels() { + for preset in TTSModelVariant.allCases { + let config = TTSKitConfig(model: preset) + let repoURL = localRepoURL(for: config) + let repoPath = repoURL.path + + guard FileManager.default.fileExists(atPath: repoPath) else { continue } + + // Verify at least one component's version directory exists + if hasModelFiles(at: repoPath, config: config) { + localModelPaths[preset] = repoPath + } + } + + if isModelDownloaded, modelState == .unloaded { + statusMessage = "Model downloaded" + } + } + + /// Check if any model component directories exist at a given repo path + private func hasModelFiles(at basePath: String, config: TTSKitConfig) -> Bool { + let baseURL = URL(fileURLWithPath: basePath) + return config.componentDirectories(in: baseURL).contains { + FileManager.default.fileExists(atPath: $0.path) + } + } + + /// Download the selected model preset without loading it + func downloadModel() async { + guard !modelState.isBusy else { return } + modelState = .downloading + downloadProgress = 0 + statusMessage = "Downloading \(selectedPreset.rawValue) model..." + + do { + // Set a HuggingFace token via TTSKitConfig(token:) if the model repo is private. + // For the public argmaxinc/ttskit-coreml repo no token is required. + let config = TTSKitConfig( + model: selectedPreset, + verbose: true + ) + let folder = try await TTSKit.download(config: config) { [weak self] progress in + Task { @MainActor in + self?.downloadProgress = progress.fractionCompleted + } + } + localModelPaths[selectedPreset] = folder.path + modelState = .unloaded + downloadProgress = 1.0 + statusMessage = "Downloaded" + } catch { + modelState = .error(error.localizedDescription) + statusMessage = "Download failed: \(error.localizedDescription)" + } + } + + /// Load the selected model preset (downloads first if needed) + func loadModel() async { + guard !modelState.isBusy else { return } + + // If switching presets while loaded, unload first + if loadedPreset != nil, loadedPreset != selectedPreset { + unloadModel() + } + + // Download if needed + if localModelPaths[selectedPreset] == nil { + await downloadModel() + guard localModelPaths[selectedPreset] != nil else { return } + } + + do { + let ttsConfig = TTSKitConfig( + model: selectedPreset, + modelFolder: localModelPaths[selectedPreset].map { URL(fileURLWithPath: $0) }, + computeOptions: computeOptions, + verbose: true + ) + + // Create TTSKit without loading - we drive load/prewarm explicitly + ttsConfig.load = false + let kit = try await TTSKit(ttsConfig) + tts = kit + + // Prewarm: compile each CoreML model sequentially, then discard. + // This prevents memory exhaustion from concurrent compilation on first launch. + // On subsequent launches the compiled cache is already on disk, so this is fast. + modelState = .prewarming + statusMessage = "Specializing \(selectedPreset.rawValue) for your device\nThis may take a few minutes on first load" + do { + try await kit.prewarmModels() + } catch { + // Prewarm failures are non-fatal (model may already be specialized). + // Surface to the user so they know something unexpected happened, + // but continue - the full load often succeeds regardless. + let msg = error.localizedDescription + print("Prewarm warning: \(msg) - continuing to load") + statusMessage = "Prewarm warning: \(msg)\nContinuing..." + } + + // Load: compiled artifacts are cached, concurrent load is safe + modelState = .loading + statusMessage = "Loading \(selectedPreset.rawValue) model..." + try await kit.loadModels() + + loadedPreset = selectedPreset + modelState = .loaded + statusMessage = "Ready - \(selectedPreset.rawValue) loaded" + AccessibilityNotification.Announcement("\(selectedPreset.rawValue) model loaded and ready").post() + // Refresh token count now that the tokenizer is available. + scheduleTokenCount() + } catch { + modelState = .error(error.localizedDescription) + statusMessage = "Load failed: \(error.localizedDescription)" + AccessibilityNotification.Announcement("Model load failed: \(error.localizedDescription)").post() + } + } + + /// Reload the current model with updated compute options + func reloadModelForComputeUnitChange() { + guard modelState == .loaded, let preset = loadedPreset else { return } + unloadModel() + selectedPreset = preset + Task { [weak self] in await self?.loadModel() } + } + + /// Unload the current model from memory + func unloadModel() { + cancelAllTasks() + let oldTTS = tts + tts = nil + Task { await oldTTS?.unloadModels() } + loadedPreset = nil + modelState = .unloaded + statusMessage = isModelDownloaded ? "Model unloaded" : "Select a model to get started" + } + + /// Reset all generation options to their factory defaults. + func resetGenerationSettings() { + temperature = Double(GenerationOptions.defaultTemperature) + topK = Double(GenerationOptions.defaultTopK) + repetitionPenalty = Double(GenerationOptions.defaultRepetitionPenalty) + maxNewTokens = Double(GenerationOptions.defaultMaxNewTokens) + concurrentWorkerCount = 0 + chunkingStrategyTag = "sentence" + targetChunkSize = Double(TextChunker.defaultTargetChunkSize) + minChunkSize = Double(TextChunker.defaultMinChunkSize) + } + + /// Delete downloaded model files for a specific variant only, + /// leaving other variants in the shared repo untouched. + func deleteModel(preset: TTSModelVariant? = nil) { + let target = preset ?? selectedPreset + + if loadedPreset == target { + unloadModel() + } + + let config = TTSKitConfig(model: target) + let repoURL = localRepoURL(for: config) + for dir in config.componentDirectories(in: repoURL) { + guard FileManager.default.fileExists(atPath: dir.path) else { continue } + try? FileManager.default.removeItem(at: dir) + } + localModelPaths.removeValue(forKey: target) + + if target == selectedPreset { + modelState = .unloaded + statusMessage = "Model deleted" + } + } + + /// Disk size of downloaded model files for a specific variant only. + func modelDiskSize(for preset: TTSModelVariant) -> String? { + guard localModelPaths[preset] != nil else { return nil } + let config = TTSKitConfig(model: preset) + let repoURL = localRepoURL(for: config) + var total: UInt64 = 0 + for dir in config.componentDirectories(in: repoURL) { + if let size = directorySize(at: dir) { + total += size + } + } + guard total > 0 else { return nil } + return ByteCountFormatter.string(fromByteCount: Int64(total), countStyle: .file) + } + + private func directorySize(at url: URL) -> UInt64? { + let fm = FileManager.default + guard let enumerator = fm.enumerator(at: url, includingPropertiesForKeys: [.fileSizeKey]) else { + return nil + } + var total: UInt64 = 0 + for case let fileURL as URL in enumerator { + if let size = try? fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize { + total += UInt64(size) + } + } + return total + } + + // MARK: - Generation + + /// Start generation in a tracked task that can be cancelled. + /// If no model is loaded yet, automatically loads the selected (or default) model first. + func startGeneration() { + generationTask?.cancel() + // Don't touch selectedGenerationID here. On iPhone, NavigationSplitView collapses + // to a NavigationStack driven by List(selection:) - setting the selection to nil + // immediately pops the detail view. The selection updates naturally when the new + // generation is saved at the end of generate(). + generationTask = Task { [weak self] in + guard let self else { return } + if modelState != .loaded { + await loadModel() + guard !Task.isCancelled, modelState == .loaded else { return } + } + guard !Task.isCancelled else { return } + await generate() + } + } + + /// Cancel all background tasks (generation, playback updates, token counting). + /// Safe to call from any lifecycle event (view disappear, scene deactivation, etc.). + func cancelAllTasks() { + generationTask?.cancel() + generationTask = nil + tokenCountTask?.cancel() + tokenCountTask = nil + stopPlaybackUpdates() + if activeAudioOutput != nil { + let audioOut = activeAudioOutput + activeAudioOutput = nil + Task { await audioOut?.stopPlayback(waitForCompletion: false) } + } + audioPlayer?.stop() + audioPlayer = nil + isPlaying = false + isStreaming = false + if generationState == .generating { + generationState = .idle + statusMessage = "Cancelled" + } + } + + /// Cancel any in-progress generation and stop audio immediately. + func cancelGeneration() { + cancelAllTasks() + generationState = .idle + statusMessage = "Cancelled" + } + + private func generate() async { + guard canGenerate, let tts else { return } + + generationState = .generating + statusMessage = "Generating speech..." + currentWaveform = [] + currentAudioSamples = [] + currentDuration = 0 + playbackTime = 0 + stepsCompleted = 0 + totalSteps = 0 + chunksTotal = 0 + + let strategy = selectedPlaybackStrategy + + switch strategy { + case .generateFirst: + break + case .auto, .stream, .buffered: + isStreaming = true + activeAudioOutput = tts.audioOutput + startPlaybackUpdates() + } + + do { + let result: SpeechResult + + switch strategy { + case .generateFirst: + result = try await generateFirstGeneration(tts: tts) + currentAudioSamples = result.audio + currentDuration = result.audioDuration + currentWaveform = peaksPerToken(from: result.audio) + try await finalizeGeneration(result: result) + if let gen = generations.first { + playGeneration(gen) + } + case .auto, .stream, .buffered: + result = try await streamGeneration(tts: tts) + stopPlaybackUpdates() + activeAudioOutput = nil + isStreaming = false + try await finalizeGeneration(result: result) + } + } catch { + stopPlaybackUpdates() + activeAudioOutput = nil + isStreaming = false + stepsCompleted = 0 + totalSteps = 0 + chunksTotal = 0 + statusMessage = "Error: \(error.localizedDescription)" + } + + if generationState == .generating { + generationState = .idle + AccessibilityNotification.Announcement("Generation cancelled").post() + } + } + + /// Build the generation options from current UI state. + private func buildOptions() -> GenerationOptions { + let workerCount = Int(concurrentWorkerCount) + var options = GenerationOptions( + temperature: Float(temperature), + topK: Int(topK), + repetitionPenalty: Float(repetitionPenalty), + maxNewTokens: Int(maxNewTokens), + concurrentWorkerCount: workerCount, + chunkingStrategy: chunkingStrategy, + targetChunkSize: Int(targetChunkSize), + minChunkSize: Int(minChunkSize) + ) + if !instruction.isEmpty { + options.instruction = instruction + } + return options + } + + /// Generate all audio up front using `tts.generate()`, tracking step-level progress. + private func generateFirstGeneration(tts: TTSKit) async throws -> SpeechResult { + let options = buildOptions() + + let result = try await tts.generate( + text: inputText, + speaker: selectedSpeaker, + language: selectedLanguage, + options: options, + callback: { [weak self] progress in + let steps = progress.stepsCompleted ?? 0 + let maxSteps = progress.totalSteps ?? 1 + let chunks = progress.totalChunks ?? 1 + Task { @MainActor [weak self] in + guard let self else { return } + stepsCompleted = steps + totalSteps = maxSteps + chunksTotal = chunks + } + return nil + } + ) + + stepsCompleted = totalSteps + currentRTF = result.timings.realTimeFactor + currentSpeedFactor = result.timings.speedFactor + currentStepsPerSecond = result.timings.tokensPerSecond + currentTimeToFirstBuffer = result.timings.timeToFirstBuffer + return result + } + + /// Run `play`, streaming waveform peaks and audio samples back to the main actor. + private func streamGeneration(tts: TTSKit) async throws -> SpeechResult { + let options = buildOptions() + let sampleRate = Double(Qwen3TTSConstants.sampleRate) + + let result = try await tts.play( + text: inputText, + speaker: selectedSpeaker, + language: selectedLanguage, + options: options, + playbackStrategy: selectedPlaybackStrategy, + callback: { [weak self] progress in + let samples = progress.audio + let peak = samples.reduce(Float(0)) { max($0, abs($1)) } + Task { @MainActor [weak self] in + guard let self else { return } + currentAudioSamples.append(contentsOf: samples) + currentDuration = Double(currentAudioSamples.count) / sampleRate + currentWaveform.append(peak) + } + return nil + } + ) + + currentAudioSamples = result.audio + currentDuration = result.audioDuration + currentRTF = result.timings.realTimeFactor + currentSpeedFactor = result.timings.speedFactor + currentStepsPerSecond = result.timings.tokensPerSecond + currentTimeToFirstBuffer = result.timings.timeToFirstBuffer + playbackTime = 0 + return result + } + + /// Persist the generation to disk as M4A, insert it into the history, and update UI state. + private func finalizeGeneration(result: SpeechResult) async throws { + let meta = AudioMetadata( + text: inputText, + speaker: selectedSpeaker.rawValue, + language: selectedLanguage.rawValue, + instruction: instruction, + modelName: "\(Qwen3TTSConstants.modelFamilyDir)_\((loadedPreset ?? selectedPreset).versionDir)", + realTimeFactor: result.timings.realTimeFactor, + speedFactor: result.timings.speedFactor, + stepsPerSecond: result.timings.tokensPerSecond, + timeToFirstBuffer: result.timings.timeToFirstBuffer + ) + let savedURL = try await AudioOutput.saveAudio( + result.audio, + toFolder: documentsDirectory, + filename: meta.suggestedFileName, + sampleRate: result.sampleRate, + metadataProvider: meta.avMetadataItems + ) + + var gen = Generation( + metadata: meta, + audioFileName: savedURL.lastPathComponent, + audioDuration: result.audioDuration, + isFavorite: false + ) + gen.waveformSamples = currentWaveform + generations.insert(gen, at: 0) + selectedGenerationID = gen.id + + generationState = .idle + statusMessage = String( + format: "Done generating %.1fs of audio, RTF %.2f", + result.audioDuration, + result.timings.realTimeFactor + ) + AccessibilityNotification.Announcement( + String(format: "Generation complete. %.1f seconds of audio.", result.audioDuration) + ).post() + } + + + // MARK: - Playback position updates + + /// Starts a MainActor task that polls the audio engine's actual playback position at ~30fps. + /// Using a Task instead of Timer ensures it always runs on the main thread. + private func startPlaybackUpdates() { + playbackUpdateTask?.cancel() + playbackUpdateTask = Task { @MainActor [weak self] in + while !Task.isCancelled { + guard let self else { break } + + if self.isStreaming, let audioOut = self.activeAudioOutput { + self.playbackTime = audioOut.currentPlaybackTime + } else if self.isPlaying, let player = self.audioPlayer { + if player.isPlaying { + self.playbackTime = player.currentTime + } else { + // Replay finished + self.isPlaying = false + self.playbackTime = 0 + break + } + } else { + break + } + + try? await Task.sleep(for: .milliseconds(Self.playbackPollIntervalMs)) + } + } + } + + private func stopPlaybackUpdates() { + playbackUpdateTask?.cancel() + playbackUpdateTask = nil + } + + // MARK: - Playback (replay saved audio) + + /// Populate the input fields from a past generation so the user can edit and re-generate. + /// Called automatically whenever a generation is selected from history. + func loadInputs(from generation: Generation) { + inputText = generation.text + selectedSpeaker = Qwen3Speaker(rawValue: generation.speaker) ?? .ryan + selectedLanguage = Qwen3Language(rawValue: generation.language) ?? .english + instruction = generation.instruction + currentRTF = generation.realTimeFactor + currentSpeedFactor = generation.speedFactor + currentStepsPerSecond = generation.stepsPerSecond + currentTimeToFirstBuffer = generation.timeToFirstBuffer + } + + func playGeneration(_ generation: Generation) { + let url = documentsDirectory.appendingPathComponent(generation.audioFileName) + guard FileManager.default.fileExists(atPath: url.path) else { return } + + do { + #if os(iOS) + let session = AVAudioSession.sharedInstance() + try session.setCategory(.playback, mode: .default, options: []) + try session.setActive(true) + #endif + + audioPlayer?.stop() + audioPlayer = try AVAudioPlayer(contentsOf: url) + audioPlayer?.play() + isPlaying = true + playbackTime = 0 + + loadWaveform(for: generation) + startPlaybackUpdates() + } catch { + statusMessage = "Playback error: \(error.localizedDescription)" + } + } + + func stopPlayback() { + audioPlayer?.stop() + isPlaying = false + playbackTime = 0 + stopPlaybackUpdates() + } + + /// URL for the audio file of a generation (for sharing/exporting) + func audioFileURL(for generation: Generation) -> URL? { + let url = documentsDirectory.appendingPathComponent(generation.audioFileName) + return FileManager.default.fileExists(atPath: url.path) ? url : nil + } + + // MARK: - History Management + + func toggleFavorite(_ id: UUID) { + guard let idx = generations.firstIndex(where: { $0.id == id }) else { return } + generations[idx].isFavorite.toggle() + saveFavorites() + } + + func deleteGeneration(_ id: UUID) { + guard let idx = generations.firstIndex(where: { $0.id == id }) else { return } + let url = documentsDirectory.appendingPathComponent(generations[idx].audioFileName) + try? FileManager.default.removeItem(at: url) + generations.remove(at: idx) + if selectedGenerationID == id { + selectedGenerationID = generations.first?.id + } + } + + func clearAllGenerations() { + for gen in generations { + let url = documentsDirectory.appendingPathComponent(gen.audioFileName) + try? FileManager.default.removeItem(at: url) + } + generations.removeAll() + selectedGenerationID = nil + UserDefaults.standard.removeObject(forKey: Self.favoritesKey) + } + + /// Debounced token counter: waits 250ms after the last keystroke before encoding. + /// Uses the loaded tokenizer when available; silently skips if none is loaded yet. + /// When the model loads for the first time, the next edit will trigger a real count. + private func scheduleTokenCount() { + tokenCountTask?.cancel() + let text = inputText + guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { + inputTokenCount = nil + return + } + tokenCountTask = Task { [weak self] in + try? await Task.sleep(for: .milliseconds(Self.tokenCountDebounceMs)) + guard !Task.isCancelled, let self else { return } + guard let count = self.tts?.tokenizer?.encode(text: text).count else { return } + await MainActor.run { self.inputTokenCount = count } + } + } + + func clearInput() { + inputText = "" + instruction = "" + inputTokenCount = nil + currentWaveform = [] + currentAudioSamples = [] + currentDuration = 0 + statusMessage = modelState == .loaded ? "Ready" : "Select a model to get started" + } + + // MARK: - Persistence + + var documentsDirectory: URL { + FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0] + } + + // MARK: Generations - derived entirely from .m4a files on disk + + /// Scan the Documents directory for `.m4a` files, read embedded metadata from + /// each one, and rebuild the in-memory `generations` array. No JSON sidecar needed. + func loadGenerations() async { + let fm = FileManager.default + guard let files = try? fm.contentsOfDirectory( + at: documentsDirectory, + includingPropertiesForKeys: nil + ) else { return } + + let m4aFiles = files + .filter { $0.pathExtension == "m4a" } + .sorted { $0.lastPathComponent > $1.lastPathComponent } // newest first by name + + let favorites = loadedFavoriteIDs() + var loaded: [Generation] = [] + + for url in m4aFiles { + do { + guard let meta = try await AudioMetadata.load(from: url) else { + print("SpeakAX: skipping \(url.lastPathComponent) - no TTSKit metadata found") + continue + } + let dur = (try? await AudioOutput.duration(of: url)) ?? 0 + var gen = Generation( + metadata: meta, + audioFileName: url.lastPathComponent, + audioDuration: dur, + isFavorite: favorites.contains(meta.id) + ) + gen.waveformSamples = waveformPeaks(from: url) + loaded.append(gen) + } catch { + print("SpeakAX: failed to load \(url.lastPathComponent) - \(error.localizedDescription)") + } + } + + generations = loaded + } + + // MARK: Favorites - stored in UserDefaults (only mutable state not in the file) + + private static let favoritesKey = "ttskit_favoriteGenerationIDs" + + private func loadedFavoriteIDs() -> Set { + let strings = UserDefaults.standard.stringArray(forKey: Self.favoritesKey) ?? [] + return Set(strings.compactMap { UUID(uuidString: $0) }) + } + + private func saveFavorites() { + let ids = generations.filter(\.isFavorite).map(\.id.uuidString) + UserDefaults.standard.set(ids, forKey: Self.favoritesKey) + } + + // MARK: - Waveform + + /// Read an audio file from disk and return waveform peaks at token density. + /// Returns `nil` if the file can't be read (missing, corrupt, etc.). + func waveformPeaks(from url: URL) -> [Float]? { + guard FileManager.default.fileExists(atPath: url.path), + let file = try? AVAudioFile(forReading: url) else { return nil } + let frameCount = AVAudioFrameCount(file.length) + guard let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCount), + (try? file.read(into: buffer)) != nil, + let channelData = buffer.floatChannelData else { return nil } + let samples = Array(UnsafeBufferPointer(start: channelData[0], count: Int(buffer.frameLength))) + return peaksPerToken(from: samples) + } + + /// Resample raw audio into 1 peak per token (~80ms). + /// Matches the fixed bar width used by WaveformView. + func peaksPerToken(from audioSamples: [Float]) -> [Float] { + let samplesPerBar = Int(WaveformView.secondsPerBar * Double(Qwen3TTSConstants.sampleRate)) + guard samplesPerBar > 0, !audioSamples.isEmpty else { return [] } + let barCount = (audioSamples.count + samplesPerBar - 1) / samplesPerBar + return (0.. 0 else { return } + + let centerX = size.width / 2 + let midY = size.height / 2 + let maxBarHeight = size.height * 0.85 + + // How many pixels displayedTime shifts the waveform left + let playbackOffsetPx = CGFloat(displayedTime / Self.secondsPerBar) * barStep + + for i in 0.. 0 && x < size.width else { continue } + + let amplitude = CGFloat(min(abs(samples[i]), 1.0)) + let barHeight = max(1, amplitude * maxBarHeight) + let rect = CGRect( + x: x, + y: midY - barHeight / 2, + width: barWidth, + height: barHeight + ) + + let barTime = Double(i) * Self.secondsPerBar + let isPlayed = barTime < displayedTime + let color = isPlayed ? accentColor : accentColor.opacity(0.3) + context.fill(Path(roundedRect: rect, cornerRadius: 1), with: .color(color)) + } + + // Fixed playhead line at center + let line = Path { p in + p.move(to: CGPoint(x: centerX, y: 0)) + p.addLine(to: CGPoint(x: centerX, y: size.height)) + } + context.stroke(line, with: .color(Color.accentColor), lineWidth: 1.5) + + // Playhead dot + let dot = Path(ellipseIn: CGRect(x: centerX - 4, y: -1, width: 8, height: 8)) + context.fill(dot, with: .color(Color.accentColor)) + } + .onChange(of: playbackTime) { _, newTime in + if newTime == 0 { + // Explicit reset - new session starting + displayedTime = 0 + } else { + // Clamp to the last generated bar: playhead can't outrun visible audio + let maxTime = Double(samples.count) * Self.secondsPerBar + displayedTime = max(displayedTime, min(newTime, maxTime)) + } + } + .onChange(of: samples.count) { _, count in + // Waveform cleared for a new generation + if count == 0 { displayedTime = 0 } + } + .accessibilityLabel("Audio waveform") + .accessibilityValue( + playbackTime > 0 + ? "Playback at \(Int(playbackTime)) seconds of \(Int(totalDuration))" + : samples.isEmpty ? "No audio" : "\(Int(totalDuration)) seconds" + ) + } +} + +// MARK: - Thumbnail (sidebar rows) + +struct WaveformThumbnail: View { + var samples: [Float] + var color: Color = .secondary + + var body: some View { + Canvas { context, size in + let count = samples.count + guard count > 0 else { return } + + let barWidth = size.width / CGFloat(count) + let midY = size.height / 2 + + for (i, sample) in samples.enumerated() { + let amp = CGFloat(min(abs(sample), 1.0)) + let h = max(0.5, amp * size.height * 0.85) + let x = CGFloat(i) * barWidth + let rect = CGRect(x: x, y: midY - h / 2, width: max(0.5, barWidth - 0.3), height: h) + context.fill(Path(roundedRect: rect, cornerRadius: 0.3), with: .color(color)) + } + } + .frame(width: 48, height: 24) + .accessibilityHidden(true) + } +} + +// MARK: - Previews + +#Preview("Pre-recorded - start") { + let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) } + WaveformView(samples: samples, playbackTime: 0, totalDuration: 10.0) + .frame(width: 600, height: 120).padding() +} + +#Preview("Pre-recorded - mid playback") { + let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) } + WaveformView(samples: samples, playbackTime: 5.0, totalDuration: 10.0) + .frame(width: 600, height: 120).padding() +} + +#Preview("Pre-recorded - near end") { + let samples: [Float] = (0..<200).map { _ in Float.random(in: 0...1) } + WaveformView(samples: samples, playbackTime: 9.0, totalDuration: 10.0) + .frame(width: 600, height: 120).padding() +} diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 95f8598d..dc9ac9fe 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "831ad63194a5262b2549d58e383a520f9cbbc80b4a75660fbbcc56d65edfdab4", + "originHash" : "105884455b6729deaf95d3dd16df7029fd7929c924fcab55200562b52b64dbc2", "pins" : [ { "identity" : "swift-argument-parser", @@ -10,13 +10,31 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-jinja.git", + "state" : { + "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", + "version" : "2.3.1" + } + }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "fc6543263e4caed9bf6107466d625cfae9357f08", - "version" : "0.1.8" + "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", + "version" : "1.1.6" } } ], diff --git a/Makefile b/Makefile index 5a8534a0..26e73ea6 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ PYTHON_COMMAND := python3 # Define model repository and directories MODEL_REPO := argmaxinc/whisperkit-coreml MODEL_REPO_DIR := ./Models/whisperkit-coreml +TTS_MODEL_REPO := argmaxinc/ttskit-coreml +TTS_MODEL_REPO_DIR := ./Models/ttskit-coreml BASE_COMPILED_DIR := ./Models GIT_HASH := $(shell git rev-parse --short HEAD) @@ -31,20 +33,24 @@ setup: @echo "Checking for fastlane" @which fastlane > /dev/null || (echo "Installing fastlane..." && brew install fastlane) @echo "fastlane is installed." - @$(MAKE) generate-whisperax-xcconfig + @$(MAKE) generate-xcconfigs @echo "Done 🚀" -generate-whisperax-xcconfig: - @echo "Updating DEVELOPMENT_TEAM in Examples/WhisperAX/Debug.xcconfig..." +generate-xcconfigs: @TEAM_ID=$$(defaults read com.apple.dt.Xcode IDEProvisioningTeams | plutil -convert json -r -o - -- - | jq -r 'to_entries[0].value | sort_by(.teamType == "Individual") | .[0].teamID' 2>/dev/null); \ if [ -z "$$TEAM_ID" ]; then \ echo "Error: No Development Team ID found. Please log into Xcode with your Apple ID and select a team."; \ else \ echo "DEVELOPMENT_TEAM=$$TEAM_ID" > Examples/WhisperAX/Debug.xcconfig; \ - echo "DEVELOPMENT_TEAM has been updated in Examples/WhisperAX/Debug.xcconfig with your Development Team ID: $$TEAM_ID"; \ + echo "Updated Examples/WhisperAX/Debug.xcconfig with Development Team ID: $$TEAM_ID"; \ + echo "DEVELOPMENT_TEAM=$$TEAM_ID" > Examples/TTS/SpeakAX/Debug.xcconfig; \ + echo "Updated Examples/TTS/SpeakAX/Debug.xcconfig with Development Team ID: $$TEAM_ID"; \ fi +generate-whisperax-xcconfig: generate-xcconfigs +generate-speakax-xcconfig: generate-xcconfigs + setup-huggingface-cli: @if huggingface-cli whoami; then \ @@ -74,6 +80,19 @@ setup-model-repo: git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \ fi +setup-tts-model-repo: + @echo "Setting up TTS repository..." + @mkdir -p $(BASE_COMPILED_DIR) + @if [ -d "$(TTS_MODEL_REPO_DIR)/.git" ]; then \ + echo "Repository exists, resetting..."; \ + export GIT_LFS_SKIP_SMUDGE=1; \ + cd $(TTS_MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \ + else \ + echo "Repository not found, initializing..."; \ + export GIT_LFS_SKIP_SMUDGE=1; \ + git clone https://huggingface.co/$(TTS_MODEL_REPO) $(TTS_MODEL_REPO_DIR); \ + fi + # Download all models download-models: setup-model-repo @@ -94,6 +113,24 @@ download-model: @cd $(MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" +download-tts-models: setup-tts-model-repo + @echo "Downloading all TTS models..." + @cd $(TTS_MODEL_REPO_DIR) && \ + git lfs pull --include="qwen3_tts/**" + +# Download a specific TTS model size +# Usage: make download-tts-model MODEL=0.6b +# make download-tts-model MODEL=1.7b +download-tts-model: setup-tts-model-repo + @if [ -z "$(MODEL)" ]; then \ + echo "Error: MODEL not set. Usage: make download-tts-model MODEL=0.6b"; \ + echo "Available models: 0.6b, 1.7b"; \ + exit 1; \ + fi + @echo "Downloading TTS model $(MODEL)..." + @cd $(TTS_MODEL_REPO_DIR) && \ + git lfs pull --include="qwen3_tts/*/12hz-$(MODEL)-customvoice/**" + build: @echo "Building WhisperKit..." @swift build -v @@ -103,6 +140,7 @@ build-cli: @echo "Building WhisperKit CLI..." @swift build -c release --product whisperkit-cli + test: @echo "Running tests..." @swift test -v diff --git a/Package.resolved b/Package.resolved index fbd008f0..955602b1 100644 --- a/Package.resolved +++ b/Package.resolved @@ -32,8 +32,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "d363e83a77bafe144808a3d01556139fe67cd8bc", - "version" : "1.1.2" + "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", + "version" : "1.1.6" } } ], diff --git a/Package.swift b/Package.swift index 1b7f4ad8..be181b74 100644 --- a/Package.swift +++ b/Package.swift @@ -17,13 +17,17 @@ let package = Package( name: "WhisperKit", targets: ["WhisperKit"] ), + .library( + name: "TTSKit", + targets: ["TTSKit"] + ), .executable( name: "whisperkit-cli", targets: ["WhisperKitCLI"] ) ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "1.1.2")), + .package(url: "https://github.com/huggingface/swift-transformers.git", .upToNextMinor(from: "1.1.6")), .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"), ] + (isServerEnabled() ? [ .package(url: "https://github.com/vapor/vapor.git", from: "4.115.1"), @@ -33,13 +37,26 @@ let package = Package( ] : []), targets: [ + .target( + name: "ArgmaxCore" + ), .target( name: "WhisperKit", dependencies: [ + "ArgmaxCore", .product(name: "Hub", package: "swift-transformers"), .product(name: "Tokenizers", package: "swift-transformers"), ] ), + .target( + name: "TTSKit", + dependencies: [ + "ArgmaxCore", + .product(name: "Tokenizers", package: "swift-transformers"), + .product(name: "Hub", package: "swift-transformers"), + ], + swiftSettings: [.enableExperimentalFeature("StrictConcurrency")] + ), .testTarget( name: "WhisperKitTests", dependencies: [ @@ -47,22 +64,28 @@ let package = Package( .product(name: "Hub", package: "swift-transformers"), .product(name: "Tokenizers", package: "swift-transformers"), ], - path: "Tests", + exclude: ["UnitTestsPlan.xctestplan"], resources: [ - .process("WhisperKitTests/Resources"), + .process("Resources"), + ] + ), + .testTarget( + name: "TTSKitTests", + dependencies: [ + "TTSKit" ] ), .executableTarget( name: "WhisperKitCLI", dependencies: [ "WhisperKit", + "TTSKit", .product(name: "ArgumentParser", package: "swift-argument-parser"), ] + (isServerEnabled() ? [ .product(name: "Vapor", package: "vapor"), .product(name: "OpenAPIRuntime", package: "swift-openapi-runtime"), .product(name: "OpenAPIVapor", package: "swift-openapi-vapor"), ] : []), - path: "Sources/WhisperKitCLI", exclude: (isServerEnabled() ? [] : ["Server"]), swiftSettings: (isServerEnabled() ? [.define("BUILD_SERVER_CLI")] : []) ) diff --git a/README.md b/README.md index 0d7dde71..d93ef8b2 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,27 @@ WhisperKit is an [Argmax](https://www.takeargmax.com) framework for deploying st - [Model Selection](#model-selection) - [Generating Models](#generating-models) - [Swift CLI](#swift-cli) -- [WhisperKit Local Server](#whisperkit-local-server) + - [WhisperKit Local Server](#whisperkit-local-server) + - [Building the Server](#building-the-server) + - [Starting the Server](#starting-the-server) + - [API Endpoints](#api-endpoints) + - [Supported Parameters](#supported-parameters) + - [Client Examples](#client-examples) + - [Generating the API Specification](#generating-the-api-specification) + - [Client Generation](#client-generation) + - [API Limitations](#api-limitations) + - [Fully Supported Features](#fully-supported-features) +- [TTSKit](#ttskit) + - [Quick Example](#quick-example-1) + - [Model Selection](#model-selection-1) + - [Custom Voices](#custom-voices) + - [Real-Time Streaming Playback](#real-time-streaming-playback) + - [Generation Options](#generation-options) + - [Style Instructions (1.7B only)](#style-instructions-17b-only) + - [Saving Audio](#saving-audio) + - [Progress Callbacks](#progress-callbacks) + - [Swift CLI](#swift-cli-1) + - [Demo App](#demo-app) - [Contributing \& Roadmap](#contributing--roadmap) - [License](#license) - [Citation](#citation) @@ -48,12 +68,12 @@ WhisperKit is an [Argmax](https://www.takeargmax.com) framework for deploying st ### Swift Package Manager -WhisperKit can be integrated into your Swift project using the Swift Package Manager. +WhisperKit and TTSKit are separate library products in the same Swift package. Add the package once and pick the products you need. ### Prerequisites - macOS 14.0 or later. -- Xcode 15.0 or later. +- Xcode 16.0 or later. ### Xcode Steps @@ -61,11 +81,11 @@ WhisperKit can be integrated into your Swift project using the Swift Package Man 2. Navigate to `File` > `Add Package Dependencies...`. 3. Enter the package repository URL: `https://github.com/argmaxinc/whisperkit`. 4. Choose the version range or specific version. -5. Click `Finish` to add WhisperKit to your project. +5. When prompted to choose library products, select **WhisperKit**, **TTSKit**, or both. ### Package.swift -If you're using WhisperKit as part of a swift package, you can include it in your Package.swift dependencies as follows: +If you're using WhisperKit or TTSKit as part of a swift package, you can include it in your Package.swift dependencies as follows: ```swift dependencies: [ @@ -73,12 +93,15 @@ dependencies: [ ], ``` -Then add `WhisperKit` as a dependency for your target: +Then add the products you need as target dependencies: ```swift .target( name: "YourApp", - dependencies: ["WhisperKit"] + dependencies: [ + "WhisperKit", // speech-to-text + "TTSKit", // text-to-speech + ] ), ``` @@ -308,6 +331,154 @@ The local server fully supports these OpenAI API features: - **Temperature control**: Sampling temperature for transcription randomness - **Prompt text**: Text guidance for transcription style and context +## TTSKit + +TTSKit is an on-device text-to-speech framework built on Core ML. It runs [Qwen3 TTS](https://github.com/QwenLM/Qwen3-TTS) models entirely on Apple silicon with real-time streaming playback, no server required. + +- macOS 15.0 or later. +- iOS 18.0 or later. + +### Quick Example + +This example demonstrates how to generate speech from text: + +```swift +import TTSKit + +Task { + let tts = try await TTSKit() + let result = try await tts.generate(text: "Hello from TTSKit!") + print("Generated \(result.audioDuration)s of audio at \(result.sampleRate)Hz") +} +``` + +`TTSKit()` automatically downloads the default 0.6B model on first run, loads the tokenizer and six CoreML models concurrently, and is ready to generate. + +### Model Selection + +TTSKit ships two model sizes. You can select the model by passing a variant to `TTSKitConfig`: + +```swift +// Fast, runs on all platforms (~1 GB download) +let tts = try await TTSKit(TTSKitConfig(model: .qwen3TTS_0_6b)) + +// Higher quality, macOS only (~2.2 GB download, supports style instructions) +let tts = try await TTSKit(TTSKitConfig(model: .qwen3TTS_1_7b)) +``` + +Models are hosted on [HuggingFace](https://huggingface.co/argmaxinc/ttskit-coreml) and cached locally after the first download. + +#### Custom Voices + +You can choose from 9 built-in voices and 10 languages: + +```swift +let result = try await tts.generate( + text: "こんにちは世界", + speaker: .onoAnna, + language: .japanese +) +``` + +**Voices:** `.ryan`, `.aiden`, `.onoAnna` (`"ono-anna"`), `.sohee`, `.eric`, `.dylan`, `.serena`, `.vivian`, `.uncleFu` (`"uncle-fu"`) + +**Languages:** `.english`, `.chinese`, `.japanese`, `.korean`, `.german`, `.french`, `.russian`, `.portuguese`, `.spanish`, `.italian` + +#### Real-Time Streaming Playback + +`play` streams audio to the device speakers frame-by-frame as it is generated: + +```swift +try await tts.play(text: "This starts playing before generation finishes.") +``` + +You can control how much audio is buffered before playback begins. The default `.auto` strategy measures the first generation step and pre-buffers just enough to avoid underruns: + +```swift +try await tts.play( + text: "Long passage...", + playbackStrategy: .auto +) +``` + +Other strategies include `.stream` (immediate, no buffer), `.buffered(seconds:)` (fixed pre-buffer), and `.generateFirst` (generate all audio first, then play). + +### Generation Options + +You can customize sampling, chunking, and concurrency via `GenerationOptions`: + +```swift +// Defaults recommended by Qwen +var options = GenerationOptions() +options.temperature = 0.9 +options.topK = 50 +options.repetitionPenalty = 1.05 +options.maxNewTokens = 245 + +// Long text is automatically split at sentence boundaries +options.chunkingStrategy = .sentence +options.concurrentWorkerCount = nil // nil = all chunks run concurrently with a good default for the device + +let result = try await tts.generate(text: longArticle, options: options) +``` + +#### Style Instructions (1.7B only) + +The 1.7B model accepts a natural-language style instruction that controls prosody: + +```swift +var options = GenerationOptions() +options.instruction = "Speak slowly and warmly, like a storyteller." + +let result = try await tts.generate( + text: "Once upon a time...", + speaker: .ryan, + options: options +) +``` + +### Saving Audio + +Generated audio can be saved to WAV or M4A: + +```swift +let result = try await tts.generate(text: "Save me!") +let outputDir = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0] + +// Save as .wav or .m4a (AAC) +try await AudioOutput.saveAudio(result.audio, toFolder: outputDir, filename: "output", format: .m4a) +``` + +### Progress Callbacks + +You can receive per-step audio during generation. Return `false` from the callback to cancel early: + +```swift +let result = try await tts.generate(text: "Hello!") { progress in + print("Audio chunk: \(progress.audio.count) samples") + if let stepTime = progress.stepTime { + print("First step took \(stepTime)s") + } + return true // return false to cancel +} +``` + +### Swift CLI + +The TTS command is available through the same `whisperkit-cli` tool. You can generate speech and optionally play it back in real time: + +```bash +swift run whisperkit-cli tts --text "Hello from the command line" --play +swift run whisperkit-cli tts --text "Save to file" --output-path output.wav +swift run whisperkit-cli tts --text "日本語テスト" --speaker ono-anna --language japanese +swift run whisperkit-cli tts --text-file article.txt --model 1.7b --instruction "Read cheerfully" +swift run whisperkit-cli tts --help +``` + +### Demo App + +The [SpeakAX](Examples/TTS/SpeakAX/) example app showcases real-time streaming, model management, waveform visualization, and generation history on macOS and iOS. See the [SpeakAX README](Examples/TTS/SpeakAX/README.md) for build instructions. + ## Contributing & Roadmap Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future. diff --git a/Sources/ArgmaxCore/ConcurrencyUtilities.swift b/Sources/ArgmaxCore/ConcurrencyUtilities.swift new file mode 100644 index 00000000..507a5e6a --- /dev/null +++ b/Sources/ArgmaxCore/ConcurrencyUtilities.swift @@ -0,0 +1,112 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation +import os.lock + +// MARK: - Concurrency Utilities + +public struct ConcurrencyUtilities { + private init() {} + + /// Number of active processors on this device. + public static var activeProcessorCount: Int { + ProcessInfo.processInfo.activeProcessorCount + } +} + +// MARK: - Unfair Lock + +/// Thin wrapper around `os_unfair_lock` that exposes a Swift-friendly +/// `withLock` helper. This lock is non-reentrant and optimized for low +/// contention, matching the semantics of Core Foundation's unfair lock. +public final class UnfairLock: @unchecked Sendable { + var lock = os_unfair_lock() + + public init() {} + + public func withLock(_ body: () throws -> T) rethrows -> T { + os_unfair_lock_lock(&lock) + defer { os_unfair_lock_unlock(&lock) } + return try body() + } +} + +// MARK: - Property Lock + +/// Serializes whole-value reads and writes with an `os_unfair_lock`. +/// Useful for properties on types marked `@unchecked Sendable`. +/// +/// **Known limitation — only whole-value replacement is safe:** +/// +/// ```swift +/// holder.ref = otherRef // safe: single locked write +/// holder.count = newCount // safe: single locked write +/// +/// holder.ref.count += 1 // NOT safe: gets the reference (locked), +/// // then mutates .count with no lock held +/// holder.count += 1 // NOT safe: get and set are two separate +/// // lock acquisitions with a gap between them +/// ``` +/// +/// For read-modify-write safety, callers must replace the entire value +/// atomically or use their own external synchronisation. +@propertyWrapper +public struct PropertyLock: Sendable, Codable { + private let lock: UnfairLock + private var value: Value + + public init(wrappedValue: Value) { + self.lock = UnfairLock() + self.value = wrappedValue + } + + public init(from decoder: Swift.Decoder) throws { + self.lock = UnfairLock() + self.value = try Value(from: decoder) + } + + public func encode(to encoder: Encoder) throws { + try lock.withLock { + try value.encode(to: encoder) + } + } + + public var wrappedValue: Value { + get { + lock.withLock { + return value + } + } + set { + lock.withLock { + value = newValue + } + } + } +} + +// MARK: - Early Stop Actor + +/// An actor that provides thread-safe early stopping functionality using UUIDs as keys. +public actor EarlyStopActor { + private var shouldStop = [UUID: Bool]() + + public init() {} + + /// Sets the stop flag for a given UUID + public func set(_ value: Bool, for uuid: UUID) { + shouldStop[uuid] = value + } + + /// Gets the stop flag for a given UUID + public func get(for uuid: UUID) -> Bool { + return shouldStop[uuid] ?? false + } + + /// Removes and returns the stop flag for a given UUID + @discardableResult + public func remove(for uuid: UUID) -> Bool? { + return shouldStop.removeValue(forKey: uuid) + } +} diff --git a/Sources/ArgmaxCore/FileUtilities.swift b/Sources/ArgmaxCore/FileUtilities.swift new file mode 100644 index 00000000..a5957c8b --- /dev/null +++ b/Sources/ArgmaxCore/FileUtilities.swift @@ -0,0 +1,55 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation + +#if canImport(PDFKit) +import PDFKit +#endif + +// MARK: - File Text Extraction + +/// Utility for extracting plain text from files on disk. +/// +/// Supports: +/// - Plain text files (`.txt`, and any UTF-8 readable file) +/// - PDF files (via PDFKit, where available) +public enum FileUtilities { + /// Reads and returns the text content of the file at `url`. + /// + /// - Returns: The extracted string, or `nil` if the file cannot be read or has no + /// readable text content. + public static func readTextContent(at url: URL) -> String? { + switch url.pathExtension.lowercased() { + case "pdf": + return readPDF(at: url) + default: + // Try UTF-8 first, fall back to Latin-1 for legacy files + if let text = try? String(contentsOf: url, encoding: .utf8) { + return text.isEmpty ? nil : text + } + if let text = try? String(contentsOf: url, encoding: .isoLatin1) { + return text.isEmpty ? nil : text + } + return nil + } + } + + // MARK: - PDF + + #if canImport(PDFKit) + private static func readPDF(at url: URL) -> String? { + guard let document = PDFDocument(url: url) else { return nil } + var pages: [String] = [] + for i in 0.. String? { nil } + #endif +} diff --git a/Sources/ArgmaxCore/FloatType.swift b/Sources/ArgmaxCore/FloatType.swift new file mode 100644 index 00000000..4fac0d81 --- /dev/null +++ b/Sources/ArgmaxCore/FloatType.swift @@ -0,0 +1,18 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Accelerate +import CoreML + +// MARK: - Platform-aware Float Type + +#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) +public typealias FloatType = Float16 +#else +public typealias FloatType = Float +#endif + +#if (os(macOS) || targetEnvironment(macCatalyst)) && arch(arm64) && compiler(<6) +extension Float16: BNNSScalar {} +extension Float16: MLShapedArrayScalar {} +#endif diff --git a/Sources/ArgmaxCore/FoundationExtensions.swift b/Sources/ArgmaxCore/FoundationExtensions.swift new file mode 100644 index 00000000..b680b2c5 --- /dev/null +++ b/Sources/ArgmaxCore/FoundationExtensions.swift @@ -0,0 +1,139 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation + +// MARK: - Float + +public extension Float { + /// Rounds to the specified number of decimal places. + func rounded(_ decimalPlaces: Int) -> Float { + let divisor = pow(10.0, Float(decimalPlaces)) + return (self * divisor).rounded() / divisor + } +} + +// MARK: - FileManager + +public extension FileManager { + /// Resolves an input path to an absolute path, expanding tilde and resolving + /// relative paths against the current working directory. + static func resolveAbsolutePath(_ inputPath: String) -> String { + let fileManager = FileManager.default + + let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath + + if pathWithTildeExpanded.hasPrefix("/") { + return pathWithTildeExpanded + } + + if let cwd = fileManager.currentDirectoryPath as String? { + let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path + return resolvedPath + } + + return inputPath + } +} + +// MARK: - Array + +extension Array { + /// Splits the array into batches of the given size. + public func batched(into size: Int) -> [[Element]] { + return stride(from: 0, to: count, by: size).map { + Array(self[$0..() + return self.filter { element in + if seen.contains(element) { + return false + } else { + seen.insert(element) + return true + } + } + } +} + +// MARK: - String + +extension String { + /// Returns the text up to and including the last natural boundary in the string. + /// + /// Boundaries are tested in priority order: sentence enders (. ! ? \n), clause + /// enders (, ; : - –), then word boundaries (space). A candidate is only accepted + /// when its encoded token count reaches `minTokenCount`. + /// + /// - Parameters: + /// - minTokenCount: Minimum number of tokens the candidate must contain. + /// - encode: Closure that tokenizes a string and returns its token IDs. + /// - Returns: The trimmed substring up to the last qualifying boundary, or `nil`. + public func lastNaturalBoundary(minTokenCount: Int, encode: (String) -> [Int]) -> String? { + let sentenceEnders: [Character] = [".", "!", "?", "\n"] + let clauseEnders: [Character] = [",", ";", ":", "-", "–"] + + for enders in [sentenceEnders, clauseEnders] { + if let idx = lastIndex(where: { enders.contains($0) }) { + let candidate = String(self[...idx]).trimmingCharacters(in: .whitespacesAndNewlines) + if encode(candidate).count >= minTokenCount { + return candidate + } + } + } + + if let idx = lastIndex(of: " ") { + let candidate = String(self[..= minTokenCount { + return candidate + } + } + + return nil + } + + /// Trims up to `upto` occurrences of `character` from the end of the string. + public func trimmingFromEnd(character: Character = " ", upto: Int) -> String { + var result = self + var trimmed = 0 + while trimmed < upto && result.last == character { + result.removeLast() + trimmed += 1 + } + return result + } +} + +extension [String] { + /// Filters strings matching a glob pattern using `fnmatch`. + public func matching(glob: String) -> [String] { + filter { fnmatch(glob, $0, 0) == 0 } + } +} + +// MARK: - ProcessInfo (macOS) + +#if os(macOS) || targetEnvironment(simulator) +public extension ProcessInfo { + static func stringFromSysctl(named name: String) -> String { + var size: size_t = 0 + sysctlbyname(name, nil, &size, nil, 0) + var machineModel = [CChar](repeating: 0, count: Int(size)) + sysctlbyname(name, &machineModel, &size, nil, 0) + return String(cString: machineModel) + } + + static let processor = stringFromSysctl(named: "machdep.cpu.brand_string") + static let cores = stringFromSysctl(named: "machdep.cpu.core_count") + static let threads = stringFromSysctl(named: "machdep.cpu.thread_count") + static let vendor = stringFromSysctl(named: "machdep.cpu.vendor") + static let family = stringFromSysctl(named: "machdep.cpu.family") + static let hwModel = stringFromSysctl(named: "hw.model") +} +#endif diff --git a/Sources/ArgmaxCore/Logging.swift b/Sources/ArgmaxCore/Logging.swift new file mode 100644 index 00000000..ad5407f4 --- /dev/null +++ b/Sources/ArgmaxCore/Logging.swift @@ -0,0 +1,116 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import OSLog + +/// Shared logger for all Argmax frameworks (WhisperKit, TTSKit, etc.). +/// +/// Configure the log level once at startup: +/// ```swift +/// Logging.shared.logLevel = .debug +/// ``` +/// or via a config object, following the WhisperKit pattern: +/// ```swift +/// Logging.shared.logLevel = config.verbose ? config.logLevel : .none +/// ``` +open class Logging { + public static let shared = Logging() + public var logLevel: LogLevel = .none + + public typealias LoggingCallback = (_ message: String) -> Void + public var loggingCallback: LoggingCallback? + + private let logger = OSLog( + subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.argmaxcore", + category: "Argmax" + ) + + @frozen + public enum LogLevel: Int { + case debug = 1 + case info = 2 + case error = 3 + case none = 4 + + func shouldLog(level: LogLevel) -> Bool { + return self.rawValue <= level.rawValue + } + } + + private init() {} + + public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) { + let message = items.map { "\($0)" }.joined(separator: separator) + if let logger = loggingCallback { + logger(message) + } else { + os_log("%{public}@", log: logger, type: type, message) + } + } + + public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") { + if shared.logLevel.shouldLog(level: .debug) { + shared.log(items, separator: separator, terminator: terminator, type: .debug) + } + } + + public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") { + if shared.logLevel.shouldLog(level: .info) { + shared.log(items, separator: separator, terminator: terminator, type: .info) + } + } + + public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") { + if shared.logLevel.shouldLog(level: .error) { + shared.log(items, separator: separator, terminator: terminator, type: .error) + } + } +} + +public extension Logging { + /// Format a timing entry as a human-readable string with per-run average and percentage. + /// + /// Output format: ` 123.45 ms / 100 runs ( 1.23 ms/run) 45.67%` + /// + /// - Parameters: + /// - time: Duration in seconds. + /// - runs: Number of calls / iterations (used for per-run average). + /// - fullPipelineDuration: Total pipeline duration in **milliseconds** (for percentage). + static func formatTimeWithPercentage(_ time: Double, _ runs: Double, _ fullPipelineDuration: Double) -> String { + let percentage = (time * 1000 / fullPipelineDuration) * 100 + let runTime = runs > 0 ? time * 1000 / Double(runs) : 0 + return String(format: "%8.2f ms / %6.0f runs (%8.2f ms/run) %5.2f%%", time * 1000, runs, runTime, percentage) + } + + static func logCurrentMemoryUsage(_ message: String) { + let memoryUsage = getMemoryUsage() + Logging.debug("\(message) - Memory usage: \(memoryUsage) MB") + } + + static func getMemoryUsage() -> UInt64 { + var info = mach_task_basic_info() + var count = mach_msg_type_number_t(MemoryLayout.size) / 4 + + let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) { + $0.withMemoryRebound(to: integer_t.self, capacity: 1) { + task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count) + } + } + + guard kerr == KERN_SUCCESS else { + return 0 + } + + return info.resident_size / 1024 / 1024 + } +} + +@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.logCurrentMemoryUsage(_:)` instead.") +public func logCurrentMemoryUsage(_ message: String) { + Logging.logCurrentMemoryUsage(message) +} + +@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.getMemoryUsage()` instead.") +public func getMemoryUsage() -> UInt64 { + return Logging.getMemoryUsage() +} diff --git a/Sources/ArgmaxCore/MLModelExtensions.swift b/Sources/ArgmaxCore/MLModelExtensions.swift new file mode 100644 index 00000000..f9abac50 --- /dev/null +++ b/Sources/ArgmaxCore/MLModelExtensions.swift @@ -0,0 +1,64 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import CoreML + +// MARK: - Async Prediction + +public extension MLModel { + /// Async wrapper for `MLModel.prediction` that uses native async prediction + /// on macOS 14+ / iOS 17+ and falls back to a Task-wrapped call on older OS. + func asyncPrediction( + from input: MLFeatureProvider, + options: MLPredictionOptions = MLPredictionOptions() + ) async throws -> MLFeatureProvider { + if #available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) { + return try await prediction(from: input, options: options) + } else { + return try await Task { + try prediction(from: input, options: options) + }.value + } + } + + /// Async prediction with MLState for stateful models. + /// MLState requires macOS 15+ where native async prediction is available. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + func asyncPrediction( + from input: MLFeatureProvider, + using state: MLState, + options: MLPredictionOptions = MLPredictionOptions() + ) async throws -> MLFeatureProvider { + try await prediction(from: input, using: state, options: options) + } +} + +// MARK: - Compute Units Description + +public extension MLComputeUnits { + var description: String { + switch self { + case .cpuOnly: + return "cpuOnly" + case .cpuAndGPU: + return "cpuAndGPU" + case .all: + return "all" + case .cpuAndNeuralEngine: + return "cpuAndNeuralEngine" + @unknown default: + return "unknown" + } + } + + /// Human-readable display name suitable for UI presentation. + var displayName: String { + switch self { + case .cpuOnly: return "CPU" + case .cpuAndGPU: return "GPU" + case .cpuAndNeuralEngine: return "Neural Engine" + case .all: return "All" + @unknown default: return "Unknown" + } + } +} diff --git a/Sources/ArgmaxCore/MLModelLoading.swift b/Sources/ArgmaxCore/MLModelLoading.swift new file mode 100644 index 00000000..17de7412 --- /dev/null +++ b/Sources/ArgmaxCore/MLModelLoading.swift @@ -0,0 +1,31 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import CoreML +import Foundation + +// MARK: - CoreML model loading protocol + +/// Shared lifecycle contract for any CoreML-backed model component. +/// +/// Conform to this protocol to get a unified `loadModel`/`unloadModel` interface. +/// `prewarmMode` compiles the model on-device then immediately discards it to +/// serialize compilation and cap peak memory before a final concurrent load. +public protocol MLModelLoading { + /// Load the CoreML model bundle at `url` using the specified compute units. + /// + /// When `prewarmMode` is `true` the model is compiled on-device and then + /// immediately discarded. This serializes compilation to cap peak memory before + /// the final concurrent load. + func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws + + /// Release the loaded model weights from memory. + func unloadModel() +} + +public extension MLModelLoading { + /// Convenience overload — loads with `prewarmMode: false`. + func loadModel(at url: URL, computeUnits: MLComputeUnits) async throws { + try await loadModel(at: url, computeUnits: computeUnits, prewarmMode: false) + } +} diff --git a/Sources/ArgmaxCore/MLMultiArrayExtensions.swift b/Sources/ArgmaxCore/MLMultiArrayExtensions.swift new file mode 100644 index 00000000..5f056d80 --- /dev/null +++ b/Sources/ArgmaxCore/MLMultiArrayExtensions.swift @@ -0,0 +1,141 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import CoreML + +// MARK: - MLMultiArray Creation + +public extension MLMultiArray { + /// Creates an MLMultiArray pre-filled with an initial value. + /// Uses IOSurface-backed storage for float16 arrays. + convenience init(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) throws { + switch dataType { + case .float16: + guard let pixelBuffer = Self.pixelBuffer(for: shape) else { + throw MLMultiArrayCreationError.pixelBufferFailed + } + self.init(pixelBuffer: pixelBuffer, shape: shape) + default: + try self.init(shape: shape, dataType: dataType) + } + + switch dataType { + case .double: + if let value = initialValue as? Double { + let typedPointer = dataPointer.bindMemory(to: Double.self, capacity: count) + typedPointer.initialize(repeating: value, count: count) + } + case .float32: + if let value = initialValue as? Float { + let typedPointer = dataPointer.bindMemory(to: Float.self, capacity: count) + typedPointer.initialize(repeating: value, count: count) + } + case .float16: + if let value = initialValue as? FloatType { + let typedPointer = dataPointer.bindMemory(to: FloatType.self, capacity: count) + typedPointer.initialize(repeating: value, count: count) + } + case .int32: + if let value = initialValue as? Int32 { + let typedPointer = dataPointer.bindMemory(to: Int32.self, capacity: count) + typedPointer.initialize(repeating: value, count: count) + } + #if compiler(>=6.2) + case .int8: + if #available(macOS 26.0, iOS 26.0, watchOS 26.0, visionOS 26.0, tvOS 26.0, *), + let value = initialValue as? Int8 { + let typedPointer = dataPointer.bindMemory(to: Int8.self, capacity: count) + typedPointer.initialize(repeating: value, count: count) + } + #endif + @unknown default: + break + } + } + + /// Creates an MLMultiArray from an [Int] array. + /// Values are stored in the last dimension (default is dims=1). + static func from(_ array: [Int], dims: Int = 1) throws -> MLMultiArray { + var shape = Array(repeating: 1, count: dims) + shape[shape.count - 1] = array.count + let output = try MLMultiArray(shape: shape as [NSNumber], dataType: .int32) + let pointer = UnsafeMutablePointer(OpaquePointer(output.dataPointer)) + for (i, item) in array.enumerated() { + pointer[i] = Int32(item) + } + return output + } +} + +// MARK: - MLMultiArray Indexing & Fill + +public extension MLMultiArray { + /// Computes the linear offset from multi-dimensional indices using strides. + @inline(__always) + func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { + var linearOffset = 0 + let strideInts = strideInts ?? strides.map { $0.intValue } + for (dimension, stride) in zip(index, strideInts) { + linearOffset += dimension.intValue * stride + } + return linearOffset + } + + /// Fills a range of indices in the last dimension with a value. + /// Requires shape [1, 1, n]. + func fillLastDimension(indexes: Range, with value: FloatType) { + precondition(shape.count == 3 && shape[0] == 1 && shape[1] == 1, "Must have [1, 1, n] shape") + withUnsafeMutableBufferPointer(ofType: FloatType.self) { ptr, strides in + for index in indexes { + ptr[index * strides[2]] = value + } + } + } + + /// Fills specific multi-dimensional indices with a value. + func fill(indexes: [[NSNumber]], with value: Value) { + let pointer = UnsafeMutablePointer(OpaquePointer(dataPointer)) + let strideInts = strides.map { $0.intValue } + for index in indexes { + let linearOffset = linearOffset(for: index, strides: strideInts) + pointer[linearOffset] = value + } + } +} + +// MARK: - IOSurface-backed Pixel Buffer + +extension MLMultiArray { + /// Creates a CVPixelBuffer suitable for float16 IOSurface-backed MLMultiArrays. + public class func pixelBuffer(for shape: [NSNumber]) -> CVPixelBuffer? { + guard let width = shape.last?.intValue else { return nil } + let height = shape[0.. [Int] { + await shapedArray(of: Int32.self).scalars.map { Int($0) } + } + + func toFloatArray() async -> [Float] { + switch scalarType { + case is Float32.Type: + return await shapedArray(of: Float32.self).scalars.map { Float($0) } + case is FloatType.Type: + return await shapedArray(of: FloatType.self).scalars.map { Float($0) } + case is Float.Type: + return await shapedArray(of: Float.self).scalars + case is Int32.Type: + return await shapedArray(of: Int32.self).scalars.map { Float($0) } + default: + fatalError("Unsupported scalar type: \(scalarType)") + } + } + + func toMLMultiArray() async -> MLMultiArray { + switch scalarType { + case is Float32.Type: + return MLMultiArray(await shapedArray(of: Float32.self)) + case is FloatType.Type: + return MLMultiArray(await shapedArray(of: FloatType.self)) + case is Float.Type: + return MLMultiArray(await shapedArray(of: Float.self)) + case is Int32.Type: + return MLMultiArray(await shapedArray(of: Int32.self)) + default: + fatalError("Unsupported scalar type: \(scalarType)") + } + } + + // MARK: Sync (legacy — uses DispatchSemaphore, unsafe in concurrent async contexts) + + @available(*, deprecated, message: "Use await toIntArray() instead.") + func asIntArray() -> [Int] { + let semaphore = DispatchSemaphore(value: 0) + var result: [Int] = [] + Task(priority: .high) { + result = await self.toIntArray() + semaphore.signal() + } + semaphore.wait() + return result + } + + @available(*, deprecated, message: "Use await toFloatArray() instead.") + func asFloatArray() -> [Float] { + let semaphore = DispatchSemaphore(value: 0) + var result: [Float] = [] + Task(priority: .high) { + result = await self.toFloatArray() + semaphore.signal() + } + semaphore.wait() + return result + } + + @available(*, deprecated, message: "Use await toMLMultiArray() instead.") + func asMLMultiArray() -> MLMultiArray { + let semaphore = DispatchSemaphore(value: 0) + var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0) + Task(priority: .high) { + result = await self.toMLMultiArray() + semaphore.signal() + } + semaphore.wait() + return result + } +} +#endif diff --git a/Sources/ArgmaxCore/ModelState.swift b/Sources/ArgmaxCore/ModelState.swift new file mode 100644 index 00000000..872c915a --- /dev/null +++ b/Sources/ArgmaxCore/ModelState.swift @@ -0,0 +1,53 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation + +// MARK: - ModelState + +/// Lifecycle state of a loaded ML model pipeline. +/// +/// Shared by both WhisperKit and TTSKit so that UI components, callbacks, and +/// utilities can reference a single canonical type. +/// +/// State machine: +/// ``` +/// unloaded → downloading → downloaded → loading → loaded +/// unloaded → prewarming → prewarmed +/// loaded → unloading → unloaded +/// ``` +@frozen +public enum ModelState: CustomStringConvertible { + case unloading + case unloaded + case loading + case loaded + case prewarming + case prewarmed + case downloading + case downloaded + + public var description: String { + switch self { + case .unloading: return "Unloading" + case .unloaded: return "Unloaded" + case .loading: return "Loading" + case .loaded: return "Loaded" + case .prewarming: return "Specializing" + case .prewarmed: return "Specialized" + case .downloading: return "Downloading" + case .downloaded: return "Downloaded" + } + } + + /// Returns `true` when a loading or downloading operation is in progress. + public var isBusy: Bool { + switch self { + case .loading, .prewarming, .downloading, .unloading: return true + default: return false + } + } +} + +/// Callback invoked when the pipeline's `modelState` changes. +public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void diff --git a/Sources/ArgmaxCore/ModelUtilities.swift b/Sources/ArgmaxCore/ModelUtilities.swift new file mode 100644 index 00000000..c67cb792 --- /dev/null +++ b/Sources/ArgmaxCore/ModelUtilities.swift @@ -0,0 +1,73 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import CoreML +import Foundation + +public struct ModelUtilities { + private init() {} + + // MARK: - Model Dimension Introspection + + /// Read a dimension from a model input's multiarray constraint shape. + public static func getModelInputDimension(_ model: MLModel?, named: String, position: Int) -> Int? { + guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil } + guard inputDescription.type == .multiArray else { return nil } + guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } + let shape = shapeConstraint.shape.map { $0.intValue } + return shape[position] + } + + /// Read a dimension from a model output's multiarray constraint shape. + public static func getModelOutputDimension(_ model: MLModel?, named: String, position: Int) -> Int? { + guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil } + guard inputDescription.type == .multiArray else { return nil } + guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } + let shape = shapeConstraint.shape.map { $0.intValue } + return shape[position] + } + + @available(*, deprecated, renamed: "getModelInputDimension") + public static func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { + getModelInputDimension(model, named: named, position: position) + } + + @available(*, deprecated, renamed: "getModelOutputDimension") + public static func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { + getModelOutputDimension(model, named: named, position: position) + } + + // MARK: - Model URL Detection + + /// Detects the best available CoreML model URL in a folder, preferring + /// compiled `.mlmodelc` over `.mlpackage`. + public static func detectModelURL(inFolder path: URL, named modelName: String) -> URL { + let compiledUrl = path.appending(path: "\(modelName).mlmodelc") + let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel") + + let compiledModelExists = FileManager.default.fileExists(atPath: compiledUrl.path) + let packageModelExists = FileManager.default.fileExists(atPath: packageUrl.path) + + if packageModelExists && !compiledModelExists { + return packageUrl + } + return compiledUrl + } + + /// Scans a folder for the first CoreML model bundle when the filename is not known in advance. + /// + /// Prefers a compiled `.mlmodelc` bundle over an `.mlpackage` source bundle. + public static func detectModelURL(inFolder path: URL) -> URL? { + guard let contents = try? FileManager.default.contentsOfDirectory( + at: path, includingPropertiesForKeys: nil + ) else { return nil } + + if let compiled = contents.first(where: { $0.pathExtension == "mlmodelc" }) { + return compiled + } + if let package = contents.first(where: { $0.pathExtension == "mlpackage" }) { + return package.appending(path: "Data/com.apple.CoreML/model.mlmodel") + } + return nil + } +} diff --git a/Sources/TTSKit/Configurations.swift b/Sources/TTSKit/Configurations.swift new file mode 100644 index 00000000..a4196233 --- /dev/null +++ b/Sources/TTSKit/Configurations.swift @@ -0,0 +1,39 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import CoreML +import Foundation + +// MARK: - Compute Options + +/// Per-component CoreML compute unit configuration. +/// +/// Used by `TTSKitConfig` to specify which hardware accelerators each model component +/// should use. This struct is model-agnostic; any backend with multiple CoreML components +/// that need different compute targets can adopt it. +public struct ComputeOptions: Sendable { + /// Compute units for embedding lookup models (TextProjector, CodeEmbedder, MultiCodeEmbedder). + /// Defaults to CPU-only since these are simple table lookups with minimal compute. + public var embedderComputeUnits: MLComputeUnits + + /// Compute units for CodeDecoder. Defaults to CPU + Neural Engine. + public var codeDecoderComputeUnits: MLComputeUnits + + /// Compute units for MultiCodeDecoder. Defaults to CPU + Neural Engine. + public var multiCodeDecoderComputeUnits: MLComputeUnits + + /// Compute units for SpeechDecoder. Defaults to CPU + Neural Engine. + public var speechDecoderComputeUnits: MLComputeUnits + + public init( + embedderComputeUnits: MLComputeUnits = .cpuOnly, + codeDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine, + multiCodeDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine, + speechDecoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine + ) { + self.embedderComputeUnits = embedderComputeUnits + self.codeDecoderComputeUnits = codeDecoderComputeUnits + self.multiCodeDecoderComputeUnits = multiCodeDecoderComputeUnits + self.speechDecoderComputeUnits = speechDecoderComputeUnits + } +} diff --git a/Sources/TTSKit/Generating.swift b/Sources/TTSKit/Generating.swift new file mode 100644 index 00000000..1a87aec6 --- /dev/null +++ b/Sources/TTSKit/Generating.swift @@ -0,0 +1,61 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation + +// MARK: - SpeechGenerating + +/// Protocol for a single-chunk speech synthesis task. +/// +/// `TTSKit` creates tasks via `setupGenerateTask(...)` and calls `run(...)` on each one. +/// Each task instance is independent - fresh KV caches, own RNG - so multiple tasks +/// can run concurrently without data races. +/// +/// The interface uses plain `String` for `voice` and `language` to stay model-agnostic. +/// Each implementation maps these to its internal representation (e.g. token IDs). +/// +/// **Reference implementation:** `Qwen3GenerateTask` in `Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift`. +public protocol SpeechGenerating: Sendable { + /// Default voice identifier for this model (e.g. `"ryan"` for Qwen3 TTS). + /// + /// Returned by `TTSKit.generate()` and `play()` when `voice` is `nil`. + /// Each model implementation provides its own sensible default. + var defaultVoice: String { get } + + /// Default language identifier for this model (e.g. `"english"` for Qwen3 TTS). + /// + /// Returned by `TTSKit.generate()` and `play()` when `language` is `nil`. + var defaultLanguage: String { get } + + /// Output sample rate in Hz (e.g. 24000 for Qwen3 TTS). + var sampleRate: Int { get } + + /// Number of PCM samples produced per decoded frame (e.g. 1920 for Qwen3 TTS). + var samplesPerFrame: Int { get } + + /// Minimum pre-buffer duration (seconds) in `.auto` playback mode. + var minimumBufferDuration: TimeInterval { get } + + /// Generate speech for a single text chunk, delivering per-step audio via `callback`. + /// + /// - Parameters: + /// - text: The text chunk to synthesize. + /// - voice: Voice identifier (model-specific, e.g. `Qwen3Speaker.rawValue`). + /// - language: Language identifier (model-specific, e.g. `Qwen3Language.rawValue`). + /// - options: Generation options (temperature, top-k, chunking, concurrency, etc.) + /// - callback: Called on every decoded step with audio samples and running timings. + /// `SpeechProgress.stepTime` is non-nil only on the first step, allowing + /// adaptive playback buffer configuration before audio starts. + /// Return `false` to cancel; `nil` or `true` to continue. + /// - prefixCache: Optional cached prefix state to skip invariant prefill tokens. + /// - Returns: The assembled `SpeechResult` for this text chunk. + /// - Throws: `TTSError` on generation failure or task cancellation. + func run( + text: String, + voice: String, + language: String, + options: GenerationOptions, + callback: SpeechCallback, + prefixCache: TTSPromptCache? + ) async throws -> SpeechResult +} diff --git a/Sources/TTSKit/Models.swift b/Sources/TTSKit/Models.swift new file mode 100644 index 00000000..6fc53b45 --- /dev/null +++ b/Sources/TTSKit/Models.swift @@ -0,0 +1,532 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +@_exported import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Generation progress + +/// Per-step progress information delivered to a `SpeechCallback` during generation. +/// +/// A single callback receives both streaming audio and timing metadata on every +/// generation step. +/// +/// **Typical usage:** +/// ```swift +/// let result = try await tts.generate(text: "Hello!") { progress in +/// player.enqueue(progress.audio) // stream audio +/// if let t = progress.stepTime { configureBufferingStrategy(t) } // first-step setup +/// return nil // nil or true = continue, false = cancel +/// } +/// ``` +public struct SpeechProgress: Sendable { + /// Audio samples produced by this generation step (~80ms at 24 kHz). + public var audio: [Float] + + /// Timings accumulated so far in the current chunk. + public var timings: SpeechTimings + + /// Wall-clock duration of this generation step (seconds). + /// + /// Non-`nil` **only on the first step** - use it to configure adaptive playback + /// buffers (see `PlaybackStrategy.auto`) before the first audio frame is played. + public var stepTime: TimeInterval? + + /// Zero-based index of the chunk currently being generated (for multi-chunk requests). + public var chunkIndex: Int? + + /// Total number of text chunks in this generation request. + public var totalChunks: Int? + + /// Running count of generation steps completed so far across all chunks. + /// Increases by one for every decoder step, regardless of which chunk produced it. + public var stepsCompleted: Int? + + /// Estimated total steps for the entire request (`maxNewTokens × totalChunks`). + /// Actual steps may be fewer if chunks finish early (EOS). + public var totalSteps: Int? + + public init( + audio: [Float], + timings: SpeechTimings, + stepTime: TimeInterval? = nil, + chunkIndex: Int? = nil, + totalChunks: Int? = nil, + stepsCompleted: Int? = nil, + totalSteps: Int? = nil + ) { + self.audio = audio + self.timings = timings + self.stepTime = stepTime + self.chunkIndex = chunkIndex + self.totalChunks = totalChunks + self.stepsCompleted = stepsCompleted + self.totalSteps = totalSteps + } +} + +/// A closure invoked on each generation step with decoded audio and timing info. +/// +/// Return `false` to cancel generation early; return `nil` or `true` to continue. +/// Matches the `TranscriptionCallback` pattern in WhisperKit. +public typealias SpeechCallback = (@Sendable (SpeechProgress) -> Bool?)? + +// MARK: - SpeechModel Protocol + +/// The stable public contract that every TTS model family must satisfy. +/// +/// Conform to this protocol to add a new TTS model family to TTSKit. +/// +/// **Minimum implementation:** +/// ```swift +/// public class MyTTSModel: SpeechModel { +/// public var sampleRate: Int { 24000 } +/// +/// public func generate(text: String, voice: String, language: String, +/// options: GenerationOptions, callback: SpeechCallback) async throws -> SpeechResult { ... } +/// +/// public func play(text: String, voice: String, language: String, +/// options: GenerationOptions, playbackStrategy: PlaybackStrategy, +/// callback: SpeechCallback) async throws -> SpeechResult { ... } +/// } +/// ``` +public protocol SpeechModel: AnyObject, Sendable { + /// Output sample rate in Hz for all audio produced by this model. + var sampleRate: Int { get } + + /// Generate speech from text and return the complete audio result. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Voice/speaker identifier. `nil` uses the model's default voice. + /// - language: Language identifier. `nil` uses the model's default language. + /// - options: Sampling and generation options. Model-specific fields are ignored by models + /// that do not support them. + /// - callback: Optional per-step callback receiving decoded audio chunks. + /// Return `false` to cancel; `nil` or `true` to continue. + /// - Returns: The assembled `SpeechResult` containing all audio and timing data. + /// - Throws: `TTSError` on generation failure or task cancellation. + func generate( + text: String, + voice: String?, + language: String?, + options: GenerationOptions, + callback: SpeechCallback + ) async throws -> SpeechResult + + /// Generate speech and stream it through the device audio output in real time. + /// + /// Default implementation calls `generate` and plays via `AudioOutput`. Override + /// if the model has its own streaming playback path. + func play( + text: String, + voice: String?, + language: String?, + options: GenerationOptions, + playbackStrategy: PlaybackStrategy, + callback: SpeechCallback + ) async throws -> SpeechResult +} + +// MARK: - Playback Strategy + +/// Controls how `play()` buffers audio before starting playback. +/// +/// The TTS generation loop produces one RVQ frame (~80ms of audio) per step. +/// On fast devices, steps complete well under 80ms and audio can stream immediately. +/// On slower devices (or in debug builds), steps may exceed 80ms, causing the +/// playback buffer to drain and produce choppy audio. +/// +/// `.auto` (default) measures the first generation step - which runs the full +/// pipeline (MultiCodeDecoder + SpeechDecoder + CodeDecoder) before any audio +/// is emitted - and uses that exact timing to compute the pre-buffer needed so +/// playback is less likely to underrun. The buffer is re-assessed at the start of each chunk. +public enum PlaybackStrategy: Sendable { + /// Automatically determine buffer duration from the first step's measured + /// wall-clock time. On fast devices this resolves to a small minimum (~240ms) + /// to absorb per-step variance. On slower devices it pre-buffers just enough + /// to avoid underruns for the full generation. Re-assessed per chunk. + case auto + + /// Always stream frame-by-frame with no pre-buffer. + /// May produce choppy audio if the device can't generate faster than real-time. + case stream + + /// Pre-buffer a fixed duration of audio before starting playback. + case buffered(seconds: Double) + + /// Generate all audio for each chunk first, then play. + /// Highest latency but guaranteed smooth playback. + case generateFirst + + // MARK: - Buffer calculation + + /// Default minimum buffer duration (seconds) applied in `.auto` mode when no + /// model-specific value is available. + /// 80ms ≈ 1 audio frame at 24 kHz / 1920 spf - provides headroom for scheduling jitter. + /// The `SpeechDecoding` protocol exposes `minimumBufferDuration` so each model can + /// override this value; pass it to `requiredBuffer(minimumBuffer:)` at the call site. + public static let minimumBufferDuration: TimeInterval = 0.08 + + /// Audio duration produced by a single generation step (seconds). + /// + /// - Parameters: + /// - samplesPerFrame: PCM samples produced per decode step (model-specific). + /// - sampleRate: Output sample rate in Hz (model-specific). + /// - Returns: Audio duration in seconds for one generation step. + public static func audioPerStep(samplesPerFrame: Int, sampleRate: Int) -> Double { + Double(samplesPerFrame) / Double(sampleRate) + } + + /// Compute the required pre-buffer duration (seconds) given a measured step time. + /// + /// Called on every generation step so the buffer duration converges toward + /// the minimum as the measured speed approaches or exceeds real-time. + /// No additional safety margin is applied; `maxNewTokens` is already an + /// upper bound on actual generation length, providing natural conservatism. + /// + /// - Parameters: + /// - stepTime: Measured wall-clock time of one generation step (seconds). + /// - maxNewTokens: Upper bound on remaining generation steps. + /// - samplesPerFrame: PCM samples produced per decode step (model-specific). + /// - sampleRate: Output sample rate in Hz (model-specific). + /// - minimumBuffer: Floor on the returned buffer (defaults to + /// `PlaybackStrategy.minimumBufferDuration`; pass + /// `speechDecoder.minimumBufferDuration` for a model-specific override). + /// - Returns: Buffer duration in seconds (≥ `minimumBuffer`). + public static func requiredBuffer( + stepTime: TimeInterval, + maxNewTokens: Int, + samplesPerFrame: Int, + sampleRate: Int, + minimumBuffer: TimeInterval = PlaybackStrategy.minimumBufferDuration + ) -> TimeInterval { + let perStep = audioPerStep(samplesPerFrame: samplesPerFrame, sampleRate: sampleRate) + let speedRatio = perStep / stepTime + let deficit = max(0.0, 1.0 - speedRatio) + let maxAudioDuration = Double(maxNewTokens) * perStep + let deficitBuffer = maxAudioDuration * deficit + return max(minimumBuffer, deficitBuffer) + } +} + +// MARK: - Generation Options + +/// Options that control the speech synthesis pipeline. +/// +/// Mirrors `DecodingOptions` in WhisperKit: all fields have sensible defaults so +/// the zero-argument initializer works for most use cases. +public struct GenerationOptions: Codable, Sendable { + // MARK: - Defaults + + public static let defaultTemperature: Float = 0.9 + public static let defaultTopK: Int = 50 + public static let defaultRepetitionPenalty: Float = 1.05 + public static let defaultMaxNewTokens: Int = 245 + + public var temperature: Float + public var topK: Int + public var repetitionPenalty: Float + public var maxNewTokens: Int + + /// Number of concurrent workers for multi-chunk generation. + /// - `0`: all chunks run concurrently in one batch (default, fastest for non-streaming use cases). + /// - `1`: sequential - one chunk at a time; required for real-time `play` streaming. + /// - `N`: at most N chunks run concurrently. + public var concurrentWorkerCount: Int + + /// How to split long text into chunks. Defaults to `.sentence`. + /// Set to `.none` to force a single-pass generation without sentence splitting. + public var chunkingStrategy: TextChunkingStrategy? + + /// Target chunk size in tokens for sentence chunking. + /// `nil` resolves to `TextChunker.defaultTargetChunkSize` at the call site. + public var targetChunkSize: Int? + + /// Minimum chunk size in tokens. + /// `nil` resolves to `TextChunker.defaultMinChunkSize` at the call site. + public var minChunkSize: Int? + + /// Optional style instruction for controlling speech characteristics + /// (e.g., `"Very happy"`). Prepended as a text-only user prompt before the main + /// TTS segment. For Qwen3, this is only supported by the 1.7B model variant. + public var instruction: String? + + /// Force the legacy `[FloatType]` inference path even on macOS 15+ / iOS 18+. + /// When `false` (default), the MLTensor path is taken on supported OS versions. + /// Set to `true` in tests to exercise the pre-macOS-15 code path on current hardware. + // TODO: Remove forking logic with package with min os version upgrade + public var forceLegacyEmbedPath: Bool + + public init( + temperature: Float = GenerationOptions.defaultTemperature, + topK: Int = GenerationOptions.defaultTopK, + repetitionPenalty: Float = GenerationOptions.defaultRepetitionPenalty, + maxNewTokens: Int = GenerationOptions.defaultMaxNewTokens, + concurrentWorkerCount: Int = 0, + chunkingStrategy: TextChunkingStrategy? = nil, + targetChunkSize: Int? = nil, + minChunkSize: Int? = nil, + instruction: String? = nil, + forceLegacyEmbedPath: Bool = false + ) { + self.temperature = temperature + self.topK = topK + self.repetitionPenalty = repetitionPenalty + self.maxNewTokens = maxNewTokens + self.concurrentWorkerCount = concurrentWorkerCount + self.chunkingStrategy = chunkingStrategy + self.targetChunkSize = targetChunkSize + self.minChunkSize = minChunkSize + self.instruction = instruction + self.forceLegacyEmbedPath = forceLegacyEmbedPath + } +} + +// MARK: - Timings + +/// All timing values are stored in seconds, matching `TranscriptionTimings` in WhisperKit. +public struct SpeechTimings: Codable, Sendable { + // MARK: - Model loading + + public var modelLoading: TimeInterval = 0 + public var tokenizerLoadTime: TimeInterval = 0 + + // MARK: - Pipeline phases + + public var tokenize: TimeInterval = 0 + public var prefill: TimeInterval = 0 + public var timeToFirstBuffer: TimeInterval = 0 + public var fullPipeline: TimeInterval = 0 + + // MARK: - Generation loop + + /// Total wall-clock time spent in the autoregressive decoding loop. + /// Mirrors `decodingLoop` in `TranscriptionTimings`. + public var decodingLoop: TimeInterval = 0 + + // MARK: - CodeDecoder + + /// Sum of `model.prediction()` call durations for the main autoregressive decoder. + /// Mirrors `decodingPredictions` in `TranscriptionTimings`. + public var decodingPredictions: TimeInterval = 0 + + // MARK: - MultiCodeDecoder + + /// Total wall-clock time for MultiCodeDecoder across all steps. + public var multiCodeDecoder: TimeInterval = 0 + /// Sum of `model.prediction()` calls inside MultiCodeDecoder. + public var multiCodeDecoderPredictions: TimeInterval = 0 + public var multiCodeDecoderSampling: TimeInterval = 0 + public var multiCodeDecoderEmbedding: TimeInterval = 0 + /// Sum of KV cache update calls inside MultiCodeDecoder. + /// Mirrors `decodingKvCaching` in `TranscriptionTimings`. + public var decodingKvCaching: TimeInterval = 0 + public var totalMultiCodeDecoderPredictions: Double = 0 + + // MARK: - SpeechDecoder + + /// Total wall-clock time for SpeechDecoder across all steps. + public var speechDecoder: TimeInterval = 0 + /// Sum of `model.prediction()` calls inside SpeechDecoder. + public var speechDecoderPredictions: TimeInterval = 0 + + // MARK: - Non-model overhead + + /// Outer CodeDecoder KV cache update time. + public var kvCacheUpdate: TimeInterval = 0 + /// CodeEmbedder lookup time per step. + public var codeEmbed: TimeInterval = 0 + /// MultiCodeEmbedder + sum for next step. + public var codecHidden: TimeInterval = 0 + /// TextProjector + combine-embeddings per step. + public var textProjection: TimeInterval = 0 + /// Outer code-0 sampling time. + /// Mirrors `decodingSampling` in `TranscriptionTimings`. + public var decodingSampling: TimeInterval = 0 + + // MARK: - Counters + + public var prefillTokens: Double = 0 + /// Total autoregressive steps. Mirrors `totalDecodingLoops` in `TranscriptionTimings`. + public var totalDecodingLoops: Double = 0 + /// Duration of audio produced (seconds). Mirrors `inputAudioSeconds` in `TranscriptionTimings`. + public var inputAudioSeconds: TimeInterval = 0 + + public init() {} + + // MARK: - Computed properties + + /// Total wall-clock time in `model.prediction()` calls across all three models. + /// Note: with async parallelism this sum may exceed `decodingLoop` (overlap). + public var totalPredictions: TimeInterval { + decodingPredictions + multiCodeDecoderPredictions + speechDecoderPredictions + } + + /// Intra-step parallel overlap: SpeechDecoder time that ran concurrently with + /// CodeDecoder within each generation step. + public var concurrentStepOverlap: TimeInterval { + max(0, totalPredictions - decodingLoop) + } + + /// Non-inference overhead on the main path (excludes overlapped SpeechDecoder time). + public var totalNonPrediction: TimeInterval { + let mainPathPredictions = decodingPredictions + multiCodeDecoderPredictions + return decodingLoop - mainPathPredictions + } + + public var tokensPerSecond: Double { + fullPipeline > 0 ? totalDecodingLoops / fullPipeline : 0 + } + + public var realTimeFactor: Double { + inputAudioSeconds > 0 ? fullPipeline / inputAudioSeconds : 0 + } + + public var speedFactor: Double { + inputAudioSeconds > 0 ? inputAudioSeconds / fullPipeline : 0 + } + + // MARK: - Merge + + /// Accumulate all per-chunk timing fields from `other` into this timing. + /// + /// `modelLoading`, `tokenizerLoadTime`, `timeToFirstBuffer`, and `fullPipeline` + /// are intentionally excluded - callers set those separately on the combined struct. + public mutating func merge(_ other: SpeechTimings) { + tokenize += other.tokenize + prefill += other.prefill + prefillTokens += other.prefillTokens + decodingLoop += other.decodingLoop + decodingPredictions += other.decodingPredictions + multiCodeDecoder += other.multiCodeDecoder + multiCodeDecoderPredictions += other.multiCodeDecoderPredictions + multiCodeDecoderSampling += other.multiCodeDecoderSampling + multiCodeDecoderEmbedding += other.multiCodeDecoderEmbedding + decodingKvCaching += other.decodingKvCaching + totalMultiCodeDecoderPredictions += other.totalMultiCodeDecoderPredictions + speechDecoder += other.speechDecoder + speechDecoderPredictions += other.speechDecoderPredictions + kvCacheUpdate += other.kvCacheUpdate + codeEmbed += other.codeEmbed + codecHidden += other.codecHidden + textProjection += other.textProjection + decodingSampling += other.decodingSampling + totalDecodingLoops += other.totalDecodingLoops + } +} + +// MARK: - Result + +/// The complete output of a speech synthesis request. +/// +/// Mirrors `TranscriptionResult` in WhisperKit: holds the audio samples, timing +/// breakdown, and sample rate. Conforms to `Codable` so results can be serialized +/// for caching or logging. +public struct SpeechResult: Codable, Sendable { + /// Raw float audio samples (mono PCM). + public let audio: [Float] + + /// Generation performance timings. + public let timings: SpeechTimings + + /// Sample rate of the audio in Hz. + public let sampleRate: Int + + /// Audio duration in seconds. + public var audioDuration: Double { + Double(audio.count) / Double(sampleRate) + } + + public init(audio: [Float], timings: SpeechTimings, sampleRate: Int) { + self.audio = audio + self.timings = timings + self.sampleRate = sampleRate + } + + /// Log detailed timing breakdown to console, matching WhisperKit's format. + public func logTimings() { + let totalLoops = timings.totalDecodingLoops + let mcdPredRuns = timings.totalMultiCodeDecoderPredictions + let fullPipelineDuration = max(timings.decodingLoop, timings.fullPipeline) * 1000 + let formatTime = { (duration: TimeInterval, count: Double) in + Logging.formatTimeWithPercentage(duration, count, fullPipelineDuration) + } + + Logging.info( + """ + ---- Speech Timings ---- + Tokenize: \(formatTime(timings.tokenize, 1)) + Prefill: \(formatTime(timings.prefill, timings.prefillTokens)) + All Predictions: \(formatTime(timings.totalPredictions, totalLoops)) + Non-inference (main): \(formatTime(timings.totalNonPrediction, totalLoops)) + Concurrent Step Overlap: \(formatTime(timings.concurrentStepOverlap, totalLoops)) + CodeDecoder: \(formatTime(timings.decodingPredictions, totalLoops)) + MultiCodeDecoder: \(formatTime(timings.multiCodeDecoder, totalLoops)) + - Predictions: \(formatTime(timings.multiCodeDecoderPredictions, mcdPredRuns)) + - Sampling: \(formatTime(timings.multiCodeDecoderSampling, totalLoops)) + - Embedding: \(formatTime(timings.multiCodeDecoderEmbedding, totalLoops)) + - KV Caching: \(formatTime(timings.decodingKvCaching, totalLoops)) + SpeechDecoder: \(formatTime(timings.speechDecoder, totalLoops)) + - Predictions: \(formatTime(timings.speechDecoderPredictions, totalLoops)) + KV Cache (outer): \(formatTime(timings.kvCacheUpdate, totalLoops)) + Code Embed (outer): \(formatTime(timings.codeEmbed, totalLoops)) + Codec Hidden: \(formatTime(timings.codecHidden, totalLoops)) + Text Projection: \(formatTime(timings.textProjection, totalLoops)) + Sampling (outer): \(formatTime(timings.decodingSampling, totalLoops)) + Decoding Loop: \(formatTime(timings.decodingLoop, totalLoops)) + ------------------------------- + Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds + - Tokenizer: \(String(format: "%.2f", timings.tokenizerLoadTime)) seconds + Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds + Time to first buffer: \(String(format: "%.2f", timings.timeToFirstBuffer)) seconds + Total Steps: \(Int(totalLoops)) + Steps per Second: \(String(format: "%.2f", timings.tokensPerSecond)) steps/s + Real Time Factor: \(String(format: "%.3f", timings.realTimeFactor)) + Speed Factor: \(String(format: "%.3f", timings.speedFactor)) + Audio Duration: \(String(format: "%.2f", audioDuration)) seconds + """ + ) + } +} + +// MARK: - Decoder Result Types + +/// Result from MultiCodeDecoder.generateMultiCodes +public struct MultiCodeGenerationResult { + public let codes: [Int32] + public let timings: SpeechTimings + /// Pre-computed embeddings of codes 1-14 with position offsets (offsetIndex 0-13). + /// Produced as a side effect of multi-code generation and reused for codec-hidden + /// computation, saving 14 embed model calls per step. + /// The embedding of code15 (offset 14) is not included and must be fetched separately. + public let offsetCodeEmbeds: [[FloatType]] + /// MLTensor variant of offsetCodeEmbeds - populated by the async tensor path, + /// avoiding the [FloatType] -> MLTensor round-trip. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public var offsetCodeEmbedTensors: [MLTensor]? { _offsetCodeEmbedTensors as? [MLTensor] } + + let _offsetCodeEmbedTensors: Any? + + public init(codes: [Int32], timings: SpeechTimings, offsetCodeEmbeds: [[FloatType]]) { + self.codes = codes + self.timings = timings + self.offsetCodeEmbeds = offsetCodeEmbeds + self._offsetCodeEmbedTensors = nil + } + + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public init(codes: [Int32], timings: SpeechTimings, offsetCodeEmbedTensors: [MLTensor]) { + self.codes = codes + self.timings = timings + self.offsetCodeEmbeds = [] + self._offsetCodeEmbedTensors = offsetCodeEmbedTensors + } +} + +/// Result from SpeechDecoder.decodeFrameAsync +public struct SpeechDecoderTimedResult: Sendable { + public let samples: [Float] + public let timings: SpeechTimings +} diff --git a/Sources/TTSKit/Protocols.swift b/Sources/TTSKit/Protocols.swift new file mode 100644 index 00000000..c4d16884 --- /dev/null +++ b/Sources/TTSKit/Protocols.swift @@ -0,0 +1,135 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Text Projector Protocol + +/// Projects text token IDs into the shared embedding space. +public protocol TextProjecting: MLModelLoading { + var model: MLModel? { get } + func project(tokenId: Int32) async throws -> [FloatType] +} + +// MARK: - Code Embedder Protocol + +/// Embeds codec-0 tokens into the shared embedding space. +public protocol CodeEmbedding: MLModelLoading { + var model: MLModel? { get } + func embed(tokenId: Int32) async throws -> [FloatType] +} + +// MARK: - Multi-Code Embedder Protocol + +/// Embeds codec-1..15 tokens (with position offsets) into the shared embedding space. +public protocol MultiCodeEmbedding: MLModelLoading { + var model: MLModel? { get } + func embed(tokenId: Int32) async throws -> [FloatType] + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + func embed(tokenId: Int32) async throws -> MLTensor +} + +// MARK: - Code Decoder Protocol + +/// Autoregressive decoder that generates codec-0 tokens from combined embeddings. +public protocol CodeDecoding: MLModelLoading { + var model: MLModel? { get } + var isStateful: Bool { get } + var kvCacheEmbedDim: Int { get } + var kvCacheMaxSequenceLength: Int { get } + var embedSize: Int { get } + func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput + func makeState() -> Any? +} + +// MARK: - Code Decoder Output + +public struct CodeDecoderOutput { + /// Codec-0 logits - MLTensor (macOS 15+ async path) or MLMultiArray (sync path). + public let logits: any EmbedTensorType // [1, 1, vocabSize] + /// Hidden states - `[FloatType]` (legacy path) or MLTensor (async). Passed to MultiCodeDecoder. + public let hiddenStates: any EmbedTensorType + /// KV cache updates - populated by the sync path; nil when the async decoder updates the cache internally. + public let keyCacheUpdates: MLMultiArray? + public let valueCacheUpdates: MLMultiArray? + /// Time spent on KV cache update inside the decoder (async path only). Lets callers + /// subtract this from total decode time to isolate pure prediction cost. + public var internalCacheUpdateTime: TimeInterval = 0 +} + +// MARK: - Multi-Code Decoder Output + +public struct MultiCodeDecoderOutput { + /// All-head logits - MLTensor (macOS 15+ async path) or MLMultiArray (sync path). + public let allLogits: any EmbedTensorType // [1, 15, vocabSize] + /// KV cache updates - populated by the sync path; nil when the async decoder updates the cache internally. + public let keyCacheUpdates: MLMultiArray? + public let valueCacheUpdates: MLMultiArray? +} + +// MARK: - Multi-Code Decoder Protocol + +/// Generates codec-1..15 tokens for a single RVQ frame given hidden states and codec-0 embedding. +public protocol MultiCodeDecoding: MLModelLoading { + var model: MLModel? { get } + var isStateful: Bool { get } + + // MARK: - Cache geometry (read after loadModel) + + var kvCacheEmbedDim: Int { get } + var kvCacheMaxSequenceLength: Int { get } + var codecVocabSize: Int { get } + + // MARK: - Decoding + + func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any?) async throws -> MultiCodeDecoderOutput + func makeState() -> Any? + + /// Generate codes 1-15 for one RVQ frame. + func generateMultiCodes( + hiddenStates: [FloatType], + code0Embed: [FloatType], + multiCodeEmbedder: any MultiCodeEmbedding, + sampler: any TokenSampling, + options: GenerationOptions + ) async throws -> MultiCodeGenerationResult +} + +// MARK: - Speech Decoder Protocol + +/// Decodes RVQ code frames into audio waveform samples. +public protocol SpeechDecoding: MLModelLoading { + var model: MLModel? { get } + + // MARK: - Audio format (model-specific) + + /// Output sample rate in Hz (e.g. 24000 for Qwen3 TTS). + var sampleRate: Int { get } + /// Number of PCM samples produced per decoded RVQ frame (e.g. 1920 for Qwen3 TTS). + var samplesPerFrame: Int { get } + /// Minimum pre-buffer duration (seconds) in `.auto` playback mode. + var minimumBufferDuration: TimeInterval { get } + + // MARK: - Cache geometry (read after loadModel) + + var kvCacheEmbedDim: Int { get } + var kvCacheMaxSequenceLength: Int { get } + var hiddenDim: Int { get } + var hiddenContextLen: Int { get } + + // MARK: - Decoding + + /// Decode a single RVQ frame (16 codes) into audio samples. + func decodeFrame( + codes: [Int32], + cache: SpeechDecoderCache + ) async throws -> [Float] + + /// Async decode that returns audio samples with wall-clock timing. + func decodeFrameAsync( + codes: [Int32], + cache: SpeechDecoderCache + ) async throws -> SpeechDecoderTimedResult +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift new file mode 100644 index 00000000..c06e75c7 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3CodeDecoder.swift @@ -0,0 +1,190 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Implementation + +/// Autoregressive codec-0 decoder backed by a CoreML model. +/// +/// Thread safety: mutable state (`model`, dimension properties) is set once during +/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe. +/// Per-generation `MLState` is created via `makeState()` and owned by each task, +/// never stored on this shared instance. +public class Qwen3CodeDecoder: CodeDecoding, @unchecked Sendable { + public var model: MLModel? + + /// KV cache embedding dimension, detected from model at load time + public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.cdCacheDim + /// KV cache max sequence length, detected from model at load time + public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.cdMaxSeq + /// Input embedding dimension + public private(set) var embedSize: Int = Qwen3TTSConstants.embedDim + + public init() {} + + public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig) + + // In prewarm mode, compilation is complete - discard to free memory before next model compiles + guard !prewarmMode else { return } + + self.model = loaded + + // Detect dimensions from model description + // key_cache_updates output shape: [1, cacheDim, 1, 1] + if let dim = ModelUtilities.getModelOutputDimension(model, named: "key_cache_updates", position: 1) { + self.kvCacheEmbedDim = dim + } + // key_padding_mask input shape: [1, maxSeqLen] + if let seq = ModelUtilities.getModelInputDimension(model, named: "key_padding_mask", position: 1) { + self.kvCacheMaxSequenceLength = seq + } + // input_embeds input shape: [1, embedDim, 1, 1] + if let dim = ModelUtilities.getModelInputDimension(model, named: "input_embeds", position: 1) { + self.embedSize = dim + } + } + + public var isStateful: Bool { + guard let model else { return false } + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + return !model.modelDescription.stateDescriptionsByName.isEmpty + } + return false + } + + /// Create a fresh MLState for a new generation session (stateful models only). + /// Returns nil for non-stateful models. The caller owns the returned state. + public func makeState() -> Any? { + guard isStateful, let model else { return nil } + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + return model.makeState() + } + return nil + } + + public func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any? = nil) async throws -> CodeDecoderOutput { + guard let model else { throw TTSError.generationFailed("CodeDecoder model not loaded") } + + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let tensor = inputEmbeds as? MLTensor { + return try await decodeWithTensor(tensor, model: model, cache: cache, state: state) + } + + guard let array = inputEmbeds as? MLMultiArray else { + throw TTSError.generationFailed("CodeDecoder: unsupported embed input type \(type(of: inputEmbeds))") + } + return try await decodeWithMultiArray(array, model: model, cache: cache, state: state) + } + + /// MLTensor path: passes `[String: MLTensor]` directly - no FeatureProvider boxing. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + private func decodeWithTensor(_ inputEmbeds: MLTensor, model: MLModel, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput { + var inputs: [String: MLTensor] = [ + "input_embeds": inputEmbeds, + "cache_length": cache.cacheLengthTensor, + "kv_cache_update_mask": cache.kvCacheUpdateMaskTensor, + "key_padding_mask": cache.keyPaddingMaskTensor + ] + if !isStateful, let keyCacheTensor = cache.keyCacheTensor, let valueCacheTensor = cache.valueCacheTensor { + inputs["key_cache"] = keyCacheTensor + inputs["value_cache"] = valueCacheTensor + } + + let outputs: [String: MLTensor] + if let mlState = state as? MLState { + outputs = try await model.prediction(from: inputs, using: mlState) + } else { + outputs = try await model.prediction(from: inputs) + } + + guard let keyTensor = outputs["key_cache_updates"], + let valueTensor = outputs["value_cache_updates"] + else { + throw TTSError.generationFailed("CodeDecoder: missing key/value cache update tensors") + } + + if let mlState = state as? MLState, isStateful { + await KVCache.updateStateCache( + state: mlState, + keyTensor: keyTensor, + valueTensor: valueTensor, + position: Int(cache.cacheLength) + ) + } + + let cacheUpdateStart = CFAbsoluteTimeGetCurrent() + await cache.update(keyTensor: keyTensor, valueTensor: valueTensor) + let cacheTime = CFAbsoluteTimeGetCurrent() - cacheUpdateStart + + guard let logitsTensor = outputs["logits"], + let hiddenStatesTensor = outputs["hidden_states"] + else { + throw TTSError.generationFailed("CodeDecoder: missing logits or hidden_states tensor") + } + return CodeDecoderOutput( + logits: logitsTensor, + hiddenStates: hiddenStatesTensor, + keyCacheUpdates: nil, + valueCacheUpdates: nil, + internalCacheUpdateTime: cacheTime + ) + } + + /// MLMultiArray path: FeatureProvider-based prediction for older OS versions. + private func decodeWithMultiArray(_ inputEmbeds: MLMultiArray, model: MLModel, cache: KVCache, state: Any?) async throws -> CodeDecoderOutput { + var dict: [String: MLFeatureValue] = try [ + "input_embeds": MLFeatureValue(multiArray: inputEmbeds), + "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()), + "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask), + "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask) + ] + if !isStateful, let keyCache = cache.keyCache, let valueCache = cache.valueCache { + dict["key_cache"] = MLFeatureValue(multiArray: keyCache) + dict["value_cache"] = MLFeatureValue(multiArray: valueCache) + } + + let input = try MLDictionaryFeatureProvider(dictionary: dict) + let output: MLFeatureProvider + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState { + output = try await model.asyncPrediction(from: input, using: mlState) + } else { + output = try await model.asyncPrediction(from: input) + } + + guard let keyCacheUpdates = output.featureValue(for: "key_cache_updates")?.multiArrayValue, + let valueCacheUpdates = output.featureValue(for: "value_cache_updates")?.multiArrayValue + else { + throw TTSError.generationFailed("CodeDecoder: missing key/value cache update arrays") + } + + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState, isStateful { + KVCache.updateStateCache( + state: mlState, + keyCacheUpdates: keyCacheUpdates, + valueCacheUpdates: valueCacheUpdates, + position: Int(cache.cacheLength) + ) + } + + guard let logitsArray = output.featureValue(for: "logits")?.multiArrayValue, + let hiddenStatesArray = output.featureValue(for: "hidden_states")?.multiArrayValue + else { + throw TTSError.generationFailed("CodeDecoder: missing logits or hidden_states array") + } + return CodeDecoderOutput( + logits: logitsArray, + hiddenStates: EmbedUtilities.extractEmbed(from: hiddenStatesArray), + keyCacheUpdates: keyCacheUpdates, + valueCacheUpdates: valueCacheUpdates + ) + } + + public func unloadModel() { + model = nil + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift new file mode 100644 index 00000000..78a0d90c --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3Config.swift @@ -0,0 +1,372 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Model Family + +/// Identifies the TTS model architecture. Used by `setupPipeline`, `loadModels`, +/// and `setupGenerateTask` to dispatch to the correct model-specific code paths. +@frozen +public enum TTSModelFamily: String, Sendable { + case qwen3 + // case kokoro // future +} + +// MARK: - Model Variants + +/// Pre-configured TTS model sizes with matching variant defaults. +/// +/// Pass a variant to `TTSKitConfig(model:)` to select the desired model size. +/// Mirrors `ModelVariant` in WhisperKit: `@frozen`, `CustomStringConvertible`, `CaseIterable`. +@frozen +public enum TTSModelVariant: String, CustomStringConvertible, CaseIterable, Sendable { + case qwen3TTS_0_6b = "0.6b" + case qwen3TTS_1_7b = "1.7b" + + /// The model architecture family this variant belongs to. + public var family: TTSModelFamily { + switch self { + case .qwen3TTS_0_6b, .qwen3TTS_1_7b: return .qwen3 + } + } + + public var description: String { + switch self { + case .qwen3TTS_0_6b: return "Qwen3-TTS-0.6B" + case .qwen3TTS_1_7b: return "Qwen3-TTS-1.7B" + } + } + + /// Display name suitable for UI presentation. + public var displayName: String { + switch self { + case .qwen3TTS_0_6b: return "Qwen3 TTS 0.6B" + case .qwen3TTS_1_7b: return "Qwen3 TTS 1.7B" + } + } + + /// Whether this model supports the voice instruction prompt. + /// Only the 1.7B variant has the capacity to follow style instructions. + public var supportsVoiceDirection: Bool { + switch self { + case .qwen3TTS_0_6b: return false + case .qwen3TTS_1_7b: return true + } + } + + /// Returns `true` when this variant can run on the current platform. + /// + /// The 1.7B model requires more peak memory during CoreML compilation than + /// iOS/iPadOS devices can reliably provide, so it is restricted to macOS. + public var isAvailableOnCurrentPlatform: Bool { + #if os(macOS) + return true + #else + switch self { + case .qwen3TTS_0_6b: return true + case .qwen3TTS_1_7b: return false + } + #endif + } + + /// The best default variant for the current platform. + public static var defaultForCurrentPlatform: TTSModelVariant { + return .qwen3TTS_0_6b + } + + public var versionDir: String { + switch self { + case .qwen3TTS_0_6b: return "12hz-0.6b-customvoice" + case .qwen3TTS_1_7b: return "12hz-1.7b-customvoice" + } + } + + /// Recommended code decoder variant for this model size. + public var codeDecoderVariant: String { Qwen3VariantDefaults.codeDecoder } + /// Recommended multi-code decoder variant for this model size. + public var multiCodeDecoderVariant: String { Qwen3VariantDefaults.multiCodeDecoder } + /// Code embedder variant (same across model sizes). + public var codeEmbedderVariant: String { Qwen3VariantDefaults.codeEmbedder } + /// Multi-code embedder variant (same across model sizes). + public var multiCodeEmbedderVariant: String { Qwen3VariantDefaults.multiCodeEmbedder } + /// Text projector variant (same across model sizes). + public var textProjectorVariant: String { Qwen3VariantDefaults.textProjector } + /// Recommended speech decoder variant for this model size. + public var speechDecoderVariant: String { Qwen3VariantDefaults.speechDecoder } + /// HuggingFace repo used to load the tokenizer for this model size. + public var tokenizerRepo: String { Qwen3TTSConstants.defaultTokenizerRepo } +} + +// MARK: - Variant Defaults + +/// Default quantization variant strings matching the standard model repository layout. +public enum Qwen3VariantDefaults { + public static let codeDecoder = "W8A16-stateful" + public static let multiCodeDecoder = "W8A16" + public static let codeEmbedder = "W16A16" + public static let multiCodeEmbedder = "W16A16" + public static let textProjector = "W8A16" + public static let speechDecoder = "W8A16" +} + +// MARK: - TTSKit Configuration + +/// Configuration for initializing a `TTSKit` instance. +/// +/// Mirrors `WhisperKitConfig`: an `open class` so you can subclass for custom +/// configurations without modifying `TTSKit` itself. +/// +/// **Minimal usage** (auto-downloads models and tokenizer): +/// ```swift +/// let tts = try await TTSKit() +/// ``` +/// +/// **Local models** (skip download): +/// ```swift +/// let config = TTSKitConfig(modelFolder: URL(fileURLWithPath: "/path/to/models")) +/// let tts = try await TTSKit(config) +/// ``` +/// +/// **Expected local directory layout:** +/// ``` +/// modelFolder/ +/// └── qwen3_tts/ +/// ├── code_decoder///*.mlmodelc +/// ├── multi_code_decoder///*.mlmodelc +/// ├── code_embedder///*.mlmodelc +/// ├── multi_code_embedder///*.mlmodelc +/// ├── text_projector///*.mlmodelc +/// └── speech_decoder///*.mlmodelc +/// ``` +open class TTSKitConfig { + // MARK: - Model selection + + /// Model variant that determines default versionDir, component variants, and tokenizer. + public var model: TTSModelVariant + + // MARK: - Model location + + /// Local URL to the model repository directory. + /// If `nil`, models are downloaded from `modelRepo` on HuggingFace Hub. + public var modelFolder: URL? + + /// Base URL for downloading and caching models. + /// `nil` uses the Hub library's default cache directory. + public var downloadBase: URL? + + /// HuggingFace repo ID for auto-downloading models. + public var modelRepo: String + + // MARK: - Tokenizer + + /// HuggingFace repo ID or local folder path for the tokenizer + /// (resolved from `model` by default). + public var tokenizerFolder: URL? + + // MARK: - Authentication + + /// HuggingFace API token for private repos (or set the `HF_TOKEN` env var). + public var modelToken: String? + + /// HuggingFace Hub endpoint URL. + /// + /// Override to point at a regional mirror or an on-premise Hub instance. + /// Mirrors `WhisperKitConfig` (via `Constants.defaultRemoteEndpoint`). + public var modelEndpoint: String + + // MARK: - Component variants + + /// Version directory shared across all components (resolved from `model` by default). + public var versionDir: String + + /// Per-component quantization variant (resolved from `model` by default). + public var codeDecoderVariant: String + public var multiCodeDecoderVariant: String + public var codeEmbedderVariant: String + public var multiCodeEmbedderVariant: String + public var textProjectorVariant: String + public var speechDecoderVariant: String + + // MARK: - Compute + + /// Compute unit configuration per model component. + public var computeOptions: ComputeOptions + + // MARK: - Logging + + /// Whether to emit diagnostic logs during loading and generation. + public var verbose: Bool + + /// Logging level when `verbose` is `true`. Defaults to `.debug`. + public var logLevel: Logging.LogLevel + + // MARK: - Download options + + /// Specific git revision (commit SHA, tag, or branch) to download from the Hub. + /// `nil` (default) resolves to the repo's default branch head. + /// Mirrors `WhisperKit.download(variant:revision:)`. + public var downloadRevision: String? + + /// Additional glob patterns to include during model download, appended to the + /// patterns generated from the configured component variants. + public var downloadAdditionalPatterns: [String] + + /// Use a background `URLSession` for model downloads. + /// Mirrors `WhisperKitConfig.useBackgroundDownloadSession`. + public var useBackgroundDownloadSession: Bool + + /// Download models if not already available locally. + /// When `true` (default), `loadModels()` will trigger a download if `modelFolder` is nil. + public var download: Bool + + // MARK: - Lifecycle flags + + /// Enable model prewarming. + /// + /// Prewarming compiles each CoreML model sequentially then discards weights, + /// minimizing peak memory during compilation. Call before `loadModels()` on first + /// launch or after a model update. Mirrors `WhisperKitConfig.prewarm`. + public var prewarm: Bool? + + /// Load models immediately after init. + /// `nil` loads when `modelFolder` is non-nil, matching WhisperKit's default. + public var load: Bool? + + // MARK: - Generation + + /// Optional seed for reproducible generation. + /// Each concurrent task receives a derived seed (`seed ^ taskIndex`). + public var seed: UInt64? + + // MARK: - Component overrides + + /// Set any of these to substitute a custom implementation for that model component. + /// `nil` means TTSKit will use the default Qwen3 TTS class for that component. + /// + /// Example: + /// ```swift + /// let config = TTSKitConfig() + /// config.codeDecoder = MyCodeDecoder() + /// let tts = try await TTSKit(config) + /// ``` + public var textProjector: (any TextProjecting)? + public var codeEmbedder: (any CodeEmbedding)? + public var multiCodeEmbedder: (any MultiCodeEmbedding)? + public var codeDecoder: (any CodeDecoding)? + public var multiCodeDecoder: (any MultiCodeDecoding)? + public var speechDecoder: (any SpeechDecoding)? + + public init( + model: TTSModelVariant = .qwen3TTS_0_6b, + modelFolder: URL? = nil, + downloadBase: URL? = nil, + modelRepo: String = Qwen3TTSConstants.defaultModelRepo, + tokenizerFolder: URL? = nil, + modelToken: String? = nil, + modelEndpoint: String = Qwen3TTSConstants.defaultEndpoint, + versionDir: String? = nil, + codeDecoderVariant: String? = nil, + multiCodeDecoderVariant: String? = nil, + codeEmbedderVariant: String? = nil, + multiCodeEmbedderVariant: String? = nil, + textProjectorVariant: String? = nil, + speechDecoderVariant: String? = nil, + computeOptions: ComputeOptions = ComputeOptions(), + verbose: Bool = true, + logLevel: Logging.LogLevel = .info, + downloadRevision: String? = nil, + downloadAdditionalPatterns: [String] = [], + useBackgroundDownloadSession: Bool = false, + download: Bool = true, + prewarm: Bool? = nil, + load: Bool? = nil, + seed: UInt64? = nil + ) { + self.model = model + self.modelFolder = modelFolder + self.downloadBase = downloadBase + self.modelRepo = modelRepo + self.tokenizerFolder = tokenizerFolder + self.modelToken = modelToken + self.modelEndpoint = modelEndpoint + self.versionDir = versionDir ?? model.versionDir + self.codeDecoderVariant = codeDecoderVariant ?? model.codeDecoderVariant + self.multiCodeDecoderVariant = multiCodeDecoderVariant ?? model.multiCodeDecoderVariant + self.codeEmbedderVariant = codeEmbedderVariant ?? model.codeEmbedderVariant + self.multiCodeEmbedderVariant = multiCodeEmbedderVariant ?? model.multiCodeEmbedderVariant + self.textProjectorVariant = textProjectorVariant ?? model.textProjectorVariant + self.speechDecoderVariant = speechDecoderVariant ?? model.speechDecoderVariant + self.computeOptions = computeOptions + self.verbose = verbose + self.logLevel = logLevel + self.downloadRevision = downloadRevision + self.downloadAdditionalPatterns = downloadAdditionalPatterns + self.useBackgroundDownloadSession = useBackgroundDownloadSession + self.download = download + self.prewarm = prewarm + self.load = load + self.seed = seed + } + + // MARK: - Path resolution + + /// Resolve the full path to a component's model bundle. + /// + /// Requires `modelFolder` to be set (either directly or via download). + /// Delegates to `ModelUtilities.detectModelURL(inFolder:)`, which prefers + /// a compiled `.mlmodelc` bundle and falls back to `.mlpackage`. + public func modelURL(component: String, variant: String) -> URL? { + guard let modelFolder else { return nil } + + let variantDir = modelFolder.appending(path: Qwen3TTSConstants.modelFamilyDir) + .appending(path: component).appending(path: versionDir) + .appending(path: variant) + + return ModelUtilities.detectModelURL(inFolder: variantDir) + } + + /// The effective tokenizer source: local `tokenizerFolder` if set, otherwise the + /// model's default HuggingFace repo ID. + public var tokenizerSource: String { + tokenizerFolder?.path ?? model.tokenizerRepo + } + + /// Component names in the model layout. + public static let componentNames = [ + "text_projector", "code_embedder", "multi_code_embedder", + "code_decoder", "multi_code_decoder", "speech_decoder" + ] + + /// Version-specific directory for each component inside `modelFolder`. + /// + /// e.g. `modelFolder/qwen3_tts/code_decoder/12hz-0.6b-customvoice` + /// + /// Useful for targeted deletion or disk-size calculation of a single variant. + public func componentDirectories(in folder: URL? = nil) -> [URL] { + guard let base = folder ?? modelFolder else { return [] } + return Self.componentNames.map { component in + base + .appending(path: Qwen3TTSConstants.modelFamilyDir) + .appending(path: component) + .appending(path: versionDir) + } + } + + /// Glob patterns used to download only the files needed for the configured variants. + public var downloadPatterns: [String] { + let components: [(String, String)] = [ + ("text_projector", textProjectorVariant), + ("code_embedder", codeEmbedderVariant), + ("multi_code_embedder", multiCodeEmbedderVariant), + ("code_decoder", codeDecoderVariant), + ("multi_code_decoder", multiCodeDecoderVariant), + ("speech_decoder", speechDecoderVariant) + ] + return components.map { + "\(Qwen3TTSConstants.modelFamilyDir)/\($0.0)/\(versionDir)/\($0.1)/**" + } + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift new file mode 100644 index 00000000..9085a5e8 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3Embedders.swift @@ -0,0 +1,103 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Code Embedder Implementation + +/// Codec-0 token embedder backed by a CoreML model. +/// +/// Thread safety: per-call input tensors - safe for concurrent use from multiple tasks. +public class Qwen3CodeEmbedder: CodeEmbedding, @unchecked Sendable { + public var model: MLModel? + + public init() {} + + public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig) + + guard !prewarmMode else { return } + + self.model = loaded + } + + public func embed(tokenId: Int32) async throws -> [FloatType] { + guard let model else { throw TTSError.generationFailed("CodeEmbedder model not loaded") } + let ids = try EmbedUtilities.makeInt32Array([tokenId]) + let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)]) + let output = try await model.asyncPrediction(from: provider) + guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else { + throw TTSError.generationFailed("CodeEmbedder: missing input_embeds output") + } + return EmbedUtilities.extractEmbed(from: embedArray) + } + + /// Optimised async path: passes `[String: MLTensor]` directly - no FeatureProvider boxing. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func embed(tokenId: Int32) async throws -> MLTensor { + guard let model else { throw TTSError.generationFailed("CodeEmbedder model not loaded") } + let outputs = try await model.prediction(from: [ + "input_ids": MLTensor(shape: [1], scalars: [tokenId]) + ]) + guard let embedTensor = outputs["input_embeds"] else { + throw TTSError.generationFailed("CodeEmbedder: missing input_embeds tensor output") + } + return embedTensor + } + + public func unloadModel() { + model = nil + } +} + +// MARK: - Multi-Code Embedder Implementation + +/// Codec-1..15 token embedder backed by a CoreML model. +/// +/// Thread safety: same as Qwen3CodeEmbedder - per-call input tensors, safe for concurrent use. +public class Qwen3MultiCodeEmbedder: MultiCodeEmbedding, @unchecked Sendable { + public var model: MLModel? + + public init() {} + + public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig) + + guard !prewarmMode else { return } + + self.model = loaded + } + + public func embed(tokenId: Int32) async throws -> [FloatType] { + guard let model else { throw TTSError.generationFailed("MultiCodeEmbedder model not loaded") } + let ids = try EmbedUtilities.makeInt32Array([tokenId]) + let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)]) + let output = try await model.asyncPrediction(from: provider) + guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else { + throw TTSError.generationFailed("MultiCodeEmbedder: missing input_embeds output") + } + return EmbedUtilities.extractEmbed(from: embedArray) + } + + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func embed(tokenId: Int32) async throws -> MLTensor { + guard let model else { throw TTSError.generationFailed("MultiCodeEmbedder model not loaded") } + let outputs = try await model.prediction(from: [ + "input_ids": MLTensor(shape: [1], scalars: [tokenId]) + ]) + guard let embedTensor = outputs["input_embeds"] else { + throw TTSError.generationFailed("MultiCodeEmbedder: missing input_embeds tensor output") + } + return embedTensor + } + + public func unloadModel() { + model = nil + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift b/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift new file mode 100644 index 00000000..240ab322 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3GenerateTask.swift @@ -0,0 +1,838 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation +import Tokenizers + +// MARK: - Internal phase-result types + +/// Output of the tokenize phase - fed into prefill and the generation loop. +struct TokenizeResult { + let textTokenIds: [Int32] + let trailingTextTokens: [Int32] + let firstTextEmbed: [FloatType] + let variableEmbed: [FloatType] + let textPadEmbed: [FloatType] + let timings: SpeechTimings +} + +/// Output of the prefill phase - fed into the generation loop. +struct PrefillResult { + let cdCache: KVCache + let lastCdOutput: CodeDecoderOutput + let timings: SpeechTimings +} + +/// Output of the autoregressive generation loop. +struct GenerationLoopResult { + let audio: [Float] + let steps: Int + let timings: SpeechTimings +} + +// MARK: - Qwen3GenerateTask + +/// Qwen3 TTS single-chunk generation task. +/// +/// The core building block for Qwen3 TTS generation, analogous to `TranscribeTask` +/// in WhisperKit. Each task creates its own KV caches, MLState, and sampler, then +/// runs a complete prefill + autoregressive decode cycle. +/// +/// Conforms to `SpeechGenerating` so it can be returned by +/// `TTSKit.setupGenerateTask(...)` and consumed by `TTSKit`'s generic +/// orchestration layer. +/// +/// Thread safety: all stored properties are `let` (immutable after init). Each task +/// owns its own sampler (derived seed) so concurrent tasks don't share RNG state. +/// Model components are shared read-only references - `MLModel.prediction()` is +/// thread-safe. The class is `@unchecked Sendable` to permit `open` subclassing. +open class Qwen3GenerateTask: @unchecked Sendable, SpeechGenerating { + /// Model components - concrete Qwen3 types for correct async method dispatch. + /// Using `any Protocol` existentials would cause async extension methods to dispatch + /// to the protocol default (sync path) instead of the Qwen3-specific MLTensor path. + public let textProjector: Qwen3TextProjector + public let codeEmbedder: Qwen3CodeEmbedder + public let multiCodeEmbedder: Qwen3MultiCodeEmbedder + public let codeDecoder: Qwen3CodeDecoder + public let multiCodeDecoder: Qwen3MultiCodeDecoder + public let speechDecoder: Qwen3SpeechDecoder + public let sampler: any TokenSampling + public let tokenizer: any Tokenizer + public let suppressTokenIds: Set + + /// Timings captured at model-load time (modelLoading + tokenizerLoading populated). + public let loadTimings: SpeechTimings + + /// Progress object for tracking generation. `totalUnitCount` is set to + /// `maxNewTokens` at the start of generation; `completedUnitCount` is + /// updated after each decoding step. + public let progress: Progress + + // MARK: - Initialization + + public init( + textProjector: Qwen3TextProjector, + codeEmbedder: Qwen3CodeEmbedder, + multiCodeEmbedder: Qwen3MultiCodeEmbedder, + codeDecoder: Qwen3CodeDecoder, + multiCodeDecoder: Qwen3MultiCodeDecoder, + speechDecoder: Qwen3SpeechDecoder, + sampler: any TokenSampling, + tokenizer: any Tokenizer, + suppressTokenIds: Set, + loadTimings: SpeechTimings = SpeechTimings(), + progress: Progress? = nil + ) { + self.textProjector = textProjector + self.codeEmbedder = codeEmbedder + self.multiCodeEmbedder = multiCodeEmbedder + self.codeDecoder = codeDecoder + self.multiCodeDecoder = multiCodeDecoder + self.speechDecoder = speechDecoder + self.sampler = sampler + self.tokenizer = tokenizer + self.suppressTokenIds = suppressTokenIds + self.loadTimings = loadTimings + self.progress = progress ?? Progress() + } + + // MARK: - SpeechGenerating defaults + + /// Default voice for Qwen3 TTS. Matches `Qwen3Speaker.ryan`. + public var defaultVoice: String { Qwen3Speaker.ryan.rawValue } + + /// Default language for Qwen3 TTS. Matches `Qwen3Language.english`. + public var defaultLanguage: String { Qwen3Language.english.rawValue } + + // MARK: - Audio format (forwarded from speechDecoder) + + public var sampleRate: Int { speechDecoder.sampleRate } + public var samplesPerFrame: Int { speechDecoder.samplesPerFrame } + public var minimumBufferDuration: TimeInterval { speechDecoder.minimumBufferDuration } + + // MARK: - Run + + /// Generate speech for a single text segment. + /// + /// Creates fresh KV caches, runs prefill, then autoregressive generation with + /// interleaved SpeechDecoder audio output. Safe to call concurrently from + /// multiple tasks against the same model instances. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Raw string matching `Qwen3Speaker.rawValue`; falls back to `.ryan`. + /// - language: Raw string matching `Qwen3Language.rawValue`; falls back to `.english`. + /// - options: Generation options (temperature, top-k, etc.) + /// - callback: Per-step callback receiving decoded audio and running timings. + /// `TTSProgress.stepTime` is non-nil only on the first step. + /// Return `false` to cancel; `nil` or `true` to continue. + /// - prefixCache: Optional cached prefix state to skip invariant prefill tokens. + /// - Returns: A `SpeechResult` containing the complete audio and timings for this chunk. + /// - Throws: `TTSError` on generation failure or task cancellation. + open func run( + text: String, + voice: String, + language: String, + options: GenerationOptions, + callback: SpeechCallback, + prefixCache: TTSPromptCache? = nil + ) async throws -> SpeechResult { + let qwen3Speaker = Qwen3Speaker(rawValue: voice) ?? .ryan + let lang = Qwen3Language(rawValue: language) ?? .english + + var timings = loadTimings + let pipelineStart = CFAbsoluteTimeGetCurrent() + + progress.totalUnitCount = Int64(options.maxNewTokens) + progress.completedUnitCount = 0 + + // Create task-local MLState for stateful decoders (nil for non-stateful) + let cdState = codeDecoder.makeState() + + // Phase 1: Tokenize text and build initial embeddings + let tokenizeResult = try await tokenizeAndBuildEmbeds(text: text) + timings.merge(tokenizeResult.timings) + + // Phase 2: Prefill the CodeDecoder with the prompt prefix + let prefillResult = try await prefillCodeDecoder( + tokenizeResult: tokenizeResult, + speaker: qwen3Speaker, lang: lang, + options: options, + prefixCache: prefixCache, + voice: voice, language: language, + cdState: cdState + ) + timings.merge(prefillResult.timings) + + // Phase 3: Autoregressive RVQ generation with interleaved audio decode + let loopResult = try await runGenerationLoop( + tokenizeResult: tokenizeResult, + prefillResult: prefillResult, + cdState: cdState, + options: options, + pipelineStart: pipelineStart, + callback: callback, + baseTimings: timings + ) + timings.merge(loopResult.timings) + timings.timeToFirstBuffer = loopResult.timings.timeToFirstBuffer + + timings.fullPipeline = CFAbsoluteTimeGetCurrent() - pipelineStart + timings.inputAudioSeconds = Double(loopResult.audio.count) / Double(speechDecoder.sampleRate) + + progress.completedUnitCount = progress.totalUnitCount + + let genMs = timings.decodingLoop * 1000 + let avgMs = loopResult.steps > 0 ? genMs / Double(loopResult.steps) : 0 + let stepsPerSec = loopResult.steps > 0 ? Double(loopResult.steps) / timings.decodingLoop : 0 + Logging.info( + String( + format: "Generation: %d frames in %.1fms (%.1fms/step, %.1f frames/s)", + loopResult.steps, genMs, avgMs, stepsPerSec + )) + + return SpeechResult(audio: loopResult.audio, timings: timings, sampleRate: speechDecoder.sampleRate) + } + + // MARK: - Phase 1: Tokenize + + /// Tokenize `text` and pre-compute the initial embeddings needed for prefill and decoding. + private func tokenizeAndBuildEmbeds( + text: String + ) async throws -> TokenizeResult { + let start = CFAbsoluteTimeGetCurrent() + + let textTokenIds = tokenizer.encode(text: text).map { Int32($0) } + guard !textTokenIds.isEmpty else { throw TTSError.emptyText } + + let firstTextEmbed = try await textProjector.project(tokenId: textTokenIds[0]) + let codecBOSEmbed = try await codeEmbedder.embed(tokenId: Qwen3TTSConstants.codecBOS) + let variableEmbed = EmbedUtilities.addEmbeddings(firstTextEmbed, codecBOSEmbed) + let textPadEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD) + + var phaseTimings = SpeechTimings() + phaseTimings.tokenize = CFAbsoluteTimeGetCurrent() - start + + return TokenizeResult( + textTokenIds: textTokenIds, + trailingTextTokens: Array(textTokenIds.dropFirst()), + firstTextEmbed: firstTextEmbed, + variableEmbed: variableEmbed, + textPadEmbed: textPadEmbed, + timings: phaseTimings + ) + } + + // MARK: - Phase 2: Prefill + + /// Prefill the CodeDecoder KV cache with the invariant prompt prefix. + /// + /// If `prefixCache` matches the current voice/language/instruction, restores the + /// cached state and only decodes the variable token. Otherwise runs a full prefill. + private func prefillCodeDecoder( + tokenizeResult: TokenizeResult, + speaker: Qwen3Speaker, + lang: Qwen3Language, + options: GenerationOptions, + prefixCache: TTSPromptCache?, + voice: String, + language: String, + cdState: Any? + ) async throws -> PrefillResult { + let start = CFAbsoluteTimeGetCurrent() + + let cdCache = try KVCache( + cacheDim: codeDecoder.kvCacheEmbedDim, + maxSeqLength: codeDecoder.kvCacheMaxSequenceLength, + isStateful: codeDecoder.isStateful + ) + + let usedCache = prefixCache?.matches(voice: voice, language: language, instruction: options.instruction) == true + var totalPrefillTokens: Int + var lastCdOutput: CodeDecoderOutput? + + if usedCache, let prefixCache { + cdCache.restore(from: prefixCache.kvSnapshot) + if let stateData = prefixCache.stateData { + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = cdState as? MLState { + mlState.restore(from: stateData) + } + } + totalPrefillTokens = prefixCache.prefixLength + 1 + + // TODO: Remove forking logic with package with min os version upgrade + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath { + lastCdOutput = try await codeDecoder.decode( + inputEmbeds: tokenizeResult.variableEmbed.asMLTensor(), cache: cdCache, state: cdState + ) + } else { + let embedArr = try EmbedUtilities.createEmbedMLArray(tokenizeResult.variableEmbed) + lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState) + } + } else { + let embedDim = codeDecoder.embedSize + let combinedEmbeds = try await buildCombinedEmbeddings( + speaker: speaker, lang: lang, + instruction: options.instruction, + firstTextEmbed: tokenizeResult.firstTextEmbed, + embedDim: embedDim + ) + totalPrefillTokens = combinedEmbeds.count + + // TODO: Remove forking logic with package with min os version upgrade + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath { + for embed in combinedEmbeds { + lastCdOutput = try await codeDecoder.decode(inputEmbeds: embed.asMLTensor(), cache: cdCache, state: cdState) + } + } else { + for (embedIndex, embed) in combinedEmbeds.enumerated() { + if embedIndex > 0, let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates { + cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates) + } + let embedArr = try EmbedUtilities.createEmbedMLArray(embed) + lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState) + } + } + } + + guard let resolvedLastCdOutput = lastCdOutput else { + throw TTSError.generationFailed("Prefill produced no decoder output") + } + + var phaseTimings = SpeechTimings() + phaseTimings.prefill = CFAbsoluteTimeGetCurrent() - start + phaseTimings.prefillTokens = Double(totalPrefillTokens) + + let prefillMs = phaseTimings.prefill * 1000 + let prefillTokS = phaseTimings.prefill > 0 ? Double(totalPrefillTokens) / phaseTimings.prefill : 0 + let cacheTag = usedCache ? " (cache hit, restored \(prefixCache?.prefixLength ?? 0) tokens)" : "" + Logging.info( + String( + format: "Prefill: %.1fms (%d tokens, %.1f tok/s)%@", + prefillMs, totalPrefillTokens, prefillTokS, cacheTag + )) + + return PrefillResult(cdCache: cdCache, lastCdOutput: resolvedLastCdOutput, timings: phaseTimings) + } + + // MARK: - Phase 3: Generation loop + + /// Run the autoregressive RVQ generation loop, delivering audio frames via `callback`. + /// + /// `baseTimings` carries the tokenize + prefill phase timings and is used to build + /// accurate cumulative `SpeechProgress` values for callbacks. + /// Returns the assembled audio, the number of steps completed, and the loop-phase timings. + private func runGenerationLoop( + tokenizeResult: TokenizeResult, + prefillResult: PrefillResult, + cdState: Any?, + options: GenerationOptions, + pipelineStart: CFAbsoluteTime, + callback: SpeechCallback, + baseTimings: SpeechTimings + ) async throws -> GenerationLoopResult { + let cdCache = prefillResult.cdCache + var lastCdOutput = prefillResult.lastCdOutput + var timings = baseTimings + + let roleTokenIds = tokenizer.encode(text: "<|im_start|>assistant\n").map { Int32($0) } + let maxStepsByPrefill = 8 * (roleTokenIds.count + tokenizeResult.textTokenIds.count) + + let sdCache = try SpeechDecoderCache( + cacheDim: speechDecoder.kvCacheEmbedDim, + maxSeqLength: speechDecoder.kvCacheMaxSequenceLength, + hiddenDim: speechDecoder.hiddenDim, + hiddenContextLen: speechDecoder.hiddenContextLen + ) + + var generatedTokens: [Int32] = [] + var code0 = await sampler.sampleCodec0( + logits: lastCdOutput.logits, + temperature: options.temperature, topK: options.topK, + generatedTokens: generatedTokens, + repetitionPenalty: options.repetitionPenalty, + suppressTokenIds: suppressTokenIds + ) + generatedTokens.append(code0) + + var allAudio: [Float] = [] + var stepIndex = 0 + var firstBufferEmitted = false + + // TODO: Remove forking logic with package with min os version upgrade + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), !options.forceLegacyEmbedPath { + let textPadEmbedTensor: MLTensor = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD) + + while code0 != Qwen3TTSConstants.codecEOS + && !cdCache.isFull + && stepIndex < options.maxNewTokens + && stepIndex < maxStepsByPrefill + { + try Task.checkCancellation() + let stepStart = CFAbsoluteTimeGetCurrent() + + let codeEmbedStart = CFAbsoluteTimeGetCurrent() + let code0EmbedTensor: MLTensor = try await codeEmbedder.embed(tokenId: code0) + timings.codeEmbed += CFAbsoluteTimeGetCurrent() - codeEmbedStart + + let mcdStart = CFAbsoluteTimeGetCurrent() + guard let hiddenStatesTensor = lastCdOutput.hiddenStates as? MLTensor else { + throw TTSError.generationFailed("Expected MLTensor hidden states on async path") + } + let mcdResult = try await multiCodeDecoder.generateMultiCodes( + hiddenStatesTensor: hiddenStatesTensor, + code0EmbedTensor: code0EmbedTensor, + multiCodeEmbedder: multiCodeEmbedder, + sampler: sampler, options: options + ) + timings.multiCodeDecoder += CFAbsoluteTimeGetCurrent() - mcdStart + timings.multiCodeDecoderPredictions += mcdResult.timings.multiCodeDecoderPredictions + timings.multiCodeDecoderSampling += mcdResult.timings.multiCodeDecoderSampling + timings.multiCodeDecoderEmbedding += mcdResult.timings.multiCodeDecoderEmbedding + timings.decodingKvCaching += mcdResult.timings.decodingKvCaching + timings.totalMultiCodeDecoderPredictions += mcdResult.timings.totalMultiCodeDecoderPredictions + + let rvqFrame = [code0] + mcdResult.codes + + if !firstBufferEmitted { + // First step: give SpeechDecoder exclusive compute access for minimum TTFB. + // Emit the buffer immediately, then do the remaining step work (codec hidden, + // text projection, CodeDecoder) which only affects the *next* step's readiness. + let sdResult = try await speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache) + let stepTime = CFAbsoluteTimeGetCurrent() - stepStart + timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions + timings.speechDecoder += sdResult.timings.speechDecoderPredictions + allAudio.append(contentsOf: sdResult.samples) + + timings.timeToFirstBuffer = CFAbsoluteTimeGetCurrent() - pipelineStart + firstBufferEmitted = true + + if callback?( + SpeechProgress( + audio: sdResult.samples, + timings: timings, + stepTime: stepTime + ) + ) == false { + break + } + + let codecHiddenStart = CFAbsoluteTimeGetCurrent() + guard let lastMcdCode = mcdResult.codes.last else { + throw TTSError.generationFailed("Multi-code generation result has no codes") + } + let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14) + let code15EmbedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: code15OffsetId) + var allCodeEmbedTensors: [MLTensor] = [code0EmbedTensor] + if let tensorEmbeds = mcdResult.offsetCodeEmbedTensors { + allCodeEmbedTensors += tensorEmbeds + } else { + allCodeEmbedTensors += mcdResult.offsetCodeEmbeds.map { $0.asMLTensor() } + } + allCodeEmbedTensors.append(code15EmbedTensor) + let codecHiddenTensor = EmbedUtilities.sumEmbeddings(allCodeEmbedTensors) + timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart + + let textProjStart = CFAbsoluteTimeGetCurrent() + let textEmbedTensor: MLTensor = + stepIndex < tokenizeResult.trailingTextTokens.count + ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex]) + : textPadEmbedTensor + let combinedTensor = EmbedUtilities.addEmbeddings(codecHiddenTensor, textEmbedTensor) + timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart + + let decodingStart = CFAbsoluteTimeGetCurrent() + lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedTensor, cache: cdCache, state: cdState) + timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart - lastCdOutput.internalCacheUpdateTime + timings.kvCacheUpdate += lastCdOutput.internalCacheUpdateTime + } else { + // Subsequent steps: overlap SpeechDecoder with CodeDecoder for throughput + async let speechResult = speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache) + + let codecHiddenStart = CFAbsoluteTimeGetCurrent() + guard let lastMcdCode = mcdResult.codes.last else { + throw TTSError.generationFailed("Multi-code generation result has no codes") + } + let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14) + let code15EmbedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: code15OffsetId) + var allCodeEmbedTensors: [MLTensor] = [code0EmbedTensor] + if let tensorEmbeds = mcdResult.offsetCodeEmbedTensors { + allCodeEmbedTensors += tensorEmbeds + } else { + allCodeEmbedTensors += mcdResult.offsetCodeEmbeds.map { $0.asMLTensor() } + } + allCodeEmbedTensors.append(code15EmbedTensor) + let codecHiddenTensor = EmbedUtilities.sumEmbeddings(allCodeEmbedTensors) + timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart + + let textProjStart = CFAbsoluteTimeGetCurrent() + let textEmbedTensor: MLTensor = + stepIndex < tokenizeResult.trailingTextTokens.count + ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex]) + : textPadEmbedTensor + let combinedTensor = EmbedUtilities.addEmbeddings(codecHiddenTensor, textEmbedTensor) + timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart + + let decodingStart = CFAbsoluteTimeGetCurrent() + lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedTensor, cache: cdCache, state: cdState) + timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart - lastCdOutput.internalCacheUpdateTime + timings.kvCacheUpdate += lastCdOutput.internalCacheUpdateTime + + let sdResult = try await speechResult + timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions + timings.speechDecoder += sdResult.timings.speechDecoderPredictions + allAudio.append(contentsOf: sdResult.samples) + + if callback?( + SpeechProgress( + audio: sdResult.samples, + timings: { + var merged = baseTimings; merged.merge(timings); return merged + }(), stepTime: nil)) == false + { + break + } + } + + let samplingStart = CFAbsoluteTimeGetCurrent() + code0 = await sampler.sampleCodec0( + logits: lastCdOutput.logits, + temperature: options.temperature, topK: options.topK, + generatedTokens: generatedTokens, + repetitionPenalty: options.repetitionPenalty, + suppressTokenIds: suppressTokenIds + ) + generatedTokens.append(code0) + timings.decodingSampling += CFAbsoluteTimeGetCurrent() - samplingStart + + timings.decodingLoop += CFAbsoluteTimeGetCurrent() - stepStart + stepIndex += 1 + progress.completedUnitCount = Int64(stepIndex) + + if stepIndex == 1 || stepIndex % 10 == 0 { + let stepMs = (CFAbsoluteTimeGetCurrent() - stepStart) * 1000 + Logging.debug( + String( + format: " Step %d: %.1fms (avg %.1fms/step)", + stepIndex, stepMs, timings.decodingLoop * 1000 / Double(stepIndex))) + } + } + } else { + // Legacy path (older OS) + while code0 != Qwen3TTSConstants.codecEOS && !cdCache.isFull && stepIndex < options.maxNewTokens && stepIndex < maxStepsByPrefill { + try Task.checkCancellation() + let stepStart = CFAbsoluteTimeGetCurrent() + + let cacheUpdateStart = CFAbsoluteTimeGetCurrent() + if let keyUpdates = lastCdOutput.keyCacheUpdates, let valueUpdates = lastCdOutput.valueCacheUpdates { + cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates) + } + timings.kvCacheUpdate += CFAbsoluteTimeGetCurrent() - cacheUpdateStart + + let codeEmbedStart = CFAbsoluteTimeGetCurrent() + let code0Embed = try await codeEmbedder.embed(tokenId: code0) + timings.codeEmbed += CFAbsoluteTimeGetCurrent() - codeEmbedStart + + let mcdStart = CFAbsoluteTimeGetCurrent() + guard let hiddenStates = lastCdOutput.hiddenStates as? [FloatType] else { + throw TTSError.generationFailed("Expected [FloatType] hidden states on legacy path") + } + let mcdResult = try await multiCodeDecoder.generateMultiCodes( + hiddenStates: hiddenStates, code0Embed: code0Embed, + multiCodeEmbedder: multiCodeEmbedder, sampler: sampler, options: options + ) + timings.multiCodeDecoder += CFAbsoluteTimeGetCurrent() - mcdStart + timings.multiCodeDecoderPredictions += mcdResult.timings.multiCodeDecoderPredictions + timings.multiCodeDecoderSampling += mcdResult.timings.multiCodeDecoderSampling + timings.multiCodeDecoderEmbedding += mcdResult.timings.multiCodeDecoderEmbedding + timings.decodingKvCaching += mcdResult.timings.decodingKvCaching + timings.totalMultiCodeDecoderPredictions += mcdResult.timings.totalMultiCodeDecoderPredictions + + let rvqFrame = [code0] + mcdResult.codes + + if !firstBufferEmitted { + let sdResult = try await speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache) + timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions + timings.speechDecoder += sdResult.timings.speechDecoderPredictions + allAudio.append(contentsOf: sdResult.samples) + + timings.timeToFirstBuffer = CFAbsoluteTimeGetCurrent() - pipelineStart + firstBufferEmitted = true + let stepTime = CFAbsoluteTimeGetCurrent() - stepStart + + if callback?( + SpeechProgress( + audio: sdResult.samples, + timings: { + var merged = baseTimings; merged.merge(timings); return merged + }(), stepTime: stepTime)) == false + { + break + } + + let codecHiddenStart = CFAbsoluteTimeGetCurrent() + guard let lastMcdCode = mcdResult.codes.last else { + throw TTSError.generationFailed("Multi-code generation result has no codes") + } + let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14) + var allCodeEmbeds: [[FloatType]] = [code0Embed] + allCodeEmbeds += mcdResult.offsetCodeEmbeds + try await allCodeEmbeds.append(multiCodeEmbedder.embed(tokenId: code15OffsetId)) + let codecHidden = EmbedUtilities.sumEmbeddings(allCodeEmbeds) + timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart + + let textProjStart = CFAbsoluteTimeGetCurrent() + let textTokenEmbed: [FloatType] = + stepIndex < tokenizeResult.trailingTextTokens.count + ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex]) + : tokenizeResult.textPadEmbed + let combinedArr = try EmbedUtilities.createEmbedMLArray(EmbedUtilities.addEmbeddings(codecHidden, textTokenEmbed)) + timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart + + let decodingStart = CFAbsoluteTimeGetCurrent() + lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedArr, cache: cdCache, state: cdState) + timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart + } else { + async let speechResult = speechDecoder.decodeFrameAsync(codes: rvqFrame, cache: sdCache) + + let codecHiddenStart = CFAbsoluteTimeGetCurrent() + guard let lastMcdCode = mcdResult.codes.last else { + throw TTSError.generationFailed("Multi-code generation result has no codes") + } + let code15OffsetId = lastMcdCode + Int32(multiCodeDecoder.codecVocabSize * 14) + var allCodeEmbeds: [[FloatType]] = [code0Embed] + allCodeEmbeds += mcdResult.offsetCodeEmbeds + try await allCodeEmbeds.append(multiCodeEmbedder.embed(tokenId: code15OffsetId)) + let codecHidden = EmbedUtilities.sumEmbeddings(allCodeEmbeds) + timings.codecHidden += CFAbsoluteTimeGetCurrent() - codecHiddenStart + + let textProjStart = CFAbsoluteTimeGetCurrent() + let textTokenEmbed: [FloatType] = + stepIndex < tokenizeResult.trailingTextTokens.count + ? try await textProjector.project(tokenId: tokenizeResult.trailingTextTokens[stepIndex]) + : tokenizeResult.textPadEmbed + let combinedArr = try EmbedUtilities.createEmbedMLArray(EmbedUtilities.addEmbeddings(codecHidden, textTokenEmbed)) + timings.textProjection += CFAbsoluteTimeGetCurrent() - textProjStart + + let decodingStart = CFAbsoluteTimeGetCurrent() + lastCdOutput = try await codeDecoder.decode(inputEmbeds: combinedArr, cache: cdCache, state: cdState) + timings.decodingPredictions += CFAbsoluteTimeGetCurrent() - decodingStart + + let sdResult = try await speechResult + timings.speechDecoderPredictions += sdResult.timings.speechDecoderPredictions + timings.speechDecoder += sdResult.timings.speechDecoderPredictions + allAudio.append(contentsOf: sdResult.samples) + + if callback?( + SpeechProgress( + audio: sdResult.samples, + timings: { + var merged = baseTimings; merged.merge(timings); return merged + }(), stepTime: nil)) == false + { + break + } + } + + let samplingStart = CFAbsoluteTimeGetCurrent() + code0 = await sampler.sampleCodec0( + logits: lastCdOutput.logits, + temperature: options.temperature, topK: options.topK, + generatedTokens: generatedTokens, + repetitionPenalty: options.repetitionPenalty, + suppressTokenIds: suppressTokenIds + ) + generatedTokens.append(code0) + timings.decodingSampling += CFAbsoluteTimeGetCurrent() - samplingStart + + timings.decodingLoop += CFAbsoluteTimeGetCurrent() - stepStart + stepIndex += 1 + progress.completedUnitCount = Int64(stepIndex) + + if stepIndex == 1 || stepIndex % 10 == 0 { + let stepMs = (CFAbsoluteTimeGetCurrent() - stepStart) * 1000 + Logging.debug( + String( + format: " Step %d: %.1fms (avg %.1fms/step)", + stepIndex, stepMs, timings.decodingLoop * 1000 / Double(stepIndex))) + } + } + } + + let stopReason: String + if code0 == Qwen3TTSConstants.codecEOS { + stopReason = "EOS token" + } else if cdCache.isFull { + stopReason = "KV cache full (\(cdCache.cacheLength)/\(cdCache.maxSeqLength))" + } else if stepIndex >= maxStepsByPrefill { + stopReason = "Audio token ratio limit (\(stepIndex)/\(maxStepsByPrefill) steps)" + } else { + stopReason = "maxNewTokens limit (\(options.maxNewTokens))" + } + Logging.info("Loop stopped: \(stopReason) after \(stepIndex) steps") + + timings.totalDecodingLoops = Double(stepIndex) + return GenerationLoopResult(audio: allAudio, steps: stepIndex, timings: timings) + } + + // MARK: - Embedding Helpers + + /// Build the full combined embedding sequence (text track + codec track) for prefill. + /// The returned array includes both the invariant prefix and the variable last token. + func buildCombinedEmbeddings( + speaker: Qwen3Speaker, + lang: Qwen3Language, + instruction: String?, + firstTextEmbed: [FloatType], + embedDim: Int + ) async throws -> [[FloatType]] { + let zeroCodecEmbed = EmbedUtilities.zeroEmbed(dim: embedDim) + + var instructTextEmbeds: [[FloatType]] = [] + var instructCodecEmbeds: [[FloatType]] = [] + if let instruction, !instruction.isEmpty { + let instructPrompt = "<|im_start|>user\n\(instruction)<|im_end|>\n" + let instructTokenIds = tokenizer.encode(text: instructPrompt).map { Int32($0) } + for tokenId in instructTokenIds { + try await instructTextEmbeds.append(textProjector.project(tokenId: tokenId)) + instructCodecEmbeds.append(zeroCodecEmbed) + } + Logging.debug("Instruction: \(instructTokenIds.count) tokens") + } + + let rolePrefix = "<|im_start|>assistant\n" + let roleTokenIds = tokenizer.encode(text: rolePrefix).map { Int32($0) } + + var textTrackEmbeds: [[FloatType]] = instructTextEmbeds + for tokenId in roleTokenIds { + try await textTrackEmbeds.append(textProjector.project(tokenId: tokenId)) + } + let textPadEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textPAD) + let textBosEmbed = try await textProjector.project(tokenId: Qwen3TTSConstants.textBOS) + + let codecIds: [Int32] = [ + Qwen3TTSConstants.codecThink, + Qwen3TTSConstants.codecThinkBos, + lang.tokenID, + Qwen3TTSConstants.codecThinkEos, + speaker.tokenID, + Qwen3TTSConstants.codecPAD, + Qwen3TTSConstants.codecBOS + ] + var codecTrackEmbeds: [[FloatType]] = [] + for codecId in codecIds { + try await codecTrackEmbeds.append(codeEmbedder.embed(tokenId: codecId)) + } + + let numPads = codecIds.count - 2 + for _ in 0.. TTSPromptCache { + let qwen3Speaker = Qwen3Speaker(rawValue: voice) ?? .ryan + let lang = Qwen3Language(rawValue: language) ?? .english + let embedDim = codeDecoder.embedSize + + // Build invariant embeddings (everything except the last variable token). + // Use a dummy firstTextEmbed since we drop the last element. + let dummyFirstTextEmbed = EmbedUtilities.zeroEmbed(dim: embedDim) + let allEmbeds = try await buildCombinedEmbeddings( + speaker: qwen3Speaker, + lang: lang, + instruction: instruction, + firstTextEmbed: dummyFirstTextEmbed, + embedDim: embedDim + ) + let invariantEmbeds = Array(allEmbeds.dropLast()) + + // Pre-initialize the MultiCodeDecoder ANE pipeline concurrently with the + // CodeDecoder prefill loop below. Cache build is the right place for this + // one-time cost: it absorbs the ~150ms without affecting TTFB, and the + // warmed pipeline persists for all subsequent generation calls. + let mcdWarmupTask: Task? + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + let mcd = multiCodeDecoder + mcdWarmupTask = Task { try? await mcd.prewarmInference() } + } else { + mcdWarmupTask = nil + } + + let cdState = codeDecoder.makeState() + let cdCache = try KVCache( + cacheDim: codeDecoder.kvCacheEmbedDim, + maxSeqLength: codeDecoder.kvCacheMaxSequenceLength, + isStateful: codeDecoder.isStateful + ) + + var lastCdOutput: CodeDecoderOutput? + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + // Async path: decoder updates cache internally + for embed in invariantEmbeds { + lastCdOutput = try await codeDecoder.decode(inputEmbeds: embed.asMLTensor(), cache: cdCache, state: cdState) + } + } else { + for (embedIndex, embed) in invariantEmbeds.enumerated() { + if embedIndex > 0, let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates { + cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates) + } + let embedArr = try EmbedUtilities.createEmbedMLArray(embed) + lastCdOutput = try await codeDecoder.decode(inputEmbeds: embedArr, cache: cdCache, state: cdState) + } + // Commit the last pending KV update so the snapshot is fully self-contained + if let keyUpdates = lastCdOutput?.keyCacheUpdates, let valueUpdates = lastCdOutput?.valueCacheUpdates { + cdCache.update(keyCacheUpdates: keyUpdates, valueCacheUpdates: valueUpdates) + } + } + + // Snapshot MLState for stateful models + var stateData: KVStateData? + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = cdState as? MLState { + stateData = mlState.snapshot() + } + + // Ensure warmup is done before returning - by this point the CodeDecoder + // loop has run (~2700ms), so this await is a no-op in practice. + await mcdWarmupTask?.value + + Logging.info("Built prompt cache: \(invariantEmbeds.count) invariant tokens, isStateful=\(codeDecoder.isStateful) for \(voice)/\(language)") + + return TTSPromptCache( + voice: voice, + language: language, + instruction: instruction, + prefixLength: invariantEmbeds.count, + kvSnapshot: cdCache.snapshot(), + stateData: stateData + ) + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift b/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift new file mode 100644 index 00000000..001b09d0 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3Models.swift @@ -0,0 +1,174 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation + +// MARK: - Qwen3 TTS Constants + +/// Qwen3 TTS model-specific constants: codec token IDs, vocabulary sizes, +/// model dimensions, cache geometry, and HuggingFace source locations. +/// +/// These values are derived from the Qwen3 TTS architecture and must be updated +/// if the model is retrained with a different configuration. +/// +/// Audio-format constants (`sampleRate`, `samplesPerFrame`) are here because +/// they match Qwen3's output format. Future model families with different audio +/// parameters should expose them via `SpeechDecoding.sampleRate` / +/// `SpeechDecoding.samplesPerFrame` rather than adding a second constants enum. +public enum Qwen3TTSConstants { + // MARK: Codec track special tokens + + public static let codecPAD: Int32 = 2148 + public static let codecBOS: Int32 = 2149 + public static let codecEOS: Int32 = 2150 + public static let codecThink: Int32 = 2154 + public static let codecThinkBos: Int32 = 2156 + public static let codecThinkEos: Int32 = 2157 + + // MARK: Text track special tokens + + public static let textPAD: Int32 = 151_671 + public static let textBOS: Int32 = 151_672 + + // MARK: Vocabulary sizes + + /// Vocabulary size for the multi-code decoder heads (codes 1-15). + public static let codecVocabSize: Int = 2048 + + // MARK: Audio format + + public static let sampleRate: Int = 24000 + public static let samplesPerFrame: Int = 1920 + + // MARK: Model dimensions + + /// Shared embedding dimension for all projectors and embedders. + public static let embedDim: Int = 1024 + + // MARK: KV cache geometry + + public static let cdCacheDim: Int = 28672 + public static let cdMaxSeq: Int = 256 + public static let mcdCacheDim: Int = 5120 + public static let mcdMaxSeq: Int = 64 + public static let sdCacheDim: Int = 8192 + public static let sdMaxSeq: Int = 256 + public static let sdHiddenDim: Int = 1024 + public static let sdHiddenContextLen: Int = 16 + + // MARK: Default HuggingFace sources + + /// Tokenizer-compatible model on HuggingFace (must contain `tokenizer.json`). + public static let defaultTokenizerRepo = "Qwen/Qwen3-0.6B" + /// HuggingFace repo hosting the pre-compiled CoreML TTS models. + public static let defaultModelRepo = "argmaxinc/ttskit-coreml" + /// Default HuggingFace Hub endpoint. Override to point at a mirror or on-premise instance. + public static let defaultEndpoint = "https://huggingface.co" + /// Intermediate subdirectory inside the model repo grouping all Qwen3 TTS components. + public static let modelFamilyDir = "qwen3_tts" + /// Default model version directory, equivalent to `TTSModelPreset.qwen3TTS_0_6B.versionDir`. + /// Provided as a stable exported constant; `TTSKitConfig` uses the preset's value directly. + public static let defaultVersionDir = "12hz-0.6b-customvoice" + + // MARK: Token suppression + + /// Codec-0 token IDs suppressed during sampling: [2048, 3072) except EOS (2150). + public static let suppressTokenIds: Set = { + var ids = Set() + for tokenId in 2048..<3072 where tokenId != Int(Qwen3TTSConstants.codecEOS) { + ids.insert(tokenId) + } + return ids + }() +} + +// MARK: - Speaker + +/// Qwen3 TTS speaker voices with their corresponding codec token IDs. +public enum Qwen3Speaker: String, CaseIterable, Sendable { + case ryan, aiden + case onoAnna = "ono-anna" + case sohee, eric, dylan, serena, vivian + case uncleFu = "uncle-fu" + + public var tokenID: Int32 { + switch self { + case .ryan: return 3061 + case .aiden: return 2861 + case .onoAnna: return 2873 + case .sohee: return 2864 + case .eric: return 2875 + case .dylan: return 2878 + case .serena: return 3066 + case .vivian: return 3065 + case .uncleFu: return 3010 + } + } + + /// Human-readable display name (handles hyphenated raw values). + public var displayName: String { + switch self { + case .ryan: return "Ryan" + case .aiden: return "Aiden" + case .onoAnna: return "Ono Anna" + case .sohee: return "Sohee" + case .eric: return "Eric" + case .dylan: return "Dylan" + case .serena: return "Serena" + case .vivian: return "Vivian" + case .uncleFu: return "Uncle Fu" + } + } + + /// Short description of the voice character and quality. + public var voiceDescription: String { + switch self { + case .ryan: return "Dynamic male voice with strong rhythmic drive." + case .aiden: return "Sunny American male voice with a clear midrange." + case .onoAnna: return "Playful Japanese female voice with a light, nimble timbre." + case .sohee: return "Warm Korean female voice with rich emotion." + case .eric: return "Lively Chengdu male voice with a slightly husky brightness." + case .dylan: return "Youthful Beijing male voice with a clear, natural timbre." + case .serena: return "Warm, gentle young female voice." + case .vivian: return "Bright, slightly edgy young female voice." + case .uncleFu: return "Seasoned male voice with a low, mellow timbre." + } + } + + /// The speaker's native language (best quality when used with this language). + public var nativeLanguage: String { + switch self { + case .ryan: return "English" + case .aiden: return "English" + case .onoAnna: return "Japanese" + case .sohee: return "Korean" + case .eric: return "Chinese (Sichuan)" + case .dylan: return "Chinese (Beijing)" + case .serena: return "Chinese" + case .vivian: return "Chinese" + case .uncleFu: return "Chinese" + } + } +} + +// MARK: - Language + +/// Qwen3 TTS supported languages with their corresponding codec token IDs. +public enum Qwen3Language: String, CaseIterable, Sendable { + case english, chinese, japanese, korean, german, french, russian, portuguese, spanish, italian + + public var tokenID: Int32 { + switch self { + case .english: return 2050 + case .chinese: return 2055 + case .japanese: return 2058 + case .korean: return 2064 + case .german: return 2053 + case .french: return 2061 + case .russian: return 2069 + case .portuguese: return 2071 + case .spanish: return 2054 + case .italian: return 2070 + } + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift new file mode 100644 index 00000000..aa9191a1 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3MultiCodeDecoder.swift @@ -0,0 +1,465 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Supporting Types + +/// Update and padding masks for a single MLTensor decode step. +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +struct MLTensorMasks { + let updateMask: MLTensor + let paddingMask: MLTensor +} + +/// Result of a single MLTensor prediction step, including the updated KV cache. +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +struct MLTensorStepResult { + let outputs: [String: MLTensor] + let keyCache: MLTensor + let valueCache: MLTensor + let cachePosition: Int32 + let predictionTime: TimeInterval + let cacheUpdateTime: TimeInterval +} + +// MARK: - Implementation + +/// Multi-code RVQ decoder backed by a CoreML model. +/// +/// Thread safety: mutable state (`model`, dimension properties) is set once during +/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe. +/// Per-call state is created locally within `generateMultiCodes()` and never stored +/// on this shared instance. +public class Qwen3MultiCodeDecoder: MultiCodeDecoding, @unchecked Sendable { + public var model: MLModel? + + /// KV cache embedding dimension, detected from model at load time + public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.mcdCacheDim + /// KV cache max sequence length, detected from model at load time + public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.mcdMaxSeq + /// Codec vocabulary size per head (codes 1-15), detected from model output + public private(set) var codecVocabSize: Int = Qwen3TTSConstants.codecVocabSize + + public init() {} + + public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig) + + // In prewarm mode, compilation is complete - discard to free memory before next model compiles + guard !prewarmMode else { return } + + self.model = loaded + + // Detect dimensions from model description + if let dim = ModelUtilities.getModelOutputDimension(model, named: "key_cache_updates", position: 1) { + self.kvCacheEmbedDim = dim + } + if let seq = ModelUtilities.getModelInputDimension(model, named: "key_padding_mask", position: 1) { + self.kvCacheMaxSequenceLength = seq + } + // all_logits output shape: [1, 15, codecVocabSize] + if let vocab = ModelUtilities.getModelOutputDimension(model, named: "all_logits", position: 2) { + self.codecVocabSize = vocab + } + // input_embeds shape: [1, embedDim, 1, 1] + if let embedDim = ModelUtilities.getModelInputDimension(model, named: "input_embeds", position: 1) { + self.inputEmbedDim = embedDim + } + } + + /// Embedding dimension for `input_embeds`, detected from the model at load time. + public private(set) var inputEmbedDim: Int = Qwen3TTSConstants.embedDim + + public var isStateful: Bool { + guard let model else { return false } + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + return !model.modelDescription.stateDescriptionsByName.isEmpty + } + return false + } + + /// Create a fresh MLState for a new RVQ frame (stateful models only). + /// Returns nil for non-stateful models. The caller owns the returned state. + public func makeState() -> Any? { + guard isStateful, let model else { return nil } + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + return model.makeState() + } + return nil + } + + /// Pre-initialize the ANE pipeline by running dummy predictions before the first + /// real generation step. Without this, the first `generateMultiCodes` call per + /// generation is ~7x slower than steady state due to lazy ANE pipeline setup. + /// + /// Run concurrently with CodeDecoder prefill so there is no net TTFB cost. + /// Replicates the exact loop pattern used in `generateMultiCodes` (4 passes + /// × 16 predictions each) to match what the ANE needs to pipeline. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func prewarmInference() async throws { + guard let model else { return } + let sequenceLength = kvCacheMaxSequenceLength + let dummyInput = MLTensor(zeros: [1, inputEmbedDim, 1, 1], scalarType: FloatType.self) + for _ in 0..<4 { + var keyCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self) + var valueCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self) + var cachePosition: Int32 = 0 + for _ in 0..<16 { + let stepResult = try await predictMLTensorStep( + inputEmbeds: dummyInput, model: model, + keyCache: keyCache, valueCache: valueCache, + cachePosition: cachePosition, sequenceLength: sequenceLength + ) + keyCache = stepResult.keyCache + valueCache = stepResult.valueCache + cachePosition = stepResult.cachePosition + } + } + } + + public func decode(inputEmbeds: any EmbedInputType, cache: KVCache, state: Any? = nil) async throws -> MultiCodeDecoderOutput { + guard let model else { + throw TTSError.generationFailed("MultiCodeDecoder model not loaded") + } + guard let array = inputEmbeds as? MLMultiArray else { + throw TTSError.generationFailed("MultiCodeDecoder: unsupported embed input type \(type(of: inputEmbeds))") + } + + var dict: [String: MLFeatureValue] = try [ + "input_embeds": MLFeatureValue(multiArray: array), + "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()), + "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask), + "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask) + ] + + // Only pass external KV cache for non-stateful models + if !isStateful, let keyCache = cache.keyCache, let valueCache = cache.valueCache { + dict["key_cache"] = MLFeatureValue(multiArray: keyCache) + dict["value_cache"] = MLFeatureValue(multiArray: valueCache) + } + + let input = try MLDictionaryFeatureProvider(dictionary: dict) + + let output: MLFeatureProvider + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState { + output = try await model.asyncPrediction(from: input, using: mlState) + } else { + output = try await model.asyncPrediction(from: input) + } + + guard let keyCacheUpdates = output.featureValue(for: "key_cache_updates")?.multiArrayValue, + let valueCacheUpdates = output.featureValue(for: "value_cache_updates")?.multiArrayValue + else { + throw TTSError.generationFailed("MultiCodeDecoder: missing key/value cache update arrays") + } + + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), let mlState = state as? MLState, isStateful { + KVCache.updateStateCache( + state: mlState, + keyCacheUpdates: keyCacheUpdates, + valueCacheUpdates: valueCacheUpdates, + position: Int(cache.cacheLength) + ) + } + + guard let allLogitsArray = output.featureValue(for: "all_logits")?.multiArrayValue else { + throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits array") + } + return MultiCodeDecoderOutput( + allLogits: allLogitsArray, + keyCacheUpdates: keyCacheUpdates, + valueCacheUpdates: valueCacheUpdates + ) + } + + /// Pure MLTensor path - cache lives as tensors, updated via element-wise masking. + /// No MLMultiArray round-trip: prediction takes/returns MLTensor, cache updates + /// are lazy tensor ops, and only the logits are materialized (by the sampler). + /// Build the update mask and padding mask tensors for a given cache position. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + func buildMasks(position: Int, sequenceLength: Int) -> MLTensorMasks { + var updateData = [FloatType](repeating: 0, count: sequenceLength) + var paddingData = [FloatType](repeating: -10000, count: sequenceLength) + if position < sequenceLength { updateData[position] = 1 } + for index in 0...min(position, sequenceLength - 1) { + paddingData[index] = 0 + } + return MLTensorMasks( + updateMask: MLTensor(shape: [1, sequenceLength], scalars: updateData), + paddingMask: MLTensor(shape: [1, sequenceLength], scalars: paddingData) + ) + } + + /// Run one MLTensor prediction step and return the outputs with an updated KV cache. + /// + /// Cache updates are performed in tensor space via element-wise masking - + /// no MLMultiArray round-trip occurs. The model must be compiled for + /// single-token `[1, embedDim, 1, 1]` input. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + func predictMLTensorStep( + inputEmbeds: MLTensor, + model: MLModel, + keyCache: MLTensor, + valueCache: MLTensor, + cachePosition: Int32, + sequenceLength: Int + ) async throws -> MLTensorStepResult { + let masks = buildMasks(position: Int(cachePosition), sequenceLength: sequenceLength) + let predictionStart = CFAbsoluteTimeGetCurrent() + let outputs = try await model.prediction(from: [ + "input_embeds": inputEmbeds, + "cache_length": MLTensor(shape: [1], scalars: [cachePosition]), + "kv_cache_update_mask": masks.updateMask, + "key_padding_mask": masks.paddingMask, + "key_cache": keyCache, + "value_cache": valueCache + ]) + let predictionTime = CFAbsoluteTimeGetCurrent() - predictionStart + + let cacheUpdateStart = CFAbsoluteTimeGetCurrent() + var positionMaskData = [FloatType](repeating: 0, count: sequenceLength) + positionMaskData[Int(cachePosition)] = 1 + let positionMask = MLTensor(shape: [1, 1, 1, sequenceLength], scalars: positionMaskData) + let invertedMask = MLTensor(repeating: FloatType(1), shape: [1, 1, 1, sequenceLength]) - positionMask + guard let keyCacheOutput = outputs["key_cache_updates"], + let valueCacheOutput = outputs["value_cache_updates"] + else { + throw TTSError.generationFailed("MultiCodeDecoder: missing key/value cache update tensors") + } + let updatedKeyCache = keyCache * invertedMask + keyCacheOutput * positionMask + let updatedValueCache = valueCache * invertedMask + valueCacheOutput * positionMask + let cacheUpdateTime = CFAbsoluteTimeGetCurrent() - cacheUpdateStart + + return MLTensorStepResult( + outputs: outputs, + keyCache: updatedKeyCache, + valueCache: updatedValueCache, + cachePosition: cachePosition + 1, + predictionTime: predictionTime, + cacheUpdateTime: cacheUpdateTime + ) + } + + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func generateMultiCodes( + hiddenStatesTensor: MLTensor, + code0EmbedTensor: MLTensor, + multiCodeEmbedder: any MultiCodeEmbedding, + sampler: any TokenSampling, + options: GenerationOptions + ) async throws -> MultiCodeGenerationResult { + guard let model else { throw TTSError.generationFailed("MultiCodeDecoder model not loaded") } + + let sequenceLength = kvCacheMaxSequenceLength + var keyCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self) + var valueCache = MLTensor(zeros: [1, kvCacheEmbedDim, 1, sequenceLength], scalarType: FloatType.self) + var cachePosition: Int32 = 0 + var timings = SpeechTimings() + + // Prefill: hidden_states, then code_0_embed. First call's logits are discarded. + var stepResult = try await predictMLTensorStep( + inputEmbeds: hiddenStatesTensor, model: model, + keyCache: keyCache, valueCache: valueCache, + cachePosition: cachePosition, sequenceLength: sequenceLength + ) + keyCache = stepResult.keyCache + valueCache = stepResult.valueCache + cachePosition = stepResult.cachePosition + timings.multiCodeDecoderPredictions += stepResult.predictionTime + timings.totalMultiCodeDecoderPredictions += 1 + timings.decodingKvCaching += stepResult.cacheUpdateTime + + stepResult = try await predictMLTensorStep( + inputEmbeds: code0EmbedTensor, model: model, + keyCache: keyCache, valueCache: valueCache, + cachePosition: cachePosition, sequenceLength: sequenceLength + ) + keyCache = stepResult.keyCache + valueCache = stepResult.valueCache + cachePosition = stepResult.cachePosition + timings.multiCodeDecoderPredictions += stepResult.predictionTime + timings.totalMultiCodeDecoderPredictions += 1 + timings.decodingKvCaching += stepResult.cacheUpdateTime + + var stepIndex = 0 + let samplingStart = CFAbsoluteTimeGetCurrent() + guard let firstStepLogits = stepResult.outputs["all_logits"] else { + throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits tensor on step 0") + } + let code1 = await sampler.sampleMultiHead( + allLogits: firstStepLogits, + headIndex: stepIndex, + temperature: options.temperature, + topK: options.topK + ) + timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - samplingStart + var codes: [Int32] = [code1] + + var offsetCodeEmbedTensors: [MLTensor] = [] + offsetCodeEmbedTensors.reserveCapacity(14) + + for _ in 0..<14 { + let embeddingStart = CFAbsoluteTimeGetCurrent() + guard let lastCode = codes.last else { + throw TTSError.generationFailed("MultiCodeDecoder: codes array is empty in MLTensor loop") + } + let offsetId = lastCode + Int32(codecVocabSize * stepIndex) + let embedTensor: MLTensor = try await multiCodeEmbedder.embed(tokenId: offsetId) + offsetCodeEmbedTensors.append(embedTensor) + timings.multiCodeDecoderEmbedding += CFAbsoluteTimeGetCurrent() - embeddingStart + + stepResult = try await predictMLTensorStep( + inputEmbeds: embedTensor, model: model, + keyCache: keyCache, valueCache: valueCache, + cachePosition: cachePosition, sequenceLength: sequenceLength + ) + keyCache = stepResult.keyCache + valueCache = stepResult.valueCache + cachePosition = stepResult.cachePosition + timings.multiCodeDecoderPredictions += stepResult.predictionTime + timings.totalMultiCodeDecoderPredictions += 1 + timings.decodingKvCaching += stepResult.cacheUpdateTime + + stepIndex += 1 + + let nextSamplingStart = CFAbsoluteTimeGetCurrent() + guard let nextStepLogits = stepResult.outputs["all_logits"] else { + throw TTSError.generationFailed("MultiCodeDecoder: missing all_logits tensor on step \(stepIndex)") + } + let code = await sampler.sampleMultiHead( + allLogits: nextStepLogits, + headIndex: stepIndex, + temperature: options.temperature, + topK: options.topK + ) + timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - nextSamplingStart + codes.append(code) + } + + return MultiCodeGenerationResult(codes: codes, timings: timings, offsetCodeEmbedTensors: offsetCodeEmbedTensors) + } + + /// Legacy path - kept for OS compatibility (pre-macOS 15). + // TODO: Remove forking logic with package with min os version upgrade + public func generateMultiCodes( + hiddenStates: [FloatType], + code0Embed: [FloatType], + multiCodeEmbedder: any MultiCodeEmbedding, + sampler: any TokenSampling, + options: GenerationOptions + ) async throws -> MultiCodeGenerationResult { + var timings = SpeechTimings() + let mcdCache = try KVCache( + cacheDim: kvCacheEmbedDim, + maxSeqLength: kvCacheMaxSequenceLength, + isStateful: isStateful + ) + + let frameState = makeState() + let embedDim = hiddenStates.count + let reuseArray = try MLMultiArray(shape: [1, NSNumber(value: embedDim), 1, 1], dataType: .float16) + + // Prefill step 1: hiddenStates from CodeDecoder + let prefillStart = CFAbsoluteTimeGetCurrent() + var mcdOutput = try await decodeEmbedBuffer(hiddenStates, reuseArray: reuseArray, cache: mcdCache, state: frameState) + timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - prefillStart + timings.totalMultiCodeDecoderPredictions += 1 + + let cacheUpdateStart = CFAbsoluteTimeGetCurrent() + guard let prefillKeyUpdates = mcdOutput.keyCacheUpdates, + let prefillValueUpdates = mcdOutput.valueCacheUpdates + else { + throw TTSError.generationFailed("MultiCodeDecoder: missing cache updates after prefill step 1") + } + mcdCache.update(keyCacheUpdates: prefillKeyUpdates, valueCacheUpdates: prefillValueUpdates) + timings.decodingKvCaching += CFAbsoluteTimeGetCurrent() - cacheUpdateStart + + // Prefill step 2: code0 embedding + let prefill2Start = CFAbsoluteTimeGetCurrent() + mcdOutput = try await decodeEmbedBuffer(code0Embed, reuseArray: reuseArray, cache: mcdCache, state: frameState) + timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - prefill2Start + timings.totalMultiCodeDecoderPredictions += 1 + + var stepIndex = 0 + let samplingStart = CFAbsoluteTimeGetCurrent() + let code1 = await sampler.sampleMultiHead( + allLogits: mcdOutput.allLogits, + headIndex: stepIndex, + temperature: options.temperature, + topK: options.topK + ) + timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - samplingStart + var codes: [Int32] = [code1] + + var offsetCodeEmbeds: [[FloatType]] = [] + offsetCodeEmbeds.reserveCapacity(14) + + for _ in 0..<14 { + let cacheStep = CFAbsoluteTimeGetCurrent() + guard let loopKeyUpdates = mcdOutput.keyCacheUpdates, + let loopValueUpdates = mcdOutput.valueCacheUpdates + else { + throw TTSError.generationFailed("MultiCodeDecoder: missing cache updates in generation loop") + } + mcdCache.update(keyCacheUpdates: loopKeyUpdates, valueCacheUpdates: loopValueUpdates) + timings.decodingKvCaching += CFAbsoluteTimeGetCurrent() - cacheStep + + let embeddingStart = CFAbsoluteTimeGetCurrent() + guard let lastCode = codes.last else { + throw TTSError.generationFailed("MultiCodeDecoder: codes array is empty in legacy loop") + } + let offsetId = lastCode + Int32(codecVocabSize * stepIndex) + let codeEmbedBuf = try await multiCodeEmbedder.embed(tokenId: offsetId) + offsetCodeEmbeds.append(codeEmbedBuf) + timings.multiCodeDecoderEmbedding += CFAbsoluteTimeGetCurrent() - embeddingStart + + let decodingStart = CFAbsoluteTimeGetCurrent() + mcdOutput = try await decodeEmbedBuffer(codeEmbedBuf, reuseArray: reuseArray, cache: mcdCache, state: frameState) + timings.multiCodeDecoderPredictions += CFAbsoluteTimeGetCurrent() - decodingStart + timings.totalMultiCodeDecoderPredictions += 1 + + stepIndex += 1 + + let nextSamplingStart = CFAbsoluteTimeGetCurrent() + let code = await sampler.sampleMultiHead( + allLogits: mcdOutput.allLogits, + headIndex: stepIndex, + temperature: options.temperature, + topK: options.topK + ) + timings.multiCodeDecoderSampling += CFAbsoluteTimeGetCurrent() - nextSamplingStart + codes.append(code) + } + + return MultiCodeGenerationResult( + codes: codes, + timings: timings, + offsetCodeEmbeds: offsetCodeEmbeds + ) + } + + /// Copy `embed` into a pre-allocated reuse array and run a single decode step. + /// Reusing `reuseArray` avoids per-step MLMultiArray allocation. + func decodeEmbedBuffer( + _ embed: [FloatType], + reuseArray: MLMultiArray, + cache: KVCache, + state: Any? + ) async throws -> MultiCodeDecoderOutput { + let reusePointer = reuseArray.dataPointer.bindMemory(to: FloatType.self, capacity: embed.count) + embed.withUnsafeBufferPointer { src in + guard let baseAddress = src.baseAddress else { return } + reusePointer.update(from: baseAddress, count: embed.count) + } + return try await decode(inputEmbeds: reuseArray, cache: cache, state: state) + } + + public func unloadModel() { + model = nil + } +} diff --git a/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift b/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift new file mode 100644 index 00000000..e9ae3789 --- /dev/null +++ b/Sources/TTSKit/Qwen3TTS/Qwen3SpeechDecoder.swift @@ -0,0 +1,187 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Accelerate +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Implementation + +/// RVQ-to-audio waveform decoder backed by a CoreML model. +/// +/// Thread safety: mutable state (`model`, dimension properties) is set once during +/// `loadModel()` and read-only thereafter. `MLModel.prediction()` is thread-safe. +public class Qwen3SpeechDecoder: SpeechDecoding, @unchecked Sendable { + public var model: MLModel? + + // MARK: - Audio format + + public let sampleRate: Int = Qwen3TTSConstants.sampleRate + public let samplesPerFrame: Int = Qwen3TTSConstants.samplesPerFrame + /// Minimum pre-buffer: 80ms ≈ 2 audio frames at 24 kHz / 1920 spf. + public let minimumBufferDuration: TimeInterval = 0.08 + + /// Detected from model metadata at load time + public private(set) var hiddenContextLen: Int = Qwen3TTSConstants.sdHiddenContextLen + /// KV cache embedding dimension + public private(set) var kvCacheEmbedDim: Int = Qwen3TTSConstants.sdCacheDim + /// KV cache max sequence length + public private(set) var kvCacheMaxSequenceLength: Int = Qwen3TTSConstants.sdMaxSeq + /// Hidden state dimension + public private(set) var hiddenDim: Int = Qwen3TTSConstants.sdHiddenDim + public init() {} + + public func loadModel(at url: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + let loaded = try await MLModel.load(contentsOf: url, configuration: modelConfig) + + // In prewarm mode, compilation is complete - discard to free memory before next model compiles + guard !prewarmMode else { return } + + self.model = loaded + + // Detect dimensions from model description + // hidden_context input shape: [1, hiddenDim, 1, contextLen] + if let dim = ModelUtilities.getModelInputDimension(model, named: "hidden_context", position: 1) { + self.hiddenDim = dim + } + if let ctxLen = ModelUtilities.getModelInputDimension(model, named: "hidden_context", position: 3) { + self.hiddenContextLen = ctxLen + } + // key_cache input shape: [1, cacheDim, 1, maxSeqLen] + if let dim = ModelUtilities.getModelInputDimension(model, named: "key_cache", position: 1) { + self.kvCacheEmbedDim = dim + } + if let seq = ModelUtilities.getModelInputDimension(model, named: "key_cache", position: 3) { + self.kvCacheMaxSequenceLength = seq + } + } + + public func decodeFrame( + codes: [Int32], + cache: SpeechDecoderCache + ) async throws -> [Float] { + guard let model else { + throw TTSError.generationFailed("SpeechDecoder model not loaded") + } + + let codesArr = try MLMultiArray(shape: [1, 16, 1], dataType: .int32) + let codesPtr = codesArr.dataPointer.bindMemory(to: Int32.self, capacity: 16) + for i in 0..<16 { + codesPtr[i] = codes[i] + } + + guard let keyCache = cache.keyCache, let valueCache = cache.valueCache else { + throw TTSError.generationFailed("SpeechDecoder: KV cache not initialized") + } + let input = try MLDictionaryFeatureProvider(dictionary: [ + "audio_codes": MLFeatureValue(multiArray: codesArr), + "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()), + "key_cache": MLFeatureValue(multiArray: keyCache), + "value_cache": MLFeatureValue(multiArray: valueCache), + "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask), + "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask), + "hidden_context": MLFeatureValue(multiArray: cache.hiddenContext) + ]) + + let output = try await model.asyncPrediction(from: input) + cache.updateWithHiddenContext(output: output) + + guard let audioArr = output.featureValue(for: "audio")?.multiArrayValue else { + throw TTSError.generationFailed("SpeechDecoder: missing audio output array") + } + let sampleCount = audioArr.count + let audioPtr = audioArr.dataPointer.bindMemory(to: FloatType.self, capacity: sampleCount) + var samples = [Float](repeating: 0, count: sampleCount) + for i in 0.. SpeechDecoderTimedResult { + guard let model else { + throw TTSError.generationFailed("SpeechDecoder model not loaded") + } + + var timings = SpeechTimings() + + // TODO: Remove forking logic with package with min os version upgrade + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + guard let keyCacheTensor = cache.keyCacheTensor, + let valueCacheTensor = cache.valueCacheTensor + else { + throw TTSError.generationFailed("SpeechDecoder: KV cache tensors not initialized") + } + let inputs: [String: MLTensor] = [ + "audio_codes": MLTensor(shape: [1, codes.count, 1], scalars: codes), + "cache_length": cache.cacheLengthTensor, + "key_cache": keyCacheTensor, + "value_cache": valueCacheTensor, + "kv_cache_update_mask": cache.kvCacheUpdateMaskTensor, + "key_padding_mask": cache.keyPaddingMaskTensor, + "hidden_context": cache.hiddenContextTensor + ] + + let predictionStart = CFAbsoluteTimeGetCurrent() + let outputs = try await model.prediction(from: inputs) + timings.speechDecoderPredictions += CFAbsoluteTimeGetCurrent() - predictionStart + + await cache.updateWithHiddenContext(tensorOutputs: outputs) + + guard let audioTensor = outputs["audio"] else { + throw TTSError.generationFailed("SpeechDecoder: missing audio tensor output") + } + let samples = await audioTensor.toFloatArray() + return SpeechDecoderTimedResult(samples: samples, timings: timings) + } else { + let codesArr = try MLMultiArray(shape: [1, 16, 1], dataType: .int32) + let codesPtr = codesArr.dataPointer.bindMemory(to: Int32.self, capacity: 16) + for i in 0..<16 { + codesPtr[i] = codes[i] + } + guard let keyCacheFallback = cache.keyCache, let valueCacheFallback = cache.valueCache else { + throw TTSError.generationFailed("SpeechDecoder: KV cache not initialized (legacy path)") + } + let input = try MLDictionaryFeatureProvider(dictionary: [ + "audio_codes": MLFeatureValue(multiArray: codesArr), + "cache_length": MLFeatureValue(multiArray: cache.makeCacheLengthArray()), + "key_cache": MLFeatureValue(multiArray: keyCacheFallback), + "value_cache": MLFeatureValue(multiArray: valueCacheFallback), + "kv_cache_update_mask": MLFeatureValue(multiArray: cache.kvCacheUpdateMask), + "key_padding_mask": MLFeatureValue(multiArray: cache.keyPaddingMask), + "hidden_context": MLFeatureValue(multiArray: cache.hiddenContext) + ]) + + let predictionStart = CFAbsoluteTimeGetCurrent() + let output = try await model.asyncPrediction(from: input, options: MLPredictionOptions()) + timings.speechDecoderPredictions += CFAbsoluteTimeGetCurrent() - predictionStart + + cache.updateWithHiddenContext(output: output) + guard let audioArr = output.featureValue(for: "audio")?.multiArrayValue else { + throw TTSError.generationFailed("SpeechDecoder: missing audio output array (legacy path)") + } + let sampleCount = audioArr.count + let audioPtr = audioArr.dataPointer.bindMemory(to: FloatType.self, capacity: sampleCount) + var samples = [Float](repeating: 0, count: sampleCount) + for i in 0.. [FloatType] { + guard let model else { throw TTSError.generationFailed("TextProjector model not loaded") } + let ids = try EmbedUtilities.makeInt32Array([tokenId]) + let provider = try MLDictionaryFeatureProvider(dictionary: ["input_ids": MLFeatureValue(multiArray: ids)]) + let output = try await model.asyncPrediction(from: provider) + guard let embedArray = output.featureValue(for: "input_embeds")?.multiArrayValue else { + throw TTSError.generationFailed("TextProjector: missing input_embeds output") + } + return EmbedUtilities.extractEmbed(from: embedArray) + } + + /// Optimised async path: passes `[String: MLTensor]` directly - no FeatureProvider boxing. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func project(tokenId: Int32) async throws -> MLTensor { + guard let model else { throw TTSError.generationFailed("TextProjector model not loaded") } + let outputs = try await model.prediction(from: [ + "input_ids": MLTensor(shape: [1], scalars: [tokenId]) + ]) + guard let embedTensor = outputs["input_embeds"] else { + throw TTSError.generationFailed("TextProjector: missing input_embeds tensor output") + } + return embedTensor + } + + public func unloadModel() { + model = nil + } +} diff --git a/Sources/TTSKit/TTSKit.swift b/Sources/TTSKit/TTSKit.swift new file mode 100644 index 00000000..d053234c --- /dev/null +++ b/Sources/TTSKit/TTSKit.swift @@ -0,0 +1,1147 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation +import Hub +import Tokenizers +import os + +// MARK: - Callback Typealiases + +// MARK: - TTSKit + +/// Generic TTS orchestrator: text chunking, concurrent generation, crossfade, and audio playback. +/// +/// Following the WhisperKit pattern, `TTSKit` exposes each model component as a +/// protocol-typed `public var`. Swap any component at runtime to change behaviour: +/// ```swift +/// let config = TTSKitConfig(load: false) +/// let tts = try await TTSKit(config) +/// tts.codeDecoder = MyOptimisedCodeDecoder() +/// try await tts.loadModels() +/// ``` +/// +/// The default implementation uses Qwen3 TTS components (`Sources/TTSKit/Qwen3TTS/`). +/// Components from entirely different model families can be plugged in by conforming +/// to the same component protocols, or by implementing `SpeechModel` directly. +/// +/// `setupGenerateTask(...)` returns an `any SpeechGenerating` - override it to use a +/// completely different generation algorithm while keeping the chunking, concurrency, +/// crossfade, and playback orchestration provided by `generate` and `play`. +/// Mirrors `WhisperKit.setupTranscribeTask(...)`. +open class TTSKit: @unchecked Sendable { + // MARK: - Model components (protocol-typed, swappable) + + /// Text token -> embedding. Conforms to `TextProjecting`. + public var textProjector: any TextProjecting = Qwen3TextProjector() + /// Codec-0 token -> embedding. Conforms to `CodeEmbedding`. + public var codeEmbedder: any CodeEmbedding = Qwen3CodeEmbedder() + /// Multi-code token -> embedding. Conforms to `MultiCodeEmbedding`. + public var multiCodeEmbedder: any MultiCodeEmbedding = Qwen3MultiCodeEmbedder() + /// Autoregressive code-0 decoder. Conforms to `CodeDecoding`. + public var codeDecoder: any CodeDecoding = Qwen3CodeDecoder() + /// Per-frame decoder. Conforms to `MultiCodeDecoding`. + public var multiCodeDecoder: any MultiCodeDecoding = Qwen3MultiCodeDecoder() + /// RVQ codes -> audio waveform. Conforms to `SpeechDecoding`. + public var speechDecoder: any SpeechDecoding = Qwen3SpeechDecoder() + /// Tokenizer. `nil` before the first `loadModels()` call or after `unloadModels()`. + public var tokenizer: (any Tokenizer)? + + // MARK: - Model state + + /// Current lifecycle state of the loaded models. + /// Mirrors `WhisperKit.modelState`. Transitions: + /// `.unloaded` -> `.downloading` -> `.downloaded` -> `.loading` -> `.loaded` + /// `.unloaded` -> `.prewarming` -> `.prewarmed` + public private(set) var modelState: ModelState = .unloaded { + didSet { modelStateCallback?(oldValue, modelState) } + } + + // MARK: - Configuration & timing + + public var config: TTSKitConfig + + /// Direct accessor for the resolved local model folder. + /// + /// Mirrors `WhisperKit.modelFolder`. Backed by `config.modelFolder`; set by + /// `setupModels()` and may also be assigned directly for offline usage. + public var modelFolder: URL? { + get { config.modelFolder } + set { config.modelFolder = newValue } + } + + /// Whether to use a background `URLSession` for model downloads. + /// + /// Mirrors `WhisperKit.useBackgroundDownloadSession`. Backed by + /// `config.useBackgroundDownloadSession`. + public var useBackgroundDownloadSession: Bool { + get { config.useBackgroundDownloadSession } + set { config.useBackgroundDownloadSession = newValue } + } + + /// Cumulative timings for the most recent pipeline run. + /// `modelLoading` and `tokenizerLoadTime` are populated after `loadModels()`. + public private(set) var currentTimings = SpeechTimings() + + /// Wall-clock seconds for the most recent full model load. + public var modelLoadTime: TimeInterval { currentTimings.modelLoading } + /// Wall-clock seconds for the most recent tokenizer load. + public var tokenizerLoadTime: TimeInterval { currentTimings.tokenizerLoadTime } + + // MARK: - Audio output + + /// Audio output used by `play`. + /// `AudioOutput` is playback-only; WhisperKit's `AudioProcessor` is capture-only. + /// They serve complementary roles and do not need to be merged. + public let audioOutput = AudioOutput() + + // MARK: - Prompt cache + + /// Cached prefix state for the most recently used voice/language/instruction. + /// Automatically built on the first `generate` call and reused for subsequent + /// calls with the same parameters. Set to `nil` to force a full prefill. + public var promptCache: TTSPromptCache? + + // MARK: - Callbacks + + /// Invoked whenever `modelState` changes. + public var modelStateCallback: ModelStateCallback? + + // MARK: - Seed + + public let seed: UInt64? + private var taskCounter: UInt64 = 0 + + // MARK: - Initialization + + /// Create a `TTSKit` instance from a `TTSKitConfig`. + /// + /// Uses the component overrides in `config` if set; otherwise instantiates the default + /// components for the selected model family. Components can also be replaced after init. + /// + /// - Parameter config: Pipeline configuration (model variant, paths, compute units, + /// component overrides, lifecycle flags). + /// - Throws: `TTSError` if the model family is unsupported or component instantiation fails. + public init(_ config: TTSKitConfig = TTSKitConfig()) async throws { + self.config = config + self.seed = config.seed + + Logging.shared.logLevel = config.verbose ? config.logLevel : .none + + setupPipeline(for: config.model, config: config) + + // Resolve or download the model folder so that the load condition below + // can use `config.modelFolder != nil` as its auto-load sentinel. + try await setupModels( + model: config.model, + downloadBase: config.downloadBase, + modelRepo: config.modelRepo, + modelToken: config.modelToken, + modelFolder: config.modelFolder, + download: config.download, + endpoint: config.modelEndpoint + ) + + if let prewarm = config.prewarm, prewarm { + Logging.info("Prewarming models...") + try await prewarmModels() + } + + // Load if explicitly requested, or if a local folder is now available + // (either provided directly or populated by setupModels above). + if config.load ?? (config.modelFolder != nil) { + Logging.info("Loading models...") + try await loadModels() + } + } + + /// Convenience initializer that exposes all configuration fields as individual parameters. + /// + /// Mirrors `WhisperKit.init(model:modelFolder:...)`. Constructs a `TTSKitConfig` and + /// delegates to `init(_ config:)`. + public convenience init( + model: TTSModelVariant = .qwen3TTS_0_6b, + modelFolder: URL? = nil, + downloadBase: URL? = nil, + modelRepo: String = Qwen3TTSConstants.defaultModelRepo, + tokenizerFolder: URL? = nil, + modelToken: String? = nil, + computeOptions: ComputeOptions? = nil, + textProjector: (any TextProjecting)? = nil, + codeEmbedder: (any CodeEmbedding)? = nil, + multiCodeEmbedder: (any MultiCodeEmbedding)? = nil, + codeDecoder: (any CodeDecoding)? = nil, + multiCodeDecoder: (any MultiCodeDecoding)? = nil, + speechDecoder: (any SpeechDecoding)? = nil, + verbose: Bool = false, + logLevel: Logging.LogLevel = .debug, + prewarm: Bool? = nil, + load: Bool? = nil, + download: Bool = true, + useBackgroundDownloadSession: Bool = false, + seed: UInt64? = nil + ) async throws { + let config = TTSKitConfig( + model: model, + modelFolder: modelFolder, + downloadBase: downloadBase, + modelRepo: modelRepo, + tokenizerFolder: tokenizerFolder, + modelToken: modelToken, + computeOptions: computeOptions ?? ComputeOptions(), + verbose: verbose, + logLevel: logLevel, + useBackgroundDownloadSession: useBackgroundDownloadSession, + download: download, + prewarm: prewarm, + load: load, + seed: seed + ) + config.textProjector = textProjector + config.codeEmbedder = codeEmbedder + config.multiCodeEmbedder = multiCodeEmbedder + config.codeDecoder = codeDecoder + config.multiCodeDecoder = multiCodeDecoder + config.speechDecoder = speechDecoder + try await self.init(config) + } + + // MARK: - Pipeline setup + + /// Configure the model-specific component properties for the active model family. + /// + /// Uses the component overrides in `config` if set; otherwise instantiates the + /// default components for the given variant's model family. Called from `init` and + /// can be called again to reconfigure the pipeline for a different variant. + /// + /// Mirrors how WhisperKit configures its encoder/decoder components based on the + /// selected model. + open func setupPipeline(for variant: TTSModelVariant, config: TTSKitConfig) { + switch variant.family { + case .qwen3: + self.textProjector = config.textProjector ?? Qwen3TextProjector() + self.codeEmbedder = config.codeEmbedder ?? Qwen3CodeEmbedder() + self.multiCodeEmbedder = config.multiCodeEmbedder ?? Qwen3MultiCodeEmbedder() + self.codeDecoder = config.codeDecoder ?? Qwen3CodeDecoder() + self.multiCodeDecoder = config.multiCodeDecoder ?? Qwen3MultiCodeDecoder() + self.speechDecoder = config.speechDecoder ?? Qwen3SpeechDecoder() + } + } + + // MARK: - Model Discovery + + /// Returns the recommended model variant for the current platform. + /// + /// Mirrors `WhisperKit.recommendedModels()`. + public static func recommendedModels() -> TTSModelVariant { + return TTSModelVariant.defaultForCurrentPlatform + } + + /// Fetch all available model variants from the HuggingFace Hub. + /// + /// Mirrors `WhisperKit.fetchAvailableModels(from:matching:downloadBase:token:)`. + /// + /// - Parameters: + /// - repo: HuggingFace repo ID to query. Defaults to the standard Qwen3 TTS repo. + /// - matching: Glob patterns to filter returned variant names. Defaults to `["*"]` (all variants). + /// - downloadBase: Optional base URL for Hub downloads. + /// - token: HuggingFace API token (or set `HF_TOKEN` env var). + /// - endpoint: HuggingFace Hub endpoint URL. + /// - Returns: Display names of available model variants matching the given patterns. + /// - Throws: `TTSError` if the Hub request fails. + public static func fetchAvailableModels( + from repo: String = Qwen3TTSConstants.defaultModelRepo, + matching: [String] = ["*"], + downloadBase: URL? = nil, + token: String? = nil, + endpoint: String = Qwen3TTSConstants.defaultEndpoint + ) async throws -> [String] { + let hubApi = HubApi(downloadBase: downloadBase, hfToken: token, endpoint: endpoint) + let hubRepo = Hub.Repo(id: repo, type: .models) + let files = try await hubApi.getFilenames(from: hubRepo, matching: ["qwen3_tts/**"]) + var variants: [String] = [] + for variant in TTSModelVariant.allCases { + let prefix = "qwen3_tts/code_decoder/\(variant.versionDir)/" + if files.contains(where: { $0.hasPrefix(prefix) }) { + variants.append(variant.displayName) + } + } + let allVariants = variants.isEmpty ? TTSModelVariant.allCases.map(\.displayName) : variants + var filteredVariants: Set = [] + for glob in matching { + filteredVariants = filteredVariants.union(allVariants.matching(glob: glob)) + } + return Array(filteredVariants) + } + + // MARK: - Download + + /// Download models for a specific variant from HuggingFace Hub. + /// + /// Mirrors `WhisperKit.download(variant:downloadBase:from:token:progressCallback:)`. + /// + /// - Parameters: + /// - variant: The model variant to download. + /// - downloadBase: Base URL for the local cache. Defaults to the Hub library default. + /// - useBackgroundSession: Use a background `URLSession` for the download. + /// - repo: HuggingFace repo ID. Defaults to the standard Qwen3 TTS repo. + /// - token: HuggingFace API token (or set `HF_TOKEN` env var). + /// - endpoint: HuggingFace Hub endpoint URL. + /// - revision: Specific git revision (commit SHA, tag, or branch) to download. + /// - additionalPatterns: Extra glob patterns to include alongside the default component patterns. + /// - progressCallback: Optional closure receiving download progress updates. + /// - Returns: Local URL of the downloaded model folder. + /// - Throws: `TTSError` if the Hub download fails. + open class func download( + variant: TTSModelVariant = .defaultForCurrentPlatform, + downloadBase: URL? = nil, + useBackgroundSession: Bool = false, + from repo: String = Qwen3TTSConstants.defaultModelRepo, + token: String? = nil, + endpoint: String = Qwen3TTSConstants.defaultEndpoint, + revision: String? = nil, + additionalPatterns: [String] = [], + progressCallback: (@Sendable (Progress) -> Void)? = nil + ) async throws -> URL { + let config = TTSKitConfig( + model: variant, + downloadBase: downloadBase, + modelRepo: repo, + modelToken: token, + modelEndpoint: endpoint, + downloadRevision: revision, + downloadAdditionalPatterns: additionalPatterns, + useBackgroundDownloadSession: useBackgroundSession + ) + return try await download(config: config, progressCallback: progressCallback) + } + + /// Download models using a full `TTSKitConfig`. + /// + /// Downloads only the files matching the configured component variants. + /// Files are cached locally by the Hub library. + /// + /// - Parameters: + /// - config: Pipeline configuration containing `modelRepo`, `modelToken`, + /// `downloadRevision`, `downloadAdditionalPatterns`, and variant settings. + /// - progressCallback: Optional closure receiving download progress updates. + /// - Returns: Local URL of the downloaded model folder. + /// - Throws: `TTSError` if the Hub download fails. + open class func download( + config: TTSKitConfig = TTSKitConfig(), + progressCallback: (@Sendable (Progress) -> Void)? = nil + ) async throws -> URL { + let hubApi = HubApi(downloadBase: config.downloadBase, hfToken: config.modelToken, endpoint: config.modelEndpoint) + let hubRepo = Hub.Repo(id: config.modelRepo, type: .models) + let patterns = config.downloadPatterns + config.downloadAdditionalPatterns + + do { + return try await hubApi.snapshot( + from: hubRepo, + revision: config.downloadRevision ?? "main", + matching: patterns + ) { progress in + progressCallback?(progress) + } + } catch { + throw TTSError.generationFailed( + "Failed to download models from \(config.modelRepo). Check that the repo exists and you have access. Error: \(error.localizedDescription)" + ) + } + } + + // MARK: - Model lifecycle + + /// Resolve the local model folder, downloading from HuggingFace Hub if needed. + /// + /// Mirrors `WhisperKit.setupModels(model:downloadBase:modelRepo:...)`. Populates + /// `config.modelFolder` so that `loadModels()` can be called immediately after. + /// Separated from `loadModels()` so callers can call setup once and load separately, + /// or override just the resolution logic without touching the load path. + /// + /// - Parameters: + /// - model: Model variant to download. `nil` uses `config.model`. + /// - downloadBase: Base URL for Hub cache. `nil` uses the Hub library default. + /// - modelRepo: HuggingFace repo ID. `nil` uses `config.modelRepo`. + /// - modelToken: HuggingFace API token. `nil` uses `config.modelToken`. + /// - modelFolder: Explicit local folder URL. When non-nil the download is skipped. + /// - download: When `true` and `modelFolder` is nil, download from the resolved repo. + /// - endpoint: HuggingFace Hub endpoint URL. Defaults to `Qwen3TTSConstants.defaultEndpoint`. + /// - Throws: `TTSError` if the download fails or the model folder cannot be resolved. + open func setupModels( + model: TTSModelVariant? = nil, + downloadBase: URL? = nil, + modelRepo: String? = nil, + modelToken: String? = nil, + modelFolder: URL? = nil, + download: Bool, + endpoint: String = Qwen3TTSConstants.defaultEndpoint + ) async throws { + if let folder = modelFolder { + config.modelFolder = folder + } else if download { + let resolvedModel = model ?? config.model + let resolvedRepo = modelRepo ?? config.modelRepo + let resolvedToken = modelToken ?? config.modelToken + + let downloadConfig = TTSKitConfig( + model: resolvedModel, + downloadBase: downloadBase ?? config.downloadBase, + modelRepo: resolvedRepo, + modelToken: resolvedToken, + modelEndpoint: endpoint, + downloadRevision: config.downloadRevision, + downloadAdditionalPatterns: config.downloadAdditionalPatterns, + useBackgroundDownloadSession: config.useBackgroundDownloadSession + ) + + modelState = .downloading + Logging.info("Downloading models from \(resolvedRepo)...") + do { + let folder = try await TTSKit.download(config: downloadConfig) { progress in + let percent = Int(progress.fractionCompleted * 100) + Logging.debug(" Download: \(percent)%") + } + config.modelFolder = folder + modelState = .downloaded + Logging.info("Models cached at \(folder.path)") + } catch { + modelState = .unloaded + throw TTSError.modelNotFound( + "Model download failed. Check the repo name and try again. Error: \(error)" + ) + } + } + } + + /// Prewarm all CoreML models by compiling them sequentially, then discarding weights. + /// + /// Serializes CoreML compilation to cap peak memory. Call before `loadModels()` on + /// first launch or after a model update. Mirrors `WhisperKit.prewarmModels()`. + open func prewarmModels() async throws { + try await loadModels(prewarmMode: true) + } + + /// Load all models and the tokenizer. + /// + /// Expects `config.modelFolder` to be set (call `setupModels` first if needed). + /// Mirrors `WhisperKit.loadModels(prewarmMode:)`. + /// + /// - Parameter prewarmMode: When `true`, compile models one at a time and discard weights + /// to limit peak memory (prewarm). When `false` (default), load all concurrently. + /// - Throws: `TTSError` if model compilation or tokenizer loading fails. + open func loadModels(prewarmMode: Bool = false) async throws { + modelState = prewarmMode ? .prewarming : .loading + + let embedUnits = config.computeOptions.embedderComputeUnits + let cdUnits = config.computeOptions.codeDecoderComputeUnits + let mcdUnits = config.computeOptions.multiCodeDecoderComputeUnits + let sdUnits = config.computeOptions.speechDecoderComputeUnits + + guard let modelFolder = config.modelFolder, + FileManager.default.fileExists(atPath: modelFolder.path) + else { + modelState = .unloaded + throw TTSError.modelNotFound(config.modelFolder?.path ?? "") + } + + // Resolve all six component URLs. A nil result means the .mlmodelc bundle is + // missing from disk - surface this immediately rather than crashing later. + func requireURL(_ component: String, _ variant: String) throws -> URL { + guard let url = config.modelURL(component: component, variant: variant) else { + throw TTSError.invalidConfiguration( + "No .mlmodelc found at \(component)/\(config.versionDir)/\(variant) inside \(modelFolder.path)." + ) + } + return url + } + let tpURL = try requireURL("text_projector", config.textProjectorVariant) + let ceURL = try requireURL("code_embedder", config.codeEmbedderVariant) + let mceURL = try requireURL("multi_code_embedder", config.multiCodeEmbedderVariant) + let cdURL = try requireURL("code_decoder", config.codeDecoderVariant) + let mcdURL = try requireURL("multi_code_decoder", config.multiCodeDecoderVariant) + let sdURL = try requireURL("speech_decoder", config.speechDecoderVariant) + + // Load tokenizer (skipped in prewarm - only CoreML compilation needed). + if !prewarmMode { + try await loadTokenizerIfNeeded() + } + + // Load the six CoreML models. + // Prewarm: sequential to serialize compilation -> lower peak memory. + // Normal: concurrent since compiled artifacts are already cached. + let modelLoadStart = CFAbsoluteTimeGetCurrent() + + if prewarmMode { + Logging.info("Prewarming 6 CoreML models sequentially (serializing compilation)...") + try await textProjector.loadModel(at: tpURL, computeUnits: embedUnits, prewarmMode: true) + try await codeEmbedder.loadModel(at: ceURL, computeUnits: embedUnits, prewarmMode: true) + try await multiCodeEmbedder.loadModel(at: mceURL, computeUnits: embedUnits, prewarmMode: true) + try await codeDecoder.loadModel(at: cdURL, computeUnits: cdUnits, prewarmMode: true) + try await multiCodeDecoder.loadModel(at: mcdURL, computeUnits: mcdUnits, prewarmMode: true) + try await speechDecoder.loadModel(at: sdURL, computeUnits: sdUnits, prewarmMode: true) + Logging.info(String(format: "Prewarm complete in %.2fs", CFAbsoluteTimeGetCurrent() - modelLoadStart)) + modelState = .prewarmed + } else { + Logging.info("Loading 6 CoreML models concurrently...") + Logging.debug(" TextProjector: \(tpURL.lastPathComponent) compute: \(embedUnits.description)") + Logging.debug(" CodeEmbedder: \(ceURL.lastPathComponent) compute: \(embedUnits.description)") + Logging.debug(" MultiCodeEmbedder: \(mceURL.lastPathComponent) compute: \(embedUnits.description)") + Logging.debug(" CodeDecoder: \(cdURL.lastPathComponent) (\(config.codeDecoderVariant), compute: \(cdUnits.description))") + Logging.debug(" MultiCodeDecoder: \(mcdURL.lastPathComponent) (\(config.multiCodeDecoderVariant), compute: \(mcdUnits.description))") + Logging.debug(" SpeechDecoder: \(sdURL.lastPathComponent) (\(config.speechDecoderVariant), compute: \(sdUnits.description))") + + async let loadTP: Void = textProjector.loadModel(at: tpURL, computeUnits: embedUnits) + async let loadCE: Void = codeEmbedder.loadModel(at: ceURL, computeUnits: embedUnits) + async let loadMCE: Void = multiCodeEmbedder.loadModel(at: mceURL, computeUnits: embedUnits) + async let loadCD: Void = codeDecoder.loadModel(at: cdURL, computeUnits: cdUnits) + async let loadMCD: Void = multiCodeDecoder.loadModel(at: mcdURL, computeUnits: mcdUnits) + async let loadSD: Void = speechDecoder.loadModel(at: sdURL, computeUnits: sdUnits) + _ = try await (loadTP, loadCE, loadMCE, loadCD, loadMCD, loadSD) + + currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + + // Sync audio output sample rate to the loaded speech decoder. + audioOutput.configure(sampleRate: speechDecoder.sampleRate) + + Logging.info(String(format: "Total model load: %.2fs", modelLoadTime)) + modelState = .loaded + } + } + + /// Load the tokenizer only if it has not been loaded yet. + /// + /// Mirrors `WhisperKit.loadTokenizerIfNeeded()`. Skips loading when `tokenizer` is + /// already set, avoiding redundant network calls or file-system work on repeated + /// `loadModels()` calls. + open func loadTokenizerIfNeeded() async throws { + guard tokenizer == nil else { + Logging.debug("Tokenizer already loaded, skipping") + return + } + self.tokenizer = try await loadTokenizer() + } + + /// Load the tokenizer from `config.tokenizerSource`. + /// + /// Checks for a local `tokenizer.json` file first; falls back to downloading from + /// the Hugging Face Hub if no local file is found. Updates `currentTimings.tokenizerLoadTime`. + /// + /// Override this method to plug in a custom tokenizer loading strategy (e.g. fully + /// offline from a bundled path) without touching the rest of `loadModels()`. + open func loadTokenizer() async throws -> any Tokenizer { + let start = CFAbsoluteTimeGetCurrent() + Logging.info("Loading tokenizer from \(config.tokenizerSource)...") + let tokenizerURL = URL(fileURLWithPath: config.tokenizerSource) + let tokenizer: any Tokenizer + if FileManager.default.fileExists(atPath: tokenizerURL.appending(path: "tokenizer.json").path) { + tokenizer = try await AutoTokenizer.from(modelFolder: tokenizerURL) + } else { + tokenizer = try await AutoTokenizer.from(pretrained: config.tokenizerSource) + } + currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - start + Logging.info(String(format: "Tokenizer loaded in %.2fs", tokenizerLoadTime)) + return tokenizer + } + + /// Release all model weights and the tokenizer from memory. + /// + /// Mirrors `WhisperKit.unloadModels()`. Transitions through `.unloading` before + /// reaching `.unloaded` so observers can distinguish the in-progress state. + open func unloadModels() async { + modelState = .unloading + textProjector.unloadModel() + codeEmbedder.unloadModel() + multiCodeEmbedder.unloadModel() + codeDecoder.unloadModel() + multiCodeDecoder.unloadModel() + speechDecoder.unloadModel() + tokenizer = nil + modelState = .unloaded + Logging.info("Unloaded all models") + } + + /// Reset all accumulated timing statistics. + /// + /// Mirrors `WhisperKit.clearState()`. Call between generation runs when you want + /// fresh per-run timing data without triggering a full reload. + open func clearState() { + currentTimings = SpeechTimings() + } + + deinit { + Task { [audioOutput] in + await audioOutput.stopPlayback(waitForCompletion: false) + } + } + + /// Register a custom log sink for all `Logging` output from TTSKit. + /// + /// Mirrors `WhisperKit.loggingCallback(_:)`. Pass `nil` to restore the default + /// print-based logger. + open func loggingCallback(_ callback: Logging.LoggingCallback?) { + Logging.shared.loggingCallback = callback + } + + // MARK: - Prompt cache management + + /// Build a prompt cache for the given voice/language/instruction combination. + /// + /// Pre-computes the invariant prefix embeddings and prefills them through the + /// CodeDecoder, returning a reusable cache that eliminates ~90% of prefill cost + /// on subsequent `generate` calls. + /// + /// The cache is stored on `self.promptCache` for automatic reuse. Delegates to + /// `Qwen3GenerateTask.buildPromptCache` on the task returned by + /// `setupGenerateTask(...)`, so Qwen3 models get prompt caching automatically. + /// + /// - Parameters: + /// - voice: Voice/speaker identifier. `nil` uses the model's `defaultVoice`. + /// - language: Language identifier. `nil` uses the model's `defaultLanguage`. + /// - instruction: Optional style instruction prepended to the TTS prompt. + /// - Returns: The built `TTSPromptCache` that can be passed to subsequent `generate` calls. + /// - Throws: `TTSError` if the model is not loaded or prompt caching is unsupported. + @discardableResult + open func buildPromptCache( + voice: String? = nil, + language: String? = nil, + instruction: String? = nil + ) async throws -> TTSPromptCache { + let task = try createTask() + let resolvedVoice = voice ?? task.defaultVoice + let resolvedLanguage = language ?? task.defaultLanguage + guard let qwen3Task = task as? Qwen3GenerateTask else { + throw TTSError.generationFailed("Prompt caching is not supported by this model family.") + } + let cache = try await qwen3Task.buildPromptCache(voice: resolvedVoice, language: resolvedLanguage, instruction: instruction) + self.promptCache = cache + return cache + } + + /// Save the current prompt cache to disk under the model's embeddings directory. + /// + /// The file is saved at `/embeddings/_.promptcache`. + public func savePromptCache() throws { + guard let cache = promptCache else { return } + guard let url = promptCacheURL(for: cache) else { + throw TTSError.generationFailed("Cannot determine prompt cache path (modelFolder not set)") + } + try cache.save(to: url) + Logging.info("Saved prompt cache to \(url.path)") + } + + /// Load a prompt cache from disk if one exists for the given parameters. + /// + /// Returns `nil` if no cached file exists. Also stores the loaded cache + /// on `self.promptCache` for automatic reuse. + /// + /// - Parameters: + /// - voice: Voice/speaker identifier. + /// - language: Language identifier. + /// - instruction: Optional style instruction. + /// - Returns: The loaded cache, or `nil` if not found. + @discardableResult + public func loadPromptCache( + voice: String, + language: String, + instruction: String? = nil + ) -> TTSPromptCache? { + let probe = TTSPromptCache( + voice: voice, language: language, instruction: instruction, + prefixLength: 0, + kvSnapshot: KVCacheSnapshot( + isStateful: false, cacheDim: 0, maxSeqLength: 0, cacheLength: 0, + keyCacheData: Data(), valueCacheData: Data(), + updateMaskData: Data(), paddingMaskData: Data() + ), + stateData: nil + ) + guard let url = promptCacheURL(for: probe), + FileManager.default.fileExists(atPath: url.path) + else { return nil } + do { + let cache = try TTSPromptCache.load(from: url) + self.promptCache = cache + Logging.info("Loaded prompt cache from \(url.path)") + return cache + } catch { + Logging.error("Failed to load prompt cache: \(error)") + return nil + } + } + + private func promptCacheURL(for cache: TTSPromptCache) -> URL? { + guard let modelFolder = config.modelFolder else { return nil } + return + modelFolder + .appendingPathComponent("embeddings") + .appendingPathComponent(cache.cacheFileName) + } + + // MARK: - Task factory + + /// Setup the generate task used for speech synthesis. + /// Subclasses may override to provide custom behavior. + /// + /// Mirrors `WhisperKit.setupTranscribeTask(...)`. Model-agnostic params are passed + /// explicitly; model-specific components are accessed from `self` (configured by + /// `setupPipeline`). + open func setupGenerateTask( + currentTimings: SpeechTimings, + progress: Progress, + tokenizer: any Tokenizer, + sampler: any TokenSampling + ) throws -> any SpeechGenerating { + switch config.model.family { + case .qwen3: + guard let qwen3TextProjector = textProjector as? Qwen3TextProjector, + let qwen3CodeEmbedder = codeEmbedder as? Qwen3CodeEmbedder, + let qwen3MultiCodeEmbedder = multiCodeEmbedder as? Qwen3MultiCodeEmbedder, + let qwen3CodeDecoder = codeDecoder as? Qwen3CodeDecoder, + let qwen3MultiCodeDecoder = multiCodeDecoder as? Qwen3MultiCodeDecoder, + let qwen3SpeechDecoder = speechDecoder as? Qwen3SpeechDecoder + else { + throw TTSError.generationFailed("Qwen3 model family requires Qwen3-specific model components") + } + return Qwen3GenerateTask( + textProjector: qwen3TextProjector, + codeEmbedder: qwen3CodeEmbedder, + multiCodeEmbedder: qwen3MultiCodeEmbedder, + codeDecoder: qwen3CodeDecoder, + multiCodeDecoder: qwen3MultiCodeDecoder, + speechDecoder: qwen3SpeechDecoder, + sampler: sampler, + tokenizer: tokenizer, + suppressTokenIds: Qwen3TTSConstants.suppressTokenIds, + loadTimings: currentTimings, + progress: progress + ) + } + } + + /// Create a fresh generation task with the guard/seed/counter boilerplate. + /// + /// Each call returns an independent task with its own sampler seed and per-task + /// buffers. Delegates to `setupGenerateTask(...)` for the actual construction. + open func createTask(progress: Progress? = nil) throws -> any SpeechGenerating { + guard let tokenizer else { + throw TTSError.tokenizerUnavailable("Tokenizer is not loaded. Call loadModels() before generating speech.") + } + let derivedSeed: UInt64? = seed.map { $0 ^ taskCounter } + taskCounter += 1 + return try setupGenerateTask( + currentTimings: currentTimings, + progress: progress ?? Progress(), + tokenizer: tokenizer, + sampler: GreedyTokenSampler(seed: derivedSeed) + ) + } + + // MARK: - Speech generation + + /// Synthesize speech from text and return the complete audio result. + /// + /// Mirrors `WhisperKit.transcribe(audioPath:decodeOptions:callback:)`. + /// Handles text chunking, optional prompt caching, and concurrent multi-chunk generation. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Voice/speaker identifier. Format is model-specific (e.g. `"ryan"` for Qwen3 TTS). + /// - language: Language identifier. Format is model-specific (e.g. `"english"` for Qwen3 TTS). + /// - options: Sampling and generation options. + /// - callback: Optional per-step callback receiving decoded audio chunks. + /// Return `false` to cancel; `nil` or `true` to continue. + /// - Returns: A `SpeechResult` containing the raw audio samples and timing breakdown. + /// - Throws: `TTSError` if text is empty, models are not loaded, or generation fails. + open func generate( + text: String, + voice: String? = nil, + language: String? = nil, + options: GenerationOptions = GenerationOptions(), + callback: SpeechCallback = nil + ) async throws -> SpeechResult { + // Auto-load models if they have not been loaded yet, mirroring WhisperKit's + // runTranscribeTask which calls loadModels() when modelState != .loaded. + if modelState != .loaded { + try await loadModels() + } + + try Task.checkCancellation() + + // Create the primary task to resolve model-specific defaults for voice/language. + // This task is also reused for the single-chunk fast path to avoid a second allocation. + let primaryTask = try createTask() + let resolvedVoice = voice ?? primaryTask.defaultVoice + let resolvedLanguage = language ?? primaryTask.defaultLanguage + + // Build prompt cache ahead of time if none exists or current doesn't match. + let cache: TTSPromptCache? + if let existing = promptCache, existing.matches(voice: resolvedVoice, language: resolvedLanguage, instruction: options.instruction) { + cache = existing + } else if tokenizer != nil { + cache = try await buildPromptCache(voice: resolvedVoice, language: resolvedLanguage, instruction: options.instruction) + } else { + cache = nil + } + + let effectiveStrategy = options.chunkingStrategy ?? .sentence + let textChunks: [String] + if effectiveStrategy == .none || tokenizer == nil { + textChunks = [text] + } else { + guard let tokenizer else { + throw TTSError.tokenizerUnavailable("Tokenizer is not loaded. Call loadModels() before generating speech.") + } + let chunker = TextChunker( + targetChunkSize: options.targetChunkSize ?? TextChunker.defaultTargetChunkSize, + minChunkSize: options.minChunkSize ?? TextChunker.defaultMinChunkSize, + tokenizer: tokenizer + ) + let chunks = chunker.chunk(text) + textChunks = chunks.isEmpty ? [text] : chunks + } + + // Single-chunk fast path: reuse primaryTask (already allocated above). + if textChunks.count == 1 { + return try await primaryTask.run( + text: textChunks[0], + voice: resolvedVoice, + language: resolvedLanguage, + options: options, + callback: callback, + prefixCache: cache + ) + } + + let workerDesc = options.concurrentWorkerCount == 0 ? "max" : "\(options.concurrentWorkerCount)" + Logging.info("Chunked TTS: \(textChunks.count) chunks, concurrency=\(workerDesc)") + for (i, chunk) in textChunks.enumerated() { + let truncated = chunk.count > 60 ? "\(chunk.prefix(60))..." : chunk + Logging.debug(" Chunk \(i): \"\(truncated)\" (\(chunk.count) chars)") + } + + let pipelineStart = CFAbsoluteTimeGetCurrent() + var combinedTimings = SpeechTimings() + combinedTimings.modelLoading = currentTimings.modelLoading + combinedTimings.tokenizerLoadTime = currentTimings.tokenizerLoadTime + + let crossfadeSamples = primaryTask.sampleRate / 10 // 100ms crossfade + var chunkAudioArrays = [[Float]](repeating: [], count: textChunks.count) + + let totalChunks = textChunks.count + + let maxSteps = totalChunks * options.maxNewTokens + + if options.concurrentWorkerCount == 1 { + var stepsSoFar = 0 + for (i, chunkText) in textChunks.enumerated() { + Logging.debug(String(format: " Generating chunk %d/%d...", i + 1, totalChunks)) + let chunkStepBase = stepsSoFar + let wrappedCallback: SpeechCallback = callback.map { cb in + { @Sendable progress in + var p = progress + p.chunkIndex = i + p.totalChunks = totalChunks + p.stepsCompleted = chunkStepBase + Int(progress.timings.totalDecodingLoops) + p.totalSteps = maxSteps + return cb(p) + } + } + let chunkResult = try await (createTask()).run( + text: chunkText, voice: resolvedVoice, language: resolvedLanguage, + options: options, callback: wrappedCallback, prefixCache: cache + ) + stepsSoFar += options.maxNewTokens + chunkAudioArrays[i] = chunkResult.audio + combinedTimings.merge(chunkResult.timings) + if i == 0 { combinedTimings.timeToFirstBuffer = chunkResult.timings.timeToFirstBuffer } + Logging.debug( + String( + format: " Chunk %d done: %.2fs audio (%d steps)", + i + 1, chunkResult.audioDuration, Int(chunkResult.timings.totalDecodingLoops)) + ) + } + } else { + let indexedChunks = textChunks.enumerated().map { (index: $0.offset, text: $0.element) } + + let effectiveWorkers = options.concurrentWorkerCount == 0 ? indexedChunks.count : options.concurrentWorkerCount + + let batchedChunks: [[(index: Int, text: String)]] + batchedChunks = stride(from: 0, to: indexedChunks.count, by: effectiveWorkers).map { + Array(indexedChunks[$0.. Int in + state += 1 + return state + } + let stepProgress = SpeechProgress( + audio: [], timings: progress.timings, + totalChunks: chunkCount, + stepsCompleted: steps, totalSteps: maxSteps + ) + return unwrappedCallback(stepProgress) + } + } + + let maxNewTokens = options.maxNewTokens + let batchResults: [(index: Int, result: SpeechResult)] = try await withThrowingTaskGroup( + of: (index: Int, result: SpeechResult).self + ) { group in + for item in taskItems { + group.addTask { + Logging.debug(String(format: " Starting chunk %d/%d...", item.index + 1, chunkCount)) + let chunkResult = try await item.task.run( + text: item.text, voice: resolvedVoice, language: resolvedLanguage, + options: options, callback: workerCallback, prefixCache: cache + ) + // Snap progress forward to the full budget for this chunk + let actualSteps = Int(chunkResult.timings.totalDecodingLoops) + let remaining = maxNewTokens - actualSteps + if remaining > 0 { + stepCounter.withLock { $0 += remaining } + } + Logging.debug( + String( + format: " Chunk %d done: %.2fs audio (%d steps)", + item.index + 1, chunkResult.audioDuration, actualSteps)) + return (index: item.index, result: chunkResult) + } + } + var results = [(index: Int, result: SpeechResult)]() + for try await result in group { + results.append(result) + } + return results + } + + for entry in batchResults { + chunkAudioArrays[entry.index] = entry.result.audio + combinedTimings.merge(entry.result.timings) + if entry.index == 0 { combinedTimings.timeToFirstBuffer = entry.result.timings.timeToFirstBuffer } + } + } + + // Deliver audio in order via callback after concurrent batch completes. + if let callback { + for (i, chunkAudio) in chunkAudioArrays.enumerated() { + let progress = SpeechProgress( + audio: chunkAudio, timings: combinedTimings, + stepTime: i == 0 ? 0 : nil, + chunkIndex: i, totalChunks: totalChunks + ) + if callback(progress) == false { break } + } + } + } + + // Crossfade consecutive chunks and assemble final audio. + let allAudio = AudioOutput.crossfade(chunkAudioArrays, fadeLength: crossfadeSamples) + + combinedTimings.fullPipeline = CFAbsoluteTimeGetCurrent() - pipelineStart + let sampleRate = primaryTask.sampleRate + combinedTimings.inputAudioSeconds = Double(allAudio.count) / Double(sampleRate) + + let steps = Int(combinedTimings.totalDecodingLoops) + let avgMs = steps > 0 ? combinedTimings.decodingLoop * 1000 / Double(steps) : 0 + Logging.info( + String( + format: "Chunked TTS: %d chunks, %d steps, %.1fms avg/step, %.2fs audio", + textChunks.count, steps, avgMs, Double(allAudio.count) / Double(sampleRate) + )) + + return SpeechResult(audio: allAudio, timings: combinedTimings, sampleRate: sampleRate) + } + + // MARK: - Play Speech + + /// Generate speech and stream it through the audio output in real time. + /// + /// Generates speech and plays it back. + /// + /// For streaming strategies (auto, stream, buffered) chunking is forced to + /// sequential (`concurrentWorkerCount = 1`) so frames can be enqueued in + /// order. `generateFirst` respects the caller's concurrency setting so the + /// full file can be generated with parallel workers before playback begins. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Voice/speaker identifier. + /// - language: Language identifier. + /// - options: Sampling and generation options. + /// - playbackStrategy: Controls how audio is buffered before playback begins. + /// - callback: Optional per-step callback. + /// - Returns: A `SpeechResult` with the complete audio and timing breakdown. + /// - Throws: `TTSError` on generation failure or task cancellation. + open func play( + text: String, + voice: String? = nil, + language: String? = nil, + options: GenerationOptions = GenerationOptions(), + playbackStrategy: PlaybackStrategy = .auto, + callback: SpeechCallback = nil + ) async throws -> SpeechResult { + var playOptions = options + + let audioOut = audioOutput + let maxTokens = playOptions.maxNewTokens + + // Pre-resolve audio format from the task so the playback closure doesn't + // reach into model-specific components (keeps TTSKit model-agnostic). + let formatTask = try createTask() + let samplesPerFrame = formatTask.samplesPerFrame + let sampleRate = formatTask.sampleRate + let minBuffer = formatTask.minimumBufferDuration + + if case .generateFirst = playbackStrategy { + let result = try await generate( + text: text, voice: voice, language: language, + options: playOptions, callback: callback + ) + try audioOut.startPlayback() + audioOut.setBufferDuration(0) + audioOut.enqueueAudioChunk(result.audio) + await audioOut.stopPlayback(waitForCompletion: true) + return result + } + + // Streaming requires sequential generation to preserve chunk order. + playOptions.concurrentWorkerCount = 1 + + try audioOut.startPlayback(deferEngineStart: true) + switch playbackStrategy { + case .stream: audioOut.setBufferDuration(0) + case let .buffered(secs): audioOut.setBufferDuration(secs) + case .auto: break + case .generateFirst: break + } + + let result = try await generate( + text: text, voice: voice, language: language, + options: playOptions, + callback: { progress in + if let stepTime = progress.stepTime, case .auto = playbackStrategy { + let buffer = PlaybackStrategy.requiredBuffer( + stepTime: stepTime, + maxNewTokens: maxTokens, + samplesPerFrame: samplesPerFrame, + sampleRate: sampleRate, + minimumBuffer: minBuffer + ) + audioOut.setBufferDuration(buffer) + let speedRatio = PlaybackStrategy.audioPerStep(samplesPerFrame: samplesPerFrame, sampleRate: sampleRate) / stepTime + Logging.info( + String( + format: "Playback: step %.1fms (%.2fx real-time) -> buffer %.2fs", + stepTime * 1000, speedRatio, buffer)) + } + audioOut.enqueueAudioChunk(progress.audio) + return callback?(progress) + } + ) + + await audioOut.stopPlayback(waitForCompletion: true) + return result + } + + // MARK: - Qwen3-typed convenience API + + /// Build a prompt cache using typed Qwen3 speaker and language enums. + /// + /// - Parameters: + /// - speaker: The `Qwen3Speaker` to pre-warm the cache for. + /// - language: The `Qwen3Language` to pre-warm the cache for. + /// - instruction: Optional style instruction (1.7B only). + /// - Returns: A `TTSPromptCache` for the given parameters. + /// - Throws: `TTSError` on generation failure. + @discardableResult + open func buildPromptCache( + speaker: Qwen3Speaker, + language: Qwen3Language, + instruction: String? = nil + ) async throws -> TTSPromptCache { + try await buildPromptCache( + voice: speaker.rawValue, + language: language.rawValue, + instruction: instruction + ) + } + + /// Generate speech from text using typed Qwen3 speaker and language enums. + /// + /// - Parameters: + /// - text: Input text to synthesise. + /// - speaker: The `Qwen3Speaker` voice to use. + /// - language: The `Qwen3Language` to synthesise in. + /// - options: Generation options controlling sampling, chunking, and concurrency. + /// - callback: Per-step callback receiving decoded audio chunks. Return `false` to cancel. + /// - Returns: The assembled `SpeechResult`. + /// - Throws: `TTSError` on generation failure or task cancellation. + open func generate( + text: String, + speaker: Qwen3Speaker, + language: Qwen3Language = .english, + options: GenerationOptions = GenerationOptions(), + callback: SpeechCallback = nil + ) async throws -> SpeechResult { + try await generate( + text: text, + voice: speaker.rawValue, + language: language.rawValue, + options: options, + callback: callback + ) + } + + /// Generate speech and stream playback using typed Qwen3 speaker and language enums. + /// + /// - Parameters: + /// - text: Input text to synthesise. + /// - speaker: The `Qwen3Speaker` voice to use. + /// - language: The `Qwen3Language` to synthesise in. + /// - options: Generation options controlling sampling, chunking, and concurrency. + /// - playbackStrategy: Controls how much audio is buffered before playback begins. + /// - callback: Per-step callback receiving decoded audio chunks. Return `false` to cancel. + /// - Returns: The assembled `SpeechResult`. + /// - Throws: `TTSError` on generation failure or task cancellation. + open func play( + text: String, + speaker: Qwen3Speaker, + language: Qwen3Language = .english, + options: GenerationOptions = GenerationOptions(), + playbackStrategy: PlaybackStrategy = .auto, + callback: SpeechCallback = nil + ) async throws -> SpeechResult { + try await play( + text: text, + voice: speaker.rawValue, + language: language.rawValue, + options: options, + playbackStrategy: playbackStrategy, + callback: callback + ) + } +} + +// MARK: - SpeechModel conformance + +extension TTSKit: SpeechModel { + /// The output sample rate of the currently loaded speech decoder. + public var sampleRate: Int { speechDecoder.sampleRate } +} diff --git a/Sources/TTSKit/Utilities/AudioOutput.swift b/Sources/TTSKit/Utilities/AudioOutput.swift new file mode 100644 index 00000000..4370c4cf --- /dev/null +++ b/Sources/TTSKit/Utilities/AudioOutput.swift @@ -0,0 +1,734 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import AVFoundation +import Accelerate +import ArgmaxCore +import Foundation + +// MARK: - Audio Output + +/// Handles audio export to file and real-time streaming playback via AVAudioEngine. +/// +/// Supports adaptive pre-buffering and edge-fading to prevent audible clicks: +/// +/// **Pre-buffering:** When a buffer duration is configured via `setBufferDuration(_:)`, +/// incoming audio frames accumulate until the threshold is reached, then flush to the +/// player all at once. This prevents underruns on slower devices. +/// +/// **Edge-fading:** Fades are applied only at actual audio discontinuities: +/// - Fade-in on the first frame of a session, chunk, or after a detected underrun. +/// - Fade-out on the last frame of a session, chunk, or before a detected underrun. +/// Interior frames of contiguous playback are untouched. +/// +/// Underrun detection uses wall-clock timing: if the current time exceeds the +/// expected playback end of all scheduled audio, the player has drained and the +/// next frame needs fade-in (and the previous tail gets fade-out). +/// +/// **Buffer lifecycle:** +/// 1. `startPlayback()` - resets all state; frames accumulate until configured. +/// 2. `setBufferDuration(_:)` - configures threshold (call after start). +/// 3. `enqueueAudioChunk(_:)` - pushes frames through the buffer/tail pipeline. +/// 4. `stopPlayback()` - commits the tail with fade-out, waits, tears down. +/// +/// Thread safety: used only from `TTSKit.play()` which forces sequential mode +/// (concurrency=1). `TTSKit` ensures serialized access. +/// Subclassing is intentionally not supported. Use the `TTSKit` component override +/// mechanism (`TTSKitConfig`) to swap in an alternative audio backend. +public class AudioOutput: @unchecked Sendable { + private var audioEngine: AVAudioEngine? + private var playerNode: AVAudioPlayerNode? + private var engineStartDeferred: Bool = false + + /// Pre-buffer threshold in seconds. `nil` means not yet configured - frames + /// accumulate in `pendingFrames` until `setBufferDuration` is called. + /// `0` means stream immediately. `> 0` means buffer that duration first. + private var bufferDuration: TimeInterval? + + /// Accumulated frames waiting to be flushed to the player. + private var pendingFrames: [[Float]] = [] + + /// Total duration (seconds) of audio in `pendingFrames`. + private var pendingDuration: TimeInterval = 0 + + /// Whether the initial buffer threshold has been met and frames flow directly + /// to the player without accumulating. + private var bufferThresholdMet: Bool = false + + /// The most recently received frame, held back so we can apply fade-out + /// when we learn it's the last frame before a gap or end of session. + private var tailFrame: [Float]? + + /// Whether the next frame scheduled to the player needs a fade-in + /// (start of session, start of chunk, or after an underrun). + private var needsFadeIn: Bool = true + + /// Wall-clock time at which all currently scheduled audio is expected to + /// finish playing. Used to detect underruns: if `now > expectedPlaybackEnd`, + /// the player has drained and the next frame follows a gap. + private var expectedPlaybackEnd: CFAbsoluteTime = 0 + + /// Player time (seconds) when the first buffer was scheduled. The player + /// node's clock starts at `.play()` but no audio plays until buffers are + /// queued, so this offset must be subtracted from `playerTime.sampleTime` + /// to get the true audio-out position. + private var playbackTimeOffset: TimeInterval = 0 + + /// Cumulative duration (seconds) of real audio that has been scheduled via + /// `scheduleWithFades`. The silent sentinel buffer used for drain detection + /// is not included. Used to clamp `currentPlaybackTime` so the reported + /// position never advances into silence gaps between chunks or past the end + /// of generated audio. + public private(set) var scheduledAudioDuration: TimeInterval = 0 + + /// Number of samples for the fade-in/fade-out ramp. + /// 256 samples at 24kHz ≈ 10.7ms - imperceptible on contiguous audio + /// but smoothly eliminates clicks at discontinuities. + public static let fadeLengthSamples: Int = 256 + + /// Output sample rate in Hz. Defaults to 24000 (Qwen3 TTS). + /// Updated by `TTSKit.loadModels()` to match the loaded speech decoder's actual sample rate. + public private(set) var sampleRate: Int + + /// The audio format used for playback and export (derived from `sampleRate`). + public private(set) var audioFormat: AVAudioFormat + + public init(sampleRate: Int = 24000) { + self.sampleRate = sampleRate + guard let format = AVAudioFormat(commonFormat: .pcmFormatFloat32, sampleRate: Double(sampleRate), channels: 1, interleaved: false) else { + preconditionFailure("AVAudioFormat init failed for PCM Float32 \(sampleRate)Hz mono - this is an invariant violation") + } + self.audioFormat = format + } + + /// Update the sample rate to match the loaded speech decoder. + /// Must be called before `startPlayback()`. + public func configure(sampleRate newRate: Int) { + guard newRate != sampleRate else { return } + sampleRate = newRate + guard let format = AVAudioFormat(commonFormat: .pcmFormatFloat32, sampleRate: Double(newRate), channels: 1, interleaved: false) else { + preconditionFailure("AVAudioFormat init failed for PCM Float32 \(newRate)Hz mono - this is an invariant violation") + } + audioFormat = format + } + + /// Current playback position in seconds, based on the audio engine's render timeline. + /// Returns 0 if the player is not active, no audio has been scheduled yet, or + /// the player hasn't started rendering. + /// + /// Clamped to `scheduledAudioDuration` so the position never advances into + /// silence gaps between chunks, the silent drain sentinel, or the hardware + /// pipeline tail after the last real audio frame. + /// Returns how many seconds of audio still need to accumulate in the pre-buffer before + /// the next chunk flushes and playback resumes. Non-zero only while in buffering + /// mode (`bufferThresholdMet == false` and a positive `bufferDuration` is set). + public var silentBufferRemaining: TimeInterval { + guard !bufferThresholdMet, let duration = bufferDuration, duration > 0 else { return 0 } + return max(0, duration - pendingDuration) + } + + public var currentPlaybackTime: TimeInterval { + // Guard against the engine being torn down (stopPlayback nullifies these). + // Checking audioEngine?.isRunning prevents accessing a detached node. + guard expectedPlaybackEnd > 0, + audioEngine?.isRunning == true, + let player = playerNode, + let nodeTime = player.lastRenderTime, + nodeTime.isSampleTimeValid, + let playerTime = player.playerTime(forNodeTime: nodeTime) + else { return 0 } + let rawTime = max(0, Double(playerTime.sampleTime) / playerTime.sampleRate - playbackTimeOffset) + return min(rawTime, scheduledAudioDuration) + } + + // MARK: - Buffer Configuration + + /// Configure the pre-buffer duration. Call after `startPlayback()`. + /// + /// - If `seconds == 0`: immediately flushes any pending frames and switches + /// to direct streaming (no buffering). + /// - If `seconds > 0`: sets the threshold. If enough audio has already + /// accumulated, flushes immediately. + /// - Can be called multiple times (e.g., per-chunk reassessment). Any held + /// tail frame from the previous chunk is committed with fade-out first. + /// + /// - Parameter seconds: Duration of audio to accumulate before flushing. + /// Pass 0 for immediate streaming (fast devices). + public func setBufferDuration(_ seconds: TimeInterval) { + let duration = max(0, seconds) + bufferDuration = duration + + // Commit any held tail as the last frame of the previous chunk + if let tail = tailFrame { + scheduleWithFades(tail, fadeIn: needsFadeIn, fadeOut: true) + needsFadeIn = true // next chunk starts fresh with fade-in + tailFrame = nil + } + + if duration == 0 { + bufferThresholdMet = true + if !pendingFrames.isEmpty { + flushPendingFrames() + } + } else { + // Re-enter buffering mode so the new chunk accumulates frames + // before flushing - lets the model catch up between chunks. + bufferThresholdMet = false + if pendingDuration >= duration { + flushPendingFrames() + } + } + } + + // MARK: - File Export + + /// Supported audio export formats. + public enum AudioFileFormat: String, Sendable { + case m4a + case wav + + public var fileExtension: String { rawValue } + + /// Resolve the effective format for the current platform. + /// On watchOS, M4A is not supported so falls back to WAV with a warning. + public static func resolve(_ preferred: AudioFileFormat = .m4a) -> AudioFileFormat { + #if os(watchOS) + if preferred == .m4a { + Logging.info("[Warning] M4A export is not available on watchOS, falling back to WAV") + return .wav + } + #endif + return preferred + } + } + + /// Save audio samples to a file. + /// + /// For M4A with metadata: writes PCM -> AAC to a temp file, then uses + /// `AVAssetExportSession` passthrough to remux with embedded metadata atoms + /// (no re-encode). For WAV or metadata-free M4A: writes directly. + /// On watchOS, `.m4a` automatically falls back to `.wav`. + /// + /// Any extension already present in `filename` is stripped before writing. + /// The output format is resolved in this order: explicit `format` parameter, + /// then extension found in `filename` if it matches a supported format, + /// then `.m4a` as the default. + /// + /// - Parameters: + /// - samples: Mono Float32 PCM samples. + /// - folder: Destination directory. Created if it doesn't exist. + /// - filename: File name, with or without extension. + /// - sampleRate: Sample rate in Hz. + /// - format: Output format. Inferred from `filename` extension when `nil`. + /// - metadataProvider: Optional metadata callback for items to embed into the file container for m4a formats. + /// - Returns: The URL of the written file. + /// - Throws: `TTSError` if audio encoding or export fails. + @discardableResult + public static func saveAudio( + _ samples: [Float], + toFolder folder: URL, + filename: String, + sampleRate: Int = 24000, + format: AudioFileFormat? = nil, + metadataProvider: (@Sendable () throws -> [AVMetadataItem])? = nil + ) async throws -> URL { + guard !samples.isEmpty else { + throw TTSError.audioOutputFailed("No audio samples to export") + } + + let filenameURL = URL(fileURLWithPath: filename) + let baseName = filenameURL.deletingPathExtension().lastPathComponent + let inferredFormat = AudioFileFormat(rawValue: filenameURL.pathExtension.lowercased()) + let resolvedFormat = AudioFileFormat.resolve(format ?? inferredFormat ?? .m4a) + let outputURL = folder + .appendingPathComponent(baseName) + .appendingPathExtension(resolvedFormat.fileExtension) + + if !FileManager.default.fileExists(atPath: folder.path) { + try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) + } + try? FileManager.default.removeItem(at: outputURL) + + let pcmBuffer = try createPCMBuffer(from: samples, sampleRate: sampleRate) + + switch resolvedFormat { + case .m4a: + #if os(watchOS) + throw TTSError.audioOutputFailed("M4A should have been resolved to WAV on watchOS") + #else + if let metadataProvider { + let metadata = try metadataProvider() + try await writeM4AWithMetadata(pcmBuffer, to: outputURL, sampleRate: sampleRate, metadata: metadata) + } else { + try writeM4A(pcmBuffer, to: outputURL, sampleRate: sampleRate) + } + #endif + case .wav: + try writeWAV(pcmBuffer, to: outputURL, sampleRate: sampleRate) + } + + return outputURL + } + + /// Return the playback duration of an audio file in seconds. + public static func duration(of url: URL) async throws -> TimeInterval { + let asset = AVURLAsset(url: url) + let cmDuration = try await asset.load(.duration) + return CMTimeGetSeconds(cmDuration) + } + + // MARK: - Crossfade assembly + + /// Assemble multiple audio chunks into one array with equal-power crossfades at each boundary. + /// + /// Uses `cos(t*pi/2)` fade-out and `sin(t*pi/2)` fade-in so that energy is + /// preserved through the overlap region. Fade curves are pre-computed once + /// via Accelerate (`vDSP_vramp` + `vvcosf`/`vvsinf`) and reused at every + /// chunk boundary; the per-boundary mix uses `vDSP_vmul` + `vDSP_vma`. + /// + /// - Parameters: + /// - chunks: Ordered audio chunks to concatenate. + /// - fadeLength: Number of overlap samples for each crossfade. + /// - Returns: Single concatenated audio array with crossfades applied at chunk boundaries. + public static func crossfade(_ chunks: [[Float]], fadeLength: Int) -> [Float] { + guard !chunks.isEmpty else { return [] } + guard chunks.count > 1 else { return chunks[0] } + + let (fadeOut, fadeIn) = equalPowerCurves(length: fadeLength) + + var result = [Float]() + result.reserveCapacity(chunks.reduce(0) { $0 + $1.count }) + result.append(contentsOf: chunks[0]) + + for i in 1.. ([Float], [Float]) { + guard length > 0 else { return ([], []) } + + var ramp = [Float](repeating: 0, count: length) + var start: Float = 0 + var step = Float.pi / 2.0 / Float(max(length - 1, 1)) + vDSP_vramp(&start, &step, &ramp, 1, vDSP_Length(length)) + + var fadeOut = [Float](repeating: 0, count: length) + var fadeIn = [Float](repeating: 0, count: length) + var n = Int32(length) + vvcosf(&fadeOut, ramp, &n) + vvsinf(&fadeIn, ramp, &n) + + return (fadeOut, fadeIn) + } + + // MARK: - Internal helpers + + private static func createPCMBuffer(from samples: [Float], sampleRate: Int) throws -> AVAudioPCMBuffer { + guard + let pcmFormat = AVAudioFormat( + commonFormat: .pcmFormatFloat32, + sampleRate: Double(sampleRate), + channels: 1, + interleaved: false + ) + else { + throw TTSError.audioOutputFailed("Failed to create PCM format for sampleRate \(sampleRate)") + } + + guard let buffer = AVAudioPCMBuffer(pcmFormat: pcmFormat, frameCapacity: AVAudioFrameCount(samples.count)) else { + throw TTSError.audioOutputFailed("Failed to create PCM buffer") + } + buffer.frameLength = AVAudioFrameCount(samples.count) + + guard let channelData = buffer.floatChannelData else { + throw TTSError.audioOutputFailed("Failed to access buffer channel data") + } + samples.withUnsafeBufferPointer { src in + guard let srcBase = src.baseAddress else { return } + channelData[0].update(from: srcBase, count: samples.count) + } + + return buffer + } + + #if !os(watchOS) + /// Write PCM samples as AAC in an M4A container. + /// + /// `AVAudioFile` accepts PCM input via `write(from:)` and internally + /// encodes to AAC when the file settings specify a compressed format, + /// so no temp file or explicit `AVAudioConverter` is needed. + private static func writeM4A( + _ pcmBuffer: AVAudioPCMBuffer, + to url: URL, + sampleRate: Int + ) throws { + let aacSettings: [String: Any] = [ + AVFormatIDKey: kAudioFormatMPEG4AAC, + AVSampleRateKey: sampleRate, + AVNumberOfChannelsKey: 1, + AVEncoderBitRateKey: 64000, + AVEncoderAudioQualityKey: AVAudioQuality.medium.rawValue + ] + + let file = try AVAudioFile( + forWriting: url, + settings: aacSettings, + commonFormat: .pcmFormatFloat32, + interleaved: false + ) + try file.write(from: pcmBuffer) + } + + /// Write AAC/M4A with embedded metadata atoms. + /// + /// `AVAudioFile` cannot embed metadata, so this does a two-step write: + /// encode PCM -> AAC into a temp file, then use `AVAssetExportSession` + /// passthrough to remux the bitstream into the final file with metadata + /// atoms attached. No audio re-encoding occurs. + private static func writeM4AWithMetadata( + _ pcmBuffer: AVAudioPCMBuffer, + to url: URL, + sampleRate: Int, + metadata: [AVMetadataItem] + ) async throws { + let tempURL = url.deletingLastPathComponent() + .appendingPathComponent(UUID().uuidString) + .appendingPathExtension("m4a") + defer { try? FileManager.default.removeItem(at: tempURL) } + + try writeM4A(pcmBuffer, to: tempURL, sampleRate: sampleRate) + + let asset = AVURLAsset(url: tempURL) + guard let exportSession = AVAssetExportSession(asset: asset, presetName: AVAssetExportPresetPassthrough) else { + throw TTSError.audioOutputFailed("AVAssetExportSession could not be created for metadata export") + } + exportSession.outputURL = url + exportSession.outputFileType = .m4a + exportSession.metadata = metadata + + await exportSession.export() + + if let exportError = exportSession.error { + throw exportError + } + guard exportSession.status == .completed else { + throw TTSError.audioOutputFailed( + "AVAssetExportSession finished with unexpected status \(exportSession.status.rawValue)") + } + } + #endif + + private static func writeWAV( + _ pcmBuffer: AVAudioPCMBuffer, + to url: URL, + sampleRate: Int + ) throws { + let settings: [String: Any] = [ + AVFormatIDKey: kAudioFormatLinearPCM, + AVSampleRateKey: sampleRate, + AVNumberOfChannelsKey: 1, + AVLinearPCMBitDepthKey: 32, + AVLinearPCMIsBigEndianKey: false, + AVLinearPCMIsFloatKey: true + ] + + let file = try AVAudioFile( + forWriting: url, + settings: settings, + commonFormat: .pcmFormatFloat32, + interleaved: false + ) + try file.write(from: pcmBuffer) + } + + // MARK: - Streaming Playback + + /// Start the audio engine for streaming playback. + /// + /// Resets all buffering, fade, and timing state. After calling this, + /// configure the buffer threshold via `setBufferDuration(_:)`. + /// + /// - Parameter deferEngineStart: When `true`, the audio engine is created and + /// connected but not started. The engine will start automatically on the first + /// `enqueueAudioChunk` call. This avoids the render thread contending with + /// model predictions during the critical time-to-first-buffer path. + /// - Throws: `TTSError` if the audio engine fails to start. + public func startPlayback(deferEngineStart: Bool = false) throws { + pendingFrames.removeAll() + pendingDuration = 0 + bufferThresholdMet = false + bufferDuration = nil + tailFrame = nil + needsFadeIn = true + expectedPlaybackEnd = 0 + playbackTimeOffset = 0 + scheduledAudioDuration = 0 + engineStartDeferred = false + + // On iOS, AVAudioEngine requires an active audio session with a playback + // category. Without this, engine.start() may silently fail or route to + // the wrong output (e.g., airpods instead of main speaker). + #if os(iOS) + let session = AVAudioSession.sharedInstance() + try session.setCategory(.playback, mode: .default, options: []) + try session.setActive(true) + #endif + + let engine = AVAudioEngine() + let player = AVAudioPlayerNode() + let format = audioFormat + + engine.attach(player) + engine.connect(player, to: engine.mainMixerNode, format: format) + + if deferEngineStart { + engineStartDeferred = true + } else { + try engine.start() + player.play() + } + + self.audioEngine = engine + self.playerNode = player + } + + /// Enqueue a chunk of audio samples for playback. + /// + /// In streaming mode, detects underruns via wall-clock timing: if the player + /// has drained since the last buffer, the held tail is committed with fade-out + /// (it was the last frame before the gap) and the incoming frame is marked for + /// fade-in. On contiguous playback, no fades are applied to interior frames. + public func enqueueAudioChunk(_ samples: [Float]) { + guard let engine = audioEngine, let player = playerNode else { return } + + if engineStartDeferred { + engineStartDeferred = false + do { + try engine.start() + player.play() + } catch { + Logging.error("AudioOutput: deferred engine start failed: \(error)") + return + } + } + + if bufferThresholdMet { + // Detect underrun: has the player drained since the last schedule? + let playerDrained = CFAbsoluteTimeGetCurrent() > expectedPlaybackEnd + + if let tail = tailFrame { + // If player drained, this tail was the last frame before silence - + // apply fade-out to smooth the audio -> silence transition, and + // mark the next frame for fade-in (silence -> audio transition). + let fadeOut = playerDrained + let fadeIn = needsFadeIn || playerDrained + scheduleWithFades(tail, fadeIn: fadeIn, fadeOut: fadeOut) + needsFadeIn = playerDrained + } else if playerDrained { + needsFadeIn = true + } + tailFrame = samples + } else { + pendingFrames.append(samples) + pendingDuration += Double(samples.count) / Double(sampleRate) + + if let threshold = bufferDuration, pendingDuration >= threshold { + flushPendingFrames() + } + } + } + + /// Flush all accumulated frames to the player. + /// + /// All frames except the last are scheduled immediately (they're contiguous). + /// The last frame becomes the tail, held back until the next frame arrives + /// or `stopPlayback` commits it with fade-out. + private func flushPendingFrames() { + guard !pendingFrames.isEmpty else { + bufferThresholdMet = true + return + } + + // For subsequent chunks (scheduledAudioDuration > 0), the player may have + // been idle during the inter-chunk gap while the model generated the next + // buffer. The raw player clock kept advancing through that silence, so + // rawTime = playbackTimeOffset + scheduledAudioDuration + gap. + // Absorb the gap into playbackTimeOffset so currentPlaybackTime resumes + // from the end of the previous chunk rather than jumping ahead by gap. + if scheduledAudioDuration > 0, + let player = playerNode, + let nodeTime = player.lastRenderTime, + nodeTime.isSampleTimeValid, + let playerTime = player.playerTime(forNodeTime: nodeTime) + { + let currentRawTime = Double(playerTime.sampleTime) / playerTime.sampleRate + let expectedRawTime = playbackTimeOffset + scheduledAudioDuration + let gap = currentRawTime - expectedRawTime + if gap > 0.01 { + playbackTimeOffset += gap + } + } + + for i in 0.. 0 { + let data = channelData[0] + let invFade = 1.0 / Float(fadeLen) + + if fadeIn { + for i in 0..) in + let format = audioFormat + if let silentBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: 1) { + silentBuffer.frameLength = 1 + silentBuffer.floatChannelData?[0][0] = 0 + player.scheduleBuffer(silentBuffer) { + continuation.resume() + } + } else { + continuation.resume() + } + } + + // Wait ~80ms after the sentinel so the hardware pipeline drains (~1-2 render cycles). + // Prevents tail clip and CoreAudio "out of order"/overload errors across devices. + try? await Task.sleep(for: .milliseconds(80)) + } + + let engine = audioEngine + playerNode = nil + audioEngine = nil + + // Stop the engine only + engine?.stop() + + pendingFrames.removeAll() + pendingDuration = 0 + bufferThresholdMet = false + bufferDuration = nil + tailFrame = nil + needsFadeIn = true + expectedPlaybackEnd = 0 + playbackTimeOffset = 0 + } +} diff --git a/Sources/TTSKit/Utilities/EmbedTypes.swift b/Sources/TTSKit/Utilities/EmbedTypes.swift new file mode 100644 index 00000000..a63991b1 --- /dev/null +++ b/Sources/TTSKit/Utilities/EmbedTypes.swift @@ -0,0 +1,185 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Accelerate +import ArgmaxCore +import CoreML + +// MARK: - TTS Embed Type Protocols + +// These protocols allow embed tensors to be represented as either MLMultiArray (all platforms) +// or MLTensor (macOS 15+ / iOS 18+) without changing the calling convention. + +/// Marker protocol for a raw TTS embedding tensor emitted by a CoreML model. +public protocol EmbedTensorType {} + +/// Marker protocol for a TTS embedding value that can be used as CoreML model input. +public protocol EmbedInputType {} + +extension MLMultiArray: EmbedTensorType {} +extension MLMultiArray: EmbedInputType {} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +extension MLTensor: EmbedTensorType {} + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +extension MLTensor: EmbedInputType {} + +extension Array: EmbedTensorType where Element == FloatType {} + +// MARK: - Embedding Utilities + +/// Static helpers for creating, combining, and extracting TTS embedding vectors. +/// +/// Mirrors the `ModelUtilities` pattern in ArgmaxCore: all operations are pure +/// static functions with no shared state. +public enum EmbedUtilities { + // MARK: - Float16 / FloatType helpers + + /// Element-wise addition of two equal-length embedding vectors via vDSP. + public static func addEmbeddings(_ a: [FloatType], _ b: [FloatType]) -> [FloatType] { + precondition(a.count == b.count) + var result = [Float](repeating: 0, count: a.count) + a.withUnsafeBufferPointer { aPtr in + b.withUnsafeBufferPointer { bPtr in + var aFloat = [Float](repeating: 0, count: a.count) + var bFloat = [Float](repeating: 0, count: a.count) + convertToFloat(aPtr, to: &aFloat) + convertToFloat(bPtr, to: &bFloat) + vDSP_vadd(aFloat, 1, bFloat, 1, &result, 1, vDSP_Length(a.count)) + } + } + return result.map { FloatType($0) } + } + + /// Accumulates embeddings into a Float32 buffer to avoid repeated type conversions, + /// reusing a single intermediate buffer rather than allocating one per embed. + public static func sumEmbeddings(_ embeddings: [[FloatType]]) -> [FloatType] { + guard let first = embeddings.first else { return [] } + let count = first.count + var accum = [Float](repeating: 0, count: count) + var floatBuf = [Float](repeating: 0, count: count) + for embed in embeddings { + embed.withUnsafeBufferPointer { ptr in + convertToFloat(ptr, to: &floatBuf) + vDSP_vadd(accum, 1, floatBuf, 1, &accum, 1, vDSP_Length(count)) + } + } + var result = [FloatType](repeating: 0, count: count) + result.withUnsafeMutableBufferPointer { dst in + accum.withUnsafeBufferPointer { src in + // vDSP Float16<->Float conversion is available on iOS 14+, macOS 11+, visionOS 1+ + // but not on watchOS. Fall back to scalar on watchOS and x86_64. + #if arch(arm64) && !os(watchOS) + vDSP.convertElements(of: src, to: &dst) + #else + for i in 0.. [FloatType] { + let dim: Int + if arr.shape.count == 4 { + dim = arr.shape[1].intValue + } else { + dim = arr.count + } + let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: arr.count) + var result = [FloatType](repeating: 0, count: dim) + + if arr.shape.count == 4 { + let stride1 = arr.strides[1].intValue + if stride1 == 1 { + // Contiguous layout ([1, D, 1, 1] with unit stride) - direct buffer copy + result.withUnsafeMutableBufferPointer { dst in + guard let dstBase = dst.baseAddress else { return } + dstBase.update(from: ptr, count: dim) + } + } else { + for d in 0.. MLMultiArray { + let dim = values.count + let arr = try MLMultiArray(shape: [1, NSNumber(value: dim), 1, 1], dataType: .float16) + let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: dim) + values.withUnsafeBufferPointer { buf in + guard let bufBase = buf.baseAddress else { return } + ptr.update(from: bufBase, count: dim) + } + return arr + } + + /// Create a zero-filled embedding vector. + /// - Parameter dim: Embedding dimension (match the actual model's embed size). + /// - Returns: A `[FloatType]` array of length `dim` filled with zeros. + public static func zeroEmbed(dim: Int = 1024) -> [FloatType] { + [FloatType](repeating: FloatType(0), count: dim) + } + + /// Pack an `[Int32]` token-id array into a flat CoreML `MLMultiArray`. + public static func makeInt32Array(_ values: [Int32]) throws -> MLMultiArray { + let arr = try MLMultiArray(shape: [NSNumber(value: values.count)], dataType: .int32) + let ptr = arr.dataPointer.bindMemory(to: Int32.self, capacity: values.count) + for (index, value) in values.enumerated() { + ptr[index] = value + } + return arr + } + + // MARK: - MLTensor helpers + + /// Element-wise addition of two MLTensor embeddings. No data copy - deferred until materialised. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + @inline(__always) + public static func addEmbeddings(_ a: MLTensor, _ b: MLTensor) -> MLTensor { a + b } + + /// Sum a list of MLTensor embeddings element-wise. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + public static func sumEmbeddings(_ tensors: [MLTensor]) -> MLTensor { + tensors.dropFirst().reduce(tensors[0], +) + } + + // MARK: - Internal conversion helper + + /// Platform-safe conversion from FloatType buffer to Float array. + /// On arm64 iOS/macOS/visionOS (Float16), uses vDSP for vectorized conversion. + @inline(__always) + static func convertToFloat(_ source: UnsafeBufferPointer, to dest: inout [Float]) { + #if arch(arm64) && !os(watchOS) + vDSP.convertElements(of: source, to: &dest) + #else + for i in 0.. convenience + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +public extension Array where Element == FloatType { + /// Wrap this embedding vector as a `[1, count, 1, 1]` Float16 MLTensor (one memcpy). + func asMLTensor() -> MLTensor { + MLTensor(MLShapedArray(scalars: self, shape: [1, count, 1, 1])) + } +} diff --git a/Sources/TTSKit/Utilities/KVCache.swift b/Sources/TTSKit/Utilities/KVCache.swift new file mode 100644 index 00000000..2e5aa5ed --- /dev/null +++ b/Sources/TTSKit/Utilities/KVCache.swift @@ -0,0 +1,306 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - KV Cache + +/// KV cache for autoregressive decoder models. +/// +/// Manages external cache arrays (`keyCache`/`valueCache`) for non-stateful models, +/// or position tracking and attention masks only for stateful models whose weights +/// are managed internally by CoreML via `MLState`. +/// +/// Thread safety: each `Qwen3GenerateTask` creates its own cache instance. +/// Caches are never shared across concurrent tasks. +public class KVCache: @unchecked Sendable { + public var cacheLength: Int32 = 0 + public let maxSeqLength: Int + public let cacheDim: Int + public let isStateful: Bool + + /// External key cache -- nil for stateful models + public let keyCache: MLMultiArray? + /// External value cache -- nil for stateful models + public let valueCache: MLMultiArray? + public let kvCacheUpdateMask: MLMultiArray + public let keyPaddingMask: MLMultiArray + + public init(cacheDim: Int, maxSeqLength: Int, isStateful: Bool = false) throws { + self.cacheDim = cacheDim + self.maxSeqLength = maxSeqLength + self.isStateful = isStateful + + if isStateful { + // Stateful models manage KV cache internally via MLState + keyCache = nil + valueCache = nil + } else { + keyCache = try MLMultiArray( + shape: [1, NSNumber(value: cacheDim), 1, NSNumber(value: maxSeqLength)], + dataType: .float16 + ) + valueCache = try MLMultiArray( + shape: [1, NSNumber(value: cacheDim), 1, NSNumber(value: maxSeqLength)], + dataType: .float16 + ) + } + + kvCacheUpdateMask = try MLMultiArray( + shape: [1, NSNumber(value: maxSeqLength)], + dataType: .float16 + ) + keyPaddingMask = try MLMultiArray( + shape: [1, NSNumber(value: maxSeqLength)], + dataType: .float16 + ) + + reset() + } + + public func reset() { + cacheLength = 0 + + // Zero-fill external KV caches (stateful models don't have them) + if let keyCache, let valueCache { + memset(keyCache.dataPointer, 0, cacheDim * maxSeqLength * MemoryLayout.size) + memset(valueCache.dataPointer, 0, cacheDim * maxSeqLength * MemoryLayout.size) + } + + // Initialize masks: first position active, rest masked + let updatePtr = kvCacheUpdateMask.dataPointer.bindMemory(to: FloatType.self, capacity: maxSeqLength) + let paddingPtr = keyPaddingMask.dataPointer.bindMemory(to: FloatType.self, capacity: maxSeqLength) + for i in 0..= maxSeqLength - 1 } + + /// How many free positions remain before the cache is full + public var freePositions: Int { maxSeqLength - 1 - Int(cacheLength) } + + public func makeCacheLengthArray() throws -> MLMultiArray { + let arr = try MLMultiArray(shape: [1], dataType: .int32) + arr[0] = NSNumber(value: cacheLength) + return arr + } +} + +// MARK: - MLTensor Access + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +public extension KVCache { + /// Cache position as a `[1]` Int32 tensor. + var cacheLengthTensor: MLTensor { MLTensor(shape: [1], scalars: [Int32(cacheLength)]) } + /// Update-mask as a `[1, maxSeqLength]` Float16 tensor. + var kvCacheUpdateMaskTensor: MLTensor { MLTensor(MLShapedArray(kvCacheUpdateMask)) } + /// Padding-mask as a `[1, maxSeqLength]` Float16 tensor. + var keyPaddingMaskTensor: MLTensor { MLTensor(MLShapedArray(keyPaddingMask)) } + /// External key-cache tensor - `nil` for stateful models. + var keyCacheTensor: MLTensor? { keyCache.map { MLTensor(MLShapedArray($0)) } } + /// External value-cache tensor - `nil` for stateful models. + var valueCacheTensor: MLTensor? { valueCache.map { MLTensor(MLShapedArray($0)) } } + + /// Async update from MLTensor outputs - materializes without blocking the cooperative pool. + func update(keyTensor: MLTensor, valueTensor: MLTensor) async { + let keyArr = await keyTensor.toMLMultiArray() + let valArr = await valueTensor.toMLMultiArray() + update(keyCacheUpdates: keyArr, valueCacheUpdates: valArr) + } +} + +// MARK: - Speech Decoder Cache + +/// Extended KV cache for SpeechDecoder with a rolling hidden context buffer. +/// +/// The hidden context window length varies by quantization variant +/// (e.g. 4 for W8A16, 16 for W16A16) and is read from the model at load time. +/// Task-local, never shared across concurrent tasks. +public class SpeechDecoderCache: KVCache, @unchecked Sendable { + public let hiddenContext: MLMultiArray // [1, hiddenDim, 1, contextLen] + public let hiddenDim: Int + public let hiddenContextLen: Int + + public init( + cacheDim: Int = Qwen3TTSConstants.sdCacheDim, + maxSeqLength: Int = Qwen3TTSConstants.sdMaxSeq, + hiddenDim: Int = Qwen3TTSConstants.sdHiddenDim, + hiddenContextLen: Int = Qwen3TTSConstants.sdHiddenContextLen + ) throws { + self.hiddenDim = hiddenDim + self.hiddenContextLen = hiddenContextLen + hiddenContext = try MLMultiArray( + shape: [1, NSNumber(value: hiddenDim), 1, NSNumber(value: hiddenContextLen)], + dataType: .float16 + ) + memset(hiddenContext.dataPointer, 0, hiddenDim * hiddenContextLen * MemoryLayout.size) + // Speech decoder is currently non-stateful + try super.init(cacheDim: cacheDim, maxSeqLength: maxSeqLength, isStateful: false) + } + + /// Update KV cache and roll hidden context left, appending new hidden state + public func updateWithHiddenContext(output: MLFeatureProvider) { + guard let keyCU = output.featureValue(for: "key_cache_updates")?.multiArrayValue, + let valCU = output.featureValue(for: "value_cache_updates")?.multiArrayValue + else { + return + } + super.update(keyCacheUpdates: keyCU, valueCacheUpdates: valCU) + + // Roll hidden context left and append new state + let hidDim = hiddenDim + let contextLen = hiddenContextLen + let hiddenContextPtr = hiddenContext.dataPointer.bindMemory(to: FloatType.self, capacity: hidDim * contextLen) + guard let updateArr = output.featureValue(for: "hidden_context_update")?.multiArrayValue else { return } + let hiddenUpdatePtr = updateArr.dataPointer.bindMemory(to: FloatType.self, capacity: updateArr.count) + let hiddenUpdateStride = updateArr.strides[1].intValue + + for dim in 0..(hiddenContext)) } + + /// Update KV cache and rolling hidden context from `[String: MLTensor]` prediction outputs. + /// Materializes tensors asynchronously to avoid blocking the cooperative thread pool. + func updateWithHiddenContext(tensorOutputs: [String: MLTensor]) async { + guard let keyUpdateTensor = tensorOutputs["key_cache_updates"], + let valueUpdateTensor = tensorOutputs["value_cache_updates"] + else { + return + } + await super.update(keyTensor: keyUpdateTensor, valueTensor: valueUpdateTensor) + + let hidDim = hiddenDim + let contextLen = hiddenContextLen + let ctxPtr = hiddenContext.dataPointer.bindMemory(to: FloatType.self, capacity: hidDim * contextLen) + guard let hiddenUpdateTensor = tensorOutputs["hidden_context_update"] else { return } + let updateArr = await hiddenUpdateTensor.toMLMultiArray() + let updatePtr = updateArr.dataPointer.bindMemory(to: FloatType.self, capacity: updateArr.count) + let updateStride = updateArr.strides[1].intValue + for dim in 0...size + + let keyUpdatePtr = keyCacheUpdates.dataPointer.bindMemory(to: FloatType.self, capacity: keyCacheUpdates.count) + let keyUpdateStride = keyCacheUpdates.strides[1].intValue + let valueUpdatePtr = valueCacheUpdates.dataPointer.bindMemory(to: FloatType.self, capacity: valueCacheUpdates.count) + let valueUpdateStride = valueCacheUpdates.strides[1].intValue + + state.withMultiArray(for: "self_attn_key_cache") { keyStateCache in + let embedDim = keyStateCache.shape[1].intValue + keyStateCache.withUnsafeMutableBytes { cachePtr, cacheStrides in + guard let baseAddress = cachePtr.baseAddress else { return } + for dim in 0.. Bool { + self.voice == voice && self.language == language && self.instruction == instruction + } +} + +// MARK: - KV Cache Snapshot + +/// Serializable snapshot of a `KVCache` state (masks, position counter, and +/// optionally key/value cache data for non-stateful models). +public struct KVCacheSnapshot: Sendable { + public let isStateful: Bool + public let cacheDim: Int + public let maxSeqLength: Int + public let cacheLength: Int32 + + /// Raw bytes from keyCache/valueCache (non-stateful models). + /// Empty `Data()` for stateful models where KV data lives in MLState. + public let keyCacheData: Data + public let valueCacheData: Data + + public let updateMaskData: Data + public let paddingMaskData: Data +} + +/// Serializable snapshot of MLState KV buffers (stateful models only). +public struct KVStateData: Sendable { + public let keyData: Data + public let valueData: Data +} + +// MARK: - KVCache Snapshot/Restore + +public extension KVCache { + /// Create a serializable snapshot of the current cache state. + func snapshot() -> KVCacheSnapshot { + let maskBytes = maxSeqLength * MemoryLayout.size + + return KVCacheSnapshot( + isStateful: isStateful, + cacheDim: cacheDim, + maxSeqLength: maxSeqLength, + cacheLength: cacheLength, + keyCacheData: keyCache.map { Data(bytes: $0.dataPointer, count: cacheDim * maxSeqLength * MemoryLayout.size) } ?? Data(), + valueCacheData: valueCache.map { Data(bytes: $0.dataPointer, count: cacheDim * maxSeqLength * MemoryLayout.size) } ?? Data(), + updateMaskData: Data(bytes: kvCacheUpdateMask.dataPointer, count: maskBytes), + paddingMaskData: Data(bytes: keyPaddingMask.dataPointer, count: maskBytes) + ) + } + + /// Restore cache state from a snapshot. The snapshot must have matching geometry. + func restore(from snapshot: KVCacheSnapshot) { + precondition( + snapshot.cacheDim == cacheDim && snapshot.maxSeqLength == maxSeqLength, + "Cache geometry mismatch: expected (\(cacheDim), \(maxSeqLength)), got (\(snapshot.cacheDim), \(snapshot.maxSeqLength))" + ) + + cacheLength = snapshot.cacheLength + + let maskBytes = maxSeqLength * MemoryLayout.size + + if let keyCache, let valueCache, !snapshot.keyCacheData.isEmpty { + let kvBytes = cacheDim * maxSeqLength * MemoryLayout.size + snapshot.keyCacheData.copyBytes(to: keyCache.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.keyCacheData.count, kvBytes)) + snapshot.valueCacheData.copyBytes(to: valueCache.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.valueCacheData.count, kvBytes)) + } + + snapshot.updateMaskData.copyBytes(to: kvCacheUpdateMask.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.updateMaskData.count, maskBytes)) + snapshot.paddingMaskData.copyBytes(to: keyPaddingMask.dataPointer.assumingMemoryBound(to: UInt8.self), count: min(snapshot.paddingMaskData.count, maskBytes)) + } +} + +// MARK: - MLState Snapshot/Restore + +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +public extension MLState { + /// Capture the current `self_attn_key_cache` and `self_attn_value_cache` buffers as raw `Data`. + func snapshot() -> KVStateData { + var keyData = Data() + withMultiArray(for: "self_attn_key_cache") { arr in + arr.withUnsafeMutableBytes { buf, _ in + guard let baseAddress = buf.baseAddress else { return } + keyData = Data(bytes: baseAddress, count: buf.count) + } + } + var valueData = Data() + withMultiArray(for: "self_attn_value_cache") { arr in + arr.withUnsafeMutableBytes { buf, _ in + guard let baseAddress = buf.baseAddress else { return } + valueData = Data(bytes: baseAddress, count: buf.count) + } + } + return KVStateData(keyData: keyData, valueData: valueData) + } + + /// Overwrite the `self_attn_key_cache` and `self_attn_value_cache` buffers from a snapshot. + func restore(from data: KVStateData) { + withMultiArray(for: "self_attn_key_cache") { arr in + arr.withUnsafeMutableBytes { buf, _ in + guard let bufBase = buf.baseAddress else { return } + data.keyData.withUnsafeBytes { src in + guard let srcBase = src.baseAddress else { return } + bufBase.copyMemory(from: srcBase, byteCount: min(src.count, buf.count)) + } + } + } + withMultiArray(for: "self_attn_value_cache") { arr in + arr.withUnsafeMutableBytes { buf, _ in + guard let bufBase = buf.baseAddress else { return } + data.valueData.withUnsafeBytes { src in + guard let srcBase = src.baseAddress else { return } + bufBase.copyMemory(from: srcBase, byteCount: min(src.count, buf.count)) + } + } + } + } +} + +// MARK: - Disk Persistence + +public extension TTSPromptCache { + /// Save this prompt cache to disk as a property list. + /// + /// The file can be reloaded with `TTSPromptCache.load(from:)` as long as + /// the model variant and cache geometry haven't changed. + func save(to url: URL) throws { + let container = CacheContainer( + voice: voice, + language: language, + instruction: instruction, + prefixLength: prefixLength, + isStateful: kvSnapshot.isStateful, + cacheDim: kvSnapshot.cacheDim, + maxSeqLength: kvSnapshot.maxSeqLength, + cacheLength: kvSnapshot.cacheLength, + keyCacheData: kvSnapshot.keyCacheData, + valueCacheData: kvSnapshot.valueCacheData, + updateMaskData: kvSnapshot.updateMaskData, + paddingMaskData: kvSnapshot.paddingMaskData, + stateKeyData: stateData?.keyData, + stateValueData: stateData?.valueData + ) + let data = try PropertyListEncoder().encode(container) + try FileManager.default.createDirectory(at: url.deletingLastPathComponent(), withIntermediateDirectories: true) + try data.write(to: url) + } + + /// Load a previously saved prompt cache from disk. + static func load(from url: URL) throws -> TTSPromptCache { + let data = try Data(contentsOf: url) + let container = try PropertyListDecoder().decode(CacheContainer.self, from: data) + let snapshot = KVCacheSnapshot( + isStateful: container.isStateful, + cacheDim: container.cacheDim, + maxSeqLength: container.maxSeqLength, + cacheLength: container.cacheLength, + keyCacheData: container.keyCacheData, + valueCacheData: container.valueCacheData, + updateMaskData: container.updateMaskData, + paddingMaskData: container.paddingMaskData + ) + let stateData: KVStateData? + if let keyData = container.stateKeyData, let valueData = container.stateValueData { + stateData = KVStateData(keyData: keyData, valueData: valueData) + } else { + stateData = nil + } + return TTSPromptCache( + voice: container.voice, + language: container.language, + instruction: container.instruction, + prefixLength: container.prefixLength, + kvSnapshot: snapshot, + stateData: stateData + ) + } + + /// File name for this cache based on voice/language/instruction. + var cacheFileName: String { + var name = "\(voice)_\(language)" + if let instruction, !instruction.isEmpty { + let hash = instruction.utf8.reduce(into: UInt64(5381)) { $0 = $0 &* 33 &+ UInt64($1) } + name += "_\(String(hash, radix: 16))" + } + return name + ".promptcache" + } +} + +/// Codable container for plist serialization (Data fields are stored as binary, not base64). +private struct CacheContainer: Codable { + let voice: String + let language: String + let instruction: String? + let prefixLength: Int + let isStateful: Bool + let cacheDim: Int + let maxSeqLength: Int + let cacheLength: Int32 + let keyCacheData: Data + let valueCacheData: Data + let updateMaskData: Data + let paddingMaskData: Data + let stateKeyData: Data? + let stateValueData: Data? +} diff --git a/Sources/TTSKit/Utilities/Sampling.swift b/Sources/TTSKit/Utilities/Sampling.swift new file mode 100644 index 00000000..69ebfa97 --- /dev/null +++ b/Sources/TTSKit/Utilities/Sampling.swift @@ -0,0 +1,384 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Accelerate +import ArgmaxCore +import CoreML +import Foundation + +// MARK: - Token Sampling + +/// Protocol for TTS token sampling strategies. +public protocol TokenSampling { + /// Sample a single token from codec-0 logits. + /// `logits` is an MLTensor (macOS 15+ async path) or MLMultiArray (sync path). + func sampleCodec0( + logits: any EmbedTensorType, + temperature: Float, + topK: Int, + generatedTokens: [Int32], + repetitionPenalty: Float, + suppressTokenIds: Set + ) async -> Int32 + + /// Sample a single token from one head of multi-code logits. + /// `allLogits` is an MLTensor (macOS 15+ async path) or MLMultiArray (sync path). + func sampleMultiHead( + allLogits: any EmbedTensorType, + headIndex: Int, + temperature: Float, + topK: Int + ) async -> Int32 +} + +// MARK: - Greedy / Top-k Sampler + +/// Greedy / top-k / temperature token sampler with a seedable RNG. +/// +/// Thread safety: each `Qwen3GenerateTask` owns its own `GreedyTokenSampler` +/// instance (created per-task in `TTSKit.createTask()` with a derived seed). +/// The `var rng` is never accessed concurrently because it's single-owner. +public class GreedyTokenSampler: TokenSampling, @unchecked Sendable { + private var rng: any RandomNumberGenerator + + /// Create a sampler with an optional seed for reproducibility. + /// - Parameter seed: If provided, uses a deterministic RNG. If nil, uses system RNG. + public init(seed: UInt64? = nil) { + if let seed { + self.rng = SeededRandomNumberGenerator(seed: seed) + } else { + self.rng = SystemRandomNumberGenerator() + } + } + + public func sampleCodec0( + logits: any EmbedTensorType, + temperature: Float, + topK: Int, + generatedTokens: [Int32], + repetitionPenalty: Float, + suppressTokenIds: Set + ) async -> Int32 { + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + let tensor: MLTensor + let vocabSize: Int + if let logitsTensor = logits as? MLTensor { + vocabSize = logitsTensor.shape.last ?? 0 + tensor = logitsTensor + } else { + guard let logitsArray = logits as? MLMultiArray else { + return Int32(Qwen3TTSConstants.codecBOS) + } + vocabSize = logitsArray.shape.last?.intValue ?? 0 + tensor = MLTensor(MLShapedArray(logitsArray)) + } + return await sampleCodec0WithMLTensor( + logitsTensor: tensor, + vocabSize: vocabSize, + temperature: temperature, + topK: topK, + generatedTokens: generatedTokens, + repetitionPenalty: repetitionPenalty, + suppressTokenIds: suppressTokenIds + ) + } + guard let logitsArray = logits as? MLMultiArray else { + return Int32(Qwen3TTSConstants.codecBOS) + } + return sampleCodec0WithVDSP( + logits: logitsArray, + temperature: temperature, + topK: topK, + generatedTokens: generatedTokens, + repetitionPenalty: repetitionPenalty, + suppressTokenIds: suppressTokenIds + ) + } + + public func sampleMultiHead( + allLogits: any EmbedTensorType, + headIndex: Int, + temperature: Float, + topK: Int + ) async -> Int32 { + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *), + let tensor = allLogits as? MLTensor + { + // Extract a single head lazily via gathering - the full [1, 15, vocabSize] tensor + // stays on device; only the top-k results (~400 B) are downloaded to CPU. + return await sampleMultiHeadFromTensor(tensor, headIndex: headIndex, temperature: temperature, topK: topK) + } + + // MLMultiArray path: pointer arithmetic for zero-copy head extraction + guard let allLogitsArray = allLogits as? MLMultiArray else { + return Int32(Qwen3TTSConstants.codecBOS) + } + let vocabSize = allLogitsArray.shape[2].intValue + let ptr = allLogitsArray.dataPointer.bindMemory(to: FloatType.self, capacity: allLogitsArray.count) + let stride1 = allLogitsArray.strides[1].intValue + let stride2 = allLogitsArray.strides[2].intValue + var logitsF = [Float](repeating: 0, count: vocabSize) + let base = headIndex * stride1 + + if stride2 == 1 { + let src = UnsafeBufferPointer(start: ptr.advanced(by: base), count: vocabSize) + EmbedUtilities.convertToFloat(src, to: &logitsF) + } else { + for i in 0.. + ) -> Int32 { + let vocabSize = logits.shape.last?.intValue ?? 0 + var logitsF = extractFloat32Logits(logits, count: vocabSize) + for id in suppressTokenIds where id < vocabSize { + logitsF[id] = -.infinity + } + if repetitionPenalty != 1.0 && !generatedTokens.isEmpty { + for tokenId in Set(generatedTokens) { + let tokenIndex = Int(tokenId) + guard tokenIndex < vocabSize else { continue } + logitsF[tokenIndex] = logitsF[tokenIndex] > 0 ? logitsF[tokenIndex] / repetitionPenalty : logitsF[tokenIndex] * repetitionPenalty + } + } + return sampleFromLogits(logitsF, temperature: temperature, topK: topK) + } + + /// MLTensor-based codec-0 sampler (macOS 15+). + /// `logitsTensor` arrives directly from the model output - no MLMultiArray conversion needed. + /// Uses `MLTensor.topK()` - O(n) partial selection - instead of vDSP_vsort's O(n log n) full sort. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + private func sampleCodec0WithMLTensor( + logitsTensor: MLTensor, + vocabSize: Int, + temperature: Float, + topK: Int, + generatedTokens: [Int32], + repetitionPenalty: Float, + suppressTokenIds: Set + ) async -> Int32 { + let needsCPUPass = !suppressTokenIds.isEmpty || (repetitionPenalty != 1.0 && !generatedTokens.isEmpty) + + let processedTensor: MLTensor + if needsCPUPass { + // Materialise to Float32, apply scalar modifications, re-wrap as [1, vocabSize] tensor. + var logitsF = await logitsTensor.reshaped(to: [1, vocabSize]).cast(to: Float.self).toFloatArray() + for id in suppressTokenIds where id < vocabSize { + logitsF[id] = -.infinity + } + if repetitionPenalty != 1.0 && !generatedTokens.isEmpty { + for tokenId in Set(generatedTokens) { + let tokenIndex = Int(tokenId) + guard tokenIndex < vocabSize else { continue } + logitsF[tokenIndex] = logitsF[tokenIndex] > 0 ? logitsF[tokenIndex] / repetitionPenalty : logitsF[tokenIndex] * repetitionPenalty + } + } + processedTensor = MLTensor(shape: [1, vocabSize], scalars: logitsF, scalarType: Float.self) + } else { + // Fully lazy path: cast + reshape stay on device until argmax/topK materializes them. + // [1, vocabSize] shape is required - argmax on a 1D tensor yields a 0D scalar. + processedTensor = logitsTensor.reshaped(to: [1, vocabSize]).cast(to: Float.self) + } + + if temperature == 0 { + return Int32(await processedTensor.argmax(alongAxis: -1).toIntArray()[0]) + } + + let probs = (processedTensor / temperature).softmax(alongAxis: -1) + return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK) + } + + /// Extract a single head from the multi-code logit tensor and sample from it (macOS 15+). + /// The full [1, numHeads, vocabSize] tensor stays on device; `gathering` selects the head + /// lazily so that only the top-k results (~400 B) need to be downloaded to CPU. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + private func sampleMultiHeadFromTensor( + _ allLogits: MLTensor, + headIndex: Int, + temperature: Float, + topK: Int + ) async -> Int32 { + let vocabSize = allLogits.shape[2] + // gathering along axis 1: [1, numHeads, vocabSize] -> [1, 1, vocabSize] + let headLogits = allLogits.gathering( + atIndices: MLTensor(shape: [1], scalars: [Int32(headIndex)], scalarType: Int32.self), + alongAxis: 1 + ) + if temperature == 0 { + return Int32(await headLogits.cast(to: Float.self).argmax(alongAxis: -1).toIntArray()[0]) + } + let probs = (headLogits.cast(to: Float.self) / temperature).softmax(alongAxis: -1) + return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK) + } + + /// Shared topK multinomial sampler over an already-softmaxed probability MLTensor. + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + private func sampleFromProbs(_ probs: MLTensor, vocabSize: Int, topK: Int) async -> Int32 { + if topK > 0 && topK < vocabSize { + // Partial selection: O(n) vs O(n log n) full sort + let (topKProbs, topKIndices) = probs.topK(topK) + let probsArray = await topKProbs.toFloatArray() + let idxArray = await topKIndices.toIntArray() + let probSum = probsArray.reduce(0, +) + let randomValue = Float.random(in: 0..= randomValue { return Int32(idxArray[i]) } + } + return idxArray.last.map(Int32.init) ?? Int32(vocabSize - 1) + } else { + let probsArray = await probs.toFloatArray() + let randomValue = Float.random(in: 0..<1, using: &rng) + var cumulativeSum: Float = 0 + for (i, probability) in probsArray.enumerated() { + cumulativeSum += probability + if cumulativeSum >= randomValue { return Int32(i) } + } + return Int32(vocabSize - 1) + } + } + + /// MLTensor sampling from a pre-extracted Float32 logits array (macOS 15+). + @available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) + private func sampleFromLogitsWithMLTensor(_ logits: [Float], temperature: Float, topK: Int) async -> Int32 { + let vocabSize = logits.count + + if temperature == 0 { + var maxValue: Float = 0 + var maxIndex: vDSP_Length = 0 + vDSP_maxvi(logits, 1, &maxValue, &maxIndex, vDSP_Length(vocabSize)) + return Int32(maxIndex) + } + + // [1, vocabSize] keeps topK/softmax results at least 2D so toFloatArray()/toIntArray() are safe + let logitsTensor = MLTensor(shape: [1, vocabSize], scalars: logits, scalarType: Float.self) + let probs = (logitsTensor / temperature).softmax(alongAxis: -1) + return await sampleFromProbs(probs, vocabSize: vocabSize, topK: topK) + } + + // MARK: - Private helpers + + private func extractFloat32Logits(_ arr: MLMultiArray, count: Int) -> [Float] { + let ptr = arr.dataPointer.bindMemory(to: FloatType.self, capacity: arr.count) + let lastStride = arr.strides.last?.intValue ?? 1 + var result = [Float](repeating: 0, count: count) + if lastStride == 1 { + let src = UnsafeBufferPointer(start: ptr, count: count) + EmbedUtilities.convertToFloat(src, to: &result) + } else { + for i in 0.. Int32 { + var mutableLogits = logits + let vocabSize = mutableLogits.count + + if temperature == 0 { + var maxValue: Float = 0 + var maxIndex: vDSP_Length = 0 + vDSP_maxvi(mutableLogits, 1, &maxValue, &maxIndex, vDSP_Length(vocabSize)) + return Int32(maxIndex) + } + + var temperatureScalar = temperature + vDSP_vsdiv(mutableLogits, 1, &temperatureScalar, &mutableLogits, 1, vDSP_Length(vocabSize)) + + if topK > 0 && topK < vocabSize { + var sorted = mutableLogits + vDSP_vsort(&sorted, vDSP_Length(vocabSize), -1) + let threshold = sorted[topK - 1] + for i in 0.. 0 { + vDSP_vsdiv(mutableLogits, 1, &sum, &mutableLogits, 1, vDSP_Length(vocabSize)) + } + + let randomValue = Float.random(in: 0..<1, using: &rng) + var cumulativeSum: Float = 0 + for i in 0..= randomValue { return Int32(i) } + } + return Int32(vocabSize - 1) + } +} + +// MARK: - Seedable RNG + +/// A seedable random number generator using xoshiro256 algorithm. +/// Produces deterministic sequences for a given seed. +public struct SeededRandomNumberGenerator: RandomNumberGenerator { + private var state: (UInt64, UInt64, UInt64, UInt64) + + public init(seed: UInt64) { + var z = seed &+ 0x9E37_79B9_7F4A_7C15 + z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9 + z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB + let s0 = z ^ (z >> 31) + + z = (seed &+ 2 &* 0x9E37_79B9_7F4A_7C15) + z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9 + z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB + let s1 = z ^ (z >> 31) + + z = (seed &+ 3 &* 0x9E37_79B9_7F4A_7C15) + z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9 + z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB + let s2 = z ^ (z >> 31) + + z = (seed &+ 4 &* 0x9E37_79B9_7F4A_7C15) + z = (z ^ (z >> 30)) &* 0xBF58_476D_1CE4_E5B9 + z = (z ^ (z >> 27)) &* 0x94D0_49BB_1331_11EB + let s3 = z ^ (z >> 31) + + state = (s0, s1, s2, s3) + } + + public mutating func next() -> UInt64 { + let result = rotl(state.1 &* 5, 7) &* 9 + let shifted = state.1 << 17 + state.2 ^= state.0 + state.3 ^= state.1 + state.1 ^= state.2 + state.0 ^= state.3 + state.2 ^= shifted + state.3 = rotl(state.3, 45) + return result + } + + private func rotl(_ x: UInt64, _ k: Int) -> UInt64 { + (x << k) | (x >> (64 - k)) + } +} diff --git a/Sources/TTSKit/Utilities/TTSError.swift b/Sources/TTSKit/Utilities/TTSError.swift new file mode 100644 index 00000000..19a0391e --- /dev/null +++ b/Sources/TTSKit/Utilities/TTSError.swift @@ -0,0 +1,27 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation + +@frozen +public enum TTSError: Error, LocalizedError, Equatable { + case emptyText + case modelNotFound(String) + case generationFailed(String) + case tokenizerUnavailable(String) + case audioOutputFailed(String) + /// A required component URL could not be resolved from the config. + /// Thrown at `loadModels()` time so callers fail early with a clear message. + case invalidConfiguration(String) + + public var errorDescription: String? { + switch self { + case .emptyText: return "Input text is empty" + case let .modelNotFound(path): return "Model directory not found: \(path)" + case let .generationFailed(msg): return "Generation failed: \(msg)" + case let .tokenizerUnavailable(msg): return "Tokenizer unavailable: \(msg)" + case let .audioOutputFailed(msg): return "Audio output failed: \(msg)" + case let .invalidConfiguration(msg): return "Invalid TTSKit configuration: \(msg)" + } + } +} diff --git a/Sources/TTSKit/Utilities/TextChunker.swift b/Sources/TTSKit/Utilities/TextChunker.swift new file mode 100644 index 00000000..b5249340 --- /dev/null +++ b/Sources/TTSKit/Utilities/TextChunker.swift @@ -0,0 +1,115 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import Foundation +import Tokenizers + +/// Strategy for splitting long text into chunks. +@frozen +public enum TextChunkingStrategy: String, Codable, CaseIterable, Sendable { + /// No chunking, generate the full text in a single pass. + case none + /// Split at sentence boundaries using TextChunker. + case sentence +} + +/// Splits long text into sentence-bounded chunks suitable for chunked TTS generation. +/// +/// Uses natural sentence boundaries (. ! ? and newlines) to avoid splitting mid-sentence, +/// with a fallback to clause boundaries (, ; :) and then word boundaries for very long sentences. +/// +/// Algorithm: tokenize the full text, decode windows of `targetChunkSize` tokens, +/// find the last natural boundary in each decoded window, then re-tokenize the accepted +/// chunk to advance by the exact token count - no character estimation needed. +/// +/// `defaultTargetChunkSize` / `defaultMinChunkSize` are the single source of truth for +/// these defaults - all other call sites (TTSGenerationOptions, SpeakAX) reference them. +public struct TextChunker { + /// Default target chunk size (tokens). + public static let defaultTargetChunkSize: Int = 42 + + /// Default minimum chunk size (tokens). + public static let defaultMinChunkSize: Int = 10 + + /// Target chunk size in tokens. Actual chunks may be shorter (sentence boundary) + /// or slightly longer (no boundary found within the window). + public let targetChunkSize: Int + + /// Minimum chunk size in tokens. Short trailing segments are merged into + /// the previous chunk to avoid tiny segments with poor prosody. + public let minChunkSize: Int + + private let encodeText: (String) -> [Int] + private let decodeTokens: ([Int]) -> String + + public init( + targetChunkSize: Int = TextChunker.defaultTargetChunkSize, + minChunkSize: Int = TextChunker.defaultMinChunkSize, + tokenizer: any Tokenizer + ) { + self.targetChunkSize = targetChunkSize + self.minChunkSize = minChunkSize + self.encodeText = { tokenizer.encode(text: $0) } + self.decodeTokens = { tokenizer.decode(tokens: $0) } + } + + /// Internal init for testing - accepts raw encode/decode functions so tests + /// don't need a heavyweight `Tokenizer` conformance. + init( + targetChunkSize: Int = TextChunker.defaultTargetChunkSize, + minChunkSize: Int = TextChunker.defaultMinChunkSize, + encode: @escaping (String) -> [Int], + decode: @escaping ([Int]) -> String + ) { + self.targetChunkSize = targetChunkSize + self.minChunkSize = minChunkSize + self.encodeText = encode + self.decodeTokens = decode + } + + /// Split text into chunks at natural boundaries, respecting the token budget. + /// Returns an array of non-empty text chunks. + public func chunk(_ text: String) -> [String] { + let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return [] } + + var tokens = encodeText(trimmed) + guard tokens.count > targetChunkSize else { return [trimmed] } + + var chunks: [String] = [] + + while !tokens.isEmpty { + if tokens.count <= targetChunkSize { + let tail = decodeTokens(tokens) + .trimmingCharacters(in: .whitespacesAndNewlines) + if !tail.isEmpty { + // Merge tiny trailing segment with previous chunk to avoid poor prosody + if encodeText(tail).count < minChunkSize, let last = chunks.last { + chunks[chunks.count - 1] = last + " " + tail + } else { + chunks.append(tail) + } + } + break + } + + // Decode a window of targetChunkSize tokens, then find the best boundary within it + let window = Array(tokens.prefix(targetChunkSize)) + let windowText = decodeTokens(window) + .trimmingCharacters(in: .whitespacesAndNewlines) + + let accepted = windowText.lastNaturalBoundary(minTokenCount: minChunkSize, encode: encodeText) ?? windowText + + if !accepted.isEmpty { + chunks.append(accepted) + } + + // Re-tokenize the accepted text to advance by its exact token count, + // avoiding drift from imperfect encode/decode round-trips + let consumed = encodeText(accepted).count + tokens.removeFirst(min(max(consumed, 1), tokens.count)) + } + + return chunks + } +} diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 79e286ff..e005bab8 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -26,7 +26,7 @@ public extension AudioChunking { let updatedSegment = TranscriptionUtilities.updateSegmentTimings(segment: segment, seekTime: seekTime) updatedSegments.append(updatedSegment) } - var updatedResult = result + let updatedResult = result updatedResult.seekTime = seekTime updatedResult.segments = updatedSegments updatedTranscriptionResults.append(updatedResult) diff --git a/Sources/WhisperKit/Core/Configurations.swift b/Sources/WhisperKit/Core/Configurations.swift index a52cf28d..fa59ca9f 100644 --- a/Sources/WhisperKit/Core/Configurations.swift +++ b/Sources/WhisperKit/Core/Configurations.swift @@ -13,7 +13,8 @@ open class WhisperKitConfig { public var modelRepo: String? /// Token for downloading models from repo (if required) public var modelToken: String? - + /// HuggingFace Hub compatible endpoint URL + public var modelEndpoint: String? /// Folder to store models public var modelFolder: String? /// Folder to store tokenizers @@ -53,11 +54,11 @@ open class WhisperKitConfig { /// model gets loaded sequentially and unloaded immediately to trigger specialization if necessary. /// /// **Trade-offs** - /// - **Pro** — The peak memory usage during compilation is reduced because + /// - **Pro** - The peak memory usage during compilation is reduced because /// only one model is kept in memory at any given point. Otherwise, the /// peak memory will bloat to all model weights combined plus the peak /// compilation memory (higher than model weights). - /// - **Con** — The load time will be multiplied by 2 (usually <1s when cache is hit) + /// - **Con** - The load time will be multiplied by 2 (usually <1s when cache is hit) /// because of the load-unload-load pattern when the specialized model file cache is /// actually hit and prewarm does not trigger specialization /// @@ -75,6 +76,7 @@ open class WhisperKitConfig { downloadBase: URL? = nil, modelRepo: String? = nil, modelToken: String? = nil, + modelEndpoint: String? = nil, modelFolder: String? = nil, tokenizerFolder: URL? = nil, computeOptions: ModelComputeOptions? = nil, @@ -97,6 +99,7 @@ open class WhisperKitConfig { self.downloadBase = downloadBase self.modelRepo = modelRepo self.modelToken = modelToken + self.modelEndpoint = modelEndpoint self.modelFolder = modelFolder self.tokenizerFolder = tokenizerFolder self.computeOptions = computeOptions diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 7712979f..1eee6beb 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -7,17 +7,7 @@ import CoreML import Hub import NaturalLanguage import Tokenizers - -#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) -public typealias FloatType = Float16 -#else -public typealias FloatType = Float -#endif - -#if (os(macOS) || targetEnvironment(macCatalyst)) && arch(arm64) && compiler(<6) -extension Float16: BNNSScalar {} -extension Float16: MLShapedArrayScalar {} -#endif +@_exported import ArgmaxCore // MARK: - CoreML @@ -101,38 +91,7 @@ public enum ModelVariant: CustomStringConvertible, CaseIterable { } } -@frozen -public enum ModelState: CustomStringConvertible { - case unloading - case unloaded - case loading - case loaded - case prewarming - case prewarmed - case downloading - case downloaded - - public var description: String { - switch self { - case .unloading: - return "Unloading" - case .unloaded: - return "Unloaded" - case .loading: - return "Loading" - case .loaded: - return "Loaded" - case .prewarming: - return "Specializing" - case .prewarmed: - return "Specialized" - case .downloading: - return "Downloading" - case .downloaded: - return "Downloaded" - } - } -} +// ModelState is defined in ArgmaxCore/ModelState.swift and re-exported here. public struct ModelComputeOptions: Sendable { public var melCompute: MLComputeUnits @@ -502,10 +461,11 @@ public struct DecodingResult { } /// Reference-type container for transcription output. -/// The stored properties stay thread-safe because each one uses -/// `TranscriptionPropertyLock`, so reads/writes hop through a private `NSLock` -/// before the value is accessed, making this shared `@unchecked Sendable` class -/// safe to hand across concurrent contexts. +/// +/// Each property is protected by its own `TranscriptionPropertyLock`, which +/// serializes whole-value reads and writes. Atomic whole-value replacement is +/// thread-safe; read-modify-write operations (e.g. `result.segments.append(...)`) +/// are not - callers must use external synchronisation. open class TranscriptionResult: Codable, @unchecked Sendable { @TranscriptionPropertyLock public var text: String @TranscriptionPropertyLock public var segments: [TranscriptionSegment] @@ -732,11 +692,6 @@ public struct TranscriptionProgress: Sendable { public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void /// A callback that reports changes in the model's state. -/// - Parameters: -/// - oldState: The previous state of the model, if any -/// - newState: The current state of the model -public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void - /// A callback that reports changes in the transcription process. /// - Parameter state: The current `TranscriptionState` of the transcription process public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 4d717b6e..65815366 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -2,6 +2,7 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import Accelerate +import ArgmaxCore import CoreML import Tokenizers diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 5349ee99..0cc30004 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -77,7 +77,8 @@ open class WhisperKit { modelRepo: config.modelRepo, modelToken: config.modelToken, modelFolder: config.modelFolder, - download: config.download + download: config.download, + endpoint: config.modelEndpoint ?? Constants.defaultRemoteEndpoint ) if let prewarm = config.prewarm, prewarm { @@ -173,6 +174,7 @@ open class WhisperKit { remoteConfigName: remoteConfigName, endpoint: endpoint ) + return ModelUtilities.modelSupport(for: deviceName, from: config) } diff --git a/Sources/WhisperKit/Utilities/Concurrency.swift b/Sources/WhisperKit/Utilities/Concurrency.swift index 92b82f74..57bdcafe 100644 --- a/Sources/WhisperKit/Utilities/Concurrency.swift +++ b/Sources/WhisperKit/Utilities/Concurrency.swift @@ -1,89 +1,8 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. -import Foundation -import os.lock +import ArgmaxCore -/// An actor that provides thread-safe early stopping functionality using UUIDs as keys -public actor EarlyStopActor { - private var shouldStop = [UUID: Bool]() - - public init() {} - - /// Sets the stop flag for a given UUID - /// - Parameters: - /// - value: The boolean value to set - /// - uuid: The UUID key - public func set(_ value: Bool, for uuid: UUID) { - shouldStop[uuid] = value - } - - /// Gets the stop flag for a given UUID - /// - Parameter uuid: The UUID key - /// - Returns: The current stop flag value, or false if not set - public func get(for uuid: UUID) -> Bool { - return shouldStop[uuid] ?? false - } - - /// Removes and returns the stop flag for a given UUID - /// - Parameter uuid: The UUID key - /// - Returns: The removed stop flag value, if it existed - public func remove(for uuid: UUID) -> Bool? { - return shouldStop.removeValue(forKey: uuid) - } -} - -/// Serializes access to a value with an `os_unfair_lock` so mutation stays -/// thread-safe. The wrapper is used by `TranscriptionResult`, which is marked -/// `@unchecked Sendable`; guarding each property with this lock helps keep the -/// result instance safe when shared across concurrent contexts. -@propertyWrapper -public struct TranscriptionPropertyLock: Sendable, Codable { - private let lock: UnfairLock - private var value: Value - - public init(wrappedValue: Value) { - self.lock = UnfairLock() - self.value = wrappedValue - } - public init(from decoder: Swift.Decoder) throws { - self.lock = UnfairLock() - self.value = try Value(from: decoder) - } - - public func encode(to encoder: Encoder) throws { - try lock.withLock { - try value.encode(to: encoder) - } - - } - - public var wrappedValue: Value { - get { - lock.withLock { - return value - } - } - set { - lock.withLock { - value = newValue - } - } - } -} - -/// Thin wrapper around `os_unfair_lock` that exposes a Swift-friendly -/// `withLock` helper. This lock is non-reentrant and optimized for low -/// contention, matching the semantics of Core Foundation’s unfair lock. -@usableFromInline -final class UnfairLock: @unchecked Sendable { - @usableFromInline - var lock = os_unfair_lock() - - @inlinable - func withLock(_ body: () throws -> T) rethrows -> T { - os_unfair_lock_lock(&lock) - defer { os_unfair_lock_unlock(&lock) } - return try body() - } -} +// Backward compatibility: TranscriptionPropertyLock is now PropertyLock in ArgmaxCore. +// Existing consumers can continue using the old name. +public typealias TranscriptionPropertyLock = PropertyLock diff --git a/Sources/WhisperKit/Utilities/Extensions+Internal.swift b/Sources/WhisperKit/Utilities/Extensions+Internal.swift index b99af3bf..0f4432c1 100644 --- a/Sources/WhisperKit/Utilities/Extensions+Internal.swift +++ b/Sources/WhisperKit/Utilities/Extensions+Internal.swift @@ -4,32 +4,6 @@ import AVFoundation import CoreML -extension MLMultiArray { - /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) - static func from(_ array: [Int], dims: Int = 1) throws -> MLMultiArray { - var shape = Array(repeating: 1, count: dims) - shape[shape.count - 1] = array.count - /// Examples: - /// dims=1 : [arr.count] - /// dims=2 : [1, arr.count] - /// - let output = try MLMultiArray(shape: shape as [NSNumber], dataType: .int32) - let pointer = UnsafeMutablePointer(OpaquePointer(output.dataPointer)) - for (i, item) in array.enumerated() { - pointer[i] = Int32(item) - } - return output - } -} - -extension Array { - func batched(into size: Int) -> [[Element]] { - return stride(from: 0, to: count, by: size).map { - Array(self[$0.. { /// Convenience method to convert the `Result` object into an array of optional arrays of `TranscriptionResult`. /// - Returns: An array of optional arrays containing `TranscriptionResult`. @@ -38,41 +12,6 @@ extension Array where Element == Result<[TranscriptionResult], Swift.Error> { } } -extension Array where Element: Hashable { - /// Returns an array with duplicates removed, preserving the original order. - var orderedSet: [Element] { - var seen = Set() - return self.filter { element in - if seen.contains(element) { - return false - } else { - seen.insert(element) - return true - } - } - } -} - -extension String { - /// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Tokenizers/Decoder.swift#L267-L275 - func trimmingFromEnd(character: Character = " ", upto: Int) -> String { - var result = self - var trimmed = 0 - while trimmed < upto && result.last == character { - result.removeLast() - trimmed += 1 - } - return result - } -} - -extension [String] { - /// Reference: https://github.com/huggingface/swift-transformers/blob/94610577e4af9bbc267060af1e25e977604dd796/Sources/Hub/HubApi.swift#L983-L987 - func matching(glob: String) -> [String] { - filter { fnmatch(glob, $0, 0) == 0 } - } -} - extension AVAudioPCMBuffer { /// Converts the buffer to a float array func asFloatArray() throws -> [Float] { diff --git a/Sources/WhisperKit/Utilities/Extensions+Public.swift b/Sources/WhisperKit/Utilities/Extensions+Public.swift index 0cca123e..c109c3be 100644 --- a/Sources/WhisperKit/Utilities/Extensions+Public.swift +++ b/Sources/WhisperKit/Utilities/Extensions+Public.swift @@ -20,13 +20,6 @@ public extension WhisperKit { } } -public extension Float { - func rounded(_ decimalPlaces: Int) -> Float { - let divisor = pow(10.0, Float(decimalPlaces)) - return (self * divisor).rounded() / divisor - } -} - public extension String { var normalized: String { // Convert to lowercase @@ -52,245 +45,6 @@ public extension String { } } -// MARK: CoreML - -public extension MLMultiArray { - convenience init(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) throws { - switch dataType { - case .float16: - // IOSurface-backed arrays are implicitly float16. They can - // reduce buffer copies for some OS:compute unit combinations. - guard let pixelBuffer = Self.pixelBuffer(for: shape) else { - throw WhisperError.initializationError("MLMultiArray: Failed to initialize PixelBuffer") - } - self.init(pixelBuffer: pixelBuffer, shape: shape) - default: - try self.init(shape: shape, dataType: dataType) - } - - switch dataType { - case .double: - if let value = initialValue as? Double { - let typedPointer = dataPointer.bindMemory(to: Double.self, capacity: count) - typedPointer.initialize(repeating: value, count: count) - } - case .float32: - if let value = initialValue as? Float { - let typedPointer = dataPointer.bindMemory(to: Float.self, capacity: count) - typedPointer.initialize(repeating: value, count: count) - } - case .float16: - if let value = initialValue as? FloatType { - let typedPointer = dataPointer.bindMemory(to: FloatType.self, capacity: count) - typedPointer.initialize(repeating: value, count: count) - } - case .int32: - if let value = initialValue as? Int32 { - let typedPointer = dataPointer.bindMemory(to: Int32.self, capacity: count) - typedPointer.initialize(repeating: value, count: count) - } - @unknown default: - fatalError("Unsupported data type") - } - } - - /// Calculate the linear offset by summing the products of each dimension's index with the dimension's stride. - /// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript) - /// - Parameters: - /// - index: The index of the element - /// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes. - @inline(__always) - func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { - var linearOffset = 0 - let strideInts = strideInts ?? strides.map { $0.intValue } - for (dimension, stride) in zip(index, strideInts) { - linearOffset += dimension.intValue * stride - } - return linearOffset - } - - func fillLastDimension(indexes: Range, with value: FloatType) { - precondition(shape.count == 3 && shape[0] == 1 && shape[1] == 1, "Must have [1, 1, n] shape") - withUnsafeMutableBufferPointer(ofType: FloatType.self) { ptr, strides in - for index in indexes { - ptr[index * strides[2]] = value - } - } - } - - func fill(indexes: [[NSNumber]], with value: Value) { - let pointer = UnsafeMutablePointer(OpaquePointer(dataPointer)) - let strideInts = strides.map { $0.intValue } - for index in indexes { - let linearOffset = linearOffset(for: index, strides: strideInts) - pointer[linearOffset] = value - } - } - - private class func pixelBuffer(for shape: [NSNumber]) -> CVPixelBuffer? { - guard let width = shape.last?.intValue else { return nil } - let height = shape[0.. [Int] { - let semaphore = DispatchSemaphore(value: 0) - var result: [Int] = [] - - Task(priority: .high) { - result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) } - semaphore.signal() - } - - semaphore.wait() - return result - } - - func asFloatArray() -> [Float] { - let semaphore = DispatchSemaphore(value: 0) - let tensorType = self.scalarType - - var result: [Float] = [] - - Task(priority: .high) { - switch tensorType { - case is Float32.Type: - result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) } - case is FloatType.Type: - result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) } - case is Float.Type: - result = await self.shapedArray(of: Float.self).scalars.map { Float($0) } - case is Int32.Type: - result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) } - default: - fatalError("Unsupported data type") - } - semaphore.signal() - } - - semaphore.wait() - return result - } - - func asMLMultiArray() -> MLMultiArray { - let semaphore = DispatchSemaphore(value: 0) - let tensorType = self.scalarType - - var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0) - - Task(priority: .high) { - switch tensorType { - case is Float32.Type: - result = MLMultiArray(await self.shapedArray(of: Float32.self)) - case is FloatType.Type: - result = MLMultiArray(await self.shapedArray(of: FloatType.self)) - case is Float.Type: - result = MLMultiArray(await self.shapedArray(of: Float.self)) - case is Int32.Type: - result = MLMultiArray(await self.shapedArray(of: Int32.self)) - default: - fatalError("Unsupported data type") - } - semaphore.signal() - } - - semaphore.wait() - return result - } -} -#endif - -public extension MLModel { - func asyncPrediction( - from input: MLFeatureProvider, - options: MLPredictionOptions - ) async throws -> MLFeatureProvider { - if #available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) { - return try await prediction(from: input, options: options) - } else { - return try await Task { - try prediction(from: input, options: options) - }.value - } - } -} - -public extension MLComputeUnits { - var description: String { - switch self { - case .cpuOnly: - return "cpuOnly" - case .cpuAndGPU: - return "cpuAndGPU" - case .all: - return "all" - case .cpuAndNeuralEngine: - return "cpuAndNeuralEngine" - @unknown default: - return "unknown" - } - } -} - -#if os(macOS) || targetEnvironment(simulator) -// From: https://stackoverflow.com/a/71726663 -public extension ProcessInfo { - static func stringFromSysctl(named name: String) -> String { - var size: size_t = 0 - sysctlbyname(name, nil, &size, nil, 0) - var machineModel = [CChar](repeating: 0, count: Int(size)) - sysctlbyname(name, &machineModel, &size, nil, 0) - return String(cString: machineModel) - } - - static let processor = stringFromSysctl(named: "machdep.cpu.brand_string") - static let cores = stringFromSysctl(named: "machdep.cpu.core_count") - static let threads = stringFromSysctl(named: "machdep.cpu.thread_count") - static let vendor = stringFromSysctl(named: "machdep.cpu.vendor") - static let family = stringFromSysctl(named: "machdep.cpu.family") - static let hwModel = stringFromSysctl(named: "hw.model") -} -#endif - -// MARK: FileManager - -public extension FileManager { - static func resolveAbsolutePath(_ inputPath: String) -> String { - let fileManager = FileManager.default - - // Expanding tilde if present - let pathWithTildeExpanded = NSString(string: inputPath).expandingTildeInPath - - // If the path is already absolute, return it - if pathWithTildeExpanded.hasPrefix("/") { - return pathWithTildeExpanded - } - - // Resolving relative path based on the current working directory - if let cwd = fileManager.currentDirectoryPath as String? { - let resolvedPath = URL(fileURLWithPath: cwd).appendingPathComponent(pathWithTildeExpanded).path - return resolvedPath - } - - return inputPath - } -} - @available(*, deprecated, message: "Subject to removal in a future version. Use `FileManager.resolveAbsolutePath(_:)` instead.") public func resolveAbsolutePath(_ inputPath: String) -> String { return FileManager.resolveAbsolutePath(inputPath) diff --git a/Sources/WhisperKit/Utilities/Logging.swift b/Sources/WhisperKit/Utilities/Logging.swift index bea4baeb..5819f278 100644 --- a/Sources/WhisperKit/Utilities/Logging.swift +++ b/Sources/WhisperKit/Utilities/Logging.swift @@ -1,92 +1,10 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import ArgmaxCore import OSLog -open class Logging { - public static let shared = Logging() - public var logLevel: LogLevel = .none - - public typealias LoggingCallback = (_ message: String) -> Void - public var loggingCallback: LoggingCallback? - - private let logger = OSLog(subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.whisperkit", category: "WhisperKit") - - @frozen - public enum LogLevel: Int { - case debug = 1 - case info = 2 - case error = 3 - case none = 4 - - func shouldLog(level: LogLevel) -> Bool { - return self.rawValue <= level.rawValue - } - } - - private init() {} - - public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) { - let message = items.map { "\($0)" }.joined(separator: separator) - if let logger = loggingCallback { - logger(message) - } else { - os_log("%{public}@", log: logger, type: type, message) - } - } - - public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") { - if shared.logLevel.shouldLog(level: .debug) { - shared.log(items, separator: separator, terminator: terminator, type: .debug) - } - } - - public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") { - if shared.logLevel.shouldLog(level: .info) { - shared.log(items, separator: separator, terminator: terminator, type: .info) - } - } - - public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") { - if shared.logLevel.shouldLog(level: .error) { - shared.log(items, separator: separator, terminator: terminator, type: .error) - } - } -} - -public extension Logging { - static func logCurrentMemoryUsage(_ message: String) { - let memoryUsage = getMemoryUsage() - Logging.debug("\(message) - Memory usage: \(memoryUsage) MB") - } - - static func getMemoryUsage() -> UInt64 { - var info = mach_task_basic_info() - var count = mach_msg_type_number_t(MemoryLayout.size) / 4 - - let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) { - $0.withMemoryRebound(to: integer_t.self, capacity: 1) { - task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count) - } - } - - guard kerr == KERN_SUCCESS else { - return 0 // If the call fails, return 0 - } - - return info.resident_size / 1024 / 1024 // Convert to MB - } -} - -@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.logCurrentMemoryUsage(_:)` instead.") -public func logCurrentMemoryUsage(_ message: String) { - Logging.logCurrentMemoryUsage(message) -} - -@available(*, deprecated, message: "Subject to removal in a future version. Use `Logging.getMemoryUsage()` instead.") -public func getMemoryUsage() -> UInt64 { - return Logging.getMemoryUsage() -} +// MARK: - WhisperKit signpost categories extension Logging { enum AudioEncoding { @@ -133,11 +51,5 @@ extension Logging { return String(format: "%.2f", timestamp) } - static func formatTimeWithPercentage(_ time: Double, _ runs: Double, _ fullPipelineDuration: Double) -> String { - let percentage = (time * 1000 / fullPipelineDuration) * 100 // Convert to percentage - let runTime = runs > 0 ? time * 1000 / Double(runs) : 0 - let formattedString = String(format: "%8.2f ms / %6.0f runs (%8.2f ms/run) %5.2f%%", time * 1000, runs, runTime, percentage) - return formattedString - } +// formatTimeWithPercentage is defined in ArgmaxCore/Logging.swift and re-exported here. } - diff --git a/Sources/WhisperKit/Utilities/ModelUtilities.swift b/Sources/WhisperKit/Utilities/ModelUtilities.swift index 5af4d950..be5296cd 100644 --- a/Sources/WhisperKit/Utilities/ModelUtilities.swift +++ b/Sources/WhisperKit/Utilities/ModelUtilities.swift @@ -1,15 +1,14 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import ArgmaxCore import CoreML import Hub import Tokenizers -public struct ModelUtilities { +extension ModelUtilities { - private init() {} - - // MARK: Public + // MARK: - WhisperKit Model Support public static func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> ModelSupport { let config = config ?? Constants.fallbackModelSupportConfig @@ -85,22 +84,6 @@ public struct ModelUtilities { at: hubTokenizerFolder ) } - - public static func detectModelURL(inFolder path: URL, named modelName: String) -> URL { - let compiledUrl = path.appending(path: "\(modelName).mlmodelc") - let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel") - - let compiledModelExists: Bool = FileManager.default.fileExists(atPath: compiledUrl.path) - let packageModelExists: Bool = FileManager.default.fileExists(atPath: packageUrl.path) - - // Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc - var modelURL = compiledUrl - if packageModelExists && !compiledModelExists { - modelURL = packageUrl - } - - return modelURL - } /// Formats and sorts model file names based on model variants /// @@ -198,40 +181,6 @@ public struct ModelUtilities { return modelVariant } - static func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { - guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil } - guard inputDescription.type == .multiArray else { return nil } - guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[position] - } - - static func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { - guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil } - guard inputDescription.type == .multiArray else { return nil } - guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[position] - } - - func getModelInputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { - guard let inputDescription = model?.modelDescription.inputDescriptionsByName[named] else { return nil } - guard inputDescription.type == .multiArray else { return nil } - guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[position] - } - - func getModelOutputDimention(_ model: MLModel?, named: String, position: Int) -> Int? { - guard let inputDescription = model?.modelDescription.outputDescriptionsByName[named] else { return nil } - guard inputDescription.type == .multiArray else { return nil } - guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[position] - } - - // MARK: Private - internal static func tokenizerNameForVariant(_ variant: ModelVariant) -> String { var tokenizerName: String switch variant { diff --git a/Sources/WhisperKitCLI/TTSCLI.swift b/Sources/WhisperKitCLI/TTSCLI.swift new file mode 100644 index 00000000..7594328a --- /dev/null +++ b/Sources/WhisperKitCLI/TTSCLI.swift @@ -0,0 +1,279 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import ArgumentParser +import CoreML +import Foundation +import TTSKit +import WhisperKit + +// MARK: - CLI-only conformances for ArgumentParser + +extension Qwen3Speaker: ExpressibleByArgument {} +extension Qwen3Language: ExpressibleByArgument {} +extension TTSModelVariant: ExpressibleByArgument {} + +// MARK: - CLI Command + +struct TTSCLI: AsyncParsableCommand { + static let configuration = CommandConfiguration( + commandName: "tts", + abstract: "Generate speech from text using Qwen3-TTS" + ) + + // MARK: - Required (one of text or text-file) + + @Option(name: .long, help: "Text to synthesize") + var text: String? + + @Option(name: .long, help: "Read text from a file (supports .txt and .md)") + var textFile: String? + + // MARK: - Common options + + @Option(name: .long, help: "Speaker voice (aiden, ryan, ono-anna, sohee, eric, dylan, serena, vivian, uncle-fu)") + var speaker: Qwen3Speaker = .aiden + + @Option(name: .long, help: "Language (english, chinese, japanese, korean)") + var language: Qwen3Language = .english + + @Option(name: .long, help: "Output audio file path") + var outputPath: String = "output" + + @Option(name: .long, help: "Output audio format (m4a or wav)") + var outputFormat: String = "m4a" + + @Flag(name: .long, help: "Play audio through speakers in real time") + var play: Bool = false + + @Flag(name: .long, help: "Enable verbose output") + var verbose: Bool = false + + // MARK: - Generation options + + @Option(name: .long, help: "Sampling temperature (0.0 for greedy)") + var temperature: Float = 0.9 + + @Option(name: .long, help: "Top-k sampling (0 to disable)") + var topK: Int = 50 + + @Option(name: .long, help: "Max RVQ frames to generate") + var maxNewTokens: Int = 245 + + @Option(name: .long, help: "Concurrent chunk workers (0=max, 1=sequential, N=batch size). Defaults to 1 with --play, 0 otherwise.") + var concurrentWorkerCount: Int? + + @Option(name: .long, help: "Target chunk size in characters for sentence splitting") + var targetChunkSize: Int = TextChunker.defaultTargetChunkSize + + @Option(name: .long, help: "Minimum chunk size in characters (short tails merge into previous chunk)") + var minChunkSize: Int = TextChunker.defaultMinChunkSize + + @Option(name: .long, help: "Style instruction (e.g., \"Speak slowly and softly\"). Only supported by the 1.7B model.") + var instruction: String? + + @Option(name: .long, help: "Random seed for reproducible output") + var seed: UInt64? + + // MARK: - Model selection + + @Option(name: .long, help: "Model preset (0.6b, 1.7b). Auto-configures version dir and variant defaults.") + var model: TTSModelVariant = .qwen3TTS_0_6b + + // MARK: - Advanced options (auto-configured by preset, can be overridden) + + @Option(name: .long, help: "Local model directory (skips download if provided)") + var modelsPath: String? + + @Option(name: .long, help: "HuggingFace repo for model download") + var modelRepo: String = Qwen3TTSConstants.defaultModelRepo + + @Option(name: .long, help: "Model version directory (overrides --model preset)") + var versionDir: String? + + @Option(name: .long, help: "HuggingFace tokenizer repo or local path") + var tokenizer: String? + + @Option(name: .long, help: "HuggingFace API token (for private repos, or set HF_TOKEN env var)") + var token: String? + + @Option(name: .long, help: "HuggingFace Hub compatible endpoint URL") + var endpoint: String? + + @Option(name: .long, help: "CodeDecoder variant (overrides --model preset)") + var codeDecoderVariant: String? + + @Option(name: .long, help: "MultiCodeDecoder variant (overrides --model preset)") + var multiCodeDecoderVariant: String? + + @Option(name: .long, help: "SpeechDecoder variant (overrides --model preset)") + var speechDecoderVariant: String? + + // MARK: - Compute unit options + + @Option(name: .long, help: "Compute units for embedders (TextProjector, CodeEmbedder, MultiCodeEmbedder) {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var embedderComputeUnits: ComputeUnits = .cpuOnly + + @Option(name: .long, help: "Compute units for CodeDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var codeDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine + + @Option(name: .long, help: "Compute units for MultiCodeDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var multiCodeDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine + + @Option(name: .long, help: "Compute units for SpeechDecoder {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}") + var speechDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine + + func run() async throws { + if verbose { + Logging.shared.loggingCallback = { + print("[TTSKit] \($0)") + } + } + // Resolve text from --text or --text-file + let inputText: String + if let textFile { + let resolvedPath = FileManager.resolveAbsolutePath(textFile) + inputText = try String(contentsOfFile: resolvedPath, encoding: .utf8) + .trimmingCharacters(in: .whitespacesAndNewlines) + } else if let text { + inputText = text + } else { + throw ValidationError("Either --text or --text-file must be provided") + } + + guard !inputText.isEmpty else { + throw ValidationError("Input text is empty") + } + + // Resolve local models path if provided + let resolvedModelFolder: URL? = modelsPath.map { + URL(fileURLWithPath: FileManager.resolveAbsolutePath($0)) + } + // Resolve tokenizer: pass as URL — loadTokenizer falls back to HF Hub if the path doesn't exist + let resolvedTokenizerFolder: URL? = tokenizer.map { + URL(fileURLWithPath: $0) + } + + // Build config from preset — explicit CLI overrides replace preset defaults. + let config = TTSKitConfig( + model: model, + modelFolder: resolvedModelFolder, + modelRepo: modelRepo, + tokenizerFolder: resolvedTokenizerFolder, + modelToken: token, + modelEndpoint: endpoint ?? Qwen3TTSConstants.defaultEndpoint, + versionDir: versionDir, + codeDecoderVariant: codeDecoderVariant, + multiCodeDecoderVariant: multiCodeDecoderVariant, + speechDecoderVariant: speechDecoderVariant, + computeOptions: ComputeOptions( + embedderComputeUnits: embedderComputeUnits.asMLComputeUnits, + codeDecoderComputeUnits: codeDecoderComputeUnits.asMLComputeUnits, + multiCodeDecoderComputeUnits: multiCodeDecoderComputeUnits.asMLComputeUnits, + speechDecoderComputeUnits: speechDecoderComputeUnits.asMLComputeUnits + ), + verbose: verbose + ) + + // Default: --play uses sequential (1), file output uses unlimited (0). + let effectiveWorkerCount = concurrentWorkerCount ?? (play ? 1 : 0) + + // Always use a seed for reproducibility -- generate one if not provided + let effectiveSeed = seed ?? UInt64.random(in: 0...UInt64(UInt32.max)) + + // Warn if instruction is used with a model that doesn't support it + var effectiveInstruction = instruction + if let instruction = effectiveInstruction, !instruction.isEmpty, model == .qwen3TTS_0_6b { + print("Warning: --instruction is only supported by the 1.7B model variant. Ignoring instruction for \(model.rawValue).") + effectiveInstruction = nil + } + + if verbose { + print("Qwen3-TTS Pipeline") + if textFile != nil { + print(" Text file: \(textFile!)") + } + print(" Text: \"\(inputText.prefix(80))\(inputText.count > 80 ? "..." : "")\"") + print(" Speaker: \(speaker.rawValue)") + print(" Language: \(language.rawValue)") + print(" Model: \(model.rawValue)") + if let inst = effectiveInstruction { + print(" Instruction: \"\(inst)\"") + } + if let folder = resolvedModelFolder { + print(" Models: \(folder.path)") + } else { + print(" Models: \(config.modelRepo) (auto-download)") + } + print(" Version: \(config.versionDir)") + print(" CodeDecoder: \(config.codeDecoderVariant)") + print(" MultiCodeDecoder: \(config.multiCodeDecoderVariant)") + print(" SpeechDecoder: \(config.speechDecoderVariant)") + print(" Embedder compute: \(embedderComputeUnits.rawValue)") + print(" CodeDecoder compute: \(codeDecoderComputeUnits.rawValue)") + print(" MultiCodeDecoder compute: \(multiCodeDecoderComputeUnits.rawValue)") + print(" SpeechDecoder compute: \(speechDecoderComputeUnits.rawValue)") + print(" Output: \(outputPath).\(outputFormat.lowercased())") + print(" Temperature: \(temperature)") + print(" Top-k: \(topK)") + print(" Play: \(play)") + let workerDesc = effectiveWorkerCount == 0 ? "max" : "\(effectiveWorkerCount)" + print(" Concurrency: \(workerDesc) (chunking: sentence)") + print(" Seed: \(effectiveSeed)") + } + + // Initialize pipeline (downloads if needed, loads tokenizer + 6 models concurrently) + config.seed = effectiveSeed + let tts = try await TTSKit(config) + + let options = GenerationOptions( + temperature: temperature, + topK: topK, + repetitionPenalty: 1.05, + maxNewTokens: maxNewTokens, + concurrentWorkerCount: effectiveWorkerCount, + targetChunkSize: targetChunkSize, + minChunkSize: minChunkSize, + instruction: effectiveInstruction + ) + + let result: SpeechResult + if play { + result = try await tts.play( + text: inputText, + speaker: speaker, + language: language, + options: options + ) + } else { + result = try await tts.generate( + text: inputText, + speaker: speaker, + language: language, + options: options + ) + } + + let format = AudioOutput.AudioFileFormat(rawValue: outputFormat.lowercased()) ?? .m4a + let outputURL = URL(fileURLWithPath: outputPath) + let outputFolder = outputURL.deletingLastPathComponent().path == "." + ? URL(fileURLWithPath: FileManager.default.currentDirectoryPath) + : outputURL.deletingLastPathComponent() + + let savedURL = try await AudioOutput.saveAudio( + result.audio, + toFolder: outputFolder, + filename: outputURL.lastPathComponent, + sampleRate: result.sampleRate, + format: format + ) + + result.logTimings() + + if verbose { + print(String(format: "Generated %.2fs of audio -> %@", result.audioDuration, savedURL.path)) + } else { + print(savedURL.path) + } + } +} diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 0ce09a4f..3701d725 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -5,6 +5,7 @@ import ArgumentParser import CoreML import Foundation import WhisperKit +import TTSKit struct TranscribeCLI: AsyncParsableCommand { static let configuration = CommandConfiguration( diff --git a/Sources/WhisperKitCLI/TranscribeCLIArguments.swift b/Sources/WhisperKitCLI/TranscribeCLIArguments.swift index eb5d1d47..0f54f1be 100644 --- a/Sources/WhisperKitCLI/TranscribeCLIArguments.swift +++ b/Sources/WhisperKitCLI/TranscribeCLIArguments.swift @@ -25,6 +25,9 @@ struct TranscribeCLIArguments: ParsableArguments { @Option(help: "Path to save the downloaded tokenizer files") var downloadTokenizerPath: String? + @Option(name: .long, help: "HuggingFace Hub compatible endpoint URL") + var endpoint: String? + @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine diff --git a/Sources/WhisperKitCLI/TranscribeCLIUtils.swift b/Sources/WhisperKitCLI/TranscribeCLIUtils.swift index 98dbfe41..586429b0 100644 --- a/Sources/WhisperKitCLI/TranscribeCLIUtils.swift +++ b/Sources/WhisperKitCLI/TranscribeCLIUtils.swift @@ -31,6 +31,7 @@ internal class TranscribeCLIUtils { return WhisperKitConfig( model: modelName, downloadBase: downloadModelFolder, + modelEndpoint: arguments.endpoint, modelFolder: arguments.modelPath, tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, diff --git a/Sources/WhisperKitCLI/WhisperKitCLI.swift b/Sources/WhisperKitCLI/WhisperKitCLI.swift index 3ceec64b..5ed6cc50 100644 --- a/Sources/WhisperKitCLI/WhisperKitCLI.swift +++ b/Sources/WhisperKitCLI/WhisperKitCLI.swift @@ -8,9 +8,9 @@ let VERSION: String = "development" var subcommands: [ParsableCommand.Type] { #if BUILD_SERVER_CLI - [TranscribeCLI.self, ServeCLI.self] + [TranscribeCLI.self, TTSCLI.self, ServeCLI.self] #else - [TranscribeCLI.self] + [TranscribeCLI.self, TTSCLI.self] #endif } diff --git a/Tests/TTSKitTests/TTSKitIntegrationTests.swift b/Tests/TTSKitTests/TTSKitIntegrationTests.swift new file mode 100644 index 00000000..48f7b23a --- /dev/null +++ b/Tests/TTSKitTests/TTSKitIntegrationTests.swift @@ -0,0 +1,495 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import AVFoundation +import CoreML +import Foundation +@testable import TTSKit +import XCTest + +// MARK: - Helpers + +extension XCTestCase { + /// Create a `TTSKit` instance for integration tests. + /// + /// Downloads models from the Hub if not already cached on disk, matching the + /// same pattern as `tinyModelPath()` in WhisperKit's test suite. The test + /// will fail (not skip) if the download fails. + func makeCachedTTS( + model: TTSModelVariant = .qwen3TTS_0_6b, + seed: UInt64 = 42 + ) async throws -> TTSKit { + let config = TTSKitConfig(model: model, verbose: true, logLevel: .debug, seed: seed) + return try await TTSKit(config) + } +} + +// MARK: - Integration Tests + +final class TTSKitIntegrationTests: XCTestCase { + + // MARK: - Basic generation + + /// Short sentence -> non-empty audio at 24kHz. + func testBasicShortGeneration() async throws { + let tts = try await makeCachedTTS(seed: 42) + let result = try await tts.generate( + text: "Hello, this is a basic smoke test of the WhisperKit TTS pipeline.", + speaker: .ryan, + language: .english, + options: GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245) + ) + + XCTAssertGreaterThan(result.audio.count, 0, "Audio samples should be non-empty") + XCTAssertGreaterThan(result.audioDuration, 1.0, "Expect at least 1s of speech") + XCTAssertLessThan(result.audioDuration, 30.0, "Short sentence should stay under 30s") + XCTAssertEqual(result.sampleRate, 24000, "Sample rate should be 24kHz") + } + + // MARK: - Long text with sentence chunking + + /// 200+ word multi-paragraph text - chunked, all audio produced, no truncation crash. + func testLongTextSequentialChunks() async throws { + let tts = try await makeCachedTTS(seed: 42) + let longText = """ + The history of artificial intelligence begins in antiquity, with myths, stories, \ + and rumors of artificial beings endowed with intelligence by master craftsmen. \ + The seeds of modern AI were planted by classical philosophers who attempted to \ + describe the process of human thinking as the mechanical manipulation of symbols. \ + In 1950, Alan Turing published a landmark paper in which he speculated about the \ + possibility of creating machines that think. Fast forward to the 2020s, and large \ + language models have transformed every domain from science to storytelling. \ + Text-to-speech systems now produce voices so natural that listeners struggle to \ + distinguish them from real human speakers. + """ + + let result = try await tts.generate( + text: longText, + speaker: .aiden, + language: .english, + options: GenerationOptions(concurrentWorkerCount: 1) + ) + + XCTAssertGreaterThan(result.audioDuration, 20.0, "Long text should produce substantial audio") + XCTAssertGreaterThan(result.timings.totalDecodingLoops, 100, "Should require many decode steps") + } + + /// Same long text with unlimited concurrent workers + func testLongTextUnlimitedConcurrentWorkers() async throws { + let tts = try await makeCachedTTS(seed: 42) + let longText = """ + First paragraph about science. It covers many interesting topics in depth. \ + Second paragraph switches to history. Many events happened over the centuries. \ + Third paragraph explores music theory. Harmony and rhythm define the art form. + """ + + let result = try await tts.generate( + text: longText, + speaker: .aiden, + language: .english, + options: GenerationOptions(concurrentWorkerCount: 0) + ) + + XCTAssertGreaterThan(result.audioDuration, 5.0) + } + + // MARK: - Concurrent workers + + /// workers=2 on a 3-chunk text - verifies the batching loop handles partial batches. + func testConcurrentWorkersBatching() async throws { + let tts = try await makeCachedTTS(seed: 7) + let threeChunkText = """ + First sentence for chunk one. Second sentence for chunk two. Third sentence for chunk three. + """ + let result = try await tts.generate( + text: threeChunkText, + speaker: .serena, + language: .english, + options: GenerationOptions( + concurrentWorkerCount: 2, + targetChunkSize: 10, + minChunkSize: 2 + ) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + XCTAssertGreaterThan(result.audioDuration, 2.0) + } + + // MARK: - Reproducibility + + /// Two runs with the same seed and temperature=0 should produce byte-identical audio. + func testDeterministicOutputWithFixedSeed() async throws { + let text = "Reproducibility test with a fixed seed." + let options = GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 100) + + let tts1 = try await makeCachedTTS(seed: 42) + let result1 = try await tts1.generate( + text: text, speaker: .ryan, language: .english, options: options + ) + + let tts2 = try await makeCachedTTS(seed: 42) + let result2 = try await tts2.generate( + text: text, speaker: .ryan, language: .english, options: options + ) + + XCTAssertEqual(result1.audio.count, result2.audio.count, "Step counts must match") + // Compare raw float bytes for strict reproducibility + let data1 = Data(bytes: result1.audio, count: result1.audio.count * MemoryLayout.size) + let data2 = Data(bytes: result2.audio, count: result2.audio.count * MemoryLayout.size) + XCTAssertEqual(data1, data2, "Audio samples must be byte-identical with the same seed and greedy decoding") + } + + // MARK: - Speaker voices + + /// All 9 Qwen3Speaker voices should produce non-empty audio without error. + func testAllSpeakerVoices() async throws { + let tts = try await makeCachedTTS(seed: 1) + let text = "Testing this voice." + let options = GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 60) + + for speaker in Qwen3Speaker.allCases { + let result = try await tts.generate( + text: text, speaker: speaker, language: .english, options: options + ) + XCTAssertGreaterThan(result.audio.count, 0, "Speaker \(speaker.rawValue) produced no audio") + } + } + + // MARK: - Language support + + /// Korean text with a Korean-native speaker should produce non-empty audio. + func testKoreanLanguage() async throws { + let tts = try await makeCachedTTS(seed: 7) + let result = try await tts.generate( + text: "안녕하세요. 저는 한국어를 말할 수 있습니다.", + speaker: .sohee, + language: .korean, + options: GenerationOptions(temperature: 0.9, maxNewTokens: 120) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + XCTAssertGreaterThan(result.audioDuration, 0.5) + } + + // MARK: - Edge cases + + /// Single-word input should not crash and should produce short audio. + func testSingleWordInput() async throws { + let tts = try await makeCachedTTS(seed: 1) + let result = try await tts.generate( + text: "Hello.", + speaker: .ryan, + language: .english, + options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 60) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + XCTAssertLessThan(result.audioDuration, 5.0, "Single word should be short") + } + + /// Text with numbers, punctuation, currency, and percentages should not crash. + func testNumbersAndPunctuation() async throws { + let tts = try await makeCachedTTS(seed: 3) + let result = try await tts.generate( + text: "On January 1st, 2025, the price was $4.99 - a 12% discount from $5.67.", + speaker: .dylan, + language: .english, + options: GenerationOptions(maxNewTokens: 245) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + } + + /// Text with emoji and mixed scripts (Latin + accented) should not crash. + func testUnicodeAndEmoji() async throws { + let tts = try await makeCachedTTS(seed: 5) + let result = try await tts.generate( + text: "Today's meeting is at 3pm 🎉. The café serves café au lait. Résumé updated.", + speaker: .ryan, + language: .english, + options: GenerationOptions(maxNewTokens: 200) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + } + + // MARK: - WAV export + + /// saveAudio round-trip: write WAV then verify the file exists and is readable. + func testSaveAudioRoundTrip() async throws { + let tts = try await makeCachedTTS(seed: 42) + let result = try await tts.generate( + text: "Audio export round trip test.", + speaker: .ryan, + language: .english, + options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 80) + ) + + let tmpFolder = FileManager.default.temporaryDirectory + let filename = "tts_roundtrip_\(UUID().uuidString)" + let savedURL = try await AudioOutput.saveAudio( + result.audio, + toFolder: tmpFolder, + filename: filename, + sampleRate: result.sampleRate, + format: .wav + ) + defer { try? FileManager.default.removeItem(at: savedURL) } + + XCTAssertTrue(FileManager.default.fileExists(atPath: savedURL.path), "WAV file should exist after saveAudio") + + let duration = try await AudioOutput.duration(of: savedURL) + XCTAssertEqual(duration, result.audioDuration, accuracy: 0.1, "WAV duration should match result.audioDuration") + } + + // MARK: - SpeechModel protocol + + /// TTSKit should satisfy the SpeechModel protocol; generate via the protocol entry point. + func testTTSModelProtocolConformance() async throws { + let tts = try await makeCachedTTS(seed: 42) + let model: any SpeechModel = tts + + XCTAssertEqual(model.sampleRate, 24000) + + let result = try await model.generate( + text: "Protocol conformance check.", + voice: Qwen3Speaker.ryan.rawValue, + language: Qwen3Language.english.rawValue, + options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 80), + callback: nil + ) + + XCTAssertGreaterThan(result.audio.count, 0) + XCTAssertEqual(result.sampleRate, 24000) + } + + // MARK: - Performance + + /// Verify timings are populated and the generation loop completed within a reasonable ceiling. + func testTimingsArePopulated() async throws { + let tts = try await makeCachedTTS(seed: 42) + let result = try await tts.generate( + text: "Performance check: this sentence measures how fast the model runs on device.", + speaker: .ryan, + language: .english, + options: GenerationOptions(temperature: 0.0, topK: 0, maxNewTokens: 200) + ) + + XCTAssertGreaterThan(result.timings.totalDecodingLoops, 0, "Should have completed at least one decode step") + XCTAssertGreaterThan(result.timings.fullPipeline, 0, "Full pipeline duration should be non-zero") + XCTAssertGreaterThan(result.audioDuration, 0, "Audio duration should be non-zero") + } + + // MARK: - Prompt caching + + /// Build a prompt cache and verify it stores the expected metadata. + func testBuildPromptCache() async throws { + let tts = try await makeCachedTTS(seed: 42) + let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english) + + XCTAssertEqual(cache.voice, "ryan") + XCTAssertEqual(cache.language, "english") + XCTAssertNil(cache.instruction) + XCTAssertGreaterThan(cache.prefixLength, 0, "Cache should contain at least one invariant token") + XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: nil)) + XCTAssertFalse(cache.matches(voice: "aiden", language: "english", instruction: nil)) + } + + /// Cached prefill should be faster than uncached since it only processes the variable token. + func testPromptCacheSpeedup() async throws { + let tts = try await makeCachedTTS(seed: 42) + let text = "Prompt cache speed test with enough words to generate meaningful audio output." + let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245) + + // Run without cache — full prefill of all tokens + let uncachedResult = try await tts.createTask().run( + text: text, voice: "ryan", language: "english", + options: options, callback: nil, prefixCache: nil + ) + let uncachedPrefill = uncachedResult.timings.prefill + + // Build cache, then run with it — only the variable token is prefilled + let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english) + let cachedResult = try await tts.createTask().run( + text: text, voice: "ryan", language: "english", + options: options, callback: nil, prefixCache: cache + ) + let cachedPrefill = cachedResult.timings.prefill + + XCTAssertGreaterThan(uncachedResult.audioDuration, 1.0, + "Uncached run should produce at least 1s of audio") + XCTAssertGreaterThan(cachedResult.audioDuration, 1.0, + "Cached run should produce at least 1s of audio") + XCTAssertLessThan(cachedPrefill, uncachedPrefill, + "Cached prefill (\(cachedPrefill * 1000)ms) should be faster than uncached (\(uncachedPrefill * 1000)ms)") + + // Verify auto-build on generate works + tts.promptCache = nil + _ = try await tts.generate( + text: text, speaker: .ryan, language: .english, options: options + ) + XCTAssertNotNil(tts.promptCache, "Cache should be auto-built after generate") + } + + /// Two consecutive cached runs should produce valid audio of similar duration. + func testPromptCacheDeterminism() async throws { + let text = "Cache determinism test with a longer sentence so the model generates real speech." + let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245) + + let tts = try await makeCachedTTS(seed: 42) + let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english) + + let result1 = try await tts.createTask().run( + text: text, voice: "ryan", language: "english", + options: options, callback: nil, prefixCache: cache + ) + let result2 = try await tts.createTask().run( + text: text, voice: "ryan", language: "english", + options: options, callback: nil, prefixCache: cache + ) + + XCTAssertGreaterThan(result1.audioDuration, 1.0, "First cached run should produce at least 1s") + XCTAssertGreaterThan(result2.audioDuration, 1.0, "Second cached run should produce at least 1s") + } + + /// Cache auto-invalidates when voice changes. + func testPromptCacheInvalidationOnVoiceChange() async throws { + let tts = try await makeCachedTTS(seed: 42) + let options = GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245) + + let result1 = try await tts.generate( + text: "First voice generates meaningful speech output.", speaker: .ryan, language: .english, options: options + ) + XCTAssertEqual(tts.promptCache?.voice, "ryan") + XCTAssertGreaterThan(result1.audio.count, 0, "First voice should produce audio") + + let result2 = try await tts.generate( + text: "Second voice also generates meaningful speech output.", speaker: .aiden, language: .english, options: options + ) + XCTAssertEqual(tts.promptCache?.voice, "aiden", + "Cache should auto-rebuild when voice changes") + XCTAssertGreaterThan(result2.audio.count, 0, "Second voice should produce audio") + } + + /// Prompt cache save/load round-trip produces identical generation output. + func testPromptCacheDiskPersistence() async throws { + let tts = try await makeCachedTTS(seed: 42) + let cache = try await tts.buildPromptCache(speaker: .ryan, language: .english) + + let tmpURL = FileManager.default.temporaryDirectory + .appendingPathComponent("tts_cache_test_\(UUID().uuidString).promptcache") + defer { try? FileManager.default.removeItem(at: tmpURL) } + + try cache.save(to: tmpURL) + XCTAssertTrue(FileManager.default.fileExists(atPath: tmpURL.path), "Cache file should exist on disk") + + let loaded = try TTSPromptCache.load(from: tmpURL) + XCTAssertEqual(loaded.voice, cache.voice) + XCTAssertEqual(loaded.language, cache.language) + XCTAssertEqual(loaded.instruction, cache.instruction) + XCTAssertEqual(loaded.prefixLength, cache.prefixLength) + + // Use the loaded cache for generation — should produce valid audio + tts.promptCache = loaded + let result = try await tts.generate( + text: "Disk cache persistence test with enough text for real audio.", + speaker: .ryan, language: .english, + options: GenerationOptions(temperature: 0.9, topK: 50, maxNewTokens: 245) + ) + XCTAssertGreaterThan(result.audioDuration, 1.0, "Loaded cache should produce at least 1s of audio") + } + + /// Chunked generation should benefit from prompt caching across all chunks. + func testPromptCacheWithChunkedGeneration() async throws { + let tts = try await makeCachedTTS(seed: 42) + try await tts.buildPromptCache(speaker: .ryan, language: .english) + + let multiChunkText = """ + First sentence is fairly long to make a chunk. \ + Second sentence adds more content for another chunk. \ + Third sentence provides even more text for splitting. + """ + + let result = try await tts.generate( + text: multiChunkText, + speaker: .ryan, language: .english, + options: GenerationOptions( + concurrentWorkerCount: 1, + targetChunkSize: 15, + minChunkSize: 5 + ) + ) + + XCTAssertGreaterThan(result.audio.count, 0) + XCTAssertGreaterThan(result.audioDuration, 2.0) + } + + // MARK: - Dual inference path + + /// MLTensor path (default, macOS 15+) produces valid audio. + func testMLTensorPathGeneration() async throws { + guard #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) else { + throw XCTSkip("MLTensor path requires macOS 15+ / iOS 18+") + } + let tts = try await makeCachedTTS(seed: 42) + var opts = GenerationOptions(maxNewTokens: 100) + opts.forceLegacyEmbedPath = false + + let result = try await tts.generate( + text: "Testing the MLTensor inference path.", + speaker: .ryan, language: .english, + options: opts + ) + + XCTAssertGreaterThan(result.audio.count, 0, "MLTensor path should produce audio") + XCTAssertGreaterThan(result.audioDuration, 0.5, "Should produce at least 0.5s of speech") + } + + /// Legacy [FloatType] path (forced via forceLegacyEmbedPath) produces valid audio. + func testLegacyEmbedPathGeneration() async throws { + let tts = try await makeCachedTTS(seed: 42) + var opts = GenerationOptions(maxNewTokens: 100) + opts.forceLegacyEmbedPath = true + + let result = try await tts.generate( + text: "Testing the legacy embed inference path.", + speaker: .ryan, language: .english, + options: opts + ) + + XCTAssertGreaterThan(result.audio.count, 0, "Legacy path should produce audio") + XCTAssertGreaterThan(result.audioDuration, 0.5, "Should produce at least 0.5s of speech") + } + + /// Both paths with the same seed should produce audio of similar duration. + /// Floating-point differences between paths mean samples may not be bit-identical. + func testBothPathsProduceSimilarAudioDuration() async throws { + guard #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) else { + throw XCTSkip("Dual-path comparison requires macOS 15+ / iOS 18+") + } + let testText = "Comparing inference paths for audio duration consistency." + let maxNewTokens = 80 + + let ttsMLTensor = try await makeCachedTTS(seed: 42) + var mlTensorOpts = GenerationOptions(maxNewTokens: maxNewTokens) + mlTensorOpts.forceLegacyEmbedPath = false + let mlTensorResult = try await ttsMLTensor.generate( + text: testText, speaker: .ryan, language: .english, options: mlTensorOpts + ) + + let ttsLegacy = try await makeCachedTTS(seed: 42) + var legacyOpts = GenerationOptions(maxNewTokens: maxNewTokens) + legacyOpts.forceLegacyEmbedPath = true + let legacyResult = try await ttsLegacy.generate( + text: testText, speaker: .ryan, language: .english, options: legacyOpts + ) + + XCTAssertGreaterThan(mlTensorResult.audioDuration, 0) + XCTAssertGreaterThan(legacyResult.audioDuration, 0) + // Durations should be within 20% of each other (same number of tokens -> same frames) + let durationRatio = mlTensorResult.audioDuration / legacyResult.audioDuration + XCTAssertGreaterThan(durationRatio, 0.8, "Paths should produce similar audio duration") + XCTAssertLessThan(durationRatio, 1.2, "Paths should produce similar audio duration") + } +} diff --git a/Tests/TTSKitTests/TTSKitUnitTests.swift b/Tests/TTSKitTests/TTSKitUnitTests.swift new file mode 100644 index 00000000..49970e12 --- /dev/null +++ b/Tests/TTSKitTests/TTSKitUnitTests.swift @@ -0,0 +1,942 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import ArgmaxCore +import CoreML +@testable import TTSKit +import XCTest + +final class TTSKitUnitTests: XCTestCase { + + // MARK: - Configuration + + func testTTSKitConfigDefaults() { + let config = TTSKitConfig() + XCTAssertNil(config.modelFolder) + XCTAssertEqual(config.modelRepo, Qwen3TTSConstants.defaultModelRepo) + XCTAssertEqual(config.versionDir, Qwen3TTSConstants.defaultVersionDir) + XCTAssertEqual(config.tokenizerSource, Qwen3TTSConstants.defaultTokenizerRepo) + XCTAssertEqual(config.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder) + XCTAssertEqual(config.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder) + XCTAssertEqual(config.speechDecoderVariant, Qwen3VariantDefaults.speechDecoder) + XCTAssertTrue(config.verbose) + } + + func testTTSComputeOptionsDefaults() { + let opts = ComputeOptions() + XCTAssertEqual(opts.embedderComputeUnits, .cpuOnly) + XCTAssertEqual(opts.codeDecoderComputeUnits, .cpuAndNeuralEngine) + XCTAssertEqual(opts.multiCodeDecoderComputeUnits, .cpuAndNeuralEngine) + XCTAssertEqual(opts.speechDecoderComputeUnits, .cpuAndNeuralEngine) + } + + func testModelPresetResolvesInConfig() { + let small = TTSKitConfig(model: .qwen3TTS_0_6b) + XCTAssertEqual(small.versionDir, TTSModelVariant.qwen3TTS_0_6b.versionDir) + XCTAssertEqual(small.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder) + + let large = TTSKitConfig(model: .qwen3TTS_1_7b) + XCTAssertEqual(large.versionDir, TTSModelVariant.qwen3TTS_1_7b.versionDir) + XCTAssertEqual(large.multiCodeDecoderVariant, Qwen3VariantDefaults.multiCodeDecoder) + } + + func testGenerationOptionsDefaults() { + let opts = GenerationOptions() + XCTAssertEqual(opts.temperature, 0.9) + XCTAssertEqual(opts.topK, 50) + XCTAssertEqual(opts.repetitionPenalty, 1.05) + XCTAssertEqual(opts.maxNewTokens, 245) + XCTAssertNil(opts.chunkingStrategy) + XCTAssertNil(opts.instruction) + XCTAssertNotNil(opts.concurrentWorkerCount) + } + + func testDownloadPatterns() { + let config = TTSKitConfig() + let patterns = config.downloadPatterns + XCTAssertEqual(patterns.count, 6) + XCTAssertTrue(patterns.allSatisfy { $0.hasSuffix("/**") }) + XCTAssertTrue(patterns.contains { $0.contains("code_decoder/") }) + XCTAssertTrue(patterns.contains { $0.contains("speech_decoder/") }) + XCTAssertTrue(patterns.contains { $0.contains("text_projector/") }) + } + + // MARK: - Speaker & Language Enums + + func testSpeakerTokenIds() { + // Each speaker should have a unique token ID + let ids = Qwen3Speaker.allCases.map { $0.tokenID } + XCTAssertEqual(Set(ids).count, Qwen3Speaker.allCases.count, "Speaker token IDs should be unique") + } + + func testLanguageTokenIds() { + let ids = Qwen3Language.allCases.map { $0.tokenID } + XCTAssertEqual(Set(ids).count, Qwen3Language.allCases.count, "Language token IDs should be unique") + XCTAssertEqual(Qwen3Language.english.tokenID, 2050) + } + + // MARK: - Text Chunker + + /// Char-level tokenizer helper: 1 unicode scalar = 1 token, perfectly round-trips. + /// Sizes are exact, making test assertions easy to reason about. + private func makeChunker( + targetChunkSize: Int = TextChunker.defaultTargetChunkSize, + minChunkSize: Int = TextChunker.defaultMinChunkSize + ) -> TextChunker { + TextChunker( + targetChunkSize: targetChunkSize, + minChunkSize: minChunkSize, + encode: { $0.unicodeScalars.map { Int($0.value) } }, + decode: { String($0.compactMap { Unicode.Scalar($0) }.map { Character($0) }) } + ) + } + + func testChunkerShortText() { + let chunker = makeChunker(targetChunkSize: 50, minChunkSize: 5) + XCTAssertEqual(chunker.chunk("Hello world."), ["Hello world."]) + } + + func testChunkerEmptyText() { + let chunker = makeChunker() + XCTAssertTrue(chunker.chunk("").isEmpty) + XCTAssertTrue(chunker.chunk(" ").isEmpty) + } + + func testChunkerSentenceSplitting() { + // With char-level tokenizer (1 char = 1 token), targetChunkSize: 30 gives a 30-char window. + // "This is the first sentence." = 27 chars, so each window contains one sentence boundary. + let chunker = makeChunker(targetChunkSize: 30, minChunkSize: 5) + let text = "This is the first sentence. This is the second sentence. And here is a third one." + let chunks = chunker.chunk(text) + XCTAssertGreaterThan(chunks.count, 1, "Should split into multiple chunks") + let recombined = chunks.joined(separator: " ") + XCTAssertTrue(recombined.contains("first sentence")) + XCTAssertTrue(recombined.contains("third one")) + } + + func testChunkerMergesTinyTrailing() { + // "A reasonably long sentence that is here." = exactly 40 chars = 40 tokens (char-level). + // targetChunkSize: 40 captures the full sentence in one window at the "." boundary. + // The trailing "X." = 2 tokens is below minChunkSize: 5, so it merges with the previous chunk. + let chunker = makeChunker(targetChunkSize: 40, minChunkSize: 5) + let text = "A reasonably long sentence that is here. X." + let chunks = chunker.chunk(text) + XCTAssertEqual(chunks.count, 1, "Tiny trailing chunk should merge with previous") + } + + // MARK: - Embedding Math + + func testZeroEmbed() { + let embed = EmbedUtilities.zeroEmbed(dim: 16) + XCTAssertEqual(embed.count, 16) + XCTAssertTrue(embed.allSatisfy { $0 == 0 }) + } + + func testAddEmbeddings() { + let a: [FloatType] = [FloatType(1.0), FloatType(2.0), FloatType(3.0)] + let b: [FloatType] = [FloatType(4.0), FloatType(5.0), FloatType(6.0)] + let result = EmbedUtilities.addEmbeddings(a, b) + XCTAssertEqual(result.count, 3) + XCTAssertEqual(Float(result[0]), 5.0, accuracy: 0.01) + XCTAssertEqual(Float(result[1]), 7.0, accuracy: 0.01) + XCTAssertEqual(Float(result[2]), 9.0, accuracy: 0.01) + } + + func testSumEmbeddings() { + let embeds: [[FloatType]] = [ + [FloatType(1.0), FloatType(2.0)], + [FloatType(3.0), FloatType(4.0)], + [FloatType(5.0), FloatType(6.0)], + ] + let result = EmbedUtilities.sumEmbeddings(embeds) + XCTAssertEqual(result.count, 2) + XCTAssertEqual(Float(result[0]), 9.0, accuracy: 0.01) + XCTAssertEqual(Float(result[1]), 12.0, accuracy: 0.01) + } + + func testSumEmbeddingsEmpty() { + let result = EmbedUtilities.sumEmbeddings([]) + XCTAssertTrue(result.isEmpty) + } + + func testCreateAndExtractEmbed() throws { + let original: [FloatType] = [FloatType(1.5), FloatType(2.5), FloatType(3.5), FloatType(4.5)] + let arr = try EmbedUtilities.createEmbedMLArray(original) + XCTAssertEqual(arr.shape, [1, 4, 1, 1] as [NSNumber]) + + let extracted = EmbedUtilities.extractEmbed(from: arr) + XCTAssertEqual(extracted.count, 4) + for i in 0..<4 { + XCTAssertEqual(Float(extracted[i]), Float(original[i]), accuracy: 0.01) + } + } + + // MARK: - KV Cache + + func testKVCacheInit() throws { + let cache = try KVCache(cacheDim: 128, maxSeqLength: 32) + XCTAssertEqual(cache.cacheLength, 0) + XCTAssertEqual(cache.maxSeqLength, 32) + XCTAssertEqual(cache.cacheDim, 128) + XCTAssertFalse(cache.isStateful) + XCTAssertFalse(cache.isFull) + XCTAssertEqual(cache.freePositions, 31) // maxSeqLength - 1 + XCTAssertNotNil(cache.keyCache) + XCTAssertNotNil(cache.valueCache) + } + + func testKVCacheStatefulInit() throws { + let cache = try KVCache(cacheDim: 128, maxSeqLength: 32, isStateful: true) + XCTAssertTrue(cache.isStateful) + XCTAssertNil(cache.keyCache) + XCTAssertNil(cache.valueCache) + } + + func testKVCacheUpdateAdvancesPosition() throws { + let cache = try KVCache(cacheDim: 4, maxSeqLength: 8) + XCTAssertEqual(cache.cacheLength, 0) + + cache.update() + XCTAssertEqual(cache.cacheLength, 1) + + cache.update() + XCTAssertEqual(cache.cacheLength, 2) + XCTAssertEqual(cache.freePositions, 5) // 8 - 1 - 2 + } + + func testKVCacheIsFull() throws { + let cache = try KVCache(cacheDim: 4, maxSeqLength: 4) + // maxSeqLength=4, isFull when cacheLength >= 3 (maxSeqLength - 1) + cache.update() + cache.update() + XCTAssertFalse(cache.isFull) + cache.update() + XCTAssertTrue(cache.isFull) + } + + func testKVCacheReset() throws { + let cache = try KVCache(cacheDim: 4, maxSeqLength: 8) + cache.update() + cache.update() + XCTAssertEqual(cache.cacheLength, 2) + + cache.reset() + XCTAssertEqual(cache.cacheLength, 0) + } + + func testSpeechDecoderCacheInit() throws { + let cache = try SpeechDecoderCache( + cacheDim: 64, maxSeqLength: 16, hiddenDim: 32, hiddenContextLen: 4 + ) + XCTAssertEqual(cache.hiddenDim, 32) + XCTAssertEqual(cache.hiddenContextLen, 4) + XCTAssertEqual(cache.hiddenContext.shape, [1, 32, 1, 4] as [NSNumber]) + } + + // MARK: - Sampler + + func testGreedySamplerDeterministic() async throws { + // Two samplers with the same seed should produce the same result + let sampler1 = GreedyTokenSampler(seed: 42) + let sampler2 = GreedyTokenSampler(seed: 42) + + let vocabSize = 32 + let logits = try MLMultiArray(shape: [1, 1, NSNumber(value: vocabSize)], dataType: .float16) + let ptr = logits.dataPointer.bindMemory(to: FloatType.self, capacity: vocabSize) + for i in 0..= 2048 && Qwen3TTSConstants.codecEOS < 3072) + } + + // MARK: - Unloaded Model Errors + + func testCodeDecoderWithoutModel() { + let decoder = Qwen3CodeDecoder() + XCTAssertNil(decoder.model) + XCTAssertFalse(decoder.isStateful) + XCTAssertNil(decoder.makeState()) + } + + func testMultiCodeDecoderWithoutModel() { + let decoder = Qwen3MultiCodeDecoder() + XCTAssertNil(decoder.model) + XCTAssertFalse(decoder.isStateful) + XCTAssertNil(decoder.makeState()) + } + + func testSpeechDecoderWithoutModel() { + let decoder = Qwen3SpeechDecoder() + XCTAssertNil(decoder.model) + } + + // MARK: - Chunking Strategy + + func testChunkingStrategyEnum() { + XCTAssertEqual(TextChunkingStrategy.allCases.count, 2) + XCTAssertEqual(TextChunkingStrategy.none.rawValue, "none") + XCTAssertEqual(TextChunkingStrategy.sentence.rawValue, "sentence") + } + + // MARK: - SeededRNG Determinism + + func testSeededRNGDeterminism() { + var rng1 = SeededRandomNumberGenerator(seed: 99) + var rng2 = SeededRandomNumberGenerator(seed: 99) + for _ in 0..<100 { + XCTAssertEqual(rng1.next(), rng2.next()) + } + } + + func testSeededRNGDifferentSeeds() { + var rng1 = SeededRandomNumberGenerator(seed: 1) + var rng2 = SeededRandomNumberGenerator(seed: 2) + // Very unlikely for all 10 values to match with different seeds + var allMatch = true + for _ in 0..<10 { + if rng1.next() != rng2.next() { + allMatch = false + break + } + } + XCTAssertFalse(allMatch, "Different seeds should produce different sequences") + } + + // MARK: - SpeechProgress + + func testSpeechProgressInit() { + let samples: [Float] = [0.1, 0.2, 0.3] + let timings = SpeechTimings() + let progress = SpeechProgress(audio: samples, timings: timings, stepTime: 0.08) + XCTAssertEqual(progress.audio, samples) + XCTAssertEqual(progress.timings.fullPipeline, 0) + XCTAssertEqual(progress.stepTime ?? 0, 0.08, accuracy: 0.0001) + } + + func testSpeechProgressStepTimeNilByDefault() { + let progress = SpeechProgress(audio: [], timings: SpeechTimings()) + XCTAssertNil(progress.stepTime) + } + + func testSpeechProgressFirstStepSemantics() { + // stepTime non-nil signals first step; subsequent steps have nil + let first = SpeechProgress(audio: [0.1], timings: SpeechTimings(), stepTime: 0.05) + let subsequent = SpeechProgress(audio: [0.2], timings: SpeechTimings(), stepTime: nil) + XCTAssertNotNil(first.stepTime) + XCTAssertNil(subsequent.stepTime) + } + + // MARK: - PlaybackStrategy + + func testAudioPerStep() { + let spf = Qwen3TTSConstants.samplesPerFrame + let sr = Qwen3TTSConstants.sampleRate + let expected = Double(spf) / Double(sr) + XCTAssertEqual(PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr), expected, accuracy: 0.0001) + // ~80ms per frame at 24kHz / 1920 samples + XCTAssertEqual(PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr), 0.08, accuracy: 0.001) + } + + func testRequiredBufferFastDevice() { + // Step completes in half the frame duration -> device is 2x real-time + // deficit = max(0, 1 - 2.0) = 0, so result = minimumBufferDuration + let spf = Qwen3TTSConstants.samplesPerFrame + let sr = Qwen3TTSConstants.sampleRate + let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr) / 2.0 + let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 100, samplesPerFrame: spf, sampleRate: sr) + XCTAssertEqual(buffer, PlaybackStrategy.minimumBufferDuration, accuracy: 0.001) + } + + func testRequiredBufferSlowDevice() { + // Step takes 2x the frame duration -> device is at 0.5x real-time + // speedRatio = 0.5, deficit = 0.5, maxAudio = 100 * 0.08 = 8s + // deficitBuffer = 8 * 0.5 = 4s > minimumBufferDuration + let spf = Qwen3TTSConstants.samplesPerFrame + let sr = Qwen3TTSConstants.sampleRate + let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr) * 2.0 + let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 100, samplesPerFrame: spf, sampleRate: sr) + XCTAssertGreaterThan(buffer, PlaybackStrategy.minimumBufferDuration) + // Exact: 100 * 0.08 * 0.5 = 4.0s + XCTAssertEqual(buffer, 4.0, accuracy: 0.01) + } + + func testRequiredBufferAtExactRealTime() { + // Step equals frame duration -> speedRatio = 1, deficit = 0 -> minimum clamp applies + let spf = Qwen3TTSConstants.samplesPerFrame + let sr = Qwen3TTSConstants.sampleRate + let stepTime = PlaybackStrategy.audioPerStep(samplesPerFrame: spf, sampleRate: sr) + let buffer = PlaybackStrategy.requiredBuffer(stepTime: stepTime, maxNewTokens: 50, samplesPerFrame: spf, sampleRate: sr) + XCTAssertEqual(buffer, PlaybackStrategy.minimumBufferDuration, accuracy: 0.001) + } + + func testRequiredBufferMinimumNeverExceeded() { + // Even a very fast device should never return less than the minimum + let buffer = PlaybackStrategy.requiredBuffer( + stepTime: 0.001, maxNewTokens: 200, + samplesPerFrame: Qwen3TTSConstants.samplesPerFrame, sampleRate: Qwen3TTSConstants.sampleRate + ) + XCTAssertGreaterThanOrEqual(buffer, PlaybackStrategy.minimumBufferDuration) + } + + // MARK: - TTSModelVariant + + func testModelPresetDisplayNames() { + XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.displayName, "Qwen3 TTS 0.6B") + XCTAssertEqual(TTSModelVariant.qwen3TTS_1_7b.displayName, "Qwen3 TTS 1.7B") + } + + func testModelPresetSupportsVoiceDirection() { + XCTAssertFalse(TTSModelVariant.qwen3TTS_0_6b.supportsVoiceDirection) + XCTAssertTrue(TTSModelVariant.qwen3TTS_1_7b.supportsVoiceDirection) + } + + func testModelPresetVersionDirsDiffer() { + XCTAssertNotEqual( + TTSModelVariant.qwen3TTS_0_6b.versionDir, + TTSModelVariant.qwen3TTS_1_7b.versionDir + ) + } + + func testModelPresetAvailabilityOnMacOS() { + // On macOS, all presets should be available + #if os(macOS) + for preset in TTSModelVariant.allCases { + XCTAssertTrue(preset.isAvailableOnCurrentPlatform, "\(preset) should be available on macOS") + } + #else + XCTAssertTrue(TTSModelVariant.qwen3TTS_0_6b.isAvailableOnCurrentPlatform) + XCTAssertFalse(TTSModelVariant.qwen3TTS_1_7b.isAvailableOnCurrentPlatform) + #endif + } + + func testModelPresetVariantDefaultsConsistent() { + // Both presets share the same variant strings (all quantization is size-independent) + XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder) + XCTAssertEqual(TTSModelVariant.qwen3TTS_1_7b.codeDecoderVariant, Qwen3VariantDefaults.codeDecoder) + XCTAssertEqual(TTSModelVariant.qwen3TTS_0_6b.speechDecoderVariant, Qwen3VariantDefaults.speechDecoder) + } + + // MARK: - TTSKitConfig Component Overrides + + func testTTSKitConfigComponentOverridesNilByDefault() { + let config = TTSKitConfig() + XCTAssertNil(config.textProjector) + XCTAssertNil(config.codeEmbedder) + XCTAssertNil(config.multiCodeEmbedder) + XCTAssertNil(config.codeDecoder) + XCTAssertNil(config.multiCodeDecoder) + XCTAssertNil(config.speechDecoder) + } + + // MARK: - GenerationOptions Additional Defaults + + func testGenerationOptionsChunkingDefaults() { + let opts = GenerationOptions() + // nil defers to TextChunker.defaultTargetChunkSize / defaultMinChunkSize at call site + XCTAssertNil(opts.targetChunkSize) + XCTAssertNil(opts.minChunkSize) + // Verify the canonical defaults live in TextChunker + XCTAssertEqual(TextChunker.defaultTargetChunkSize, 42) + XCTAssertEqual(TextChunker.defaultMinChunkSize, 10) + } + + // MARK: - Speaker & Language Round-trips + + func testSpeakerRawValueRoundTrip() { + for speaker in Qwen3Speaker.allCases { + let roundTripped = Qwen3Speaker(rawValue: speaker.rawValue) + XCTAssertEqual(roundTripped, speaker, "rawValue round-trip failed for \(speaker)") + } + } + + func testLanguageRawValueRoundTrip() { + for lang in Qwen3Language.allCases { + let roundTripped = Qwen3Language(rawValue: lang.rawValue) + XCTAssertEqual(roundTripped, lang, "rawValue round-trip failed for \(lang)") + } + } + + func testUnrecognisedSpeakerFallsBack() { + XCTAssertNil(Qwen3Speaker(rawValue: "nonexistent_speaker")) + } + + func testUnrecognisedLanguageFallsBack() { + XCTAssertNil(Qwen3Language(rawValue: "klingon")) + } + + // MARK: - SpeechTimings.merge + + func testMergeTimingsAccumulates() { + var combined = SpeechTimings() + combined.decodingLoop = 1.0 + combined.totalDecodingLoops = 10 + combined.decodingPredictions = 0.5 + + var chunk = SpeechTimings() + chunk.decodingLoop = 2.0 + chunk.totalDecodingLoops = 20 + chunk.decodingPredictions = 0.3 + chunk.multiCodeDecoderPredictions = 0.1 + chunk.speechDecoderPredictions = 0.2 + + combined.merge(chunk) + + XCTAssertEqual(combined.decodingLoop, 3.0, accuracy: 0.001) + XCTAssertEqual(combined.totalDecodingLoops, 30, accuracy: 0.001) + XCTAssertEqual(combined.decodingPredictions, 0.8, accuracy: 0.001) + XCTAssertEqual(combined.multiCodeDecoderPredictions, 0.1, accuracy: 0.001) + XCTAssertEqual(combined.speechDecoderPredictions, 0.2, accuracy: 0.001) + } + + func testMergeTimingsIdentity() async throws { + var combined = SpeechTimings() + combined.decodingLoop = 5.0 + let empty = SpeechTimings() + combined.merge(empty) + XCTAssertEqual(combined.decodingLoop, 5.0, accuracy: 0.001) + } + + // MARK: - TextChunker Edge Cases + + func testChunkerPreservesAllText() { + // Each "Sentence number N is here." ≈ 26-27 chars; targetChunkSize: 50 spans ~2 sentences. + let chunker = makeChunker(targetChunkSize: 50, minChunkSize: 5) + let sentences = (1...10).map { "Sentence number \($0) is here." } + let text = sentences.joined(separator: " ") + let chunks = chunker.chunk(text) + let rejoined = chunks.joined(separator: " ") + for sentence in sentences { + XCTAssertTrue(rejoined.contains(sentence.dropLast(1)), // drop period; joins may vary + "Missing content: \(sentence)") + } + } + + func testChunkerWordBoundaryFallback() { + // No punctuation in text - chunker must fall back to word-boundary splits. + // With char-level tokenizer, targetChunkSize: 15 gives 15-char windows. + // Note: multi-word phrase checks are intentionally avoided here because the + // re-encode advance can leave a stray char when the token stream has a leading + // space (e.g. "very long" → 9 tokens, but the stream starts with " very lon"). + // This is a char-level mock artifact; BPE tokenizers fold whitespace into the + // next word token, so no drift occurs in production. + let chunker = makeChunker(targetChunkSize: 15, minChunkSize: 3) + let text = "This is a very long text with no punctuation inside it at all here" + let chunks = chunker.chunk(text) + XCTAssertGreaterThan(chunks.count, 1, "Long text without punctuation should split at word boundaries") + let rejoined = chunks.joined(separator: " ") + for word in ["long", "text", "punctuation", "here"] { + XCTAssertTrue(rejoined.contains(word), "Missing word in output: \(word)") + } + } + + // MARK: - Prompt Cache + + func testPromptCacheMatching() { + let cache = TTSPromptCache( + voice: "ryan", language: "english", instruction: nil, + prefixLength: 9, + kvSnapshot: KVCacheSnapshot( + isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0, + keyCacheData: Data(), valueCacheData: Data(), + updateMaskData: Data(), paddingMaskData: Data() + ), + stateData: nil + ) + + XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: nil)) + XCTAssertFalse(cache.matches(voice: "aiden", language: "english", instruction: nil)) + XCTAssertFalse(cache.matches(voice: "ryan", language: "korean", instruction: nil)) + XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: "Speak softly")) + } + + func testPromptCacheMatchingWithInstruction() { + let cache = TTSPromptCache( + voice: "ryan", language: "english", instruction: "Speak softly", + prefixLength: 20, + kvSnapshot: KVCacheSnapshot( + isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0, + keyCacheData: Data(), valueCacheData: Data(), + updateMaskData: Data(), paddingMaskData: Data() + ), + stateData: nil + ) + + XCTAssertTrue(cache.matches(voice: "ryan", language: "english", instruction: "Speak softly")) + XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: nil)) + XCTAssertFalse(cache.matches(voice: "ryan", language: "english", instruction: "Speak loudly")) + } + + func testPromptCacheFileName() { + let cache1 = TTSPromptCache( + voice: "ryan", language: "english", instruction: nil, + prefixLength: 9, + kvSnapshot: KVCacheSnapshot( + isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0, + keyCacheData: Data(), valueCacheData: Data(), + updateMaskData: Data(), paddingMaskData: Data() + ), + stateData: nil + ) + XCTAssertEqual(cache1.cacheFileName, "ryan_english.promptcache") + + let cache2 = TTSPromptCache( + voice: "aiden", language: "korean", instruction: "Speak slowly", + prefixLength: 20, + kvSnapshot: KVCacheSnapshot( + isStateful: false, cacheDim: 1, maxSeqLength: 1, cacheLength: 0, + keyCacheData: Data(), valueCacheData: Data(), + updateMaskData: Data(), paddingMaskData: Data() + ), + stateData: nil + ) + XCTAssertTrue(cache2.cacheFileName.hasPrefix("aiden_korean_")) + XCTAssertTrue(cache2.cacheFileName.hasSuffix(".promptcache")) + } + + func testKVCacheSnapshotRoundTrip() throws { + let dim = 4 + let seq = 8 + let cache = try KVCache(cacheDim: dim, maxSeqLength: seq, isStateful: false) + + // Simulate 3 prefill steps by advancing position and writing dummy data + for step in 0..<3 { + if let keyCache = cache.keyCache { + let ptr = keyCache.dataPointer.bindMemory(to: FloatType.self, capacity: dim * seq) + for d in 0..