Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 148 additions & 61 deletions lib/llm/src/kv_router/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use prometheus::{IntCounterVec, Opts};
use serde::{Deserialize, Serialize};
use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
collections::{HashMap, VecDeque},
iter,
rc::Rc,
sync::{Arc, OnceLock},
Expand Down Expand Up @@ -200,8 +200,9 @@ impl RouterEvent {
struct RadixBlock {
/// A map of child blocks, keyed by their local block hash.
children: HashMap<LocalBlockHash, SharedRadixBlock>,
/// A set of worker IDs associated with this block.
workers: HashSet<WorkerId>,
/// A map of worker IDs to their external sequence block hash for this block.
/// The external hash is preserved to speed up snapshotting.
workers: HashMap<WorkerId, ExternalSequenceBlockHash>,
/// A buffer of times that this block was last traversed
recent_uses: VecDeque<Instant>,
}
Expand All @@ -215,7 +216,7 @@ impl RadixBlock {
pub fn new() -> Self {
Self {
children: HashMap::new(),
workers: HashSet::new(),
workers: HashMap::new(),
recent_uses: VecDeque::new(),
}
}
Expand Down Expand Up @@ -289,7 +290,7 @@ impl RadixTree {
current_borrow.children.get(block_hash).cloned()
};
if let Some(block) = next_block {
scores.update_scores(&block.borrow().workers);
scores.update_scores(block.borrow().workers.keys());

if let Some(expiration_duration) = self.expiration_duration {
let mut block_mut = block.borrow_mut();
Expand Down Expand Up @@ -380,8 +381,11 @@ impl RadixTree {
}
};

// add our worker_id to the block
block.borrow_mut().workers.insert(worker_id);
// add our worker_id to the block with its external hash
block
.borrow_mut()
.workers
.insert(worker_id, block_id.block_hash);

// add the block to the worker_id lookup table
worker_lookup.insert(block_id.block_hash, block.clone());
Expand Down Expand Up @@ -417,7 +421,7 @@ impl RadixTree {
let mut guard = entry.borrow_mut();
guard.workers.remove(&worker_id);
if guard.workers.is_empty() {
// if no worker are using this block, that is true for all children
// if no workers are using this block, that is true for all children
guard.children.clear();
}
// remove the block from the lookup table
Expand All @@ -436,6 +440,10 @@ impl RadixTree {
if let Some((_, blocks)) = self.lookup.remove_entry(&worker) {
blocks.iter().for_each(|(_, block)| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});
}
}
Expand All @@ -445,14 +453,18 @@ impl RadixTree {
if let Some(blocks) = self.lookup.get(&worker) {
let blocks_to_clear: Vec<_> = blocks.values().collect();

// Remove the worker from each block's workers set
// Remove the worker from each block's workers map
blocks_to_clear.iter().for_each(|block| {
block.borrow_mut().workers.remove(&worker);
// If no workers are using this block, that is true for all children
if block.borrow().workers.is_empty() {
block.borrow_mut().children.clear();
}
});

// Clear the worker's blocks
if let Some(worker_blocks) = self.lookup.get_mut(&worker) {
worker_blocks.clear();
if let Some(worker_lookup) = self.lookup.get_mut(&worker) {
worker_lookup.clear();
}
}
}
Expand All @@ -461,71 +473,68 @@ impl RadixTree {
/// Uses BFS traversal to ensure that the tree reconstruction is unique,
/// though the exact event ordering will be lost.
pub fn dump_tree_as_events(&self) -> Vec<RouterEvent> {
tracing::debug!(
"Dumping radix tree as events (contains information about {:?} workers)",
self.lookup.len()
);

let mut events = Vec::new();
let mut event_id = 0u64;

// BFS queue: (current_block, parent_external_hash, tokens_hash)
let mut queue = VecDeque::new();
// BFS queue: (current_block, parent_hashes_per_worker, tokens_hash)
// parent_hashes_per_worker maps WorkerId -> ExternalSequenceBlockHash
let mut queue: VecDeque<(
SharedRadixBlock,
HashMap<WorkerId, ExternalSequenceBlockHash>,
LocalBlockHash,
)> = VecDeque::new();

// Process root's children first
let root_borrow = self.root.borrow();
for (tokens_hash, child_block) in &root_borrow.children {
queue.push_back((child_block.clone(), None, *tokens_hash));
queue.push_back((child_block.clone(), HashMap::new(), *tokens_hash));
}
drop(root_borrow);

while let Some((current_block, parent_external_hash, tokens_hash)) = queue.pop_front() {
while let Some((current_block, parent_hashes, tokens_hash)) = queue.pop_front() {
let current_borrow = current_block.borrow();

// Closure to find external hash for a block in a worker's lookup
let find_external_hash = |worker_id: &WorkerId| {
self.lookup.get(worker_id).and_then(|worker_blocks| {
worker_blocks
.iter()
.find(|(_, block)| Rc::ptr_eq(block, &current_block))
.map(|(hash, _)| *hash)
})
};
// Map of this block's external hashes per worker (for children to use as parent)
let mut current_external_hashes = HashMap::new();

// For each worker that has this block
for worker_id in &current_borrow.workers {
// Find the external hash for this block from the worker's lookup
let external_hash = find_external_hash(worker_id);

if let Some(block_hash) = external_hash {
// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_external_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash,
tokens_hash,
}],
}),
},
};
events.push(event);
event_id += 1;
}
}
for (worker_id, external_hash) in &current_borrow.workers {
// Get the correct parent hash for this worker
let parent_hash = parent_hashes.get(worker_id).copied();

// Create a store event for this worker
let event = RouterEvent {
worker_id: *worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: *external_hash,
tokens_hash,
}],
}),
},
};
events.push(event);
event_id += 1;

