Skip to content

Commit 088ba6b

Browse files
committed
normalized benchmark
1 parent 57c5566 commit 088ba6b

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

backends/candle-bench/benches/radix_mlp_benchmark.rs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ fn setup(
230230
Ok((enabled_batch, disabled_batch, enabled_batch_unpadded))
231231
}
232232

233+
fn normalize(v: &[f32]) -> Vec<f32> {
234+
let norm = (v.iter().map(|&val| val * val).sum::<f32>()).sqrt();
235+
if norm > 0.0 {
236+
v.iter().map(|&val| val / norm).collect()
237+
} else {
238+
v.to_vec()
239+
}
240+
}
241+
233242
fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
234243
assert_eq!(v1.len(), v2.len());
235244

@@ -262,7 +271,7 @@ fn bench_radix_mlp(c: &mut Criterion) {
262271
println!("Backend initialized");
263272

264273
let batch_size = 32;
265-
let size_configs = [(32,512), (256, 512), (512, 32), (512, 256), (512, 512), (512, 1024)];
274+
let size_configs = [(32, 512), (256, 512), (512, 32), (512, 256), (512, 512), (512, 1024)];
266275

267276
for (shared_prefix_len, unique_suffix_len) in size_configs {
268277
let (enabled_batch, disabled_batch, enabled_batch_unpadded) = setup(
@@ -299,20 +308,30 @@ fn bench_radix_mlp(c: &mut Criterion) {
299308
})
300309
.collect();
301310

311+
let normalized_radix_vecs: Vec<Vec<f32>> = radix_vecs.iter().map(|v| normalize(v)).collect();
312+
let normalized_regular_vecs: Vec<Vec<f32>> =
313+
regular_vecs.iter().map(|v| normalize(v)).collect();
314+
let normalized_radix_unpadded_vecs: Vec<Vec<f32>> =
315+
radix_unpadded_vecs.iter().map(|v| normalize(v)).collect();
316+
302317
assert_eq!(radix_vecs.len(), regular_vecs.len());
303318
assert_eq!(radix_unpadded_vecs.len(), regular_vecs.len());
304319

305320
for i in 0..radix_vecs.len() {
306-
let diff: f32 = radix_vecs[i]
321+
let diff: f32 = normalized_radix_vecs[i]
307322
.iter()
308-
.zip(regular_vecs[i].iter())
323+
.zip(normalized_regular_vecs[i].iter())
309324
.map(|(a, b)| (a - b).abs())
310-
.sum();
311-
let cos_sim = cosine_similarity(&radix_vecs[i], &regular_vecs[i]);
312-
let cos_sim_unpadded =
313-
cosine_similarity(&radix_unpadded_vecs[i], &regular_vecs[i]);
325+
.reduce(f32::max)
326+
.unwrap_or(0.0);
327+
let cos_sim =
328+
cosine_similarity(&normalized_radix_vecs[i], &normalized_regular_vecs[i]);
329+
let cos_sim_unpadded = cosine_similarity(
330+
&normalized_radix_unpadded_vecs[i],
331+
&normalized_regular_vecs[i],
332+
);
314333

315-
let passed = diff < 1e-2 && cos_sim > 0.999 && cos_sim_unpadded > 0.999;
334+
let passed = diff < 1e-4 && cos_sim > 0.999 && cos_sim_unpadded > 0.999;
316335

317336
if !passed {
318337
println!(
@@ -326,9 +345,12 @@ fn bench_radix_mlp(c: &mut Criterion) {
326345
"Correctness check FAILED for size ({}, {}), item {}",
327346
shared_prefix_len, unique_suffix_len, i
328347
);
329-
println!("Regular: {:?}", &regular_vecs[i][..8]);
330-
println!("Padded: {:?}", &radix_vecs[i][..8]);
331-
println!("Unpadded:{:?}", &radix_unpadded_vecs[i][..8]);
348+
println!("Regular (normalized): {:?}", &normalized_regular_vecs[i][..8]);
349+
println!("Padded (normalized): {:?}", &normalized_radix_vecs[i][..8]);
350+
println!(
351+
"Unpadded (normalized):{:?}",
352+
&normalized_radix_unpadded_vecs[i][..8]
353+
);
332354
}
333355
}
334356
println!(

0 commit comments

Comments
 (0)