Skip to content

Commit a99b655

Browse files
blarson-b10PeaBrane
authored andcommitted
perf: Improve performance of snapshot using a reverse lookup from block -> external hash (#3370)
Signed-off-by: Brian Larson <[email protected]> Signed-off-by: PeaBrane <[email protected]> Co-authored-by: PeaBrane <[email protected]>
1 parent 4fbc247 commit a99b655

File tree

3 files changed

+154
-63
lines changed

3 files changed

+154
-63
lines changed

lib/llm/src/kv_router/indexer.rs

Lines changed: 148 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
4141
use serde::{Deserialize, Serialize};
4242
use std::{
4343
cell::RefCell,
44-
collections::{HashMap, HashSet, VecDeque},
44+
collections::{HashMap, VecDeque},
4545
iter,
4646
rc::Rc,
4747
sync::{Arc, OnceLock},
@@ -200,8 +200,9 @@ impl RouterEvent {
200200
struct RadixBlock {
201201
/// A map of child blocks, keyed by their local block hash.
202202
children: HashMap<LocalBlockHash, SharedRadixBlock>,
203-
/// A set of worker IDs associated with this block.
204-
workers: HashSet<WorkerId>,
203+
/// A map of worker IDs to their external sequence block hash for this block.
204+
/// The external hash is preserved to speed up snapshotting.
205+
workers: HashMap<WorkerId, ExternalSequenceBlockHash>,
205206
/// A buffer of times that this block was last traversed
206207
recent_uses: VecDeque<Instant>,
207208
}
@@ -215,7 +216,7 @@ impl RadixBlock {
215216
pub fn new() -> Self {
216217
Self {
217218
children: HashMap::new(),
218-
workers: HashSet::new(),
219+
workers: HashMap::new(),
219220
recent_uses: VecDeque::new(),
220221
}
221222
}
@@ -289,7 +290,7 @@ impl RadixTree {
289290
current_borrow.children.get(block_hash).cloned()
290291
};
291292
if let Some(block) = next_block {
292-
scores.update_scores(&block.borrow().workers);
293+
scores.update_scores(block.borrow().workers.keys());
293294

294295
if let Some(expiration_duration) = self.expiration_duration {
295296
let mut block_mut = block.borrow_mut();
@@ -380,8 +381,11 @@ impl RadixTree {
380381
}
381382
};
382383

383-
// add our worker_id to the block
384-
block.borrow_mut().workers.insert(worker_id);
384+
// add our worker_id to the block with its external hash
385+
block
386+
.borrow_mut()
387+
.workers
388+
.insert(worker_id, block_id.block_hash);
385389

386390
// add the block to the worker_id lookup table
387391
worker_lookup.insert(block_id.block_hash, block.clone());
@@ -417,7 +421,7 @@ impl RadixTree {
417421
let mut guard = entry.borrow_mut();
418422
guard.workers.remove(&worker_id);
419423
if guard.workers.is_empty() {
420-
// if no worker are using this block, that is true for all children
424+
// if no workers are using this block, that is true for all children
421425
guard.children.clear();
422426
}
423427
// remove the block from the lookup table
@@ -436,6 +440,10 @@ impl RadixTree {
436440
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
437441
blocks.iter().for_each(|(_, block)| {
438442
block.borrow_mut().workers.remove(&worker);
443+
// If no workers are using this block, that is true for all children
444+
if block.borrow().workers.is_empty() {
445+
block.borrow_mut().children.clear();
446+
}
439447
});
440448
}
441449
}
@@ -445,14 +453,18 @@ impl RadixTree {
445453
if let Some(blocks) = self.lookup.get(&worker) {
446454
let blocks_to_clear: Vec<_> = blocks.values().collect();
447455

448-
// Remove the worker from each block's workers set
456+
// Remove the worker from each block's workers map
449457
blocks_to_clear.iter().for_each(|block| {
450458
block.borrow_mut().workers.remove(&worker);
459+
// If no workers are using this block, that is true for all children
460+
if block.borrow().workers.is_empty() {
461+
block.borrow_mut().children.clear();
462+
}
451463
});
452464

453465
// Clear the worker's blocks
454-
if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
455-
worker_blocks.clear();
466+
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
467+
worker_lookup.clear();
456468
}
457469
}
458470
}
@@ -461,71 +473,68 @@ impl RadixTree {
461473
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
462474
/// though the exact event ordering will be lost.
463475
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
476+
tracing::debug!(
477+
"Dumping radix tree as events (contains information about {:?} workers)",
478+
self.lookup.len()
479+
);
480+
464481
let mut events = Vec::new();
465482
let mut event_id = 0u64;
466483

467-
// BFS queue: (current_block, parent_external_hash, tokens_hash)
468-
let mut queue = VecDeque::new();
484+
// BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
485+
// parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash
486+
let mut queue: VecDeque<(
487+
SharedRadixBlock,
488+
HashMap<WorkerId, ExternalSequenceBlockHash>,
489+
LocalBlockHash,
490+
)> = VecDeque::new();
469491

470492
// Process root's children first
471493
let root_borrow = self.root.borrow();
472494
for (tokens_hash, child_block) in &root_borrow.children {
473-
queue.push_back((child_block.clone(), None, *tokens_hash));
495+
queue.push_back((child_block.clone(), HashMap::new(), *tokens_hash));
474496
}
475497
drop(root_borrow);
476498

477-
while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
499+
while let Some((current_block, parent_hashes, tokens_hash)) = queue.pop_front() {
478500
let current_borrow = current_block.borrow();
479501

480-
// Closure to find external hash for a block in a worker's lookup
481-
let find_external_hash = |worker_id: &WorkerId| {
482-
self.lookup.get(worker_id).and_then(|worker_blocks| {
483-
worker_blocks
484-
.iter()
485-
.find(|(_, block)| Rc::ptr_eq(block, &current_block))
486-
.map(|(hash, _)| *hash)
487-
})
488-
};
502+
// Map of this block's external hashes per worker (for children to use as parent)
503+
let mut current_external_hashes = HashMap::new();
489504

490505
// For each worker that has this block
491-
for worker_id in &current_borrow.workers {
492-
// Find the external hash for this block from the worker's lookup
493-
let external_hash = find_external_hash(worker_id);
494-
495-
if let Some(block_hash) = external_hash {
496-
// Create a store event for this worker
497-
let event = RouterEvent {
498-
worker_id: *worker_id,
499-
event: KvCacheEvent {
500-
event_id,
501-
data: KvCacheEventData::Stored(KvCacheStoreData {
502-
parent_hash: parent_external_hash,
503-
blocks: vec![KvCacheStoredBlockData {
504-
block_hash,
505-
tokens_hash,
506-
}],
507-
}),
508-
},
509-
};
510-
events.push(event);
511-
event_id += 1;
512-
}
513-
}
506+
for (worker_id, external_hash) in &current_borrow.workers {
507+
// Get the correct parent hash for this worker
508+
let parent_hash = parent_hashes.get(worker_id).copied();
509+
510+
// Create a store event for this worker
511+
let event = RouterEvent {
512+
worker_id: *worker_id,
513+
event: KvCacheEvent {
514+
event_id,
515+
data: KvCacheEventData::Stored(KvCacheStoreData {
516+
parent_hash,
517+
blocks: vec![KvCacheStoredBlockData {
518+
block_hash: *external_hash,
519+
tokens_hash,
520+
}],
521+
}),
522+
},
523+
};
524+
events.push(event);
525+
event_id += 1;
514526

515-
// Add children to queue for BFS traversal
516-
// We need to find any external hash for this block to use as parent
517-
let any_external_hash = if !current_borrow.workers.is_empty() {
518-
current_borrow
519-
.workers
520-
.iter()
521-
.next()
522-
.and_then(find_external_hash)
523-
} else {
524-
None
525-
};
527+
// Track this block's external hash for this worker
528+
current_external_hashes.insert(*worker_id, *external_hash);
529+
}
526530

531+
// Enqueue children with per-worker parent hashes
527532
for (child_tokens_hash, child_block) in &current_borrow.children {
528-
queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
533+
queue.push_back((
534+
child_block.clone(),
535+
current_external_hashes.clone(),
536+
*child_tokens_hash,
537+
));
529538
}
530539
}
531540

@@ -657,8 +666,11 @@ impl OverlapScores {
657666
///
658667
/// ### Arguments
659668
///
660-
/// * `workers` - A reference to a `HashSet` of `WorkerId`s.
661-
pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
669+
/// * `workers` - An iterator over `WorkerId` references.
670+
pub fn update_scores<'a, I>(&mut self, workers: I)
671+
where
672+
I: IntoIterator<Item = &'a WorkerId>,
673+
{
662674
for worker in workers {
663675
let score = self.scores.entry(*worker).or_insert(0);
664676
*score += 1;
@@ -2171,4 +2183,79 @@ mod tests {
21712183
1
21722184
);
21732185
}
2186+
2187+
#[test]
2188+
fn test_remove_worker_verifies_hash_removal() {
2189+
setup();
2190+
let mut trie = RadixTree::new();
2191+
2192+
let worker_0 = 0;
2193+
let worker_1 = 1;
2194+
let worker_2 = 2;
2195+
2196+
// Add blocks for multiple workers
2197+
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
2198+
.unwrap();
2199+
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
2200+
.unwrap();
2201+
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
2202+
.unwrap();
2203+
2204+
// Verify worker_0 has 3 blocks in lookup
2205+
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3);
2206+
2207+
// Verify that blocks have the correct workers
2208+
let block_1 = trie
2209+
.lookup
2210+
.get(&worker_0)
2211+
.unwrap()
2212+
.get(&ExternalSequenceBlockHash(100))
2213+
.unwrap();
2214+
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
2215+
assert!(block_1.borrow().workers.contains_key(&worker_0));
2216+
assert!(block_1.borrow().workers.contains_key(&worker_1));
2217+
assert!(block_1.borrow().workers.contains_key(&worker_2));
2218+
2219+
// Remove worker_0
2220+
trie.remove_worker(worker_0);
2221+
2222+
// Verify worker_0 is completely removed from lookup table
2223+
assert!(!trie.lookup.contains_key(&worker_0));
2224+
assert_eq!(trie.lookup.len(), 2);
2225+
2226+
// Verify that worker_0's hash is removed from the workers set
2227+
let block_1 = trie
2228+
.lookup
2229+
.get(&worker_1)
2230+
.unwrap()
2231+
.get(&ExternalSequenceBlockHash(100))
2232+
.unwrap();
2233+
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
2234+
assert!(!block_1.borrow().workers.contains_key(&worker_0));
2235+
assert!(block_1.borrow().workers.contains_key(&worker_1));
2236+
assert!(block_1.borrow().workers.contains_key(&worker_2));
2237+
2238+
// Verify that blocks with no remaining workers have their children cleared
2239+
// This tests the optimization where empty blocks clear their children
2240+
let block_2 = trie
2241+
.lookup
2242+
.get(&worker_1)
2243+
.unwrap()
2244+
.get(&ExternalSequenceBlockHash(200))
2245+
.unwrap();
2246+
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
2247+
assert!(block_2.borrow().workers.contains_key(&worker_1));
2248+
2249+
// Verify match results no longer include worker_0
2250+
let result = trie
2251+
.find_matches(
2252+
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
2253+
false,
2254+
)
2255+
.scores;
2256+
assert_eq!(result.len(), 2);
2257+
assert!(!result.contains_key(&worker_0));
2258+
assert!(result.contains_key(&worker_1));
2259+
assert!(result.contains_key(&worker_2));
2260+
}
21742261
}

lib/llm/src/kv_router/subscriber.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ async fn purge_then_snapshot(
365365
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
366366
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
367367
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
368+
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
369+
let start_time = std::time::Instant::now();
368370

369371
// First, purge acknowledged messages from the stream
370372
nats_queue.purge_acknowledged().await?;
@@ -397,9 +399,10 @@ async fn purge_then_snapshot(
397399
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
398400

399401
tracing::info!(
400-
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
402+
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
401403
events.len(),
402-
resources.bucket_name
404+
resources.bucket_name,
405+
start_time.elapsed().as_millis()
403406
);
404407

405408
Ok(())

lib/runtime/src/transports/nats.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ impl Client {
257257
tokio::io::copy(&mut obj_reader, &mut buffer)
258258
.await
259259
.map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
260+
tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
260261

261262
// Deserialize from bincode
262263
let data = bincode::deserialize(&buffer)

0 commit comments

Comments
 (0)