@@ -230,6 +230,15 @@ fn setup(
230230 Ok ( ( enabled_batch, disabled_batch, enabled_batch_unpadded) )
231231}
232232
233+ fn normalize ( v : & [ f32 ] ) -> Vec < f32 > {
234+ let norm = ( v. iter ( ) . map ( |& val| val * val) . sum :: < f32 > ( ) ) . sqrt ( ) ;
235+ if norm > 0.0 {
236+ v. iter ( ) . map ( |& val| val / norm) . collect ( )
237+ } else {
238+ v. to_vec ( )
239+ }
240+ }
241+
233242fn cosine_similarity ( v1 : & [ f32 ] , v2 : & [ f32 ] ) -> f32 {
234243 assert_eq ! ( v1. len( ) , v2. len( ) ) ;
235244
@@ -262,7 +271,7 @@ fn bench_radix_mlp(c: &mut Criterion) {
262271 println ! ( "Backend initialized" ) ;
263272
264273 let batch_size = 32 ;
265- let size_configs = [ ( 32 , 512 ) , ( 256 , 512 ) , ( 512 , 32 ) , ( 512 , 256 ) , ( 512 , 512 ) , ( 512 , 1024 ) ] ;
274+ let size_configs = [ ( 32 , 512 ) , ( 256 , 512 ) , ( 512 , 32 ) , ( 512 , 256 ) , ( 512 , 512 ) , ( 512 , 1024 ) ] ;
266275
267276 for ( shared_prefix_len, unique_suffix_len) in size_configs {
268277 let ( enabled_batch, disabled_batch, enabled_batch_unpadded) = setup (
@@ -299,20 +308,30 @@ fn bench_radix_mlp(c: &mut Criterion) {
299308 } )
300309 . collect ( ) ;
301310
311+ let normalized_radix_vecs: Vec < Vec < f32 > > = radix_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
312+ let normalized_regular_vecs: Vec < Vec < f32 > > =
313+ regular_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
314+ let normalized_radix_unpadded_vecs: Vec < Vec < f32 > > =
315+ radix_unpadded_vecs. iter ( ) . map ( |v| normalize ( v) ) . collect ( ) ;
316+
302317 assert_eq ! ( radix_vecs. len( ) , regular_vecs. len( ) ) ;
303318 assert_eq ! ( radix_unpadded_vecs. len( ) , regular_vecs. len( ) ) ;
304319
305320 for i in 0 ..radix_vecs. len ( ) {
306- let diff: f32 = radix_vecs [ i]
321+ let diff: f32 = normalized_radix_vecs [ i]
307322 . iter ( )
308- . zip ( regular_vecs [ i] . iter ( ) )
323+ . zip ( normalized_regular_vecs [ i] . iter ( ) )
309324 . map ( |( a, b) | ( a - b) . abs ( ) )
310- . sum ( ) ;
311- let cos_sim = cosine_similarity ( & radix_vecs[ i] , & regular_vecs[ i] ) ;
312- let cos_sim_unpadded =
313- cosine_similarity ( & radix_unpadded_vecs[ i] , & regular_vecs[ i] ) ;
325+ . reduce ( f32:: max)
326+ . unwrap_or ( 0.0 ) ;
327+ let cos_sim =
328+ cosine_similarity ( & normalized_radix_vecs[ i] , & normalized_regular_vecs[ i] ) ;
329+ let cos_sim_unpadded = cosine_similarity (
330+ & normalized_radix_unpadded_vecs[ i] ,
331+ & normalized_regular_vecs[ i] ,
332+ ) ;
314333
315- let passed = diff < 1e-2 && cos_sim > 0.999 && cos_sim_unpadded > 0.999 ;
334+ let passed = diff < 1e-4 && cos_sim > 0.999 && cos_sim_unpadded > 0.999 ;
316335
317336 if !passed {
318337 println ! (
@@ -326,9 +345,12 @@ fn bench_radix_mlp(c: &mut Criterion) {
326345 "Correctness check FAILED for size ({}, {}), item {}" ,
327346 shared_prefix_len, unique_suffix_len, i
328347 ) ;
329- println ! ( "Regular: {:?}" , & regular_vecs[ i] [ ..8 ] ) ;
330- println ! ( "Padded: {:?}" , & radix_vecs[ i] [ ..8 ] ) ;
331- println ! ( "Unpadded:{:?}" , & radix_unpadded_vecs[ i] [ ..8 ] ) ;
348+ println ! ( "Regular (normalized): {:?}" , & normalized_regular_vecs[ i] [ ..8 ] ) ;
349+ println ! ( "Padded (normalized): {:?}" , & normalized_radix_vecs[ i] [ ..8 ] ) ;
350+ println ! (
351+ "Unpadded (normalized):{:?}" ,
352+ & normalized_radix_unpadded_vecs[ i] [ ..8 ]
353+ ) ;
332354 }
333355 }
334356 println ! (
0 commit comments