Skip to content

Commit 15895fd

Browse files
authored
[ENH] BM25 support for Rust client (#5688)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Make `EmbeddingFunction` trait a bit more generic - New functionality - Implement `BM25SparseEmbeddingFunction` and related stuff ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the_ [_docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent df5c044 commit 15895fd

File tree

9 files changed

+541
-23
lines changed

9 files changed

+541
-23
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ futures-core = "0.3"
1818
http-body-util = "0.1.3"
1919
lazy_static = { version = "1.4" }
2020
lexical-core = "1.0"
21+
murmur3 = "0.5.2"
2122
num_cpus = "1.16.0"
2223
once_cell = "1.21.3"
2324
opentelemetry = { version = "0.27.0", default-features = false, features = ["trace", "metrics"] }

rust/chroma/Cargo.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,23 @@ description = "Client for Chroma, the AI-native cloud database."
66
license = "Apache-2.0"
77

88
[dependencies]
9-
async-trait = { workspace = true }
9+
async-trait.workspace = true
1010
backon = "1.5"
1111
bon = "3.8.1"
12-
reqwest = { version = ">=0.12.0,<0.13.0", features = ["json", "charset", "system-proxy", "http2"], default-features = false }
12+
murmur3.workspace = true
1313
opentelemetry = { version = ">=0.27,<0.32", optional = true }
14+
parking_lot.workspace = true
15+
reqwest = { version = ">=0.12.0,<0.13.0", features = ["json", "charset", "system-proxy", "http2"], default-features = false }
16+
rust-stemmers = "1.2"
1417
serde.workspace = true
1518
serde_json.workspace = true
1619
tokio = { version = ">=1,<2", features = ["sync"] }
1720
thiserror.workspace = true
1821
tracing.workspace = true
19-
parking_lot.workspace = true
2022

21-
chroma-api-types = { workspace = true }
22-
chroma-error = { workspace = true }
23-
chroma-types = { workspace = true }
23+
chroma-api-types.workspace = true
24+
chroma-error.workspace = true
25+
chroma-types.workspace = true
2426

2527
[features]
2628
default = ["rustls", "ollama"]

rust/chroma/src/embed/bm25.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
use chroma_types::SparseVector;
2+
use thiserror::Error;
3+
4+
use crate::embed::bm25_tokenizer::Bm25Tokenizer;
5+
use crate::embed::murmur3_abs_hasher::Murmur3AbsHasher;
6+
use crate::embed::{EmbeddingFunction, TokenHasher, Tokenizer};
7+
8+
/// Error type for BM25 sparse embedding.
9+
///
10+
/// This is an empty enum (uninhabited type), meaning it can never be constructed.
11+
/// BM25 encoding with infallible tokenizers and hashers cannot fail.
12+
#[derive(Debug, Error)]
13+
pub enum BM25SparseEmbeddingError {}
14+
15+
/// BM25 sparse embedding function parameterized by tokenizer and hasher.
16+
///
17+
/// The BM25 formula used:
18+
/// score = tf * (k + 1) / (tf + k * (1 - b + b * doc_len / avg_len))
19+
///
20+
/// Where:
21+
/// - tf: term frequency (count of token in document)
22+
/// - doc_len: document length in tokens (not characters)
23+
/// - k, b, avg_len: BM25 parameters
24+
///
25+
/// Type parameters:
26+
/// - T: Tokenizer implementation (e.g., Bm25Tokenizer)
27+
/// - H: TokenHasher implementation (e.g., Murmur3AbsHasher)
28+
pub struct BM25SparseEmbeddingFunction<T, H>
29+
where
30+
T: Tokenizer,
31+
H: TokenHasher,
32+
{
33+
/// Tokenizer for converting text into tokens.
34+
pub tokenizer: T,
35+
/// Hasher for converting tokens into u32 identifiers.
36+
pub hasher: H,
37+
/// BM25 saturation parameter (typically 1.2).
38+
pub k: f32,
39+
/// BM25 length normalization parameter (typically 0.75).
40+
pub b: f32,
41+
/// Average document length in tokens for normalization.
42+
pub avg_len: f32,
43+
}
44+
45+
impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
46+
/// Create BM25 with default Bm25Tokenizer and Murmur3AbsHasher.
47+
///
48+
/// This is the standard configuration matching Python's fastembed BM25.
49+
///
50+
/// Default parameters:
51+
/// - k: 1.2 (BM25 saturation parameter)
52+
/// - b: 0.75 (length normalization parameter)
53+
/// - avg_len: 256.0 (average document length in tokens)
54+
/// - tokenizer: English stemmer with 179 stopwords, 40 char token limit
55+
/// - hasher: Murmur3 with seed 0, abs() behavior
56+
pub fn default_murmur3_abs() -> Self {
57+
Self {
58+
tokenizer: Bm25Tokenizer::default(),
59+
hasher: Murmur3AbsHasher::default(),
60+
k: 1.2,
61+
b: 0.75,
62+
avg_len: 256.0,
63+
}
64+
}
65+
}
66+
67+
impl<T, H> BM25SparseEmbeddingFunction<T, H>
68+
where
69+
T: Tokenizer,
70+
H: TokenHasher,
71+
{
72+
/// Encode a single text string into a sparse vector.
73+
pub fn encode(&self, text: &str) -> Result<SparseVector, BM25SparseEmbeddingError> {
74+
let tokens = self.tokenizer.tokenize(text);
75+
76+
let doc_len = tokens.len() as f32;
77+
78+
let mut token_ids = Vec::with_capacity(tokens.len());
79+
for token in tokens {
80+
let id = self.hasher.hash(&token);
81+
token_ids.push(id);
82+
}
83+
84+
token_ids.sort_unstable();
85+
86+
let sparse_pairs = token_ids.chunk_by(|a, b| a == b).map(|chunk| {
87+
let id = chunk[0];
88+
let tf = chunk.len() as f32;
89+
90+
// BM25 formula
91+
let score = tf * (self.k + 1.0)
92+
/ (tf + self.k * (1.0 - self.b + self.b * doc_len / self.avg_len));
93+
94+
(id, score)
95+
});
96+
97+
Ok(SparseVector::from_pairs(sparse_pairs))
98+
}
99+
}
100+
101+
#[async_trait::async_trait]
102+
impl<T, H> EmbeddingFunction for BM25SparseEmbeddingFunction<T, H>
103+
where
104+
T: Tokenizer + Send + Sync + 'static,
105+
H: TokenHasher + Send + Sync + 'static,
106+
{
107+
type Embedding = SparseVector;
108+
type Error = BM25SparseEmbeddingError;
109+
110+
async fn embed_strs(&self, batches: &[&str]) -> Result<Vec<Self::Embedding>, Self::Error> {
111+
batches.iter().map(|text| self.encode(text)).collect()
112+
}
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use super::*;
118+
119+
/// Tests comprehensive tokenization covering:
120+
/// - Possessive forms (Bolt's)
121+
/// - Special characters (~, parentheses)
122+
/// - Numbers (27.8, 44.72)
123+
/// - Mixed case and abbreviations (mph, km/h)
124+
/// - Hyphens in compound words
125+
/// - Maximum token variety (12 unique tokens after processing)
126+
#[test]
127+
fn test_bm25_comprehensive_tokenization() {
128+
let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
129+
let text = "Usain Bolt's top speed reached ~27.8 mph (44.72 km/h)";
130+
131+
let result = bm25.encode(text).unwrap();
132+
133+
let expected_indices = vec![
134+
230246813, 395514983, 458027949, 488165615, 729632045, 734978415, 997512866,
135+
1114505193, 1381820790, 1501587190, 1649421877, 1837285388,
136+
];
137+
let expected_value = 1.6391153;
138+
139+
assert_eq!(result.indices.len(), 12);
140+
assert_eq!(result.indices, expected_indices);
141+
142+
for &value in &result.values {
143+
assert!((value - expected_value).abs() < 1e-5);
144+
}
145+
}
146+
147+
/// Tests tokenizer's handling of:
148+
/// - Stopword filtering ("The" is filtered out)
149+
/// - Multiple consecutive spaces
150+
/// - Hyphens in compound words (space-time)
151+
/// - Full uppercase words (WARPS)
152+
/// - Trailing punctuation (...)
153+
/// - Stemming (objects -> object)
154+
#[test]
155+
fn test_bm25_stopwords_and_punctuation() {
156+
let bm25 = BM25SparseEmbeddingFunction::default_murmur3_abs();
157+
let text = "The space-time continuum WARPS near massive objects...";
158+
159+
let result = bm25.encode(text).unwrap();
160+
161+
let expected_indices = vec![
162+
90097469, 519064992, 737893654, 1110755108, 1950894484, 2031641008, 2058513491,
163+
];
164+
let expected_value = 1.660867;
165+
166+
assert_eq!(result.indices.len(), 7);
167+
assert_eq!(result.indices, expected_indices);
168+
169+
for &value in &result.values {
170+
assert!((value - expected_value).abs() < 1e-5);
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)