diff --git a/Sources/FluidAudio/Diarizer/Core/DiarizerTypes.swift b/Sources/FluidAudio/Diarizer/Core/DiarizerTypes.swift index 2c290e28e..92eab60db 100644 --- a/Sources/FluidAudio/Diarizer/Core/DiarizerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Core/DiarizerTypes.swift @@ -118,22 +118,69 @@ public struct PipelineTimings: Sendable, Codable { } } +/// Per-chunk speaker embedding produced during offline diarization, surfaced for +/// downstream consumers that need finer-grained data than `TimedSpeakerSegment` +/// (e.g. cluster-level contamination correction). One entry per (segmentation +/// chunk, powerset speaker slot) emitted by the embedding extraction step. +/// +/// The `embedding256` field carries the L2-normalized speaker embedding; +/// `rho128` carries the PLDA-whitened representation when a PLDA model is +/// loaded (un-normalized; magnitude carries confidence). `rho128` is empty +/// when no PLDA model is available. +public struct ChunkEmbedding: Sendable, Codable { + /// Cluster identifier matching `DiarizationResult.segments[*].speakerId` + /// (formatted as "S1", "S2", ...). Use this to align chunks back to their + /// assigned cluster. + public let speakerId: String + public let chunkIndex: Int + public let speakerIndex: Int + public let startTimeSeconds: Double + public let endTimeSeconds: Double + public let embedding256: [Float] + public let rho128: [Double] + + public init( + speakerId: String, + chunkIndex: Int, + speakerIndex: Int, + startTimeSeconds: Double, + endTimeSeconds: Double, + embedding256: [Float], + rho128: [Double] = [] + ) { + self.speakerId = speakerId + self.chunkIndex = chunkIndex + self.speakerIndex = speakerIndex + self.startTimeSeconds = startTimeSeconds + self.endTimeSeconds = endTimeSeconds + self.embedding256 = embedding256 + self.rho128 = rho128 + } +} + public struct DiarizationResult: Sendable { public let segments: [TimedSpeakerSegment] /// Speaker database with embeddings (populated by offline pipelines for downstream use) public let speakerDatabase: [String: [Float]]? + /// Per-chunk speaker embeddings with cluster assignments. Populated by + /// offline pipelines when `OfflineDiarizerConfig.exposeChunkEmbeddings` + /// is enabled; nil otherwise. See `ChunkEmbedding` for field semantics. + public let chunkEmbeddings: [ChunkEmbedding]? + /// Performance timings collected during diarization public let timings: PipelineTimings? public init( segments: [TimedSpeakerSegment], speakerDatabase: [String: [Float]]? = nil, + chunkEmbeddings: [ChunkEmbedding]? = nil, timings: PipelineTimings? = nil ) { self.segments = segments self.speakerDatabase = speakerDatabase + self.chunkEmbeddings = chunkEmbeddings self.timings = timings } } diff --git a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerManager.swift b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerManager.swift index 14ceb14a7..e7f774582 100644 --- a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerManager.swift +++ b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerManager.swift @@ -315,6 +315,15 @@ public final class OfflineDiarizerManager { ) } + let publicChunkEmbeddings: [ChunkEmbedding]? = + config.exposeChunkEmbeddings + ? Self.buildPublicChunkEmbeddings( + timedEmbeddings: timedEmbeddings, + assignments: assignments, + logger: logger + ) + : nil + let totalProcessing = Date().timeIntervalSince(totalStart) let timings = PipelineTimings( modelCompilationSeconds: models.compilationDuration, @@ -328,10 +337,48 @@ public final class OfflineDiarizerManager { return DiarizationResult( segments: segments, speakerDatabase: speakerDatabase, + chunkEmbeddings: publicChunkEmbeddings, timings: timings ) } + /// Map the internal `[TimedEmbedding] + assignments` pair to the public + /// `[ChunkEmbedding]` representation. Speaker IDs follow the same + /// "S\(cluster + 1)" convention used by `OfflineReconstruction.buildSegments`, + /// so chunk embeddings can be aligned to `DiarizationResult.segments[*].speakerId` + /// by string equality. + /// + /// Returns `[]` if the input arrays disagree on length — this is treated as + /// a logged invariant violation so an unexpected mismatch surfaces in + /// production logs rather than silently breaking the public API contract. + /// + /// `internal` so unit tests in `OfflineModuleTests` can exercise the + /// mapping without needing a full pipeline run. + static func buildPublicChunkEmbeddings( + timedEmbeddings: [TimedEmbedding], + assignments: [Int], + logger: AppLogger + ) -> [ChunkEmbedding] { + guard timedEmbeddings.count == assignments.count else { + logger.warning( + "buildPublicChunkEmbeddings: timedEmbeddings.count (\(timedEmbeddings.count)) " + + "!= assignments.count (\(assignments.count)); chunkEmbeddings will be empty" + ) + return [] + } + return zip(timedEmbeddings, assignments).map { te, cluster in + ChunkEmbedding( + speakerId: "S\(cluster + 1)", + chunkIndex: te.chunkIndex, + speakerIndex: te.speakerIndex, + startTimeSeconds: te.startTime, + endTimeSeconds: te.endTime, + embedding256: te.embedding256, + rho128: te.rho128 + ) + } + } + private func purgeDiarizerRepo(at baseDirectory: URL) throws { let repoDirectory = baseDirectory.appendingPathComponent( Repo.diarizer.folderName, diff --git a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift index ba3724390..b8f34408b 100644 --- a/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Offline/Core/OfflineDiarizerTypes.swift @@ -216,13 +216,20 @@ public struct OfflineDiarizerConfig: Sendable { public var postProcessing: PostProcessing public var export: Export + /// When true, populate `DiarizationResult.chunkEmbeddings` with per-chunk + /// speaker embeddings + cluster assignments. Off by default to avoid the + /// extra memory footprint (~1–2 MB per hour of audio for the embedding + + /// PLDA payload) for callers that don't need them. + public var exposeChunkEmbeddings: Bool + public init( segmentation: Segmentation = .community, embedding: Embedding = .community, clustering: Clustering = .community, vbx: VBx = .community, postProcessing: PostProcessing = .community, - export: Export = .none + export: Export = .none, + exposeChunkEmbeddings: Bool = false ) { self.segmentation = segmentation self.embedding = embedding @@ -230,6 +237,7 @@ public struct OfflineDiarizerConfig: Sendable { self.vbx = vbx self.postProcessing = postProcessing self.export = export + self.exposeChunkEmbeddings = exposeChunkEmbeddings } public init( diff --git a/Tests/FluidAudioTests/Diarizer/Offline/OfflineModuleTests.swift b/Tests/FluidAudioTests/Diarizer/Offline/OfflineModuleTests.swift index 852024fd5..03fe11eae 100644 --- a/Tests/FluidAudioTests/Diarizer/Offline/OfflineModuleTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Offline/OfflineModuleTests.swift @@ -107,6 +107,189 @@ final class OfflineTypesTests: XCTestCase { } } +/// Coverage for the per-chunk embedding exposure surface added to surface +/// fine-grained data for downstream cluster-purity correction. Exercises the +/// public API contract (default off, opt-in on) plus the internal mapping +/// helper without requiring a full pipeline run. +@available(macOS 13.0, iOS 16.0, *) +final class ChunkEmbeddingExposureTests: XCTestCase { + + func testExposeChunkEmbeddingsDefaultsToFalse() { + let config = OfflineDiarizerConfig() + XCTAssertFalse( + config.exposeChunkEmbeddings, + "Per-chunk embedding exposure must be opt-in to avoid imposing the " + + "memory cost on existing callers." + ) + } + + func testExposeChunkEmbeddingsCanBeEnabledAndPersisted() { + var config = OfflineDiarizerConfig() + XCTAssertFalse(config.exposeChunkEmbeddings) + config.exposeChunkEmbeddings = true + XCTAssertTrue(config.exposeChunkEmbeddings) + } + + func testDiarizationResultChunkEmbeddingsDefaultsToNil() { + let result = DiarizationResult(segments: []) + XCTAssertNil(result.chunkEmbeddings) + XCTAssertNil(result.speakerDatabase) + XCTAssertNil(result.timings) + XCTAssertTrue(result.segments.isEmpty) + } + + func testDiarizationResultPreservesProvidedChunkEmbeddings() { + let chunk = ChunkEmbedding( + speakerId: "S1", + chunkIndex: 0, + speakerIndex: 0, + startTimeSeconds: 0.0, + endTimeSeconds: 1.6, + embedding256: [Float](repeating: 0.1, count: 256) + // rho128 omitted — defaults to [], representing "no PLDA available" + ) + + let result = DiarizationResult( + segments: [], + speakerDatabase: nil, + chunkEmbeddings: [chunk], + timings: nil + ) + + XCTAssertEqual(result.chunkEmbeddings?.count, 1) + XCTAssertEqual(result.chunkEmbeddings?.first?.speakerId, "S1") + XCTAssertEqual(result.chunkEmbeddings?.first?.startTimeSeconds, 0.0) + XCTAssertEqual(result.chunkEmbeddings?.first?.endTimeSeconds, 1.6) + XCTAssertEqual(result.chunkEmbeddings?.first?.embedding256.count, 256) + XCTAssertEqual( + result.chunkEmbeddings?.first?.rho128, [], + "Default rho128 should be empty when no PLDA payload is provided." + ) + } + + func testChunkEmbeddingCodableRoundTrip() throws { + let original = ChunkEmbedding( + speakerId: "S3", + chunkIndex: 7, + speakerIndex: 2, + startTimeSeconds: 5.5, + endTimeSeconds: 7.1, + embedding256: (0..<256).map { Float($0) / 255.0 }, + rho128: (0..<128).map { Double($0) / 127.0 } + ) + + let encoded = try JSONEncoder().encode(original) + let decoded = try JSONDecoder().decode(ChunkEmbedding.self, from: encoded) + + XCTAssertEqual(decoded.speakerId, original.speakerId) + XCTAssertEqual(decoded.chunkIndex, original.chunkIndex) + XCTAssertEqual(decoded.speakerIndex, original.speakerIndex) + XCTAssertEqual(decoded.startTimeSeconds, original.startTimeSeconds) + XCTAssertEqual(decoded.endTimeSeconds, original.endTimeSeconds) + XCTAssertEqual(decoded.embedding256, original.embedding256) + XCTAssertEqual(decoded.rho128, original.rho128) + } + + func testChunkEmbeddingFieldsRoundTrip() { + let embedding = (0..<256).map { Float($0) / 255.0 } + let rho = (0..<128).map { Double($0) / 127.0 } + let chunk = ChunkEmbedding( + speakerId: "S2", + chunkIndex: 42, + speakerIndex: 1, + startTimeSeconds: 12.34, + endTimeSeconds: 13.94, + embedding256: embedding, + rho128: rho + ) + + XCTAssertEqual(chunk.speakerId, "S2") + XCTAssertEqual(chunk.chunkIndex, 42) + XCTAssertEqual(chunk.speakerIndex, 1) + XCTAssertEqual(chunk.startTimeSeconds, 12.34, accuracy: 1e-9) + XCTAssertEqual(chunk.endTimeSeconds, 13.94, accuracy: 1e-9) + XCTAssertEqual(chunk.embedding256, embedding) + XCTAssertEqual(chunk.rho128, rho) + } + + func testBuildPublicChunkEmbeddingsAssignsSpeakerIdsViaClusterPlusOne() { + let logger = AppLogger(category: "ChunkEmbeddingExposureTests") + let timed: [TimedEmbedding] = [ + TimedEmbedding( + chunkIndex: 0, speakerIndex: 0, startFrame: 0, endFrame: 588, + frameWeights: [], startTime: 0.0, endTime: 1.6, + embedding256: [Float](repeating: 0.0, count: 4), + rho128: [] + ), + TimedEmbedding( + chunkIndex: 1, speakerIndex: 0, startFrame: 0, endFrame: 588, + frameWeights: [], startTime: 1.6, endTime: 3.2, + embedding256: [Float](repeating: 0.5, count: 4), + rho128: [] + ), + TimedEmbedding( + chunkIndex: 2, speakerIndex: 1, startFrame: 0, endFrame: 588, + frameWeights: [], startTime: 3.2, endTime: 4.8, + embedding256: [Float](repeating: 1.0, count: 4), + rho128: [] + ), + ] + let assignments = [0, 2, 1] + + let result = OfflineDiarizerManager.buildPublicChunkEmbeddings( + timedEmbeddings: timed, + assignments: assignments, + logger: logger + ) + + XCTAssertEqual(result.count, 3) + XCTAssertEqual(result.map { $0.speakerId }, ["S1", "S3", "S2"]) + XCTAssertEqual(result[0].startTimeSeconds, 0.0, accuracy: 1e-9) + XCTAssertEqual(result[1].startTimeSeconds, 1.6, accuracy: 1e-9) + XCTAssertEqual(result[2].startTimeSeconds, 3.2, accuracy: 1e-9) + // Confirm all chunk-level fields propagate + XCTAssertEqual(result[0].chunkIndex, 0) + XCTAssertEqual(result[1].chunkIndex, 1) + XCTAssertEqual(result[2].chunkIndex, 2) + XCTAssertEqual(result[0].speakerIndex, 0) + XCTAssertEqual(result[2].speakerIndex, 1) + } + + func testBuildPublicChunkEmbeddingsReturnsEmptyOnLengthMismatch() { + let logger = AppLogger(category: "ChunkEmbeddingExposureTests") + let timed: [TimedEmbedding] = [ + TimedEmbedding( + chunkIndex: 0, speakerIndex: 0, startFrame: 0, endFrame: 588, + frameWeights: [], startTime: 0.0, endTime: 1.6, + embedding256: [], rho128: [] + ) + ] + // Mismatched: 1 timed embedding but 2 assignments + let assignments = [0, 0] + + let result = OfflineDiarizerManager.buildPublicChunkEmbeddings( + timedEmbeddings: timed, + assignments: assignments, + logger: logger + ) + + XCTAssertTrue( + result.isEmpty, + "Mismatched lengths should produce an empty result rather than a partial mapping." + ) + } + + func testBuildPublicChunkEmbeddingsHandlesEmptyInput() { + let logger = AppLogger(category: "ChunkEmbeddingExposureTests") + let result = OfflineDiarizerManager.buildPublicChunkEmbeddings( + timedEmbeddings: [], + assignments: [], + logger: logger + ) + XCTAssertTrue(result.isEmpty) + } +} + @available(macOS 13.0, iOS 16.0, *) final class ModelWarmupTests: XCTestCase {