Skip to content

Commit dc5d666

Browse files
committed
better benchmark
1 parent cfedb27 commit dc5d666

File tree

1 file changed

+52
-11
lines changed

1 file changed

+52
-11
lines changed

backends/candle-bench/benches/radix_mlp_benchmark.rs

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,22 @@ fn setup(
230230
Ok((enabled_batch, disabled_batch, enabled_batch_unpadded))
231231
}
232232

233+
fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
234+
assert_eq!(v1.len(), v2.len());
235+
236+
let mut sumxx = 0.0;
237+
let mut sumyy = 0.0;
238+
let mut sumxy = 0.0;
239+
240+
for (x, y) in v1.iter().zip(v2.iter()) {
241+
sumxx += x * x;
242+
sumyy += y * y;
243+
sumxy += x * y;
244+
}
245+
246+
sumxy / (sumxx * sumyy).sqrt()
247+
}
248+
233249
/// The main benchmark function.
234250
fn bench_radix_mlp(c: &mut Criterion) {
235251
// 1. Setup backend
@@ -245,8 +261,8 @@ fn bench_radix_mlp(c: &mut Criterion) {
245261
.expect("Could not start backend");
246262
println!("Backend initialized");
247263

248-
let batch_size = 32;
249-
let size_configs = [(512, 256), (512, 512), (1024, 1024)];
264+
let batch_size = 16;
265+
let size_configs = [(32,512), (256, 512), (512, 32), (512, 256), (512, 512), (512, 1024)];
250266

251267
for (shared_prefix_len, unique_suffix_len) in size_configs {
252268
let (enabled_batch, disabled_batch, enabled_batch_unpadded) = setup(
@@ -260,6 +276,9 @@ fn bench_radix_mlp(c: &mut Criterion) {
260276
// --- Correctness Check ---
261277
let radix_result = backend.embed(enabled_batch.clone().into()).unwrap();
262278
let regular_result = backend.embed(disabled_batch.clone().into()).unwrap();
279+
let radix_unpadded_result = backend
280+
.embed(enabled_batch_unpadded.clone().into())
281+
.unwrap();
263282

264283
let radix_vecs: Vec<Vec<f32>> = (0..batch_size)
265284
.map(|i| match radix_result.get(&i).unwrap() {
@@ -273,25 +292,47 @@ fn bench_radix_mlp(c: &mut Criterion) {
273292
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
274293
})
275294
.collect();
295+
let radix_unpadded_vecs: Vec<Vec<f32>> = (0..batch_size)
296+
.map(|i| match radix_unpadded_result.get(&i).unwrap() {
297+
text_embeddings_backend_core::Embedding::Pooled(v) => v.clone(),
298+
text_embeddings_backend_core::Embedding::All(vecs) => vecs.last().unwrap().clone(),
299+
})
300+
.collect();
276301

277302
assert_eq!(radix_vecs.len(), regular_vecs.len());
303+
assert_eq!(radix_unpadded_vecs.len(), regular_vecs.len());
304+
278305
for i in 0..radix_vecs.len() {
279306
let diff: f32 = radix_vecs[i]
280307
.iter()
281308
.zip(regular_vecs[i].iter())
282309
.map(|(a, b)| (a - b).abs())
283310
.sum();
284-
assert!(
285-
diff < 1e-2,
286-
"Correctness check failed for size ({}, {}): Embeddings for item {} differ by {}",
287-
shared_prefix_len,
288-
unique_suffix_len,
289-
i,
290-
diff
291-
);
311+
let cos_sim = cosine_similarity(&radix_vecs[i], &regular_vecs[i]);
312+
let cos_sim_unpadded =
313+
cosine_similarity(&radix_unpadded_vecs[i], &regular_vecs[i]);
314+
315+
let passed = diff < 1e-2 && cos_sim > 0.999 && cos_sim_unpadded > 0.999;
316+
317+
if !passed {
318+
println!(
319+
"Item {}: Abs Diff: {:.4}, Cosine Sim (Padded): {:.6}, Cosine Sim (Unpadded): {:.6}",
320+
i,
321+
diff,
322+
1.0 - cos_sim,
323+
1.0 - cos_sim_unpadded
324+
);
325+
println!(
326+
"Correctness check FAILED for size ({}, {}), item {}",
327+
shared_prefix_len, unique_suffix_len, i
328+
);
329+
println!("Regular: {:?}", &regular_vecs[i][..8]);
330+
println!("Padded: {:?}", &radix_vecs[i][..8]);
331+
println!("Unpadded:{:?}", &radix_unpadded_vecs[i][..8]);
332+
}
292333
}
293334
println!(
294-
"Correctness check passed for size ({}, {}). Starting benchmark...",
335+
"Correctness check for size ({}, {}) complete. Starting benchmark...",
295336
shared_prefix_len, unique_suffix_len
296337
);
297338
// --- End Correctness Check ---

0 commit comments

Comments
 (0)