|
| 1 | +use chroma_types::SparseVector; |
| 2 | +use thiserror::Error; |
| 3 | + |
| 4 | +use crate::embed::bm25_tokenizer::Bm25Tokenizer; |
| 5 | +use crate::embed::murmur3_abs_hasher::Murmur3AbsHasher; |
| 6 | +use crate::embed::{EmbeddingFunction, TokenHasher, Tokenizer}; |
| 7 | + |
| 8 | +/// Error type for BM25 sparse embedding. |
| 9 | +/// |
| 10 | +/// This is an empty enum (uninhabited type), meaning it can never be constructed. |
| 11 | +/// BM25 encoding with infallible tokenizers and hashers cannot fail. |
| 12 | +#[derive(Debug, Error)] |
| 13 | +pub enum BM25SparseEmbeddingError {} |
| 14 | + |
| 15 | +/// BM25 sparse embedding function parameterized by tokenizer and hasher. |
| 16 | +/// |
| 17 | +/// The BM25 formula used: |
| 18 | +/// score = tf * (k + 1) / (tf + k * (1 - b + b * doc_len / avg_len)) |
| 19 | +/// |
| 20 | +/// Where: |
| 21 | +/// - tf: term frequency (count of token in document) |
| 22 | +/// - doc_len: document length in tokens (not characters) |
| 23 | +/// - k, b, avg_len: BM25 parameters |
| 24 | +/// |
| 25 | +/// Type parameters: |
| 26 | +/// - T: Tokenizer implementation (e.g., Bm25Tokenizer) |
| 27 | +/// - H: TokenHasher implementation (e.g., Murmur3AbsHasher) |
| 28 | +pub struct BM25SparseEmbeddingFunction<T, H> |
| 29 | +where |
| 30 | + T: Tokenizer, |
| 31 | + H: TokenHasher, |
| 32 | +{ |
| 33 | + /// Tokenizer for converting text into tokens. |
| 34 | + pub tokenizer: T, |
| 35 | + /// Hasher for converting tokens into u32 identifiers. |
| 36 | + pub hasher: H, |
| 37 | + /// BM25 saturation parameter (typically 1.2). |
| 38 | + pub k: f32, |
| 39 | + /// BM25 length normalization parameter (typically 0.75). |
| 40 | + pub b: f32, |
| 41 | + /// Average document length in tokens for normalization. |
| 42 | + pub avg_len: f32, |
| 43 | +} |
| 44 | + |
| 45 | +impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> { |
| 46 | + /// Create BM25 with default Bm25Tokenizer and Murmur3AbsHasher. |
| 47 | + /// |
| 48 | + /// This is the standard configuration matching Python's fastembed BM25. |
| 49 | + /// |
| 50 | + /// Default parameters: |
| 51 | + /// - k: 1.2 (BM25 saturation parameter) |
| 52 | + /// - b: 0.75 (length normalization parameter) |
| 53 | + /// - avg_len: 256.0 (average document length in tokens) |
| 54 | + /// - tokenizer: English stemmer with 179 stopwords, 40 char token limit |
| 55 | + /// - hasher: Murmur3 with seed 0, abs() behavior |
| 56 | + pub fn default_murmur3_abs() -> Self { |
| 57 | + Self { |
| 58 | + tokenizer: Bm25Tokenizer::default(), |
| 59 | + hasher: Murmur3AbsHasher::default(), |
| 60 | + k: 1.2, |
| 61 | + b: 0.75, |
| 62 | + avg_len: 256.0, |
| 63 | + } |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +impl<T, H> BM25SparseEmbeddingFunction<T, H> |
| 68 | +where |
| 69 | + T: Tokenizer, |
| 70 | + H: TokenHasher, |
| 71 | +{ |
| 72 | + /// Encode a single text string into a sparse vector. |
| 73 | + pub fn encode(&self, text: &str) -> Result<SparseVector, BM25SparseEmbeddingError> { |
| 74 | + let tokens = self.tokenizer.tokenize(text); |
| 75 | + |
| 76 | + let doc_len = tokens.len() as f32; |
| 77 | + |
| 78 | + let mut token_ids = Vec::with_capacity(tokens.len()); |
| 79 | + for token in tokens { |
| 80 | + let id = self.hasher.hash(&token); |
| 81 | + token_ids.push(id); |
| 82 | + } |
| 83 | + |
| 84 | + token_ids.sort_unstable(); |
| 85 | + |
| 86 | + let sparse_pairs = token_ids.chunk_by(|a, b| a == b).map(|chunk| { |
| 87 | + let id = chunk[0]; |
| 88 | + let tf = chunk.len() as f32; |
| 89 | + |
| 90 | + // BM25 formula |
| 91 | + let score = tf * (self.k + 1.0) |
| 92 | + / (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len)); |
| 93 | + |
| 94 | + (id, score) |
| 95 | + }); |
| 96 | + |
| 97 | + Ok(SparseVector::from_pairs(sparse_pairs)) |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +#[async_trait::async_trait] |
| 102 | +impl<T, H> EmbeddingFunction for BM25SparseEmbeddingFunction<T, H> |
| 103 | +where |
| 104 | + T: Tokenizer + Send + Sync + 'static, |
| 105 | + H: TokenHasher + Send + Sync + 'static, |
| 106 | +{ |
| 107 | + type Embedding = SparseVector; |
| 108 | + type Error = BM25SparseEmbeddingError; |
| 109 | + |
| 110 | + async fn embed_strs(&self, batches: &[&str]) -> Result<Vec<Self::Embedding>, Self::Error> { |
| 111 | + batches.iter().map(|text| self.encode(text)).collect() |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +#[cfg(test)] |
| 116 | +mod tests { |
| 117 | + use super::*; |
| 118 | + |
| 119 | + /// Tests comprehensive tokenization covering: |
| 120 | + /// - Possessive forms (Bolt's) |
| 121 | + /// - Special characters (~, parentheses) |
| 122 | + /// - Numbers (27.8, 44.72) |
| 123 | + /// - Mixed case and abbreviations (mph, km/h) |
| 124 | + /// - Hyphens in compound words |
| 125 | + /// - Maximum token variety (12 unique tokens after processing) |
| 126 | + #[test] |
| 127 | + fn test_bm25_comprehensive_tokenization() { |
| 128 | + let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs(); |
| 129 | + let text = "Usain Bolt's top speed reached ~27.8 mph (44.72 km/h)"; |
| 130 | + |
| 131 | + let result = bm25.encode(text).unwrap(); |
| 132 | + |
| 133 | + let expected_indices = vec![ |
| 134 | + 230246813, 395514983, 458027949, 488165615, 729632045, 734978415, 997512866, |
| 135 | + 1114505193, 1381820790, 1501587190, 1649421877, 1837285388, |
| 136 | + ]; |
| 137 | + let expected_value = 1.6391153; |
| 138 | + |
| 139 | + assert_eq!(result.indices.len(), 12); |
| 140 | + assert_eq!(result.indices, expected_indices); |
| 141 | + |
| 142 | + for &value in &result.values { |
| 143 | + assert!((value - expected_value).abs() < 1e-5); |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + /// Tests tokenizer's handling of: |
| 148 | + /// - Stopword filtering ("The" is filtered out) |
| 149 | + /// - Multiple consecutive spaces |
| 150 | + /// - Hyphens in compound words (space-time) |
| 151 | + /// - Full uppercase words (WARPS) |
| 152 | + /// - Trailing punctuation (...) |
| 153 | + /// - Stemming (objects -> object) |
| 154 | + #[test] |
| 155 | + fn test_bm25_stopwords_and_punctuation() { |
| 156 | + let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs(); |
| 157 | + let text = "The space-time continuum WARPS near massive objects..."; |
| 158 | + |
| 159 | + let result = bm25.encode(text).unwrap(); |
| 160 | + |
| 161 | + let expected_indices = vec![ |
| 162 | + 90097469, 519064992, 737893654, 1110755108, 1950894484, 2031641008, 2058513491, |
| 163 | + ]; |
| 164 | + let expected_value = 1.660867; |
| 165 | + |
| 166 | + assert_eq!(result.indices.len(), 7); |
| 167 | + assert_eq!(result.indices, expected_indices); |
| 168 | + |
| 169 | + for &value in &result.values { |
| 170 | + assert!((value - expected_value).abs() < 1e-5); |
| 171 | + } |
| 172 | + } |
| 173 | +} |
0 commit comments