Skip to content

Commit cfedb27

Browse files
committed
better bench
1 parent aba0825 commit cfedb27

File tree

1 file changed

+102
-84
lines changed

1 file changed

+102
-84
lines changed

backends/candle-bench/benches/radix_mlp_benchmark.rs

Lines changed: 102 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,13 @@ impl From<Batch> for text_embeddings_backend_core::Batch {
124124
}
125125

126126
/// Sets up the backend and batch data needed for the benchmark.
127-
fn setup() -> Result<(CandleBackend, Batch, Batch, Batch)> {
128-
// 1. Setup backend
129-
let model_root = download_artifacts("Qwen/Qwen3-Embedding-4B", None)?;
130-
println!("Model downloaded to {:?}", model_root);
131-
let backend = CandleBackend::new(
132-
&model_root,
133-
"float16".to_string(),
134-
ModelType::Embedding(Pool::LastToken),
135-
None,
136-
)?;
137-
println!("Backend initialized");
138-
127+
fn setup(
128+
_backend: &CandleBackend,
129+
batch_size: usize,
130+
shared_prefix_len: usize,
131+
unique_suffix_len: usize,
132+
) -> Result<(Batch, Batch, Batch)> {
139133
// 2. Create benchmark batch
140-
// Batch size of 32, 500 shared prefix, 500 unique suffix per sequence
141-
// Radix tree structure: 500x1 (shared), then 32x500 (unique tails)
142-
let batch_size: usize = 32;
143-
let shared_prefix_len: usize = 500;
144-
let unique_suffix_len: usize = 500;
145134
let shared_prefix_ids: Vec<u32> = vec![1; shared_prefix_len];
146135

147136
let mut all_input_ids = Vec::new();
@@ -183,7 +172,9 @@ fn setup() -> Result<(CandleBackend, Batch, Batch, Batch)> {
183172
);
184173

185174
println!(
186-
"RadixMLP compression: {} original tokens -> {} compact tokens ({:.1}% reduction)",
175+
"RadixMLP compression (prefix={}, suffix={}): {} original tokens -> {} compact tokens ({:.1}% reduction)",
176+
shared_prefix_len,
177+
unique_suffix_len,
187178
all_input_ids.len(),
188179
compact_input_ids.len(),
189180
(1.0 - compact_input_ids.len() as f64 / all_input_ids.len() as f64) * 100.0
@@ -236,78 +227,105 @@ fn setup() -> Result<(CandleBackend, Batch, Batch, Batch)> {
236227
fold_gather: None,
237228
};
238229

239-
Ok((
240-
backend,
241-
enabled_batch,
242-
disabled_batch,
243-
enabled_batch_unpadded,
244-
))
230+
Ok((enabled_batch, disabled_batch, enabled_batch_unpadded))
245231
}
246232

247233
/// The main benchmark function.
248234
fn bench_radix_mlp(c: &mut Criterion) {
249-
let (backend, enabled_batch, disabled_batch, enabled_batch_unpadded) =
250-
setup().expect("Failed to set up benchmark");
251-
252-
// --- Correctness Check ---
253-
// Run once before benchmarking to ensure outputs are identical.
254-
let radix_result = backend.embed(enabled_batch.clone().into()).unwrap();
255-
let regular_result = backend.embed(disabled_batch.clone().into()).unwrap();
256-
257-
// Extract embeddings from the results (IntMap<usize, Embedding>)
258-
let radix_vecs: Vec<Vec<f32>> = (0..16)
259-
.map(|i| match radix_result.get(&i).unwrap() {
260-
text_embeddings_backend_core::Embedding::Pooled(v) => v.clone(),
261-
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
262-
})
263-
.collect();
264-
let regular_vecs: Vec<Vec<f32>> = (0..16)
265-
.map(|i| match regular_result.get(&i).unwrap() {
266-
text_embeddings_backend_core::Embedding::Pooled(v) => v.clone(),
267-
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
268-
})
269-
.collect();
270-
271-
assert_eq!(radix_vecs.len(), regular_vecs.len());
272-
for i in 0..radix_vecs.len() {
273-
let diff: f32 = radix_vecs[i]
274-
.iter()
275-
.zip(regular_vecs[i].iter())
276-
.map(|(a, b)| (a - b).abs())
277-
.sum();
278-
assert!(
279-
diff < 1e-2,
280-
"Correctness check failed: Embeddings for item {} differ by {}",
281-
i,
282-
diff
235+
// 1. Setup backend
236+
let model_root = download_artifacts("Qwen/Qwen3-Embedding-0.6B", None)
237+
.expect("Failed to download artifacts");
238+
println!("Model downloaded to {:?}", model_root);
239+
let backend = CandleBackend::new(
240+
&model_root,
241+
"float16".to_string(),
242+
ModelType::Embedding(Pool::LastToken),
243+
None,
244+
)
245+
.expect("Could not start backend");
246+
println!("Backend initialized");
247+
248+
let batch_size = 32;
249+
let size_configs = [(512, 256), (512, 512), (1024, 1024)];
250+
251+
for (shared_prefix_len, unique_suffix_len) in size_configs {
252+
let (enabled_batch, disabled_batch, enabled_batch_unpadded) = setup(
253+
&backend,
254+
batch_size,
255+
shared_prefix_len,
256+
unique_suffix_len,
257+
)
258+
.expect("Failed to set up benchmark");
259+
260+
// --- Correctness Check ---
261+
let radix_result = backend.embed(enabled_batch.clone().into()).unwrap();
262+
let regular_result = backend.embed(disabled_batch.clone().into()).unwrap();
263+
264+
let radix_vecs: Vec<Vec<f32>> = (0..batch_size)
265+
.map(|i| match radix_result.get(&i).unwrap() {
266+
text_embeddings_backend_core::Embedding::Pooled(v) => v.clone(),
267+
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
268+
})
269+
.collect();
270+
let regular_vecs: Vec<Vec<f32>> = (0..batch_size)
271+
.map(|i| match regular_result.get(&i).unwrap() {
272+
text_embeddings_backend_core::Embedding::Pooled(v) => v.clone(),
273+
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
274+
})
275+
.collect();
276+
277+
assert_eq!(radix_vecs.len(), regular_vecs.len());
278+
for i in 0..radix_vecs.len() {
279+
let diff: f32 = radix_vecs[i]
280+
.iter()
281+
.zip(regular_vecs[i].iter())
282+
.map(|(a, b)| (a - b).abs())
283+
.sum();
284+
assert!(
285+
diff < 1e-2,
286+
"Correctness check failed for size ({}, {}): Embeddings for item {} differ by {}",
287+
shared_prefix_len,
288+
unique_suffix_len,
289+
i,
290+
diff
291+
);
292+
}
293+
println!(
294+
"Correctness check passed for size ({}, {}). Starting benchmark...",
295+
shared_prefix_len, unique_suffix_len
283296
);
297+
// --- End Correctness Check ---
298+
299+
let mut group = c.benchmark_group(&format!(
300+
"RadixMLP Speedup (prefix: {}, suffix: {})",
301+
shared_prefix_len, unique_suffix_len
302+
));
303+
group
304+
.sample_size(10)
305+
.warm_up_time(std::time::Duration::from_secs(3))
306+
.measurement_time(std::time::Duration::from_secs(15));
307+
308+
// Benchmark WITH RadixMLP enabled (uses shared prefix computation)
309+
group.bench_function("radix_mlp_enabled", |b| {
310+
b.iter(|| backend.embed(enabled_batch.clone().into()).unwrap())
311+
});
312+
313+
// Benchmark WITH RadixMLP enabled but without padding (uses shared prefix computation)
314+
group.bench_function("radix_mlp_enabled_unpadded", |b| {
315+
b.iter(|| {
316+
backend
317+
.embed(enabled_batch_unpadded.clone().into())
318+
.unwrap()
319+
})
320+
});
321+
322+
// Benchmark WITHOUT RadixMLP (standard full computation)
323+
group.bench_function("radix_mlp_disabled", |b| {
324+
b.iter(|| backend.embed(disabled_batch.clone().into()).unwrap())
325+
});
326+
327+
group.finish();
284328
}
285-
println!("Correctness check passed. Starting benchmark...");
286-
// --- End Correctness Check ---
287-
288-
let mut group = c.benchmark_group("RadixMLP Speedup");
289-
group.sample_size(25);
290-
291-
// Benchmark WITH RadixMLP enabled (uses shared prefix computation)
292-
group.bench_function("radix_mlp_enabled", |b| {
293-
b.iter(|| backend.embed(enabled_batch.clone().into()).unwrap())
294-
});
295-
296-
// Benchmark WITH RadixMLP enabled but without padding (uses shared prefix computation)
297-
group.bench_function("radix_mlp_enabled_unpadded", |b| {
298-
b.iter(|| {
299-
backend
300-
.embed(enabled_batch_unpadded.clone().into())
301-
.unwrap()
302-
})
303-
});
304-
305-
// Benchmark WITHOUT RadixMLP (standard full computation)
306-
group.bench_function("radix_mlp_disabled", |b| {
307-
b.iter(|| backend.embed(disabled_batch.clone().into()).unwrap())
308-
});
309-
310-
group.finish();
311329
}
312330

313331
criterion_group!(benches, bench_radix_mlp);

0 commit comments

Comments
 (0)