Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions Sources/FluidAudio/Diarizer/Core/DiarizerTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,69 @@ public struct PipelineTimings: Sendable, Codable {
}
}

/// Per-chunk speaker embedding produced during offline diarization, surfaced for
/// downstream consumers that need finer-grained data than `TimedSpeakerSegment`
/// (e.g. cluster-level contamination correction). One entry per (segmentation
/// chunk, powerset speaker slot) emitted by the embedding extraction step.
///
/// The `embedding256` field carries the L2-normalized speaker embedding;
/// `rho128` carries the PLDA-whitened representation when a PLDA model is
/// loaded (un-normalized; magnitude carries confidence). `rho128` is empty
/// when no PLDA model is available.
public struct ChunkEmbedding: Sendable, Codable {
/// Cluster identifier matching `DiarizationResult.segments[*].speakerId`
/// (formatted as "S1", "S2", ...). Use this to align chunks back to their
/// assigned cluster.
public let speakerId: String
public let chunkIndex: Int
public let speakerIndex: Int
public let startTimeSeconds: Double
public let endTimeSeconds: Double
public let embedding256: [Float]
public let rho128: [Double]

public init(
speakerId: String,
chunkIndex: Int,
speakerIndex: Int,
startTimeSeconds: Double,
endTimeSeconds: Double,
embedding256: [Float],
rho128: [Double] = []
) {
self.speakerId = speakerId
self.chunkIndex = chunkIndex
self.speakerIndex = speakerIndex
self.startTimeSeconds = startTimeSeconds
self.endTimeSeconds = endTimeSeconds
self.embedding256 = embedding256
self.rho128 = rho128
}
}