// Add children to queue for BFS traversal
// We need to find any external hash for this block to use as parent
let any_external_hash = if !current_borrow.workers.is_empty() {
current_borrow
.workers
.iter()
.next()
.and_then(find_external_hash)
} else {
None
};
// Track this block's external hash for this worker
current_external_hashes.insert(*worker_id, *external_hash);
}

// Enqueue children with per-worker parent hashes
for (child_tokens_hash, child_block) in &current_borrow.children {
queue.push_back((child_block.clone(), any_external_hash, *child_tokens_hash));
queue.push_back((
child_block.clone(),
current_external_hashes.clone(),
*child_tokens_hash,
));
}
}

Expand Down Expand Up @@ -657,8 +666,11 @@ impl OverlapScores {
///
/// ### Arguments
///
/// * `workers` - A reference to a `HashSet` of `WorkerId`s.
pub fn update_scores(&mut self, workers: &HashSet<WorkerId>) {
/// * `workers` - An iterator over `WorkerId` references.
pub fn update_scores<'a, I>(&mut self, workers: I)
where
I: IntoIterator<Item = &'a WorkerId>,
{
for worker in workers {
let score = self.scores.entry(*worker).or_insert(0);
*score += 1;
Expand Down Expand Up @@ -2171,4 +2183,79 @@ mod tests {
1
);
}

#[test]
fn test_remove_worker_verifies_hash_removal() {
setup();
let mut trie = RadixTree::new();

let worker_0 = 0;
let worker_1 = 1;
let worker_2 = 2;

// Add blocks for multiple workers
trie.apply_event(create_store_event(worker_0, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_1, 0, vec![1, 2, 3], None))
.unwrap();
trie.apply_event(create_store_event(worker_2, 0, vec![1, 4, 5], None))
.unwrap();

// Verify worker_0 has 3 blocks in lookup
assert_eq!(trie.lookup.get(&worker_0).unwrap().len(), 3);

// Verify that blocks have the correct workers
let block_1 = trie
.lookup
.get(&worker_0)
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 3); // worker_0, worker_1, and worker_2 (all have hash 1)
assert!(block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));

// Remove worker_0
trie.remove_worker(worker_0);

// Verify worker_0 is completely removed from lookup table
assert!(!trie.lookup.contains_key(&worker_0));
assert_eq!(trie.lookup.len(), 2);

// Verify that worker_0's hash is removed from the workers set
let block_1 = trie
.lookup
.get(&worker_1)
.unwrap()
.get(&ExternalSequenceBlockHash(100))
.unwrap();
assert_eq!(block_1.borrow().workers.len(), 2); // worker_1 and worker_2 remain
assert!(!block_1.borrow().workers.contains_key(&worker_0));
assert!(block_1.borrow().workers.contains_key(&worker_1));
assert!(block_1.borrow().workers.contains_key(&worker_2));

// Verify that blocks with no remaining workers have their children cleared
// This tests the optimization where empty blocks clear their children
let block_2 = trie
.lookup
.get(&worker_1)
.unwrap()
.get(&ExternalSequenceBlockHash(200))
.unwrap();
assert_eq!(block_2.borrow().workers.len(), 1); // only worker_1
assert!(block_2.borrow().workers.contains_key(&worker_1));

// Verify match results no longer include worker_0
let result = trie
.find_matches(
vec![LocalBlockHash(1), LocalBlockHash(2), LocalBlockHash(3)],
false,
)
.scores;
assert_eq!(result.len(), 2);
assert!(!result.contains_key(&worker_0));
assert!(result.contains_key(&worker_1));
assert!(result.contains_key(&worker_2));
}
}
7 changes: 5 additions & 2 deletions lib/llm/src/kv_router/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ async fn purge_then_snapshot(
// Purge before snapshot ensures new/warm-restarted routers won't replay already-acknowledged messages.
// Since KV events are idempotent, this ordering reduces unnecessary reprocessing while maintaining
// at-least-once delivery guarantees. The snapshot will capture the clean state after purge.
tracing::info!("Purging acknowledged messages and performing snapshot of radix tree");
let start_time = std::time::Instant::now();

// First, purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
Expand Down Expand Up @@ -395,9 +397,10 @@ async fn purge_then_snapshot(
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;

tracing::info!(
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
"Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms",
events.len(),
resources.bucket_name
resources.bucket_name,
start_time.elapsed().as_millis()
);

Ok(())
Expand Down
1 change: 1 addition & 0 deletions lib/runtime/src/transports/nats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ impl Client {
tokio::io::copy(&mut obj_reader, &mut buffer)
.await
.map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());

// Deserialize from bincode
let data = bincode::deserialize(&buffer)
Expand Down
Loading