@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
41
41
use serde:: { Deserialize , Serialize } ;
42
42
use std:: {
43
43
cell:: RefCell ,
44
- collections:: { HashMap , HashSet , VecDeque } ,
44
+ collections:: { HashMap , VecDeque } ,
45
45
iter,
46
46
rc:: Rc ,
47
47
sync:: { Arc , OnceLock } ,
@@ -200,8 +200,9 @@ impl RouterEvent {
200
200
struct RadixBlock {
201
201
/// A map of child blocks, keyed by their local block hash.
202
202
children : HashMap < LocalBlockHash , SharedRadixBlock > ,
203
- /// A set of worker IDs associated with this block.
204
- workers : HashSet < WorkerId > ,
203
+ /// A map of worker IDs to their external sequence block hash for this block.
204
+ /// The external hash is preserved to speed up snapshotting.
205
+ workers : HashMap < WorkerId , ExternalSequenceBlockHash > ,
205
206
/// A buffer of times that this block was last traversed
206
207
recent_uses : VecDeque < Instant > ,
207
208
}
@@ -215,7 +216,7 @@ impl RadixBlock {
215
216
pub fn new ( ) -> Self {
216
217
Self {
217
218
children : HashMap :: new ( ) ,
218
- workers : HashSet :: new ( ) ,
219
+ workers : HashMap :: new ( ) ,
219
220
recent_uses : VecDeque :: new ( ) ,
220
221
}
221
222
}
@@ -289,7 +290,7 @@ impl RadixTree {
289
290
current_borrow. children . get ( block_hash) . cloned ( )
290
291
} ;
291
292
if let Some ( block) = next_block {
292
- scores. update_scores ( & block. borrow ( ) . workers ) ;
293
+ scores. update_scores ( block. borrow ( ) . workers . keys ( ) ) ;
293
294
294
295
if let Some ( expiration_duration) = self . expiration_duration {
295
296
let mut block_mut = block. borrow_mut ( ) ;
@@ -380,8 +381,11 @@ impl RadixTree {
380
381
}
381
382
} ;
382
383
383
- // add our worker_id to the block
384
- block. borrow_mut ( ) . workers . insert ( worker_id) ;
384
+ // add our worker_id to the block with its external hash
385
+ block
386
+ . borrow_mut ( )
387
+ . workers
388
+ . insert ( worker_id, block_id. block_hash ) ;
385
389
386
390
// add the block to the worker_id lookup table
387
391
worker_lookup. insert ( block_id. block_hash , block. clone ( ) ) ;
@@ -417,7 +421,7 @@ impl RadixTree {
417
421
let mut guard = entry. borrow_mut ( ) ;
418
422
guard. workers . remove ( & worker_id) ;
419
423
if guard. workers . is_empty ( ) {
420
- // if no worker are using this block, that is true for all children
424
+ // if no workers are using this block, that is true for all children
421
425
guard. children . clear ( ) ;
422
426
}
423
427
// remove the block from the lookup table
@@ -436,6 +440,10 @@ impl RadixTree {
436
440
if let Some ( ( _, blocks) ) = self . lookup . remove_entry ( & worker) {
437
441
blocks. iter ( ) . for_each ( |( _, block) | {
438
442
block. borrow_mut ( ) . workers . remove ( & worker) ;
443
+ // If no workers are using this block, that is true for all children
444
+ if block. borrow ( ) . workers . is_empty ( ) {
445
+ block. borrow_mut ( ) . children . clear ( ) ;
446
+ }
439
447
} ) ;
440
448
}
441
449
}
@@ -445,14 +453,18 @@ impl RadixTree {
445
453
if let Some ( blocks) = self . lookup . get ( & worker) {
446
454
let blocks_to_clear: Vec < _ > = blocks. values ( ) . collect ( ) ;
447
455
448
- // Remove the worker from each block's workers set
456
+ // Remove the worker from each block's workers map
449
457
blocks_to_clear. iter ( ) . for_each ( |block| {
450
458
block. borrow_mut ( ) . workers . remove ( & worker) ;
459
+ // If no workers are using this block, that is true for all children
460
+ if block. borrow ( ) . workers . is_empty ( ) {
461
+ block. borrow_mut ( ) . children . clear ( ) ;
462
+ }
451
463
} ) ;
452
464
453
465
// Clear the worker's blocks
454
- if let Some ( worker_blocks ) = self . lookup . get_mut ( & worker) {
455
- worker_blocks . clear ( ) ;
466
+ if let Some ( worker_lookup ) = self . lookup . get_mut ( & worker) {
467
+ worker_lookup . clear ( ) ;
456
468
}
457
469
}
458
470
}
@@ -461,6 +473,11 @@ impl RadixTree {
461
473
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
462
474
/// though the exact event ordering will be lost.
463
475
pub fn dump_tree_as_events ( & self ) -> Vec < RouterEvent > {
476
+ tracing:: debug!(
477
+ "Dumping radix tree as events (contains information about {:?} workers)" ,
478
+ self . lookup. len( )
479
+ ) ;
480
+
464
481
let mut events = Vec :: new ( ) ;
465
482
let mut event_id = 0u64 ;
466
483
@@ -477,53 +494,31 @@ impl RadixTree {
477
494
while let Some ( ( current_block, parent_external_hash, tokens_hash) ) = queue. pop_front ( ) {
478
495
let current_borrow = current_block. borrow ( ) ;
479
496
480
- // Closure to find external hash for a block in a worker's lookup
481
- let find_external_hash = |worker_id : & WorkerId | {
482
- self . lookup . get ( worker_id) . and_then ( |worker_blocks| {
483
- worker_blocks
484
- . iter ( )
485
- . find ( |( _, block) | Rc :: ptr_eq ( block, & current_block) )
486
- . map ( |( hash, _) | * hash)
487
- } )
488
- } ;
497
+ // We need to find any external hash for this block to use as parent
498
+ // when we enqueue the children.
499
+ let mut any_external_hash: Option < ExternalSequenceBlockHash > = None ;
489
500
490
501
// For each worker that has this block
491
- for worker_id in & current_borrow. workers {
492
- // Find the external hash for this block from the worker's lookup
493
- let external_hash = find_external_hash ( worker_id) ;
494
-
495
- if let Some ( block_hash) = external_hash {
496
- // Create a store event for this worker
497
- let event = RouterEvent {
498
- worker_id : * worker_id,
499
- event : KvCacheEvent {
500
- event_id,
501
- data : KvCacheEventData :: Stored ( KvCacheStoreData {
502
- parent_hash : parent_external_hash,
503
- blocks : vec ! [ KvCacheStoredBlockData {
504
- block_hash,
505
- tokens_hash,
506
- } ] ,
507
- } ) ,
508
- } ,
509
- } ;
510
- events. push ( event) ;
511
- event_id += 1 ;
512
- }
502
+ for ( worker_id, external_hash) in & current_borrow. workers {
503
+ // Create a store event for this worker
504
+ let event = RouterEvent {
505
+ worker_id : * worker_id,
506
+ event : KvCacheEvent {
507
+ event_id,
508
+ data : KvCacheEventData :: Stored ( KvCacheStoreData {
509
+ parent_hash : parent_external_hash,
510
+ blocks : vec ! [ KvCacheStoredBlockData {
511
+ block_hash: * external_hash,
512
+ tokens_hash,
513
+ } ] ,
514
+ } ) ,
515
+ } ,
516
+ } ;
517
+ events. push ( event) ;
518
+ event_id += 1 ;
519
+ any_external_hash = Some ( * external_hash) ;
513
520
}
514
521
515
- // Add children to queue for BFS traversal
516
- // We need to find any external hash for this block to use as parent
517
- let any_external_hash = if !current_borrow. workers . is_empty ( ) {
518
- current_borrow
519
- . workers
520
- . iter ( )
521
- . next ( )
522
- . and_then ( find_external_hash)
523
- } else {
524
- None
525
- } ;
526
-
527
522
for ( child_tokens_hash, child_block) in & current_borrow. children {
528
523
queue. push_back ( ( child_block. clone ( ) , any_external_hash, * child_tokens_hash) ) ;
529
524
}
@@ -657,8 +652,11 @@ impl OverlapScores {
657
652
///
658
653
/// ### Arguments
659
654
///
660
- /// * `workers` - A reference to a `HashSet` of `WorkerId`s.
661
- pub fn update_scores ( & mut self , workers : & HashSet < WorkerId > ) {
655
+ /// * `workers` - An iterator over `WorkerId` references.
656
+ pub fn update_scores < ' a , I > ( & mut self , workers : I )
657
+ where
658
+ I : IntoIterator < Item = & ' a WorkerId > ,
659
+ {
662
660
for worker in workers {
663
661
let score = self . scores . entry ( * worker) . or_insert ( 0 ) ;
664
662
* score += 1 ;
@@ -2171,4 +2169,79 @@ mod tests {
2171
2169
1
2172
2170
) ;
2173
2171
}
2172
+
2173
+ #[ test]
2174
+ fn test_remove_worker_verifies_hash_removal ( ) {
2175
+ setup ( ) ;
2176
+ let mut trie = RadixTree :: new ( ) ;
2177
+
2178
+ let worker_0 = 0 ;
2179
+ let worker_1 = 1 ;
2180
+ let worker_2 = 2 ;
2181
+
2182
+ // Add blocks for multiple workers
2183
+ trie. apply_event ( create_store_event ( worker_0, 0 , vec ! [ 1 , 2 , 3 ] , None ) )
2184
+ . unwrap ( ) ;
2185
+ trie. apply_event ( create_store_event ( worker_1, 0 , vec ! [ 1 , 2 , 3 ] , None ) )
2186
+ . unwrap ( ) ;
2187
+ trie. apply_event ( create_store_event ( worker_2, 0 , vec ! [ 1 , 4 , 5 ] , None ) )
2188
+ . unwrap ( ) ;
2189
+
2190
+ // Verify worker_0 has 3 blocks in lookup
2191
+ assert_eq ! ( trie. lookup. get( & worker_0) . unwrap( ) . len( ) , 3 ) ;
2192
+
2193
+ // Verify that blocks have the correct workers
2194
+ let block_1 = trie
2195
+ . lookup
2196
+ . get ( & worker_0)
2197
+ . unwrap ( )
2198
+ . get ( & ExternalSequenceBlockHash ( 100 ) )
2199
+ . unwrap ( ) ;
2200
+ assert_eq ! ( block_1. borrow( ) . workers. len( ) , 3 ) ; // worker_0, worker_1, and worker_2 (all have hash 1)
2201
+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_0) ) ;
2202
+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_1) ) ;
2203
+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_2) ) ;
2204
+
2205
+ // Remove worker_0
2206
+ trie. remove_worker ( worker_0) ;
2207
+
2208
+ // Verify worker_0 is completely removed from lookup table
2209
+ assert ! ( !trie. lookup. contains_key( & worker_0) ) ;
2210
+ assert_eq ! ( trie. lookup. len( ) , 2 ) ;
2211
+
2212
+ // Verify that worker_0's hash is removed from the workers set
2213
+ let block_1 = trie
2214
+ . lookup
2215
+ . get ( & worker_1)
2216
+ . unwrap ( )
2217
+ . get ( & ExternalSequenceBlockHash ( 100 ) )
2218
+ . unwrap ( ) ;
2219
+ assert_eq ! ( block_1. borrow( ) . workers. len( ) , 2 ) ; // worker_1 and worker_2 remain
2220
+ assert ! ( !block_1. borrow( ) . workers. contains_key( & worker_0) ) ;
2221
+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_1) ) ;
2222
+ assert ! ( block_1. borrow( ) . workers. contains_key( & worker_2) ) ;
2223
+
2224
+ // Verify that blocks with no remaining workers have their children cleared
2225
+ // This tests the optimization where empty blocks clear their children
2226
+ let block_2 = trie
2227
+ . lookup
2228
+ . get ( & worker_1)
2229
+ . unwrap ( )
2230
+ . get ( & ExternalSequenceBlockHash ( 200 ) )
2231
+ . unwrap ( ) ;
2232
+ assert_eq ! ( block_2. borrow( ) . workers. len( ) , 1 ) ; // only worker_1
2233
+ assert ! ( block_2. borrow( ) . workers. contains_key( & worker_1) ) ;
2234
+
2235
+ // Verify match results no longer include worker_0
2236
+ let result = trie
2237
+ . find_matches (
2238
+ vec ! [ LocalBlockHash ( 1 ) , LocalBlockHash ( 2 ) , LocalBlockHash ( 3 ) ] ,
2239
+ false ,
2240
+ )
2241
+ . scores ;
2242
+ assert_eq ! ( result. len( ) , 2 ) ;
2243
+ assert ! ( !result. contains_key( & worker_0) ) ;
2244
+ assert ! ( result. contains_key( & worker_1) ) ;
2245
+ assert ! ( result. contains_key( & worker_2) ) ;
2246
+ }
2174
2247
}
0 commit comments