@@ -271,16 +271,29 @@ fn bench_radix_mlp(c: &mut Criterion) {
271271 println ! ( "Backend initialized" ) ;
272272
273273 let batch_size = 32 ;
274- let size_configs = [ ( 32 , 512 ) , ( 256 , 512 ) , ( 512 , 32 ) , ( 512 , 256 ) , ( 512 , 512 ) , ( 512 , 1024 ) ] ;
274+ let size_configs = [
275+ // 256 suffix sizes
276+ ( 1 , 256 ) ,
277+ ( 32 , 256 ) ,
278+ ( 128 , 256 ) ,
279+ ( 256 , 256 ) ,
280+ ( 512 , 256 ) ,
281+ ( 1024 , 256 ) ,
282+ ( 2048 , 256 ) ,
283+ // 1024 suffix sizes
284+ ( 1 , 1024 ) ,
285+ ( 32 , 1024 ) ,
286+ ( 128 , 1024 ) ,
287+ ( 256 , 1024 ) ,
288+ ( 512 , 1024 ) ,
289+ ( 1024 , 1024 ) ,
290+ ( 2048 , 1024 ) ,
291+ ] ;
275292
276293 for ( shared_prefix_len, unique_suffix_len) in size_configs {
277- let ( enabled_batch, disabled_batch, enabled_batch_unpadded) = setup (
278- & backend,
279- batch_size,
280- shared_prefix_len,
281- unique_suffix_len,
282- )
283- . expect ( "Failed to set up benchmark" ) ;
294+ let ( enabled_batch, disabled_batch, enabled_batch_unpadded) =
295+ setup ( & backend, batch_size, shared_prefix_len, unique_suffix_len)
296+ . expect ( "Failed to set up benchmark" ) ;
284297
285298 // --- Correctness Check ---
286299 let radix_result = backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
@@ -308,7 +321,8 @@ fn bench_radix_mlp(c: &mut Criterion) {
308321 } )
309322 . collect ( ) ;
310323
311- let normalized_radix_vecs: Vec < Vec < f32 > > = radix_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
324+ let normalized_radix_vecs: Vec < Vec < f32 > > =
325+ radix_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
312326 let normalized_regular_vecs: Vec < Vec < f32 > > =
313327 regular_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
314328 let normalized_radix_unpadded_vecs: Vec < Vec < f32 > > =
@@ -324,8 +338,7 @@ fn bench_radix_mlp(c: &mut Criterion) {
324338 . map ( |( a, b) | ( a - b) . abs ( ) )
325339 . reduce ( f32:: max)
326340 . unwrap_or ( 0.0 ) ;
327- let cos_sim =
328- cosine_similarity ( & normalized_radix_vecs[ i] , & normalized_regular_vecs[ i] ) ;
341+ let cos_sim = cosine_similarity ( & normalized_radix_vecs[ i] , & normalized_regular_vecs[ i] ) ;
329342 let cos_sim_unpadded = cosine_similarity (
330343 & normalized_radix_unpadded_vecs[ i] ,
331344 & normalized_regular_vecs[ i] ,
@@ -345,7 +358,10 @@ fn bench_radix_mlp(c: &mut Criterion) {
345358 "Correctness check FAILED for size ({}, {}), item {}" ,
346359 shared_prefix_len, unique_suffix_len, i
347360 ) ;
348- println ! ( "Regular (normalized): {:?}" , & normalized_regular_vecs[ i] [ ..8 ] ) ;
361+ println ! (
362+ "Regular (normalized): {:?}" ,
363+ & normalized_regular_vecs[ i] [ ..8 ]
364+ ) ;
349365 println ! ( "Padded (normalized): {:?}" , & normalized_radix_vecs[ i] [ ..8 ] ) ;
350366 println ! (
351367 "Unpadded (normalized):{:?}" ,
0 commit comments