Skip to content

Commit 16d6cf6

Browse files
committed
improve snapshot perf
Signed-off-by: Brian Larson <[email protected]>
1 parent 0f63a05 commit 16d6cf6

File tree

3 files changed

+135
-58
lines changed

3 files changed

+135
-58
lines changed

lib/llm/src/kv_router/indexer.rs

Lines changed: 129 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
4141
use serde::{Deserialize, Serialize};
4242
use std::{
4343
cell::RefCell,
44-
collections::{HashMap, HashSet, VecDeque},
44+
collections::{HashMap, VecDeque},
4545
iter,
4646
rc::Rc,
4747
sync::{Arc, OnceLock},
@@ -200,8 +200,9 @@ impl RouterEvent {
200200
struct RadixBlock {
201201
/// A map of child blocks, keyed by their local block hash.
202202
children: HashMap<LocalBlockHash, SharedRadixBlock>,
203-
/// A set of worker IDs associated with this block.
204-
workers: HashSet<WorkerId>,
203+
/// A map of worker IDs to their external sequence block hash for this block.
204+
/// The external hash is preserved to speed up snapshotting.
205+
workers: HashMap<WorkerId, ExternalSequenceBlockHash>,
205206
/// A buffer of times that this block was last traversed
206207
recent_uses: VecDeque<Instant>,
207208
}
@@ -215,7 +216,7 @@ impl RadixBlock {
215216
pub fn new() -> Self {
216217
Self {
217218
children: HashMap::new(),
218-
workers: HashSet::new(),
219+
workers: HashMap::new(),
219220
recent_uses: VecDeque::new(),
220221
}
221222
}
@@ -289,7 +290,7 @@ impl RadixTree {
289290
current_borrow.children.get(block_hash).cloned()
290291
};
291292
if let Some(block) = next_block {
292-
scores.update_scores(&block.borrow().workers);
293+
scores.update_scores(block.borrow().workers.keys());
293294

294295
if let Some(expiration_duration) = self.expiration_duration {
295296
let mut block_mut = block.borrow_mut();
@@ -380,8 +381,11 @@ impl RadixTree {
380381
}
381382
};
382383

383-
// add our worker_id to the block
384-
block.borrow_mut().workers.insert(worker_id);
384+
// add our worker_id to the block with its external hash
385+
block
386+
.borrow_mut()
387+
.workers
388+
.insert(worker_id, block_id.block_hash);
385389

386390
// add the block to the worker_id lookup table
387391
worker_lookup.insert(block_id.block_hash, block.clone());
@@ -417,7 +421,7 @@ impl RadixTree {
417421
let mut guard = entry.borrow_mut();
418422
guard.workers.remove(&worker_id);
419423
if guard.workers.is_empty() {
420-
// if no worker are using this block, that is true for all children
424+
// if no workers are using this block, that is true for all children
421425
guard.children.clear();
422426
}
423427
// remove the block from the lookup table
@@ -436,6 +440,10 @@ impl RadixTree {
436440
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
437441
blocks.iter().for_each(|(_, block)| {
438442
block.borrow_mut().workers.remove(&worker);
443+
// If no workers are using this block, that is true for all children
444+
if block.borrow().workers.is_empty() {
445+
block.borrow_mut().children.clear();
446+
}
439447
});
440448
}
441449
}
@@ -445,14 +453,18 @@ impl RadixTree {
445453
if let Some(blocks) = self.lookup.get(&worker) {
446454
let blocks_to_clear: Vec<_> = blocks.values().collect();
447455

448-
// Remove the worker from each block's workers set
456+
// Remove the worker from each block's workers map
449457
blocks_to_clear.iter().for_each(|block| {
450458
block.borrow_mut().workers.remove(&worker);
459+
// If no workers are using this block, that is true for all children
460+
if block.borrow().workers.is_empty() {
461+
block.borrow_mut().children.clear();
462+
}
451463
});
452464

453465
// Clear the worker's blocks
454-
if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
455-
worker_blocks.clear();
466+
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
467+
worker_lookup.clear();
456468
}
457469
}
458470
}
@@ -461,6 +473,11 @@ impl RadixTree {
461473
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
462474
/// though the exact event ordering will be lost.
463475
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
476+
tracing::debug!(
477+
"Dumping radix tree as events (contains information about {:?} workers)",
478+
self.lookup.len()
479+
);
480+
464481
let mut events = Vec::new();
465482
let mut event_id = 0u64;
466483

@@ -477,53 +494,31 @@ impl RadixTree {
477494
while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
478495
let current_borrow = current_block.borrow();
479496

480-
// Closure to find external hash for a block in a worker's lookup
481-
let find_external_hash = |worker_id: &WorkerId| {
482-
self.lookup.get(worker_id).and_then(|worker_blocks| {
483-
worker_blocks
484-
.iter()
485-
.find(|(_, block)| Rc::ptr_eq(block, &current_block))
486-
.map(|(hash, _)| *hash)
487-
})
488-
};
497+
// We need to find any external hash for this block to use as parent
498+
// when we enqueue the children.
499+
let mut any_external_hash: Option<ExternalSequenceBlockHash> = None;
489500

490501
// For each worker that has this block
491-
for worker_id in &current_borrow.workers {
492-
// Find the external hash for this block from the worker's lookup
493-
let external_hash = find_external_hash(worker_id);
494-
495-
if let Some(block_hash) = external_hash {
496-
// Create a store event for this worker
497-
let event = RouterEvent {
498-
worker_id: *worker_id,
499-
event: KvCacheEvent {
500-
event_id,
501-
data: KvCacheEventData::Stored(KvCacheStoreData {
502-
parent_hash: parent_external_hash,
503-
blocks: vec![KvCacheStoredBlockData {
504-
block_hash,
505-
tokens_hash,
506-
}],
507-
}),
508-
},
509-
};
510-
events.push(event);
511-
event_id += 1;
512-
}
502+
for (worker_id, external_hash) in &current_borrow.workers {
503+
// Create a store event for this worker
504+
let event = RouterEvent {
505+
worker_id: *worker_id,
506+
event: KvCacheEvent {
507+
event_id,
508+
data: KvCacheEventData::Stored(KvCacheStoreData {
509+
parent_hash: parent_external_hash,
510+
blocks: vec![KvCacheStoredBlockData {
511+
block_hash: *external_hash,
512+
tokens_hash,
513+
}],
514+
}),
515+
},
516+
};
517+
events.push(event);
518+
event_id += 1;
519+
any_external_hash = Some(*external_hash);
513520
}
514521

515-
// Add children to queue for BFS traversal
516-
// We need to find any external hash for this block to use as parent
517-
let any_external_hash = if !current_borrow.workers.is_empty() {
518-
current_borrow
519-
.workers
520-
.iter()
521-
.next()
522-
.and_then(find_external_hash)
523-
} else {
524-
None
525-
};
526-
527522
for (child_tokens_hash, child_block) in &current_borrow.children {
528523
queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
529524
}
@@ -657,8 +652,11 @@ impl OverlapScores {
657652
///
658653
/// ### Arguments
659654
///
660-
/// * `workers` - A reference to a `HashSet` of `WorkerId`s.
661-
pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
655+
/// * `workers` - An iterator over `WorkerId` references.
656+
pub fn update_scores<'a, I>(&mut self, workers: I)
657+
where
658+
I: IntoIterator<Item = &'a WorkerId>,
659+
{
662660
for worker in workers {
663661
let score = self.scores.entry(*worker).or_insert(0);
664662
*score += 1;
@@ -2171,4 +2169,79 @@ mod tests {
21712169
1
21722170
);
21732171
}
2172+
2173+
#[test]
2174+
fn test_remove_worker_verifies_hash_removal() {
2175+
setup();
2176+
let mut trie = RadixTree::new();
2177+
2178+
let worker_0 = 0;
2179+
let worker_1 = 1;
2180+
let worker_2 = 2;
2181+
2182+
// Add blocks for multiple workers
2183+
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
2184+
.unwrap();
2185+
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
2186+
.unwrap();
2187+
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
2188+
.unwrap();
2189+
2190+
// Verify worker_0 has 3 blocks in lookup
2191+
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3);
2192+
2193+
// Verify that blocks have the correct workers
2194+
let block_1 = trie
2195+
.lookup
2196+
.get(&worker_0)
2197+
.unwrap()
2198+
.get(&ExternalSequenceBlockHash(100))
2199+
.unwrap();
2200+
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
2201+
assert!(block_1.borrow().workers.contains_key(&worker_0));
2202+
assert!(block_1.borrow().workers.contains_key(&worker_1));
2203+
assert!(block_1.borrow().workers.contains_key(&worker_2));
2204+
2205+
// Remove worker_0
2206+
trie.remove_worker(worker_0);
2207+
2208+
// Verify worker_0 is completely removed from lookup table
2209+
assert!(!trie.lookup.contains_key(&worker_0));
2210+
assert_eq!(trie.lookup.len(), 2);
2211+
2212+
// Verify that worker_0's hash is removed from the workers set
2213+
let block_1 = trie
2214+
.lookup
2215+
.get(&worker_1)
2216+
.unwrap()
2217+
.get(&ExternalSequenceBlockHash(100))
2218+
.unwrap();
2219+
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
2220+
assert!(!block_1.borrow().workers.contains_key(&worker_0));
2221+
assert!(block_1.borrow().workers.contains_key(&worker_1));
2222+
assert!(block_1.borrow().workers.contains_key(&worker_2));
2223+
2224+
// Verify that blocks with no remaining workers have their children cleared
2225+
// This tests the optimization where empty blocks clear their children
2226+
let block_2 = trie
2227+
.lookup
2228+
.get(&worker_1)
2229+
.unwrap()
2230+
.get(&ExternalSequenceBlockHash(200))
2231+
.unwrap();
2232+
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
2233+
assert!(block_2.borrow().workers.contains_key(&worker_1));
2234+
2235+
// Verify match results no longer include worker_0
2236+
let result = trie
2237+
.find_matches(
2238+
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
2239+
false,
2240+
)
2241+
.scores;
2242+
assert_eq!(result.len(), 2);
2243+
assert!(!result.contains_key(&worker_0));
2244+
assert!(result.contains_key(&worker_1));
2245+
assert!(result.contains_key(&worker_2));
2246+
}
21742247
}

lib/llm/src/kv_router/subscriber.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ async fn purge_then_snapshot(
363363
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
364364
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
365365
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
366+
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
367+
let start_time = std::time::Instant::now();
366368

367369
// First, purge acknowledged messages from the stream
368370
nats_queue.purge_acknowledged().await?;
@@ -395,9 +397,10 @@ async fn purge_then_snapshot(
395397
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
396398

397399
tracing::info!(
398-
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
400+
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
399401
events.len(),
400-
resources.bucket_name
402+
resources.bucket_name,
403+
start_time.elapsed().as_millis()
401404
);
402405

403406
Ok(())

lib/runtime/src/transports/nats.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ impl Client {
257257
tokio::io::copy(&mut obj_reader, &mut buffer)
258258
.await
259259
.map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
260+
tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
260261

261262
// Deserialize from bincode
262263
let data = bincode::deserialize(&buffer)

0 commit comments

Comments
 (0)