public struct DiarizationResult: Sendable {
public let segments: [TimedSpeakerSegment]

/// Speaker database with embeddings (populated by offline pipelines for downstream use)
public let speakerDatabase: [String: [Float]]?

/// Per-chunk speaker embeddings with cluster assignments. Populated by
/// offline pipelines when `OfflineDiarizerConfig.exposeChunkEmbeddings`
/// is enabled; nil otherwise. See `ChunkEmbedding` for field semantics.
public let chunkEmbeddings: [ChunkEmbedding]?

/// Performance timings collected during diarization
public let timings: PipelineTimings?

public init(
segments: [TimedSpeakerSegment],
speakerDatabase: [String: [Float]]? = nil,
chunkEmbeddings: [ChunkEmbedding]? = nil,
timings: PipelineTimings? = nil
) {
self.segments = segments
self.speakerDatabase = speakerDatabase
self.chunkEmbeddings = chunkEmbeddings
self.timings = timings
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ public final class OfflineDiarizerManager {
)
}

let publicChunkEmbeddings: [ChunkEmbedding]? =
config.exposeChunkEmbeddings
? Self.buildPublicChunkEmbeddings(
timedEmbeddings: timedEmbeddings,
assignments: assignments,
logger: logger
)
: nil

let totalProcessing = Date().timeIntervalSince(totalStart)
let timings = PipelineTimings(
modelCompilationSeconds: models.compilationDuration,
Expand All @@ -328,10 +337,48 @@ public final class OfflineDiarizerManager {
return DiarizationResult(
segments: segments,
speakerDatabase: speakerDatabase,
chunkEmbeddings: publicChunkEmbeddings,
timings: timings
)
}

/// Map the internal `[TimedEmbedding] + assignments` pair to the public
/// `[ChunkEmbedding]` representation. Speaker IDs follow the same
/// "S\(cluster + 1)" convention used by `OfflineReconstruction.buildSegments`,
/// so chunk embeddings can be aligned to `DiarizationResult.segments[*].speakerId`
/// by string equality.
///
/// Returns `[]` if the input arrays disagree on length — this is treated as
/// a logged invariant violation so an unexpected mismatch surfaces in
/// production logs rather than silently breaking the public API contract.
///
/// `internal` so unit tests in `OfflineModuleTests` can exercise the
/// mapping without needing a full pipeline run.
static func buildPublicChunkEmbeddings(
timedEmbeddings: [TimedEmbedding],
assignments: [Int],
logger: AppLogger
) -> [ChunkEmbedding] {
guard timedEmbeddings.count == assignments.count else {
logger.warning(
"buildPublicChunkEmbeddings: timedEmbeddings.count (\(timedEmbeddings.count)) "
+ "!= assignments.count (\(assignments.count)); chunkEmbeddings will be empty"
)
return []
}
return zip(timedEmbeddings, assignments).map { te, cluster in
ChunkEmbedding(
speakerId: "S\(cluster + 1)",
chunkIndex: te.chunkIndex,
speakerIndex: te.speakerIndex,
startTimeSeconds: te.startTime,
endTimeSeconds: te.endTime,
embedding256: te.embedding256,
rho128: te.rho128
)
}
}

private func purgeDiarizerRepo(at baseDirectory: URL) throws {
let repoDirectory = baseDirectory.appendingPathComponent(
Repo.diarizer.folderName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,28 @@ public struct OfflineDiarizerConfig: Sendable {
public var postProcessing: PostProcessing
public var export: Export

/// When true, populate `DiarizationResult.chunkEmbeddings` with per-chunk
/// speaker embeddings + cluster assignments. Off by default to avoid the
/// extra memory footprint (~1–2 MB per hour of audio for the embedding +
/// PLDA payload) for callers that don't need them.
public var exposeChunkEmbeddings: Bool

public init(
segmentation: Segmentation = .community,
embedding: Embedding = .community,
clustering: Clustering = .community,
vbx: VBx = .community,
postProcessing: PostProcessing = .community,
export: Export = .none
export: Export = .none,
exposeChunkEmbeddings: Bool = false
) {
self.segmentation = segmentation
self.embedding = embedding
self.clustering = clustering
self.vbx = vbx
self.postProcessing = postProcessing
self.export = export
self.exposeChunkEmbeddings = exposeChunkEmbeddings
}

public init(
Expand Down
183 changes: 183 additions & 0 deletions Tests/FluidAudioTests/Diarizer/Offline/OfflineModuleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,189 @@ final class OfflineTypesTests: XCTestCase {
}
}

/// Coverage for the per-chunk embedding exposure surface added to surface
/// fine-grained data for downstream cluster-purity correction. Exercises the
/// public API contract (default off, opt-in on) plus the internal mapping
/// helper without requiring a full pipeline run.
@available(macOS 13.0, iOS 16.0, *)
final class ChunkEmbeddingExposureTests: XCTestCase {

func testExposeChunkEmbeddingsDefaultsToFalse() {
let config = OfflineDiarizerConfig()
XCTAssertFalse(
config.exposeChunkEmbeddings,
"Per-chunk embedding exposure must be opt-in to avoid imposing the "
+ "memory cost on existing callers."
)
}

func testExposeChunkEmbeddingsCanBeEnabledAndPersisted() {
var config = OfflineDiarizerConfig()
XCTAssertFalse(config.exposeChunkEmbeddings)
config.exposeChunkEmbeddings = true
XCTAssertTrue(config.exposeChunkEmbeddings)
}

func testDiarizationResultChunkEmbeddingsDefaultsToNil() {
let result = DiarizationResult(segments: [])
XCTAssertNil(result.chunkEmbeddings)
XCTAssertNil(result.speakerDatabase)
XCTAssertNil(result.timings)
XCTAssertTrue(result.segments.isEmpty)
}

func testDiarizationResultPreservesProvidedChunkEmbeddings() {
let chunk = ChunkEmbedding(
speakerId: "S1",
chunkIndex: 0,
speakerIndex: 0,
startTimeSeconds: 0.0,
endTimeSeconds: 1.6,
embedding256: [Float](repeating: 0.1, count: 256)
// rho128 omitted — defaults to [], representing "no PLDA available"
)

let result = DiarizationResult(
segments: [],
speakerDatabase: nil,
chunkEmbeddings: [chunk],
timings: nil
)

XCTAssertEqual(result.chunkEmbeddings?.count, 1)
XCTAssertEqual(result.chunkEmbeddings?.first?.speakerId, "S1")
XCTAssertEqual(result.chunkEmbeddings?.first?.startTimeSeconds, 0.0)
XCTAssertEqual(result.chunkEmbeddings?.first?.endTimeSeconds, 1.6)
XCTAssertEqual(result.chunkEmbeddings?.first?.embedding256.count, 256)
XCTAssertEqual(
result.chunkEmbeddings?.first?.rho128, [],
"Default rho128 should be empty when no PLDA payload is provided."
)
}

func testChunkEmbeddingCodableRoundTrip() throws {
let original = ChunkEmbedding(
speakerId: "S3",
chunkIndex: 7,
speakerIndex: 2,
startTimeSeconds: 5.5,
endTimeSeconds: 7.1,
embedding256: (0..<256).map { Float($0) / 255.0 },
rho128: (0..<128).map { Double($0) / 127.0 }
)

let encoded = try JSONEncoder().encode(original)
let decoded = try JSONDecoder().decode(ChunkEmbedding.self, from: encoded)

XCTAssertEqual(decoded.speakerId, original.speakerId)
XCTAssertEqual(decoded.chunkIndex, original.chunkIndex)
XCTAssertEqual(decoded.speakerIndex, original.speakerIndex)
XCTAssertEqual(decoded.startTimeSeconds, original.startTimeSeconds)
XCTAssertEqual(decoded.endTimeSeconds, original.endTimeSeconds)
XCTAssertEqual(decoded.embedding256, original.embedding256)
XCTAssertEqual(decoded.rho128, original.rho128)
}

func testChunkEmbeddingFieldsRoundTrip() {
let embedding = (0..<256).map { Float($0) / 255.0 }
let rho = (0..<128).map { Double($0) / 127.0 }
let chunk = ChunkEmbedding(
speakerId: "S2",
chunkIndex: 42,
speakerIndex: 1,
startTimeSeconds: 12.34,
endTimeSeconds: 13.94,
embedding256: embedding,
rho128: rho
)

XCTAssertEqual(chunk.speakerId, "S2")
XCTAssertEqual(chunk.chunkIndex, 42)
XCTAssertEqual(chunk.speakerIndex, 1)
XCTAssertEqual(chunk.startTimeSeconds, 12.34, accuracy: 1e-9)
XCTAssertEqual(chunk.endTimeSeconds, 13.94, accuracy: 1e-9)
XCTAssertEqual(chunk.embedding256, embedding)
XCTAssertEqual(chunk.rho128, rho)
}

func testBuildPublicChunkEmbeddingsAssignsSpeakerIdsViaClusterPlusOne() {
let logger = AppLogger(category: "ChunkEmbeddingExposureTests")
let timed: [TimedEmbedding] = [
TimedEmbedding(
chunkIndex: 0, speakerIndex: 0, startFrame: 0, endFrame: 588,
frameWeights: [], startTime: 0.0, endTime: 1.6,
embedding256: [Float](repeating: 0.0, count: 4),
rho128: []
),
TimedEmbedding(
chunkIndex: 1, speakerIndex: 0, startFrame: 0, endFrame: 588,
frameWeights: [], startTime: 1.6, endTime: 3.2,
embedding256: [Float](repeating: 0.5, count: 4),
rho128: []
),
TimedEmbedding(
chunkIndex: 2, speakerIndex: 1, startFrame: 0, endFrame: 588,
frameWeights: [], startTime: 3.2, endTime: 4.8,
embedding256: [Float](repeating: 1.0, count: 4),
rho128: []
),
]
let assignments = [0, 2, 1]

let result = OfflineDiarizerManager.buildPublicChunkEmbeddings(
timedEmbeddings: timed,
assignments: assignments,
logger: logger
)

XCTAssertEqual(result.count, 3)
XCTAssertEqual(result.map { $0.speakerId }, ["S1", "S3", "S2"])
XCTAssertEqual(result[0].startTimeSeconds, 0.0, accuracy: 1e-9)
XCTAssertEqual(result[1].startTimeSeconds, 1.6, accuracy: 1e-9)
XCTAssertEqual(result[2].startTimeSeconds, 3.2, accuracy: 1e-9)
// Confirm all chunk-level fields propagate
XCTAssertEqual(result[0].chunkIndex, 0)
XCTAssertEqual(result[1].chunkIndex, 1)
XCTAssertEqual(result[2].chunkIndex, 2)
XCTAssertEqual(result[0].speakerIndex, 0)
XCTAssertEqual(result[2].speakerIndex, 1)
}

func testBuildPublicChunkEmbeddingsReturnsEmptyOnLengthMismatch() {
let logger = AppLogger(category: "ChunkEmbeddingExposureTests")
let timed: [TimedEmbedding] = [
TimedEmbedding(
chunkIndex: 0, speakerIndex: 0, startFrame: 0, endFrame: 588,
frameWeights: [], startTime: 0.0, endTime: 1.6,
embedding256: [], rho128: []
)
]
// Mismatched: 1 timed embedding but 2 assignments
let assignments = [0, 0]

let result = OfflineDiarizerManager.buildPublicChunkEmbeddings(
timedEmbeddings: timed,
assignments: assignments,
logger: logger
)

XCTAssertTrue(
result.isEmpty,
"Mismatched lengths should produce an empty result rather than a partial mapping."
)
}

func testBuildPublicChunkEmbeddingsHandlesEmptyInput() {
let logger = AppLogger(category: "ChunkEmbeddingExposureTests")
let result = OfflineDiarizerManager.buildPublicChunkEmbeddings(
timedEmbeddings: [],
assignments: [],
logger: logger
)
XCTAssertTrue(result.isEmpty)
}
}

@available(macOS 13.0, iOS 16.0, *)
final class ModelWarmupTests: XCTestCase {

Expand Down
Loading