@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
4141use serde:: { Deserialize , Serialize } ;
4242use std:: {
4343 cell:: RefCell ,
44- collections:: { HashMap , HashSet , VecDeque } ,
44+ collections:: { HashMap , VecDeque } ,
4545 iter,
4646 rc:: Rc ,
4747 sync:: { Arc , OnceLock } ,
@@ -200,8 +200,9 @@ impl RouterEvent {
200200struct RadixBlock {
201201 /// A map of child blocks, keyed by their local block hash.
202202 children : HashMap < LocalBlockHash , SharedRadixBlock > ,
203- /// A set of worker IDs associated with this block.
204- workers : HashSet < WorkerId > ,
203+ /// A map of worker IDs to their external sequence block hash for this block.
204+ /// The external hash is preserved to speed up snapshotting.
205+ workers : HashMap < WorkerId , ExternalSequenceBlockHash > ,
205206 /// A buffer of times that this block was last traversed
206207 recent_uses : VecDeque < Instant > ,
207208}
@@ -215,7 +216,7 @@ impl RadixBlock {
215216 pub fn new ( ) -> Self {
216217 Self {
217218 children : HashMap :: new ( ) ,
218- workers : HashSet :: new ( ) ,
219+ workers : HashMap :: new ( ) ,
219220 recent_uses : VecDeque :: new ( ) ,
220221 }
221222 }
@@ -289,7 +290,7 @@ impl RadixTree {
289290 current_borrow. children . get ( block_hash) . cloned ( )
290291 } ;
291292 if let Some ( block) = next_block {
292- scores. update_scores ( & block. borrow ( ) . workers ) ;
293+ scores. update_scores ( block. borrow ( ) . workers . keys ( ) ) ;
293294
294295 if let Some ( expiration_duration) = self . expiration_duration {
295296 let mut block_mut = block. borrow_mut ( ) ;
@@ -380,8 +381,11 @@ impl RadixTree {
380381 }
381382 } ;
382383
383- // add our worker_id to the block
384- block. borrow_mut ( ) . workers . insert ( worker_id) ;
384+ // add our worker_id to the block with its external hash
385+ block
386+ . borrow_mut ( )
387+ . workers
388+ . insert ( worker_id, block_id. block_hash ) ;
385389
386390 // add the block to the worker_id lookup table
387391 worker_lookup. insert ( block_id. block_hash , block. clone ( ) ) ;
@@ -417,7 +421,7 @@ impl RadixTree {
417421 let mut guard = entry. borrow_mut ( ) ;
418422 guard. workers . remove ( & worker_id) ;
419423 if guard. workers . is_empty ( ) {
420- // if no worker are using this block, that is true for all children
424+ // if no workers are using this block, that is true for all children
421425 guard. children . clear ( ) ;
422426 }
423427 // remove the block from the lookup table
@@ -436,6 +440,10 @@ impl RadixTree {
436440 if let Some ( ( _, blocks) ) = self . lookup . remove_entry ( & worker) {
437441 blocks. iter ( ) . for_each ( |( _, block) | {
438442 block. borrow_mut ( ) . workers . remove ( & worker) ;
443+ // If no workers are using this block, that is true for all children
444+ if block. borrow ( ) . workers . is_empty ( ) {
445+ block. borrow_mut ( ) . children . clear ( ) ;
446+ }
439447 } ) ;
440448 }
441449 }
@@ -445,14 +453,18 @@ impl RadixTree {
445453 if let Some ( blocks) = self . lookup . get ( & worker) {
446454 let blocks_to_clear: Vec < _ > = blocks. values ( ) . collect ( ) ;
447455
448- // Remove the worker from each block's workers set
456+ // Remove the worker from each block's workers map
449457 blocks_to_clear. iter ( ) . for_each ( |block| {
450458 block. borrow_mut ( ) . workers . remove ( & worker) ;
459+ // If no workers are using this block, that is true for all children
460+ if block. borrow ( ) . workers . is_empty ( ) {
461+ block. borrow_mut ( ) . children . clear ( ) ;
462+ }
451463 } ) ;
452464
453465 // Clear the worker's blocks
454- if let Some ( worker_blocks ) = self . lookup . get_mut ( & worker) {
455- worker_blocks . clear ( ) ;
466+ if let Some ( worker_lookup ) = self . lookup . get_mut ( & worker) {
467+ worker_lookup . clear ( ) ;
456468 }
457469 }
458470 }
@@ -461,71 +473,68 @@ impl RadixTree {
461473 /// Uses BFS traversal to ensure that the tree reconstruction is unique,
462474 /// though the exact event ordering will be lost.
463475 pub fn dump_tree_as_events ( & self ) -> Vec < RouterEvent > {
476+ tracing:: debug!(
477+ "Dumping radix tree as events (contains information about {:?} workers)" ,
478+ self . lookup. len( )
479+ ) ;
480+
464481 let mut events = Vec :: new ( ) ;
465482 let mut event_id = 0u64 ;
466483
467- // BFS queue: (current_block, parent_external_hash, tokens_hash)
468- let mut queue = VecDeque :: new ( ) ;
484+ // BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
485+ // parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash
486+ let mut queue: VecDeque < (
487+ SharedRadixBlock ,
488+ HashMap < WorkerId , ExternalSequenceBlockHash > ,
489+ LocalBlockHash ,
490+ ) > = VecDeque :: new ( ) ;
469491
470492 // Process root's children first
471493 let root_borrow = self . root . borrow ( ) ;
472494 for ( tokens_hash, child_block) in & root_borrow. children {
473- queue. push_back ( ( child_block. clone ( ) , None , * tokens_hash) ) ;
495+ queue. push_back ( ( child_block. clone ( ) , HashMap :: new ( ) , * tokens_hash) ) ;
474496 }
475497 drop ( root_borrow) ;
476498
477- while let Some ( ( current_block, parent_external_hash , tokens_hash) ) = queue. pop_front ( ) {
499+ while let Some ( ( current_block, parent_hashes , tokens_hash) ) = queue. pop_front ( ) {
478500 let current_borrow = current_block. borrow ( ) ;
479501
480- // Closure to find external hash for a block in a worker's lookup
481- let find_external_hash = |worker_id : & WorkerId | {
482- self . lookup . get ( worker_id) . and_then ( |worker_blocks| {
483- worker_blocks
484- . iter ( )
485- . find ( |( _, block) | Rc :: ptr_eq ( block, & current_block) )
486- . map ( |( hash, _) | * hash)
487- } )
488- } ;
502+ // Map of this block's external hashes per worker (for children to use as parent)
503+ let mut current_external_hashes = HashMap :: new ( ) ;
489504
490505 // For each worker that has this block
491- for worker_id in & current_borrow. workers {
492- // Find the external hash for this block from the worker's lookup
493- let external_hash = find_external_hash ( worker_id) ;
494-
495- if let Some ( block_hash) = external_hash {
496- // Create a store event for this worker
497- let event = RouterEvent {
498- worker_id : * worker_id,
499- event : KvCacheEvent {
500- event_id,
501- data : KvCacheEventData :: Stored ( KvCacheStoreData {
502- parent_hash : parent_external_hash,
503- blocks : vec ! [ KvCacheStoredBlockData {
504- block_hash,
505- tokens_hash,
506- } ] ,
507- } ) ,
508- } ,
509- } ;
510- events. push ( event) ;
511- event_id += 1 ;
512- }
513- }
506+ for ( worker_id, external_hash) in & current_borrow. workers {
507+ // Get the correct parent hash for this worker
508+ let parent_hash = parent_hashes. get ( worker_id) . copied ( ) ;
509+
510+ // Create a store event for this worker
511+ let event = RouterEvent {
512+ worker_id : * worker_id,
513+ event : KvCacheEvent {
514+ event_id,
515+ data : KvCacheEventData :: Stored ( KvCacheStoreData {
516+ parent_hash,
517+ blocks : vec ! [ KvCacheStoredBlockData {
518+ block_hash: * external_hash,
519+ tokens_hash,
520+ } ] ,
521+ } ) ,
522+ } ,
523+ } ;
524+ events. push ( event) ;
525+ event_id += 1 ;
514526
515- // Add children to queue for BFS traversal
516- // We need to find any external hash for this block to use as parent
517- let any_external_hash = if !current_borrow. workers . is_empty ( ) {
518- current_borrow
519- . workers
520- . iter ( )
521- . next ( )
522- . and_then ( find_external_hash)
523- } else {
524- None
525- } ;
527+ // Track this block's external hash for this worker
528+ current_external_hashes. insert ( * worker_id, * external_hash) ;
529+ }
526530
531+ // Enqueue children with per-worker parent hashes
527532 for ( child_tokens_hash, child_block) in & current_borrow. children {
528- queue. push_back ( ( child_block. clone ( ) , any_external_hash, * child_tokens_hash) ) ;
533+ queue. push_back ( (
534+ child_block. clone ( ) ,
535+ current_external_hashes. clone ( ) ,
536+ * child_tokens_hash,
537+ ) ) ;
529538 }
530539 }
531540
@@ -657,8 +666,11 @@ impl OverlapScores {
657666 ///
658667 /// ### Arguments
659668 ///
660- /// * `workers` - A reference to a `HashSet` of `WorkerId`s.
661- pub fn update_scores ( & mut self , workers : & HashSet < WorkerId > ) {
669+ /// * `workers` - An iterator over `WorkerId` references.
670+ pub fn update_scores < ' a , I > ( & mut self , workers : I )
671+ where
672+ I : IntoIterator < Item = & ' a WorkerId > ,
673+ {
662674 for worker in workers {
663675 let score = self . scores . entry ( * worker) . or_insert ( 0 ) ;
664676 * score += 1 ;
@@ -2171,4 +2183,79 @@ mod tests {
21712183 1
21722184 ) ;
21732185 }
2186+
2187+ #[ test]
2188+ fn test_remove_worker_verifies_hash_removal ( ) {
2189+ setup ( ) ;
2190+ let mut trie = RadixTree :: new ( ) ;
2191+
2192+ let worker_0 = 0 ;
2193+ let worker_1 = 1 ;
2194+ let worker_2 = 2 ;
2195+
2196+ // Add blocks for multiple workers
2197+ trie. apply_event ( create_store_event ( worker_0, 0 , vec ! [ 1 , 2 , 3 ] , None ) )
2198+ . unwrap ( ) ;
2199+ trie. apply_event ( create_store_event ( worker_1, 0 , vec ! [ 1 , 2 , 3 ] , None ) )
2200+ . unwrap ( ) ;
2201+ trie. apply_event ( create_store_event ( worker_2, 0 , vec ! [ 1 , 4 , 5 ] , None ) )
2202+ . unwrap ( ) ;
2203+
2204+ // Verify worker_0 has 3 blocks in lookup
2205+ assert_eq ! ( trie. lookup. get( & worker_0) . unwrap( ) . len( ) , 3 ) ;
2206+
2207+ // Verify that blocks have the correct workers
2208+ let block_1 = trie
2209+ . lookup
2210+ . get ( & worker_0)
2211+ . unwrap ( )
2212+ . get ( & ExternalSequenceBlockHash ( 100 ) )
2213+ . unwrap ( ) ;
2214+ assert_eq ! ( block_1. borrow( ) . workers. len( ) , 3 ) ; // worker_0, worker_1, and worker_2 (all have hash 1)
2215+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_0) ) ;
2216+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_1) ) ;
2217+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_2) ) ;
2218+
2219+ // Remove worker_0
2220+ trie. remove_worker ( worker_0) ;
2221+
2222+ // Verify worker_0 is completely removed from lookup table
2223+ assert ! ( !trie. lookup. contains_key( & worker_0) ) ;
2224+ assert_eq ! ( trie. lookup. len( ) , 2 ) ;
2225+
2226+ // Verify that worker_0's hash is removed from the workers set
2227+ let block_1 = trie
2228+ . lookup
2229+ . get ( & worker_1)
2230+ . unwrap ( )
2231+ . get ( & ExternalSequenceBlockHash ( 100 ) )
2232+ . unwrap ( ) ;
2233+ assert_eq ! ( block_1. borrow( ) . workers. len( ) , 2 ) ; // worker_1 and worker_2 remain
2234+ assert ! ( !block_1. borrow( ) . workers. contains_key( & worker_0) ) ;
2235+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_1) ) ;
2236+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_2) ) ;
2237+
2238+ // Verify that blocks with no remaining workers have their children cleared
2239+ // This tests the optimization where empty blocks clear their children
2240+ let block_2 = trie
2241+ . lookup
2242+ . get ( & worker_1)
2243+ . unwrap ( )
2244+ . get ( & ExternalSequenceBlockHash ( 200 ) )
2245+ . unwrap ( ) ;
2246+ assert_eq ! ( block_2. borrow( ) . workers. len( ) , 1 ) ; // only worker_1
2247+ assert ! ( block_2. borrow( ) . workers. contains_key( & worker_1) ) ;
2248+
2249+ // Verify match results no longer include worker_0
2250+ let result = trie
2251+ . find_matches (
2252+ vec ! [ LocalBlockHash ( 1 ) , LocalBlockHash ( 2 ) , LocalBlockHash ( 3 ) ] ,
2253+ false ,
2254+ )
2255+ . scores ;
2256+ assert_eq ! ( result. len( ) , 2 ) ;
2257+ assert ! ( !result. contains_key( & worker_0) ) ;
2258+ assert ! ( result. contains_key( & worker_1) ) ;
2259+ assert ! ( result. contains_key( & worker_2) ) ;
2260+ }
21742261}
0 commit comments