@@ -124,24 +124,13 @@ impl From<Batch> for text_embeddings_backend_core::Batch {
124124}
125125
126126/// Sets up the backend and batch data needed for the benchmark.
127- fn setup ( ) -> Result < ( CandleBackend , Batch , Batch , Batch ) > {
128- // 1. Setup backend
129- let model_root = download_artifacts ( "Qwen/Qwen3-Embedding-4B" , None ) ?;
130- println ! ( "Model downloaded to {:?}" , model_root) ;
131- let backend = CandleBackend :: new (
132- & model_root,
133- "float16" . to_string ( ) ,
134- ModelType :: Embedding ( Pool :: LastToken ) ,
135- None ,
136- ) ?;
137- println ! ( "Backend initialized" ) ;
138-
127+ fn setup (
128+ _backend : & CandleBackend ,
129+ batch_size : usize ,
130+ shared_prefix_len : usize ,
131+ unique_suffix_len : usize ,
132+ ) -> Result < ( Batch , Batch , Batch ) > {
139133 // 2. Create benchmark batch
140- // Batch size of 32, 500 shared prefix, 500 unique suffix per sequence
141- // Radix tree structure: 500x1 (shared), then 32x500 (unique tails)
142- let batch_size: usize = 32 ;
143- let shared_prefix_len: usize = 500 ;
144- let unique_suffix_len: usize = 500 ;
145134 let shared_prefix_ids: Vec < u32 > = vec ! [ 1 ; shared_prefix_len] ;
146135
147136 let mut all_input_ids = Vec :: new ( ) ;
@@ -183,7 +172,9 @@ fn setup() -> Result<(CandleBackend, Batch, Batch, Batch)> {
183172 ) ;
184173
185174 println ! (
186- "RadixMLP compression: {} original tokens -> {} compact tokens ({:.1}% reduction)" ,
175+ "RadixMLP compression (prefix={}, suffix={}): {} original tokens -> {} compact tokens ({:.1}% reduction)" ,
176+ shared_prefix_len,
177+ unique_suffix_len,
187178 all_input_ids. len( ) ,
188179 compact_input_ids. len( ) ,
189180 ( 1.0 - compact_input_ids. len( ) as f64 / all_input_ids. len( ) as f64 ) * 100.0
@@ -236,78 +227,105 @@ fn setup() -> Result<(CandleBackend, Batch, Batch, Batch)> {
236227 fold_gather : None ,
237228 } ;
238229
239- Ok ( (
240- backend,
241- enabled_batch,
242- disabled_batch,
243- enabled_batch_unpadded,
244- ) )
230+ Ok ( ( enabled_batch, disabled_batch, enabled_batch_unpadded) )
245231}
246232
247233/// The main benchmark function.
248234fn bench_radix_mlp ( c : & mut Criterion ) {
249- let ( backend, enabled_batch, disabled_batch, enabled_batch_unpadded) =
250- setup ( ) . expect ( "Failed to set up benchmark" ) ;
251-
252- // --- Correctness Check ---
253- // Run once before benchmarking to ensure outputs are identical.
254- let radix_result = backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
255- let regular_result = backend. embed ( disabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
256-
257- // Extract embeddings from the results (IntMap<usize, Embedding>)
258- let radix_vecs: Vec < Vec < f32 > > = ( 0 ..16 )
259- . map ( |i| match radix_result. get ( & i) . unwrap ( ) {
260- text_embeddings_backend_core:: Embedding :: Pooled ( v) => v. clone ( ) ,
261- text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
262- } )
263- . collect ( ) ;
264- let regular_vecs: Vec < Vec < f32 > > = ( 0 ..16 )
265- . map ( |i| match regular_result. get ( & i) . unwrap ( ) {
266- text_embeddings_backend_core:: Embedding :: Pooled ( v) => v. clone ( ) ,
267- text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
268- } )
269- . collect ( ) ;
270-
271- assert_eq ! ( radix_vecs. len( ) , regular_vecs. len( ) ) ;
272- for i in 0 ..radix_vecs. len ( ) {
273- let diff: f32 = radix_vecs[ i]
274- . iter ( )
275- . zip ( regular_vecs[ i] . iter ( ) )
276- . map ( |( a, b) | ( a - b) . abs ( ) )
277- . sum ( ) ;
278- assert ! (
279- diff < 1e-2 ,
280- "Correctness check failed: Embeddings for item {} differ by {}" ,
281- i,
282- diff
235+ // 1. Setup backend
236+ let model_root = download_artifacts ( "Qwen/Qwen3-Embedding-0.6B" , None )
237+ . expect ( "Failed to download artifacts" ) ;
238+ println ! ( "Model downloaded to {:?}" , model_root) ;
239+ let backend = CandleBackend :: new (
240+ & model_root,
241+ "float16" . to_string ( ) ,
242+ ModelType :: Embedding ( Pool :: LastToken ) ,
243+ None ,
244+ )
245+ . expect ( "Could not start backend" ) ;
246+ println ! ( "Backend initialized" ) ;
247+
248+ let batch_size = 32 ;
249+ let size_configs = [ ( 512 , 256 ) , ( 512 , 512 ) , ( 1024 , 1024 ) ] ;
250+
251+ for ( shared_prefix_len, unique_suffix_len) in size_configs {
252+ let ( enabled_batch, disabled_batch, enabled_batch_unpadded) = setup (
253+ & backend,
254+ batch_size,
255+ shared_prefix_len,
256+ unique_suffix_len,
257+ )
258+ . expect ( "Failed to set up benchmark" ) ;
259+
260+ // --- Correctness Check ---
261+ let radix_result = backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
262+ let regular_result = backend. embed ( disabled_batch. clone ( ) . into ( ) ) . unwrap ( ) ;
263+
264+ let radix_vecs: Vec < Vec < f32 > > = ( 0 ..batch_size)
265+ . map ( |i| match radix_result. get ( & i) . unwrap ( ) {
266+ text_embeddings_backend_core:: Embedding :: Pooled ( v) => v. clone ( ) ,
267+ text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
268+ } )
269+ . collect ( ) ;
270+ let regular_vecs: Vec < Vec < f32 > > = ( 0 ..batch_size)
271+ . map ( |i| match regular_result. get ( & i) . unwrap ( ) {
272+ text_embeddings_backend_core:: Embedding :: Pooled ( v) => v. clone ( ) ,
273+ text_embeddings_backend_core:: Embedding :: All ( vecs) => vecs. last ( ) . unwrap ( ) . clone ( ) ,
274+ } )
275+ . collect ( ) ;
276+
277+ assert_eq ! ( radix_vecs. len( ) , regular_vecs. len( ) ) ;
278+ for i in 0 ..radix_vecs. len ( ) {
279+ let diff: f32 = radix_vecs[ i]
280+ . iter ( )
281+ . zip ( regular_vecs[ i] . iter ( ) )
282+ . map ( |( a, b) | ( a - b) . abs ( ) )
283+ . sum ( ) ;
284+ assert ! (
285+ diff < 1e-2 ,
286+ "Correctness check failed for size ({}, {}): Embeddings for item {} differ by {}" ,
287+ shared_prefix_len,
288+ unique_suffix_len,
289+ i,
290+ diff
291+ ) ;
292+ }
293+ println ! (
294+ "Correctness check passed for size ({}, {}). Starting benchmark..." ,
295+ shared_prefix_len, unique_suffix_len
283296 ) ;
297+ // --- End Correctness Check ---
298+
299+ let mut group = c. benchmark_group ( & format ! (
300+ "RadixMLP Speedup (prefix: {}, suffix: {})" ,
301+ shared_prefix_len, unique_suffix_len
302+ ) ) ;
303+ group
304+ . sample_size ( 10 )
305+ . warm_up_time ( std:: time:: Duration :: from_secs ( 3 ) )
306+ . measurement_time ( std:: time:: Duration :: from_secs ( 15 ) ) ;
307+
308+ // Benchmark WITH RadixMLP enabled (uses shared prefix computation)
309+ group. bench_function ( "radix_mlp_enabled" , |b| {
310+ b. iter ( || backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) )
311+ } ) ;
312+
313+ // Benchmark WITH RadixMLP enabled but without padding (uses shared prefix computation)
314+ group. bench_function ( "radix_mlp_enabled_unpadded" , |b| {
315+ b. iter ( || {
316+ backend
317+ . embed ( enabled_batch_unpadded. clone ( ) . into ( ) )
318+ . unwrap ( )
319+ } )
320+ } ) ;
321+
322+ // Benchmark WITHOUT RadixMLP (standard full computation)
323+ group. bench_function ( "radix_mlp_disabled" , |b| {
324+ b. iter ( || backend. embed ( disabled_batch. clone ( ) . into ( ) ) . unwrap ( ) )
325+ } ) ;
326+
327+ group. finish ( ) ;
284328 }
285- println ! ( "Correctness check passed. Starting benchmark..." ) ;
286- // --- End Correctness Check ---
287-
288- let mut group = c. benchmark_group ( "RadixMLP Speedup" ) ;
289- group. sample_size ( 25 ) ;
290-
291- // Benchmark WITH RadixMLP enabled (uses shared prefix computation)
292- group. bench_function ( "radix_mlp_enabled" , |b| {
293- b. iter ( || backend. embed ( enabled_batch. clone ( ) . into ( ) ) . unwrap ( ) )
294- } ) ;
295-
296- // Benchmark WITH RadixMLP enabled but without padding (uses shared prefix computation)
297- group. bench_function ( "radix_mlp_enabled_unpadded" , |b| {
298- b. iter ( || {
299- backend
300- . embed ( enabled_batch_unpadded. clone ( ) . into ( ) )
301- . unwrap ( )
302- } )
303- } ) ;
304-
305- // Benchmark WITHOUT RadixMLP (standard full computation)
306- group. bench_function ( "radix_mlp_disabled" , |b| {
307- b. iter ( || backend. embed ( disabled_batch. clone ( ) . into ( ) ) . unwrap ( ) )
308- } ) ;
309-
310- group. finish ( ) ;
311329}
312330
313331criterion_group ! ( benches, bench_radix_mlp) ;
0 commit comments