@@ -230,6 +230,22 @@ fn setup(
230230 Ok ( ( enabled_batch, disabled_batch, enabled_batch_unpadded) )
231231}
232232
233+ fn cosine_similarity ( v1 : & [ f32 ] , v2 : & [ f32 ] ) -> f32 {
234+ assert_eq ! ( v1. len( ) , v2. len( ) ) ;
235+
236+ let mut sumxx = 0.0 ;
237+ let mut sumyy = 0.0 ;
238+ let mut sumxy = 0.0 ;
239+
240+ for ( x, y) in v1. iter ( ) . zip ( v2. iter ( ) ) {
241+ sumxx += x * x;
242+ sumyy += y * y;
243+ sumxy += x * y;
244+ }
245+
246+ sumxy / ( sumxx * sumyy) . sqrt ( )
247+ }
248+
233249/// The main benchmark function.
234250fn bench_radix_mlp ( c : & mut Criterion ) {
235251 // 1. Setup backend
@@ -245,8 +261,8 @@ fn bench_radix_mlp(c: &mut Criterion) {
245261 . expect ( "Could not start backend" ) ;
246262 println ! ( "Backend initialized" ) ;
247263
248- let batch_size = 32 ;
249- let size_configs = [ ( 512 , 256 ) , ( 512 , 512 ) , ( 1024 , 1024 ) ] ;
264+ let batch_size = 16 ;
265+ let size_configs = [ ( 32 , 512 ) , ( 256 , 512 ) , ( 512 , 32 ) , ( 512 , 256 ) , ( 512 , 512 ) , ( 512 , 1024 ) ] ;
250266
251267 for ( shared_prefix_len, unique_suffix_len) in size_configs {
252268 let ( enabled_batch, disabled_batch, enabled_batch_unpadded) = setup (
@@ -260,6 +276,9 @@ fn bench_radix_mlp(c: &mut Criterion) {
260276 // --- Correctness Check ---
261277 let radix_result = backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
262278 let regular_result = backend. embed ( disabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
279+ let radix_unpadded_result = backend
280+ . embed ( enabled_batch_unpadded. clone ( ) . into ( ) )
281+ . unwrap ( ) ;
263282
264283 let radix_vecs: Vec < Vec < f32 > > = ( 0 ..batch_size)
265284 . map ( |i| match radix_result. get ( & i) . unwrap ( ) {
@@ -273,25 +292,47 @@ fn bench_radix_mlp(c: &mut Criterion) {
273292 text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
274293 } )
275294 . collect ( ) ;
295+ let radix_unpadded_vecs: Vec < Vec < f32 > > = ( 0 ..batch_size)
296+ . map ( |i| match radix_unpadded_result. get ( & i) . unwrap ( ) {
297+ text_embeddings_backend_core:: Embedding :: Pooled ( v) => v. clone ( ) ,
298+ text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
299+ } )
300+ . collect ( ) ;
276301
277302 assert_eq ! ( radix_vecs. len( ) , regular_vecs. len( ) ) ;
303+ assert_eq ! ( radix_unpadded_vecs. len( ) , regular_vecs. len( ) ) ;
304+
278305 for i in 0 ..radix_vecs. len ( ) {
279306 let diff: f32 = radix_vecs[ i]
280307 . iter ( )
281308 . zip ( regular_vecs[ i] . iter ( ) )
282309 . map ( |( a, b) | ( a - b) . abs ( ) )
283310 . sum ( ) ;
284- assert ! (
285- diff < 1e-2 ,
286- "Correctness check failed for size ({}, {}): Embeddings for item {} differ by {}" ,
287- shared_prefix_len,
288- unique_suffix_len,
289- i,
290- diff
291- ) ;
311+ let cos_sim = cosine_similarity ( & radix_vecs[ i] , & regular_vecs[ i] ) ;
312+ let cos_sim_unpadded =
313+ cosine_similarity ( & radix_unpadded_vecs[ i] , & regular_vecs[ i] ) ;
314+
315+ let passed = diff < 1e-2 && cos_sim > 0.999 && cos_sim_unpadded > 0.999 ;
316+
317+ if !passed {
318+ println ! (
319+ "Item {}: Abs Diff: {:.4}, Cosine Sim (Padded): {:.6}, Cosine Sim (Unpadded): {:.6}" ,
320+ i,
321+ diff,
322+ 1.0 - cos_sim,
323+ 1.0 - cos_sim_unpadded
324+ ) ;
325+ println ! (
326+ "Correctness check FAILED for size ({}, {}), item {}" ,
327+ shared_prefix_len, unique_suffix_len, i
328+ ) ;
329+ println ! ( "Regular: {:?}" , & regular_vecs[ i] [ ..8 ] ) ;
330+ println ! ( "Padded: {:?}" , & radix_vecs[ i] [ ..8 ] ) ;
331+ println ! ( "Unpadded:{:?}" , & radix_unpadded_vecs[ i] [ ..8 ] ) ;
332+ }
292333 }
293334 println ! (
294- "Correctness check passed for size ({}, {}). Starting benchmark..." ,
335+ "Correctness check for size ({}, {}) complete . Starting benchmark..." ,
295336 shared_prefix_len, unique_suffix_len
296337 ) ;
297338 // --- End Correctness Check ---
0 commit comments