Skip to content

Commit 89224ce

Browse files
committed
better bench
1 parent 088ba6b commit 89224ce

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

backends/candle-bench/benches/radix_mlp_benchmark.rs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,29 @@ fn bench_radix_mlp(c: &mut Criterion) {
271271
println!("Backend initialized");
272272

273273
let batch_size = 32;
274-
let size_configs = [(32, 512), (256, 512), (512, 32), (512, 256), (512, 512), (512, 1024)];
274+
let size_configs = [
275+
// 256 suffix sizes
276+
(1, 256),
277+
(32, 256),
278+
(128, 256),
279+
(256, 256),
280+
(512, 256),
281+
(1024, 256),
282+
(2048, 256),
283+
// 1024 suffix sizes
284+
(1, 1024),
285+
(32, 1024),
286+
(128, 1024),
287+
(256, 1024),
288+
(512, 1024),
289+
(1024, 1024),
290+
(2048, 1024),
291+
];
275292

276293
for (shared_prefix_len, unique_suffix_len) in size_configs {
277-
let (enabled_batch, disabled_batch, enabled_batch_unpadded) = setup(
278-
&backend,
279-
batch_size,
280-
shared_prefix_len,
281-
unique_suffix_len,
282-
)
283-
.expect("Failed to set up benchmark");
294+
let (enabled_batch, disabled_batch, enabled_batch_unpadded) =
295+
setup(&backend, batch_size, shared_prefix_len, unique_suffix_len)
296+
.expect("Failed to set up benchmark");
284297

285298
// --- Correctness Check ---
286299
let radix_result = backend.embed(enabled_batch.clone().into()).unwrap();
@@ -308,7 +321,8 @@ fn bench_radix_mlp(c: &mut Criterion) {
308321
})
309322
.collect();
310323

311-
let normalized_radix_vecs: Vec<Vec<f32>> = radix_vecs.iter().map(|v| normalize(v)).collect();
324+
let normalized_radix_vecs: Vec<Vec<f32>> =
325+
radix_vecs.iter().map(|v| normalize(v)).collect();
312326
let normalized_regular_vecs: Vec<Vec<f32>> =
313327
regular_vecs.iter().map(|v| normalize(v)).collect();
314328
let normalized_radix_unpadded_vecs: Vec<Vec<f32>> =
@@ -324,8 +338,7 @@ fn bench_radix_mlp(c: &mut Criterion) {
324338
.map(|(a, b)| (a - b).abs())
325339
.reduce(f32::max)
326340
.unwrap_or(0.0);
327-
let cos_sim =
328-
cosine_similarity(&normalized_radix_vecs[i], &normalized_regular_vecs[i]);
341+
let cos_sim = cosine_similarity(&normalized_radix_vecs[i], &normalized_regular_vecs[i]);
329342
let cos_sim_unpadded = cosine_similarity(
330343
&normalized_radix_unpadded_vecs[i],
331344
&normalized_regular_vecs[i],
@@ -345,7 +358,10 @@ fn bench_radix_mlp(c: &mut Criterion) {
345358
"Correctness check FAILED for size ({}, {}), item {}",
346359
shared_prefix_len, unique_suffix_len, i
347360
);
348-
println!("Regular (normalized): {:?}", &normalized_regular_vecs[i][..8]);
361+
println!(
362+
"Regular (normalized): {:?}",
363+
&normalized_regular_vecs[i][..8]
364+
);
349365
println!("Padded (normalized): {:?}", &normalized_radix_vecs[i][..8]);
350366
println!(
351367
"Unpadded (normalized):{:?}",

0 commit comments

Comments
 (0)