From 147dc1b801a5bef20a21170e64cb02e6b9aceb6a Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Wed, 13 Aug 2025 23:05:14 -0700 Subject: [PATCH 01/17] trtllm integration connector api fix fix fix fix interace fix interace fix add logs async leader fix fix fix fix fix scheduled tokens fix fix fix fix add logs add logs fix fix and log fix and log fix fix fix layout fix fix fix fix fmt fix fix comments fmt fix comment Signed-off-by: richardhuo-nv --- lib/bindings/python/rust/llm/block_manager.rs | 111 ++++ .../rust/llm/block_manager/distributed.rs | 2 +- .../llm/block_manager/distributed/leader.rs | 56 +- .../llm/block_manager/distributed/utils.rs | 4 +- .../llm/block_manager/distributed/worker.rs | 6 +- .../python/rust/llm/block_manager/vllm.rs | 3 + .../rust/llm/block_manager/vllm/connector.rs | 2 + .../vllm/connector/leader/slot.rs | 107 +++- .../vllm/connector/trtllm_leader.rs | 484 ++++++++++++++++++ .../vllm/connector/trtllm_worker.rs | 435 ++++++++++++++++ .../block_manager/vllm/connector/worker.rs | 6 +- .../dynamo/llm/trtllm_integration/__init__.py | 2 + .../trtllm_integration/connector/__init__.py | 7 + .../connector/kvbm_connector_leader.py | 133 +++++ .../connector/kvbm_connector_worker.py | 127 +++++ .../src/dynamo/llm/trtllm_integration/rust.py | 45 ++ lib/llm/src/block_manager/distributed.rs | 22 +- .../src/block_manager/distributed/leader.rs | 349 +++++++++++-- .../src/block_manager/distributed/worker.rs | 232 +++++++-- lib/llm/src/block_manager/pool/managed.rs | 1 + 20 files changed, 2011 insertions(+), 123 deletions(-) create mode 100644 lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs create mode 100644 lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs create mode 100644 lib/bindings/python/src/dynamo/llm/trtllm_integration/__init__.py create mode 100644 lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/__init__.py create mode 100644 lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py create mode 100644 lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_worker.py create mode 100644 lib/bindings/python/src/dynamo/llm/trtllm_integration/rust.py diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index eca0f873f7..346b1bcce3 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -14,6 +14,7 @@ // limitations under the License. use super::*; +use anyhow::Result; use dynamo_llm::block_manager::block::{ data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical, }; @@ -220,3 +221,113 @@ impl BlockManager { &self.inner } } + +#[derive(Default)] +pub struct BlockManagerBuilder { + worker_id: u64, + leader: Option, + page_size: usize, + disable_device_pool: bool, +} + +impl BlockManagerBuilder { + pub fn new() -> Self { + Self { + page_size: 0, + ..Default::default() + } + } + + pub fn worker_id(mut self, id: u64) -> Self { + self.worker_id = id; + self + } + pub fn page_size(mut self, ps: usize) -> Self { + self.page_size = ps; + self + } + pub fn leader(mut self, l: distributed::KvbmLeader) -> Self { + self.leader = Some(l); + self + } + pub fn disable_device_pool(mut self, yes: bool) -> Self { + self.disable_device_pool = yes; + self + } + + /// Async build (call from an async context). + pub async fn build(self) -> Result { + let worker_id = self.worker_id; + let leader = self.leader.ok_or_else(|| { + anyhow::anyhow!("leader is required (runtime is always taken from leader)") + })?; + + // Get (inner leader handle, runtime) from the provided leader. + let (leader_inner, drt) = leader.dissolve(); + + let cancel_token = CancellationToken::new(); + + // Runtime & model config + let runtime_config = dynamo_llm::block_manager::KvManagerRuntimeConfig::builder() + .worker_id(worker_id) + .cancellation_token(cancel_token.clone()) + .build()?; + + let mut config = + dynamo_llm::block_manager::KvBlockManagerConfig::builder().runtime(runtime_config); + + let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() + .num_layers(1) + .outer_dim(1) + .page_size(self.page_size) + .inner_dim(1) + .build()?; + + config = config.model(model_config); + + // Layouts derived from leader’s counts + if !self.disable_device_pool { + config = config.device_layout( + dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + .num_blocks(leader_inner.num_device_blocks()) + .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + .build()?, + ); + } + + if leader_inner.num_host_blocks() > 0 { + config = config.host_layout( + dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + .num_blocks(leader_inner.num_host_blocks()) + .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + .build()?, + ); + } + + if leader_inner.num_disk_blocks() > 0 { + config = config.disk_layout( + dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + .num_blocks(leader_inner.num_disk_blocks()) + .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + .build()?, + ); + } + + let config = config.build()?; + + let resources = + DistributedLeaderWorkerResources::new(Some(leader_inner), cancel_token.child_token())?; + + let inner = dynamo_llm::block_manager::KvBlockManager::< + Logical, + BasicMetadata, + >::new(config, resources) + .await?; + + Ok(BlockManager { + inner, + drt, + _controller: None, + }) + } +} diff --git a/lib/bindings/python/rust/llm/block_manager/distributed.rs b/lib/bindings/python/rust/llm/block_manager/distributed.rs index 69248360c4..5b7d810ab3 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed.rs @@ -8,5 +8,5 @@ mod utils; mod worker; pub use leader::KvbmLeader; -pub use utils::get_barrier_id; +pub use utils::get_barrier_id_prefix; pub use worker::{KvbmWorker, VllmTensor}; diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/leader.rs b/lib/bindings/python/rust/llm/block_manager/distributed/leader.rs index 4d938bb1e6..f565f4fc9a 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/leader.rs @@ -2,10 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use utils::get_barrier_id; +use utils::get_barrier_id_prefix; use derive_getters::Dissolve; -use llm_rs::block_manager::distributed::{KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig}; +use llm_rs::block_manager::distributed::{ + KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig, +}; const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB"; const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"; @@ -16,15 +18,32 @@ const DISK_CACHE_OVERRIDE: &str = "DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS"; const LEADER_WORKER_INIT_TIMEOUT_SECS: &str = "DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS"; const DEFAULT_INIT_TIMEOUT_SECS: u64 = 120; -fn compute_num_blocks(cache_size_key: &str, override_key: &str, bytes_per_block: usize) -> usize { - if let Ok(override_num_blocks) = std::env::var(override_key) { - override_num_blocks.parse::().unwrap_or(0) - } else { - let cache_size_gb = std::env::var(cache_size_key) - .unwrap_or_default() - .parse::() - .unwrap_or(0.0); - ((cache_size_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize +fn read_env_usize(key: &str) -> Option { + std::env::var(key).ok()?.trim().parse::().ok() +} + +fn read_cache_size_float(key: &str) -> f64 { + std::env::var(key) + .unwrap_or_default() + .parse::() + .unwrap_or(0.0) +} + +fn get_blocks_config(cache_size_key: &str, override_key: &str) -> KvbmLeaderNumBlocksConfig { + if let Some(nblocks) = read_env_usize(override_key) { + // Optional: still read cache size for observability, but override takes precedence. + let cache_gb: f64 = read_cache_size_float(cache_size_key); + return KvbmLeaderNumBlocksConfig { + cache_size_in_gb: cache_gb, + num_blocks_overriden: nblocks, + }; + } + + // No override -> compute from cache size (in GB) + let cache_gb: f64 = read_cache_size_float(cache_size_key); + KvbmLeaderNumBlocksConfig { + cache_size_in_gb: cache_gb, + num_blocks_overriden: 0, } } @@ -51,22 +70,19 @@ impl KvbmLeader { #[pymethods] impl KvbmLeader { #[new] - #[pyo3(signature = (bytes_per_block, world_size, drt))] - fn new(bytes_per_block: usize, world_size: usize, drt: DistributedRuntime) -> PyResult { - let num_host_blocks = compute_num_blocks(CPU_CACHE, CPU_CACHE_OVERRIDE, bytes_per_block); - let num_disk_blocks = compute_num_blocks(DISK_CACHE, DISK_CACHE_OVERRIDE, bytes_per_block); - - let barrier_id = get_barrier_id(); + #[pyo3(signature = (world_size, drt))] + fn new(world_size: usize, drt: DistributedRuntime) -> PyResult { + let barrier_id_prefix = get_barrier_id_prefix(); let leader_init_timeout_sec: u64 = get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS); let config = KvbmLeaderConfig::builder() - .barrier_id(barrier_id) - .num_host_blocks(num_host_blocks) - .num_disk_blocks(num_disk_blocks) + .barrier_id_prefix(barrier_id_prefix) .world_size(world_size) .leader_init_timeout_secs(leader_init_timeout_sec) .drt(drt.inner().clone()) + .host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE)) + .disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE)) .build() .map_err(to_pyerr)?; diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs b/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs index 8520e3025f..2777260fb4 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs @@ -1,6 +1,6 @@ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -pub fn get_barrier_id() -> String { - std::env::var("DYN_KVBM_BARRIER_ID").unwrap_or("kvbm".to_string()) +pub fn get_barrier_id_prefix() -> String { + std::env::var("DYN_KVBM_BARRIER_ID_PREFIX").unwrap_or("kvbm".to_string()) } diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs index d4fc329fec..b21ca01c42 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs @@ -4,7 +4,7 @@ use super::*; use std::sync::Arc; -use utils::get_barrier_id; +use utils::get_barrier_id_prefix; use llm_rs::block_manager::distributed::{ BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl, @@ -131,7 +131,7 @@ impl KvbmWorker { vllm_tensors.push(Arc::new(vllm_tensor)); } - let barrier_id = get_barrier_id(); + let barrier_id_prefix = get_barrier_id_prefix(); let config = KvbmWorkerConfig::builder() .drt(drt) @@ -140,7 +140,7 @@ impl KvbmWorker { .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) - .barrier_id(barrier_id) + .barrier_id_prefix(barrier_id_prefix) .build() .map_err(to_pyerr)?; diff --git a/lib/bindings/python/rust/llm/block_manager/vllm.rs b/lib/bindings/python/rust/llm/block_manager/vllm.rs index c41680eaef..56bd675558 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm.rs @@ -50,6 +50,9 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // TODO: use TRTLLM own integration module + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector.rs index ab87880cad..391eaac7ec 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector.rs @@ -7,6 +7,8 @@ use dynamo_llm::block_manager::{ }; pub mod leader; +pub mod trtllm_leader; +pub mod trtllm_worker; pub mod worker; use pyo3::prelude::*; diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 6ef91aefb6..ccca6f94d7 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -106,6 +106,17 @@ pub trait Slot: std::fmt::Debug { num_scheduled_tokens: usize, ) -> Result<(), SlotError>; + // TRT-LLM does not include scheduled tokens in the scheduler output. + // Ideally, we should have a dedicated implementation for the TRT-LLM slot. + // However, since only this single function needs to be rewritten for now, + // we keep it as a separate function in Slot. + fn apply_scheduler_output_with_computed_position( + &mut self, + tokens: &[u32], + block_ids: &[usize], + computed_position: usize, + ) -> Result<(), SlotError>; + fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>; @@ -228,6 +239,11 @@ impl SlotManager for ConnectorSlotManager { tokens: Vec, salt_hash: SaltHash, ) -> Result<(), SlotError> { + tracing::debug!( + "creating slot with request_id: {}, num_tokens: {}", + request_id, + tokens.len() + ); let slot = VllmConnectorSlot::new( request_id.to_string(), tokens.into(), @@ -566,6 +582,93 @@ impl Slot for VllmConnectorSlot { Ok(()) } + #[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))] + fn apply_scheduler_output_with_computed_position( + &mut self, + tokens: &[u32], + block_ids: &[usize], + computed_position: usize, + ) -> Result<(), SlotError> { + // TRTLLM's KV Connector Manager will have (computed_position - external matches) + // in onborading case + if computed_position < self.current_position { + tracing::debug!( + "computed_position={} <= current_position={}, so we are onboarding during prefilling phase", + computed_position, self.current_position + ); + return Ok(()); + } + + // now we decide what we should do for the new computed tokens + + if computed_position < self.sequence.total_tokens() { + // no need to apply new tokens, since it's applied when created the slot during prefilling + self.state = SlotState::Prefilling; + } else { + tracing::debug!( + "appending {} newly decoded tokens to sequence", + tokens.len() + ); + self.sequence.extend(tokens.into()).unwrap(); + self.state = SlotState::Decoding; + } + + // apply new block_ids, this should be applied for both prefilling and decoding + // because this is unknown when creating the slot + if !block_ids.is_empty() { + tracing::debug!("assigning {} new device blocks slot", block_ids.len()); + self.device_blocks.extend(block_ids); + } + + let num_candidate_blocks = + ((computed_position + 1) / self.block_size) - self.evaluated_blocks; + + if num_candidate_blocks != 0 { + // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. + // for now, offload all the blocks to the host + let offload_block_ids: Vec = self + .device_blocks + .iter() + .skip(self.evaluated_blocks) + .take(num_candidate_blocks) + .copied() + .collect::>(); + + assert_eq!( + offload_block_ids.len(), + num_candidate_blocks, + "device block overflow - candidate blocks exceed block count at offset {}", + self.evaluated_blocks + ); + + let offload_token_blocks: Vec = self + .sequence + .blocks() + .iter() + .skip(self.evaluated_blocks) + .take(num_candidate_blocks) + .cloned() + .collect::>(); + + self.offload_blocks(&offload_block_ids, &offload_token_blocks) + .expect("failed to offload blocks"); + + self.evaluated_blocks += num_candidate_blocks; + } + + // done applying policy + tracing::debug!( + "done applying kv cache policy at current_position: {}; computed_position: {}", + self.current_position, + computed_position, + ); + + // advance current position to computed position + self.current_position = computed_position; + + Ok(()) + } + fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> { if self.iteration_first_scheduled.is_none() { self.iteration_first_scheduled = Some(iteration); @@ -676,7 +779,7 @@ impl Slot for VllmConnectorSlot { let num_matched_blocks = num_matched_host_blocks + num_matched_disk_blocks; tracing::debug!( - "matched {} host blocks and {} disk blocks; {} total blocks", + "successfully matched {} host blocks and {} disk blocks; {} total blocks", num_matched_host_blocks, num_matched_disk_blocks, num_matched_blocks @@ -925,7 +1028,7 @@ impl VllmConnectorSlot { tracing::debug!( request_id = self.request_id, operation_id = %operation_id, - "onboarding {} blocks from {:?} to device", + "start onboarding {} blocks from {:?} to device", num_blocks, src_storage_pool, ); diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs new file mode 100644 index 0000000000..bb8d15cccd --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -0,0 +1,484 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; + +use crate::llm::block_manager::vllm::connector::leader::slot::{ + ConnectorSlotManager, SlotManager, SlotState, +}; +use crate::llm::block_manager::BlockManagerBuilder; +use crate::llm::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRequest}; +use crate::DistributedRuntime as PyDistributedRuntime; +use anyhow; +use std::collections::HashSet; +use std::sync::{Arc, OnceLock}; +use tokio::runtime::Handle; + +pub trait Leader: Send + Sync + std::fmt::Debug { + fn get_num_new_matched_tokens( + &mut self, + request_id: String, + request_num_tokens: usize, + num_computed_tokens: usize, + ) -> anyhow::Result<(usize, bool)>; + + fn update_state_after_alloc( + &mut self, + request_id: String, + block_ids: Vec, + context_current_position: usize, + ) -> anyhow::Result<()>; + + fn build_connector_metadata( + &mut self, + scheduler_output: SchedulerOutput, + ) -> anyhow::Result>; + + fn request_finished( + &mut self, + request_id: String, + block_ids: Vec, + ) -> anyhow::Result; + + fn has_slot(&self, request_id: String) -> bool; + + fn create_slot(&mut self, request: KvbmRequest, tokens: Vec) -> anyhow::Result<()>; + + fn slot_manager(&self) -> &ConnectorSlotManager; +} + +#[derive(Debug)] +pub struct KvConnectorLeader { + slot_manager: Arc>>, + block_size: usize, + inflight_requests: HashSet, + onboarding_slots: HashSet, + iteration_counter: u64, + inflight_request_to_num_external_tokens: HashMap, +} + +impl KvConnectorLeader { + fn new( + worker_id: u64, + drt: PyDistributedRuntime, + page_size: usize, + leader_py: PyKvbmLeader, + ) -> Self { + tracing::info!( + "KvConnectorLeader initialized with worker_id: {}", + worker_id + ); + + let leader = leader_py.get_inner().clone(); + let drt = drt.inner().clone(); + let handle: Handle = drt.runtime().primary(); + + let slot_manager_cell = Arc::new(OnceLock::new()); + + { + let slot_manager_cell = slot_manager_cell.clone(); + + handle.spawn(async move { + let ready = leader.wait_worker_sync_ready().await; + if !ready { + tracing::error!( + "KvConnectorLeader init aborted: leader worker barrier not ready!", + ); + return; + } + + let block_manager = match BlockManagerBuilder::new() + .worker_id(worker_id) + .leader(leader_py) // your distributed::KvbmLeader + .page_size(page_size) + .disable_device_pool(false) + .build() + .await + { + Ok(bm) => bm, + Err(e) => { + tracing::error!("Failed to build BlockManager: {}", e); + return; + } + }; + + // Create the slot manager now that everything is ready + let sm = ConnectorSlotManager::new( + block_manager.get_block_manager().clone(), + leader.clone(), + drt.clone(), + ); + + let _ = slot_manager_cell.set(sm); + + // another barrier sync to make sure worker init won't return before leader is ready + leader.spawn_leader_readiness_barrier(drt); + + tracing::info!("KvConnectorLeader init complete."); + }); + } + + Self { + slot_manager: slot_manager_cell, + block_size: page_size, + inflight_requests: HashSet::new(), + onboarding_slots: HashSet::new(), + iteration_counter: 0, + inflight_request_to_num_external_tokens: HashMap::new(), + } + } +} + +impl Leader for KvConnectorLeader { + #[inline] + fn slot_manager(&self) -> &ConnectorSlotManager { + self.slot_manager + .get() + .expect("slot_manager not initialized") + } + + /// Match the tokens in the request with the available block pools. + /// Note: the necessary details of the request are captured prior to this call. For trtllm, + /// we make a create slot call prior to this call, so a slot is guaranteed to exist. + /// + /// To align with the connector interface, we must ensure that if no blocks are matched, we return (0, false). + /// In our implementation, if we match any block, we return (num_matched_tokens, true). + #[tracing::instrument(level = "debug", skip(self, request_num_tokens, num_computed_tokens))] + fn get_num_new_matched_tokens( + &mut self, + request_id: String, + request_num_tokens: usize, + num_computed_tokens: usize, + ) -> anyhow::Result<(usize, bool)> { + tracing::debug!( + "request_num_tokens: {request_num_tokens}; num_computed_tokens: {num_computed_tokens}" + ); + + // TRTLLM could match partial blocks if enable_partial_reuse = True, + // immediately return 0 to simplify things. + if num_computed_tokens % self.block_size != 0 { + return Ok((0, false)); + } + + let shared_slot = self.slot_manager().get_slot(&request_id)?; + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + + if slot.state() == SlotState::Prefilling { + tracing::warn!("slot is in the Prefilled state; this seems like we need to reset the slot and start over"); + slot.reset(); + } + + // early exit if we cannot match full block + if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { + let total_tokens = slot.sequence().total_tokens(); + tracing::debug!( + "total_tokens in sequence: {total_tokens}; num_computed_tokens: {num_computed_tokens}; can not match full block." + ); + return Ok((0, false)); + } + + // find matches for any remaining tokens + // this will advance the computed position and hold any newly matched blocks in the slot + slot.acquire_local_matches(num_computed_tokens)?; + + // return the number of external tokens that are ready for onboarding + // we always return true here as we always asynchronously onboard matched blocks + if let SlotState::OnboardStaged(num_external_tokens) = slot.state() { + debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0); + tracing::debug!( + request_id = request_id, + "scheduling onboarding for {} external tokens", + num_external_tokens + ); + // Add to the map so that onboarding can be triggered in update_state_after_alloc. + self.inflight_request_to_num_external_tokens + .insert(request_id, num_external_tokens); + Ok((num_external_tokens, true)) + } else { + Ok((0, false)) + } + } + + /// Note: TRTLLM will not provide any scheduler output data for requests that are onboarding. it is entirely + /// on the connector's implementation to handle this case. + #[tracing::instrument(level = "debug", skip_all, fields(request_id))] + fn update_state_after_alloc( + &mut self, + request_id: String, + block_ids: Vec, + context_current_position: usize, + ) -> anyhow::Result<()> { + tracing::debug!(request_id, "num_device_blocks: {}", block_ids.len(),); + + let shared_slot = self.slot_manager().get_slot(&request_id)?; + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + + // we have not yet advanced the computed position, but now we can, since we have an indication that we have + // necessary gpu blocks into which we will load the external tokens. + + slot.append_mutable_device_blocks(&block_ids)?; + + if let Some(&num_external_tokens) = self + .inflight_request_to_num_external_tokens + .get(&request_id) + { + if num_external_tokens > 0 { + let num_computed_tokens = (context_current_position + 1) - num_external_tokens; + slot.record_cached_device_tokens(num_computed_tokens); + slot.advance_computed_position(num_computed_tokens)?; + + tracing::debug!( + request_id = request_id, + "triggering onboarding for {} external tokens", + num_external_tokens + ); + slot.trigger_onboarding(num_external_tokens)?; + self.onboarding_slots.insert(request_id.clone()); + } + + self.inflight_request_to_num_external_tokens + .remove(&request_id); + } + + Ok(()) + } + + #[tracing::instrument(level = "debug", skip_all, fields(iteration = self.iteration_counter + 1))] + fn build_connector_metadata( + &mut self, + scheduler_output: SchedulerOutput, + ) -> anyhow::Result> { + // the iteration counter is used to track the number of times we have built the connector metadata + // all connetor operations have the iteration counter at which they were issued. + // this allows operations to be lazily enqueued to the transfer engine + // the worker side of the connector will track all operations for completion before the request is + // allowed to be marked as finished. + self.iteration_counter += 1; + let iteration = self.iteration_counter; + + tracing::debug!("Building connector metadata"); + tracing::debug!("SchedulerOutput: {scheduler_output:#?}"); + + let mut inflight_requests = self.inflight_requests.clone(); + let mut md = ConnectorMetadata::new(iteration); + + let onboarding_slots = std::mem::take(&mut self.onboarding_slots); + + // Worker-side - we create a request slot for onboarding, then delete it when onboarding is finished, then + // recreate it again when we start the prefill/decode phase. + // + // This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot + // once for onboarding (this loop), then again for prefill/decode (new_requests loop). + for request_id in onboarding_slots.iter() { + let shared_slot = self.slot_manager().get_slot(request_id)?; + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + + md.create_slot(request_id.clone()); + + if let Some(pending_ops) = slot.take_pending_operations() { + tracing::debug!("adding {} pending onboarding operations", pending_ops.len()); + md.add_operations(pending_ops); + } + } + + // todo: update the code and abstraction to account for this two-phase lifecycle. + for new_req in &scheduler_output.new_requests { + let request_id = &new_req.request_id; + assert!( + inflight_requests.remove(request_id), + "request_id {request_id} not found in inflight_requests: " + ); + + let shared_slot = self.slot_manager().get_slot(request_id)?; + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + + // inform the worker that a new request-slot should be created + md.create_slot(new_req.request_id.clone()); + + slot.record_start_iteration(iteration)?; + + debug_assert!( + matches!( + slot.state(), + SlotState::Initialized | SlotState::Onboarding(_) + ), + "current slot state: {:?}", + slot.state() + ); + + slot.apply_scheduler_output_with_computed_position( + &new_req.prompt_token_ids, + &new_req.block_ids, + new_req.num_computed_tokens - 1, + )?; + + if let Some(pending_ops) = slot.take_pending_operations() { + tracing::debug!( + "adding {} pending operations for slot {}", + pending_ops.len(), + new_req.request_id + ); + md.add_operations(pending_ops); + } + } + + for cached_req in &scheduler_output.cached_requests { + let request_id = &cached_req.request_id; + + // note: evicition might trigger this assert + assert!( + inflight_requests.remove(request_id), + "request_id {request_id} not found in inflight_requests: " + ); + + let shared_slot = self.slot_manager().get_slot(request_id)?; + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + + slot.apply_scheduler_output_with_computed_position( + &cached_req.new_token_ids, + &cached_req.new_block_ids, + cached_req.num_computed_tokens - 1, + )?; + + if let Some(pending_ops) = slot.take_pending_operations() { + tracing::debug!( + "adding {} pending operations for slot {}", + pending_ops.len(), + request_id + ); + md.add_operations(pending_ops); + } + } + + tracing::debug!("metadata: {md:#?}"); + serde_json::to_vec(&md) + .map_err(|e| anyhow::anyhow!("Failed to serialize connector metadata: {}", e)) + } + + fn request_finished( + &mut self, + request_id: String, + block_ids: Vec, + ) -> anyhow::Result { + tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}"); + // grab the slot + let shared_slot = self.slot_manager().get_slot(&request_id)?; + + // mark the slot as finished + let mut slot = shared_slot + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; + slot.mark_as_finished(self.iteration_counter)?; + + // todo: allow the request to resolve when it should exit + // the request may have some outstanding operations + // we would like to inform it to shutdown, then have it signal to the work that is officially gone, + // then we can remove the slot and trigger the worker to clean up as well. + + // remove it from the manager as we will never use it again + self.slot_manager().remove_slot(&request_id)?; + self.inflight_request_to_num_external_tokens + .remove(&request_id); + + // if the slot has finished, we can return false to trtllm, indicating all gpu blocks are free to be reused + // otherwise, we return false, which means there are still outstanding operations on gpu blocks which + // must be awaited before the gpu blocks can be reused. if we return true, then it is the worker side + // of the connector api which will be used to inform trtllm that the request is finished. + if let SlotState::Finished = slot.state() { + Ok(false) + } else { + debug_assert!(matches!(slot.state(), SlotState::Finishing)); + Ok(true) + } + } + + fn has_slot(&self, request_id: String) -> bool { + self.slot_manager().has_slot(&request_id) + } + + /// Create a new slot for the given request ID. + /// This is used to create a new slot for the request. + fn create_slot(&mut self, request: KvbmRequest, tokens: Vec) -> anyhow::Result<()> { + self.slot_manager() + .create_slot(&request.request_id, tokens, request.salt_hash)?; + + self.inflight_requests.insert(request.request_id); + + Ok(()) + } +} + +#[pyclass] +pub struct PyTrtllmKvConnectorLeader { + connector_leader: Box, +} + +#[pymethods] +impl PyTrtllmKvConnectorLeader { + #[new] + #[pyo3(signature = (worker_id, drt, page_size, leader))] + pub fn new( + worker_id: u64, + drt: PyDistributedRuntime, + page_size: usize, + leader: PyKvbmLeader, + ) -> Self { + let connector_leader: Box = + Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)); + Self { connector_leader } + } + + fn get_num_new_matched_tokens( + &mut self, + request_id: String, + request_num_tokens: usize, + num_computed_tokens: usize, + ) -> PyResult<(usize, bool)> { + self.connector_leader + .get_num_new_matched_tokens(request_id, request_num_tokens, num_computed_tokens) + .map_err(to_pyerr) + } + + fn update_state_after_alloc( + &mut self, + request_id: String, + block_ids: Vec, + context_current_position: usize, + ) -> PyResult<()> { + self.connector_leader + .update_state_after_alloc(request_id, block_ids, context_current_position) + .map_err(to_pyerr) + } + + fn build_connector_metadata(&mut self, scheduler_output: SchedulerOutput) -> PyResult> { + self.connector_leader + .build_connector_metadata(scheduler_output) + .map_err(to_pyerr) + } + + fn request_finished(&mut self, request_id: &str, block_ids: Vec) -> PyResult { + self.connector_leader + .request_finished(request_id.to_string(), block_ids) + .map_err(to_pyerr) + } + + fn has_slot(&self, request_id: &str) -> bool { + self.connector_leader.has_slot(request_id.to_string()) + } + + fn create_slot(&mut self, request: KvbmRequest, tokens: Vec) -> PyResult<()> { + self.connector_leader + .create_slot(request, tokens) + .map_err(to_pyerr) + } +} diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs new file mode 100644 index 0000000000..b967c2b653 --- /dev/null +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs @@ -0,0 +1,435 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use dynamo_llm::block_manager::connector::protocol::TransferType; +use dynamo_llm::block_manager::connector::scheduler::{ + Scheduler, TransferSchedulerClient, WorkerSchedulerClient, +}; + +use std::collections::HashSet; +use std::sync::{Arc, OnceLock}; + +use super::*; +use crate::llm::block_manager::distributed::get_barrier_id_prefix; +use crate::llm::block_manager::vllm::connector::worker::event_sync_blocking; +use crate::{ + llm::block_manager::distributed::VllmTensor, to_pyerr, + DistributedRuntime as PyDistributedRuntime, +}; + +use anyhow; +use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig}; +use dynamo_llm::block_manager::storage::torch::TorchTensor; +use dynamo_runtime::utils::task::CriticalTaskExecutionHandle; +use dynamo_runtime::DistributedRuntime; + +pub trait Worker: Send + Sync { + fn register_kv_caches( + &mut self, + num_device_blocks: usize, + page_size: usize, + device_id: usize, + dtype_width_bytes: usize, + kv_cache_tensor: Arc, + raw_event_handles: Vec, + ) -> anyhow::Result<()>; + + fn bind_connector_meta(&mut self, metadata: Vec) -> anyhow::Result<()>; + + fn start_load_kv(&mut self) -> anyhow::Result<()>; + + fn save_kv_layer(&mut self, layer_idx: usize) -> anyhow::Result<()>; + + fn get_finished( + &mut self, + finished_gen_req_ids: Vec, + started_loading_req_ids: Vec, + ) -> (Vec, Vec); +} + +pub struct KvConnectorWorker { + drt: DistributedRuntime, + kvbm_worker: OnceLock, + connector: WorkerSchedulerClient, + transfer_client: TransferSchedulerClient, + + /// Map of request id to inflight load requests + maybe_finished_onboarding: HashSet, + + /// Map of request id to inflight finished requests + maybe_finished_offloading: HashSet, + + onboarding_operations: Vec, + offloading_operations: Vec, + + bound: bool, + iteration: u64, + layers_complete: usize, + + /// cuda events created by the python side + layer_events: Vec, +} + +impl KvConnectorWorker { + fn new(py_drt: PyDistributedRuntime, trtllm_rank: String) -> anyhow::Result { + let drt = py_drt.inner.clone(); + let runtime = drt.runtime().primary(); + + let (scheduler, worker_client, transfer_client) = Scheduler::new(drt.primary_token()); + + CriticalTaskExecutionHandle::new_with_runtime( + move |_| { + let mut scheduler = scheduler; + async move { scheduler.run().await } + }, + drt.primary_token(), + "kv-connector-scheduler-task", + &runtime, + )? + .detach(); + + tracing::info!( + "KvConnectorWorker initialized with worker_rank: {}", + trtllm_rank + ); + + Ok(Self { + drt, + kvbm_worker: OnceLock::new(), + connector: worker_client, + transfer_client, + maybe_finished_onboarding: HashSet::new(), + maybe_finished_offloading: HashSet::new(), + onboarding_operations: Vec::new(), + offloading_operations: Vec::new(), + bound: false, + iteration: 0, + layers_complete: 0, + layer_events: Vec::new(), + }) + } +} + +impl Worker for KvConnectorWorker { + fn register_kv_caches( + &mut self, + num_device_blocks: usize, + page_size: usize, + device_id: usize, + dtype_width_bytes: usize, + kv_cache_tensor: Arc, + raw_event_handles: Vec, + ) -> anyhow::Result<()> { + if self.kvbm_worker.get().is_some() { + tracing::warn!("kvbm worker already registered"); + return Err(anyhow::anyhow!("kvbm worker already registered")); + } + + let kv_cache_tensors = vec![kv_cache_tensor as Arc]; + + let config = KvbmWorkerConfig::builder() + .drt(self.drt.clone()) + .num_device_blocks(num_device_blocks) + .page_size(page_size) + .tensors(kv_cache_tensors) + .device_id(device_id) + .dtype_width_bytes(dtype_width_bytes) + .is_fully_contiguous_layout(true) + .barrier_id_prefix(get_barrier_id_prefix()) + .scheduler_client(Some(self.transfer_client.clone())) + .build()?; + + self.layer_events = raw_event_handles; + + let worker = self.drt.runtime().primary().block_on(async move { + let worker = KvbmWorker::new(config).await?; + anyhow::Ok(worker) + })?; + + self.kvbm_worker + .set(worker) + .map_err(|_| anyhow::anyhow!("failed to set kvbm worker"))?; + + Ok(()) + } + + fn bind_connector_meta(&mut self, metadata: Vec) -> anyhow::Result<()> { + let metadata: ConnectorMetadata = serde_json::from_slice(&metadata)?; + self.bound = true; + self.iteration = metadata.iteration; + self.layers_complete = 0; + tracing::debug!( + iteration = self.iteration, + "bound new metadata: {metadata:#?}" + ); + + self.connector.start_next_iteration()?; + + debug_assert_eq!( + self.connector.iteration(), + metadata.iteration, + "iteration mismatch" + ); + + // local actions + // - create a request slot for each new request + // - for each action in the metadata, add the action to the request slot + // - send the list of actions to the engine to track completion + + for slot in metadata.new_slots { + debug_assert!(!self.connector.has_slot(&slot), "slot already exists"); + self.connector.create_slot(slot)?; + } + + let mut onboarding_operations = Vec::new(); + let mut offloading_operations = Vec::new(); + + for operation in metadata.operations { + tracing::debug!( + request_id = operation.request_id, operation_id = %operation.uuid, + "adding operation to slot: {operation:#?}" + ); + + match operation.transfer_type { + TransferType::Load => onboarding_operations.push(operation), + TransferType::Store => offloading_operations.push(operation), + } + } + + debug_assert!( + self.onboarding_operations.is_empty(), + "onboarding operations should be empty" + ); + self.onboarding_operations = onboarding_operations; + + debug_assert!( + self.offloading_operations.is_empty(), + "offloading operations should be empty" + ); + self.offloading_operations = offloading_operations; + + Ok(()) + } + + fn save_kv_layer(&mut self, _layer_idx: usize) -> anyhow::Result<()> { + self.layers_complete += 1; + if self.layers_complete == self.layer_events.len() { + let offloading_operations = std::mem::take(&mut self.offloading_operations); + // block on the the completion of the last layer + // todo(ryan): capture the context, pass this to the scheduler to do the await on another thread + // or put the event on a stream and use stream waits to keep it all on device. + event_sync_blocking(self.layer_events[self.layers_complete - 1]); + for operation in offloading_operations { + self.connector.enqueue_request(operation); + } + } + Ok(()) + } + + fn start_load_kv(&mut self) -> anyhow::Result<()> { + let onboarding_operations = self.onboarding_operations.clone(); + for operation in onboarding_operations { + let request_id = operation.request_id.clone(); + self.connector.enqueue_request(operation); + self.maybe_finished_onboarding.insert(request_id); + } + Ok(()) + } + + fn get_finished( + &mut self, + finished_gen_req_ids: Vec, + started_loading_req_ids: Vec, + ) -> (Vec, Vec) { + // we do not have to visit every slot on every pass, just slots we are waiting on + // + // there are two conditions where we would be waiting: + // 1. if we have requested a load, we need to wait for it to complete + // - the load request would come in via the metadata this is processsed in the bind + // 2. if we have requested a finished event, then we need to await for all outstanding + // operations to complete -- either by finishing or being cancelled + // - the finish request is triggered by this function, it is not seen in the metadata + // + // under each scenario, we mark the `maybe_finished_onboarding` and `maybe_finished_offloading` hashsets with + // the request id + // + // on each forward pass we visit the maybe slots to see if they are finished + let mut is_finished_offloading = HashSet::new(); + let mut is_finished_onboarding = HashSet::new(); + + // before we process the maybes, add any newly annotated finished requests + // to the maybe finished set + for request_id in finished_gen_req_ids { + tracing::debug!(request_id, "marking request as finished"); + + if !self.connector.has_slot(&request_id.to_string()) { + tracing::warn!( + request_id, + "finished request received for unknown request_id; assuming never started" + ); + continue; + } + + if self + .maybe_finished_offloading + .contains(&request_id.to_string()) + { + tracing::warn!(request_id, "possibly got a duplicate finished request; request_id already in the maybe_finished_offloading set"); + } else { + tracing::debug!( + request_id, + "received finished request; adding to maybe_finished_offloading set" + ); + self.maybe_finished_offloading + .insert(request_id.to_string()); + } + } + + for request_id in started_loading_req_ids { + tracing::debug!(request_id, "marking request as finished"); + + if !self.connector.has_slot(&request_id.to_string()) { + tracing::warn!( + request_id, + "finished request received for unknown request_id; assuming never started" + ); + continue; + } + + if self + .maybe_finished_onboarding + .contains(&request_id.to_string()) + { + tracing::warn!(request_id, "possibly got a duplicate finished request; request_id already in the maybe_finished_onboarding set"); + } + } + + // visit each request slot in the maybe finished set + for request_id in self.maybe_finished_offloading.iter() { + if self.connector.has_slot(request_id) { + if self.connector.is_complete(request_id) { + tracing::debug!(request_id, "request slot is finished offloading"); + is_finished_offloading.insert(request_id.to_string()); + } else { + tracing::debug!(request_id, "request slot is not finished offloading"); + } + } else { + // made this condition more strict slot existence checks were added as a prerequesite + // to be added to the maybe_finished_offloading set. + panic!("request slot missing for {request_id}; however, it was present when added to the maybe finished offloading set"); + } + } + + // remove the finished requests from the maybe finished set + // note: when storing is finished we also remove the request from the engine state + for request_id in &is_finished_offloading { + self.maybe_finished_offloading.remove(request_id); + + // currently chomping the error as the engine is closed and we are shutting down + if self.connector.has_slot(request_id) { + self.connector.remove_slot(request_id); + } else { + tracing::debug!(request_id, "is_finished_offloading: request slot is not found - likely aborted, removing from is finished offloading set"); + } + } + + // visit each request slot in the maybe finished set to see if it is finished + for request_id in self.maybe_finished_onboarding.iter() { + if self.connector.has_slot(request_id) { + if self.connector.is_complete(request_id) { + tracing::debug!(request_id, "request slot is finished onboarding"); + is_finished_onboarding.insert(request_id.clone()); + } else { + tracing::debug!(request_id, "request slot is not finished onboarding"); + } + } else { + panic!("request slot missing for {request_id}; however, it was present when added to the maybe finished onboarding set"); + } + } + + // remove the finished requests from the maybe finished set + for request_id in &is_finished_onboarding { + self.maybe_finished_onboarding.remove(request_id); + if self.connector.has_slot(request_id) { + self.connector.remove_slot(request_id); + } + } + + let finished_offloading: Vec = is_finished_offloading + .iter() + .filter_map(|s| s.parse::().ok()) // parse String -> u64 + .collect(); + + let finished_onboarding: Vec = is_finished_onboarding + .iter() + .filter_map(|s| s.parse::().ok()) // parse String -> u64 + .collect(); + + (finished_offloading, finished_onboarding) + } +} + +#[pyclass] +pub struct PyTrtllmKvConnectorWorker { + connector_worker: Box, +} + +#[pymethods] +impl PyTrtllmKvConnectorWorker { + #[new] + #[pyo3(signature = (py_drt, trtllm_rank))] + pub fn new(py_drt: PyDistributedRuntime, trtllm_rank: String) -> PyResult { + let connector_worker: Box = + Box::new(KvConnectorWorker::new(py_drt, trtllm_rank).map_err(to_pyerr)?); + Ok(Self { connector_worker }) + } + + pub fn register_kv_caches( + &mut self, + num_device_blocks: usize, + page_size: usize, + device_id: usize, + dtype_width_bytes: usize, + kv_cache_tensor: Py, + raw_event_handles: Vec, + ) -> PyResult<()> { + // Convert Python tensor to Rust VllmTensor objects + let rust_kv_cache_tensor = Arc::new(VllmTensor::new(kv_cache_tensor).map_err(to_pyerr)?); + + self.connector_worker + .register_kv_caches( + num_device_blocks, + page_size, + device_id, + dtype_width_bytes, + rust_kv_cache_tensor, + raw_event_handles, + ) + .map_err(to_pyerr) + } + + pub fn bind_connector_meta(&mut self, metadata: Vec) -> PyResult<()> { + self.connector_worker + .bind_connector_meta(metadata) + .map_err(to_pyerr) + } + + pub fn save_kv_layer(&mut self, layer_idx: usize) -> PyResult<()> { + self.connector_worker + .save_kv_layer(layer_idx) + .map_err(to_pyerr) + } + + pub fn start_load_kv(&mut self) -> PyResult<()> { + self.connector_worker.start_load_kv().map_err(to_pyerr) + } + + pub fn get_finished( + &mut self, + finished_gen_req_ids: Vec, + started_loading_req_ids: Vec, + ) -> (Vec, Vec) { + self.connector_worker + .get_finished(finished_gen_req_ids, started_loading_req_ids) + } +} diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index ba25fad0f9..cbf27a5ce9 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -11,7 +11,7 @@ use std::collections::HashSet; use std::sync::{Arc, OnceLock}; use super::*; -use crate::llm::block_manager::distributed::get_barrier_id; +use crate::llm::block_manager::distributed::get_barrier_id_prefix; use crate::{ llm::block_manager::distributed::VllmTensor, to_pyerr, DistributedRuntime as PyDistributedRuntime, @@ -166,7 +166,7 @@ impl Worker for KvConnectorWorker { .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) - .barrier_id(get_barrier_id()) + .barrier_id_prefix(get_barrier_id_prefix()) .scheduler_client(Some(self.transfer_client.clone())) .build()?; @@ -477,7 +477,7 @@ fn _get_current_context() -> CUcontext { ctx } -fn event_sync_blocking(event: u64) { +pub fn event_sync_blocking(event: u64) { let status = unsafe { cuEventSynchronize(event as CUevent) }; assert_eq!( status, diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/__init__.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/__init__.py new file mode 100644 index 0000000000..1a8431c3e3 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/__init__.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/__init__.py new file mode 100644 index 0000000000..f019457936 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .kvbm_connector_leader import DynamoKVBMConnectorLeader +from .kvbm_connector_worker import DynamoKVBMConnectorWorker + +__all__ = ["DynamoKVBMConnectorLeader", "DynamoKVBMConnectorWorker"] diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py new file mode 100644 index 0000000000..1ecd8406d5 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import List + +from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( + KvCacheConnectorScheduler, + SchedulerOutput, +) +from tensorrt_llm.bindings.executor import ExecutorConfig +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest + +from dynamo.llm import KvbmLeader +from dynamo.llm.trtllm_integration.rust import KvbmRequest +from dynamo.llm.trtllm_integration.rust import ( + KvConnectorLeader as RustKvConnectorLeader, +) +from dynamo.llm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput +from dynamo.runtime import DistributedRuntime + + +class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler): + def __init__(self, executor_config: ExecutorConfig): + super().__init__(executor_config) + self.drt = DistributedRuntime.detached() + + world_size = self._config.mapping.world_size + self.block_size = self._config.tokens_per_block + + # Set bytes_per_block to 0, because we will retrieve the actual value from the worker side. + leader = KvbmLeader(world_size, drt=self.drt) + + print( + f"KvConnectorLeader initialized with rank: {executor_config.mapping.rank}" + ) + self._connector = RustKvConnectorLeader( + executor_config.mapping.rank, self.drt, self.block_size, leader + ) + + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: + """ + Build the metadata for the worker. + This is called by the KV Cache Manager when adding a sequence. + Args: + scheduler_output: The data for all inflight requests. + Returns: + The metadata for the workers. + """ + output = RustSchedulerOutput() + + for req in scheduler_output.new_requests: + output.add_new_request( + str(req.request_id), + req.new_tokens, + req.new_block_ids, + req.computed_position + 1, + ) + + resumed_from_preemption = False + for req in scheduler_output.cached_requests: + output.add_cached_request( + str(req.request_id), + resumed_from_preemption, + req.new_tokens, + req.new_block_ids, + req.computed_position + 1, + ) + + return self._connector.build_connector_metadata(output) + + def get_num_new_matched_tokens( + self, request: LlmRequest, num_computed_tokens: int + ) -> tuple[int, bool]: + """ + Get the number of tokens that can be loaded from remote KV cache. + This does not include the tokens already matched on device (indicated by `num_computed_tokens`). + Args: + request: The request to get the number of tokens for. + num_computed_tokens: The number of tokens already matched on device. + Returns: + The number of tokens that can be loaded from remote KV cache. + Whether the tokens will be loaded asynchronously. + """ + self._create_slot(request) + return self._connector.get_num_new_matched_tokens( + str(request.request_id), + len(request.get_tokens(0)), + num_computed_tokens, + ) + + def update_state_after_alloc(self, request: LlmRequest, block_ids: List[int]): + """ + Called after get_num_new_matched_tokens is called to provide the block ids to the scheduler. + Args: + request: The request that was allocated resources. + block_ids: The KV cacheblock IDs that were allocated. + """ + self._connector.update_state_after_alloc( + str(request.request_id), block_ids, request.context_current_position + ) + + def request_finished(self, request: LlmRequest, cache_block_ids: list[int]) -> bool: + """ + Called when a request is finished generating tokens. + Args: + request: The request that finished generating tokens. + Returns: + Whether the request is performing asynchronous saving operations. + If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers). + """ + is_async_saving = self._connector.request_finished( + str(request.request_id), cache_block_ids + ) + return is_async_saving + + def _create_slot(self, request: LlmRequest) -> None: + """Create a slot for the request""" + + if self._connector.has_slot(str(request.request_id)): + return None + + if bool(request.multimodal_positions): + raise ValueError("Unsupported request - requires mm extra keys") + + all_token_ids = request.get_tokens(0) + + # extract the critial aspects of the request that effect how the tokens are hashed + request = KvbmRequest( + request_id=str(request.request_id), lora_name=None, salt_hash=None + ) + + self._connector.create_slot(request, all_token_ids) diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_worker.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_worker.py new file mode 100644 index 0000000000..ead3b67fe0 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_worker.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from tensorrt_llm import logger +from tensorrt_llm._torch.pyexecutor.kv_cache_connector import KvCacheConnectorWorker +from tensorrt_llm.bindings.executor import ExecutorConfig + +from dynamo.llm.trtllm_integration.rust import ( + KvConnectorWorker as RustKvConnectorWorker, +) +from dynamo.runtime import DistributedRuntime + + +class DynamoKVBMConnectorWorker(KvCacheConnectorWorker): + def __init__(self, executor_config: ExecutorConfig): + super().__init__(executor_config) + + self.drt = DistributedRuntime.detached() + + self.rank = executor_config.mapping.rank + + self._connector = RustKvConnectorWorker( + self.drt, str(executor_config.mapping.rank) + ) + + def register_kv_caches(self, kv_cache_tensor: torch.Tensor): + """ + Register the KV cache tensors to the worker. + This can be used for something like NIXL registration. + Args: + kv_cache_tensor: The contiguous KV cache tensor. + """ + print(f"Register KV Caches on rank {self.rank}") + logger.info( + f"KvConnectorWorker started registering the kv caches on rank {self._config.mapping.rank}" + ) + + num_device_blocks = kv_cache_tensor.shape[0] + page_size = self._config.tokens_per_block + device_id = kv_cache_tensor.device.index + kv_cache_dtype = kv_cache_tensor.dtype + + num_cache_layers = kv_cache_tensor.shape[1] + self.events = [ + torch.cuda.Event(enable_timing=False, interprocess=False) + for _ in range(num_cache_layers) + ] + + for event in self.events: + event.record(torch.cuda.current_stream(device_id)) + + raw_event_handles = [event.cuda_event for event in self.events] + + self._connector.register_kv_caches( + num_device_blocks, + page_size, + device_id, + kv_cache_dtype.itemsize, + kv_cache_tensor, + raw_event_handles, + ) + + def bind_connector_meta(self, metadata: object): + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + metadata (bytes): the connector metadata. + """ + super().bind_connector_meta(metadata) + self._connector.bind_connector_meta(metadata) + + def start_load_kv(self, stream: torch.cuda.Stream): + """ + Begin loading the KV cache in preparation for the next forward pass. + Specific blocks to transfer are indicated by the scheduler's metadata. + """ + self._connector.start_load_kv() + + def wait_for_save(self, stream: torch.cuda.Stream): + """ + Block until all synchronous saving operations are complete. Called at the end of the forward pass. + """ + pass + + def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream): + """ + Wait for a layer to finish being loaded before proceeding with the forward pass on the layer. + Note: This function is called immediately before the layer's work is enqueued into the stream. + Args: + layer_idx: The index of the layer to wait for. + stream: The stream the forward pass is being executed on. + """ + pass + + def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream): + """ + Begin saving the KV cache for a layer. + Note: This function is called immediately after the layer's work is enqueued into the stream. + Args: + layer_idx: The index of the layer to save. + stream: The stream the forward pass is being executed on. + """ + self.events[layer_idx].record(stream) + self._connector.save_kv_layer(layer_idx) + + def get_finished( + self, finished_gen_req_ids: list[int], started_loading_req_ids: list[int] + ) -> tuple[list[int], list[int]]: + """ + Get the requests that have finished loading and saving. + Args: + finished_gen_req_ids: The IDs of the requests that have finished generating tokens, and are now asynchronously saving. + started_loading_req_ids: The IDs of the requests that have started asynchronously loading. + Returns: + The IDs of the requests that have finished saving. + The IDs of the requests that have finished loading. + Note: IDs may only be returned from this call after they've been provided in the `finished_gen_req_ids` and `started_loading_req_ids` arguments. + Additionally, the runtime will only take action based on these returned IDs once they've been returned by ALL workers. This allows some workers to take longer than others to complete the operations. + """ + return self._connector.get_finished( + finished_gen_req_ids, started_loading_req_ids + ) diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/rust.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/rust.py new file mode 100644 index 0000000000..cd486e8527 --- /dev/null +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/rust.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Loader for the Rust-based TensorRT-LLM integration objects, using objects from _vllm_integration for now +""" + +try: + # TODO: use TRTLLM own integration module + from dynamo._core import _vllm_integration + + # Runtime - dynamically loaded classes from Rust extension + KvbmRequest = getattr(_vllm_integration, "KvbmRequest") + KvbmBlockList = getattr(_vllm_integration, "KvbmBlockList") + BlockState = getattr(_vllm_integration, "BlockState") + BlockStates = getattr(_vllm_integration, "BlockStates") + SlotUpdate = getattr(_vllm_integration, "SlotUpdate") + + KvConnectorWorker = getattr(_vllm_integration, "PyTrtllmKvConnectorWorker") + KvConnectorLeader = getattr(_vllm_integration, "PyTrtllmKvConnectorLeader") + SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput") + +except ImportError: + print( + "Failed to import Dynamo KVBM. TensorRT-LLM integration will not be available." + ) + KvbmRequest = None + KvbmBlockList = None + BlockState = None + BlockStates = None + SlotUpdate = None + KvConnectorWorker = None + KvConnectorLeader = None + SchedulerOutput = None + +__all__ = [ + "KvbmRequest", + "KvbmBlockList", + "BlockState", + "BlockStates", + "SlotUpdate", + "KvConnectorWorker", + "KvConnectorLeader", + "SchedulerOutput", +] diff --git a/lib/llm/src/block_manager/distributed.rs b/lib/llm/src/block_manager/distributed.rs index ee338166eb..9aa72c7b8f 100644 --- a/lib/llm/src/block_manager/distributed.rs +++ b/lib/llm/src/block_manager/distributed.rs @@ -8,7 +8,7 @@ mod zmq; mod leader; mod worker; -pub use leader::{KvbmLeader, KvbmLeaderConfig}; +pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig}; pub use transfer::BlockTransferHandler; pub use utils::{ BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType, @@ -130,21 +130,31 @@ mod tests { vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))]; let config = KvbmWorkerConfig::builder() - .barrier_id(barrier_id.clone()) + .barrier_id_prefix(barrier_id.clone()) .num_device_blocks(NUM_BLOCKS) .tensors(tensors) - .worker_id(i) + .device_id(i) .build()?; let worker = KvbmWorker::new(config).await?; workers.push(worker); } + let host_blocks = KvbmLeaderNumBlocksConfig { + cache_size_in_gb: 1.0, + num_blocks_overriden: NUM_BLOCKS, + }; + + let disk_blocks = KvbmLeaderNumBlocksConfig { + cache_size_in_gb: 1.0, + num_blocks_overriden: NUM_BLOCKS, + }; + let leader_config = KvbmLeaderConfig::builder() - .barrier_id(barrier_id) + .barrier_id_prefix(barrier_id) .world_size(num_workers) - .num_host_blocks(NUM_BLOCKS) - .num_disk_blocks(NUM_BLOCKS) + .host_blocks_config(host_blocks) + .disk_blocks_config(disk_blocks) .build()?; // When/if this returns, we know that all the workers were also successful. diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index 90d9eb457e..5ec0ad6097 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -11,10 +11,13 @@ use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier; use derive_builder::Builder; use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::sync::oneshot; -use tokio_util::sync::CancellationToken; +use tokio::sync::Notify; +use tokio::sync::OnceCell; +use tokio::time::sleep; /// Data that is sent to workers over ETCD to establish a ZMQ connection. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -25,17 +28,31 @@ pub struct KvbmLeaderData { pub num_disk_blocks: usize, } -#[derive(Builder, Clone, Debug)] -pub struct KvbmLeaderConfig { - #[builder(default = "0")] - num_host_blocks: usize, +#[derive(Builder, Clone, Debug, Default)] +pub struct KvbmLeaderNumBlocksConfig { + #[builder(default = "0.0")] + pub cache_size_in_gb: f64, #[builder(default = "0")] - num_disk_blocks: usize, + pub num_blocks_overriden: usize, +} +fn compute_num_blocks( + num_blocks_config: &KvbmLeaderNumBlocksConfig, + bytes_per_block: usize, +) -> usize { + if num_blocks_config.num_blocks_overriden > 0 { + num_blocks_config.num_blocks_overriden + } else { + ((num_blocks_config.cache_size_in_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize + } +} + +#[derive(Builder, Clone, Debug)] +pub struct KvbmLeaderConfig { /// The barrier id to use for syncing with workers. #[builder(default = "String::from(\"kvbm\")")] - barrier_id: String, + barrier_id_prefix: String, /// The world size. #[builder(default = "1")] @@ -47,6 +64,12 @@ pub struct KvbmLeaderConfig { #[builder(setter(strip_option))] drt: Option, + + #[builder(default = "KvbmLeaderNumBlocksConfig::default()")] + host_blocks_config: KvbmLeaderNumBlocksConfig, + + #[builder(default = "KvbmLeaderNumBlocksConfig::default()")] + disk_blocks_config: KvbmLeaderNumBlocksConfig, } impl KvbmLeaderConfig { @@ -55,6 +78,14 @@ impl KvbmLeaderConfig { } } +#[derive(Debug, Default)] +pub struct KvbmLeaderState { + pub num_device_blocks: Arc, + pub num_host_blocks: Arc, + pub num_disk_blocks: Arc, + pub workers_allocation_ready: Arc, +} + /// The leader of the KVBM. /// /// This is responsible for: @@ -62,9 +93,13 @@ impl KvbmLeaderConfig { /// - Syncing the leader barrier with workers. /// - Sending messages to workers. pub struct KvbmLeader { - num_device_blocks: usize, - zmq_leader: ZmqActiveMessageLeader, + state: Arc, + zmq_leader: Arc>, config: KvbmLeaderConfig, + //readiness flags + workers_sync_ready: Arc, + workers_sync_ready_notify: Arc, + workers_sync_done: Arc, } impl KvbmLeader { @@ -76,34 +111,106 @@ impl KvbmLeader { } }; - tracing::info!( - "Syncing leader barrier with {} workers on barrier id {}", - config.world_size, - config.barrier_id + let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?; + + let leader = Self { + state: Arc::new(KvbmLeaderState::default()), + zmq_leader: Arc::new(tokio::sync::OnceCell::new()), + config, + workers_sync_ready: Arc::new(AtomicBool::new(false)), + workers_sync_ready_notify: Arc::new(Notify::new()), + workers_sync_done: Arc::new(AtomicBool::new(false)), + }; + + let cancel_token = tokio_util::sync::CancellationToken::new(); + leader.spawn_barrier_task( + drt, + leader_sockets.pub_url.clone(), + leader_sockets.ack_url.clone(), ); + leader.spawn_zmq_task(leader_sockets, cancel_token); - let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?; + Ok(leader) + } + + fn spawn_barrier_task( + &self, + drt: DistributedRuntime, + leader_sockets_pub_url: String, + leader_sockets_ack_url: String, + ) { + let state = self.state.clone(); + let leader_config = self.config.clone(); + let ready = Arc::clone(&self.workers_sync_ready); + let notify = Arc::clone(&self.workers_sync_ready_notify); + let done = Arc::clone(&self.workers_sync_done); + + tokio::spawn(async move { + match KvbmLeader::run_barrier_sync( + drt, + leader_sockets_pub_url, + leader_sockets_ack_url, + leader_config, + ) + .await + { + Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) => { + // write back results + state + .num_device_blocks + .store(num_device_blocks, Ordering::Release); + state + .num_host_blocks + .store(num_host_blocks, Ordering::Release); + state + .num_disk_blocks + .store(num_disk_blocks, Ordering::Release); + ready.store(true, Ordering::Release); + done.store(true, Ordering::Release); + notify.notify_waiters(); + } + Err(e) => { + tracing::error!("Barrier sync failed: {e:?}"); + done.store(true, Ordering::Release); + notify.notify_waiters(); + } + } + }); + } - let zmq_data = Arc::new(KvbmLeaderData { - pub_url: leader_sockets.pub_url.clone(), - ack_url: leader_sockets.ack_url.clone(), - num_host_blocks: config.num_host_blocks, - num_disk_blocks: config.num_disk_blocks, + async fn run_barrier_sync( + drt: DistributedRuntime, + leader_sockets_pub_url: String, + leader_sockets_ack_url: String, + leader_config: KvbmLeaderConfig, + ) -> anyhow::Result<(usize, usize, usize)> { + let barrier_id_worker_to_leader = + format!("{}{}", leader_config.barrier_id_prefix, "-worker-to-leader"); + tracing::info!( + "Syncing leader barrier with {} workers on barrier id {}", + leader_config.world_size, + barrier_id_worker_to_leader + ); + let zmq_data_worker_to_leader: Arc = Arc::new(KvbmLeaderData { + pub_url: leader_sockets_pub_url.clone(), + ack_url: leader_sockets_ack_url.clone(), + num_host_blocks: 0, // doesn't matter for worker to leader sync + num_disk_blocks: 0, // doesn't matter for worker to leader sync }); // Build our leader barrier and publish the data. // TODO: Use a separate timeout parameter from the ZMQ connection timeout - let leader_barrier: LeaderBarrier = + let worker_to_leader_barrier: LeaderBarrier = LeaderBarrier::new( - config.barrier_id.clone(), - config.world_size, - Some(Duration::from_secs(config.leader_init_timeout_secs)), + barrier_id_worker_to_leader.clone(), + leader_config.world_size, + Some(Duration::from_secs(leader_config.leader_init_timeout_secs)), ); - let worker_data = leader_barrier - .sync(&drt, zmq_data.as_ref()) + let worker_data = worker_to_leader_barrier + .sync(&drt, zmq_data_worker_to_leader.as_ref()) .await - .map_err(|e| anyhow::anyhow!("Failed to sync leader barrier: {:?}", e))?; + .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?; let num_device_blocks = worker_data .values() @@ -111,46 +218,190 @@ impl KvbmLeader { .min() .unwrap(); - tracing::info!("Leader barrier synced with {} workers", config.world_size); + let bytes_per_block: usize = worker_data.values().map(|d| d.bytes_per_block).sum(); + + assert!( + bytes_per_block > 0, + "bytes_per_block must be greater than 0" + ); + + tracing::info!( + "Worker to leader barrier synced with {} workers", + leader_config.world_size + ); tracing::debug!("Worker data: {:?}", worker_data); - // Now, create our active message leader. - // This also blocks until a ZMQ connection has been established. - let cancel_token = CancellationToken::new(); - let zmq_leader = ZmqActiveMessageLeader::new( - leader_sockets, - config.world_size, - Duration::from_secs(config.leader_init_timeout_secs), - cancel_token.clone(), - ) - .await?; - - Ok(Self { - num_device_blocks, - zmq_leader, - config, - }) + let num_host_blocks = + compute_num_blocks(&leader_config.host_blocks_config, bytes_per_block); + let num_disk_blocks = + compute_num_blocks(&leader_config.disk_blocks_config, bytes_per_block); + + // Start the second sync to transfer num_host_blocks and num_disk_blocks to worker + let barrier_id_leader_to_worker = + format!("{}{}", leader_config.barrier_id_prefix, "-leader-to-worker"); + tracing::info!( + "Syncing leader barrier with {} workers on barrier id {}", + leader_config.world_size, + barrier_id_leader_to_worker + ); + + let zmq_data_leader_to_worker = Arc::new(KvbmLeaderData { + pub_url: leader_sockets_pub_url.clone(), + ack_url: leader_sockets_ack_url.clone(), + num_host_blocks, + num_disk_blocks, + }); + + let leader_to_worker_barrier: LeaderBarrier = + LeaderBarrier::new( + barrier_id_leader_to_worker.clone(), + leader_config.world_size, + Some(Duration::from_secs(leader_config.leader_init_timeout_secs)), + ); + + let _worker_data = leader_to_worker_barrier + .sync(&drt, zmq_data_leader_to_worker.as_ref()) + .await + .map_err(|e| anyhow::anyhow!("Failed to sync leader to worker barrier: {:?}", e))?; + + tracing::info!( + "Worker to leader barrier synced with {} workers", + leader_config.world_size + ); + Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) + } + + fn spawn_zmq_task( + &self, + leader_sockets: LeaderSockets, + cancel: tokio_util::sync::CancellationToken, + ) { + let cell = self.zmq_leader.clone(); + let state = self.state.clone(); + let world_size = self.config.world_size; + let timeout = self.config.leader_init_timeout_secs; + + tokio::spawn(async move { + let res = ZmqActiveMessageLeader::new( + leader_sockets, + world_size, + std::time::Duration::from_secs(timeout), + cancel, + ) + .await; + + match res { + Ok(zmq) => { + let _ = cell.set(zmq); + // mark ready + state + .workers_allocation_ready + .store(true, Ordering::Release); + } + Err(e) => { + tracing::error!("ZMQ init failed: {e:?}"); + } + } + }); + } + + pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) { + let leader_config = self.config.clone(); + let handle = drt.runtime().primary(); + handle.spawn(async move { + match KvbmLeader::run_leader_readiness(drt, leader_config).await { + Ok(()) => { + tracing::info!("leader readiness barrier synced!"); + } + Err(e) => { + tracing::error!("leader readiness barrier failed: {e:?}"); + } + } + }); + } + + async fn run_leader_readiness( + drt: DistributedRuntime, + leader_config: KvbmLeaderConfig, + ) -> anyhow::Result<()> { + let barrier_id_leader_ready = + format!("{}{}", leader_config.barrier_id_prefix, "-leader-ready"); + tracing::info!( + "Syncing leader readiness barrier with {} workers on barrier id {}", + leader_config.world_size, + barrier_id_leader_ready + ); + + let leader_readiness_barrier: LeaderBarrier<(), ()> = LeaderBarrier::new( + barrier_id_leader_ready.clone(), + leader_config.world_size, + Some(Duration::from_secs(leader_config.leader_init_timeout_secs)), + ); + + let _ = leader_readiness_barrier + .sync(&drt, &()) + .await + .map_err(|e| { + anyhow::anyhow!("Failed to sync leader readiness barrier on leader: {:?}", e) + })?; + + Ok(()) } pub async fn transfer_blocks_request( &self, request: BlockTransferRequest, ) -> anyhow::Result> { + let zmq = self + .zmq_leader + .get() + .ok_or_else(|| anyhow::anyhow!("ZMQ leader not ready"))?; let data = vec![serde_json::to_vec(&request)?]; - self.zmq_leader - .broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data) - .await + zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await + } + + pub fn is_worker_sync_ready(&self) -> bool { + self.workers_sync_ready.load(Ordering::Acquire) + } + + pub fn is_worker_sync_done(&self) -> bool { + self.workers_sync_done.load(Ordering::Acquire) } pub fn num_device_blocks(&self) -> usize { - self.num_device_blocks + self.state.num_device_blocks.load(Ordering::Acquire) } pub fn num_host_blocks(&self) -> usize { - self.config.num_host_blocks + self.state.num_host_blocks.load(Ordering::Acquire) } pub fn num_disk_blocks(&self) -> usize { - self.config.num_disk_blocks + self.state.num_disk_blocks.load(Ordering::Acquire) + } + + pub async fn wait_worker_sync_ready(&self) -> bool { + if self.is_worker_sync_ready() { + return true; + } + if self.is_worker_sync_done() { + return false; + } + + let notified = self.workers_sync_ready_notify.notified(); + if self.is_worker_sync_ready() { + return true; + } + if self.is_worker_sync_done() { + return false; + } + + // bounded wait + tokio::select! { + _ = notified => { + self.is_worker_sync_ready() + } + _ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false, + } } } diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index fc4c9a8232..3ce729f4ab 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -35,6 +35,7 @@ use dynamo_runtime::{ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KvbmWorkerData { pub num_device_blocks: usize, + pub bytes_per_block: usize, } pub fn load_and_validate_tensors( @@ -82,7 +83,7 @@ pub fn load_and_validate_tensors( Ok((device_tensors, shape.unwrap())) } -#[derive(Builder)] +#[derive(Builder, Clone)] #[builder(pattern = "owned")] pub struct KvbmWorkerConfig { drt: DistributedRuntime, @@ -101,8 +102,11 @@ pub struct KvbmWorkerConfig { #[builder(default = "2")] dtype_width_bytes: usize, + #[builder(default = false)] + is_fully_contiguous_layout: bool, + #[builder(default = "String::from(\"kvbm\")")] - barrier_id: String, + barrier_id_prefix: String, #[builder(default = "None")] scheduler_client: Option, @@ -153,37 +157,57 @@ impl KvbmWorker { ))); } - let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks { - (false, shape[1]) - } else if shape[1] >= config.num_device_blocks { - (true, shape[0]) + let layout_type: LayoutType; + let mut outer_dim = 1; + let num_layers; + let inner_dim; + if !config.is_fully_contiguous_layout { + let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks { + (false, shape[1]) + } else if shape[1] >= config.num_device_blocks { + (true, shape[0]) + } else { + return Err(anyhow::anyhow!(format!( + "Unsupported kv cache layout. Got shape: {:?}", + shape + ))); + }; + layout_type = LayoutType::LayerSeparate { outer_contiguous }; + num_layers = device_tensors.len(); + inner_dim = shape[2..].iter().product::() / config.page_size; + + tracing::info!( + "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", + device_tensors.len(), + outer_dim, + config.page_size, + inner_dim + ); } else { - return Err(anyhow::anyhow!(format!( - "Unsupported kv cache layout. Got shape: {:?}", - shape - ))); - }; - - let inner_dim = shape[2..].iter().product::() / config.page_size; + layout_type = LayoutType::FullyContiguous; + num_layers = shape[1]; + outer_dim = shape[2]; + inner_dim = shape[3..].iter().product::() / config.page_size; + tracing::info!( + "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", + num_layers, + outer_dim, + config.page_size, + inner_dim + ); + } - tracing::info!( - "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", - device_tensors.len(), - outer_dim, - config.page_size, - inner_dim - ); + let bytes_per_block = + num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes; let mut layout_builder_instance = LayoutConfigBuilder::default(); let layout_builder = layout_builder_instance - .num_layers(device_tensors.len()) + .num_layers(num_layers) .outer_dim(outer_dim) .page_size(config.page_size) .inner_dim(inner_dim) .dtype_width_bytes(config.dtype_width_bytes); - let layout_type = LayoutType::LayerSeparate { outer_contiguous }; - let device_layout = layout_builder .num_blocks(config.num_device_blocks) .build()? @@ -195,20 +219,42 @@ impl KvbmWorker { // let scheduler = KvbmWorkerScheduler::new(config.scheduler.clone()); let cancel_token = config.drt.primary_token().clone(); + // barrier sync with leader to get the leader data + let leader_data = tokio::task::block_in_place(|| { + // This is now synchronous blocking code + // We need a separate current-thread runtime to block_on async calls here + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + KvbmWorker::leader_barrier_sync( + config.clone(), + cancel_token.clone(), + bytes_per_block, + ) + .await + }) + })?; + // establish a oneshot channel to get back the raw BlockTransferHandler let (handler_tx, handler_rx) = oneshot::channel(); + // establish a oneshot channel to block on the main routine to wait for layout allocation readiness + let (layout_ready_tx, layout_ready_rx) = oneshot::channel::(); + let scheduler_client = config.scheduler_client.clone(); + let worker_config = config.clone(); + // start background worker task to do layout allocation for host or disk let task = CriticalTaskExecutionHandle::new( move |cancel_token| { KvbmWorker::worker_task( device_layout, layout_builder_clone, + leader_data, layout_type, - config, + worker_config, cancel_token, handler_tx, + layout_ready_tx, scheduler_client, ) }, @@ -216,6 +262,24 @@ impl KvbmWorker { "kvbm-worker-task", )?; + let worker_config = config.clone(); + let cancel_for_barrier = cancel_token.clone(); + // wait until the leader finished the initialization of all components + tokio::task::block_in_place(|| { + // This is now synchronous blocking code + // We need a separate current-thread runtime to block_on async calls here + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + KvbmWorker::leader_readiness_sync(worker_config, cancel_for_barrier).await + }) + })?; + + // waiting for the worker layout allocation ready + match layout_ready_rx.await { + Ok(_) => tracing::info!("worker layout allocation finished."), + Err(_) => tracing::error!("Worker layout dropped without sending"), + } + Ok(Self { task: Some(task), block_transfer_handler_rx: Some(handler_rx), @@ -248,15 +312,11 @@ impl KvbmWorker { Ok(blocks) } - async fn worker_task( - device_layout: Box>, - mut layout_builder: LayoutConfigBuilder, - layout_type: LayoutType, + async fn leader_barrier_sync( config: KvbmWorkerConfig, cancel_token: CancellationToken, - handler_tx: oneshot::Sender, - scheduler_client: Option, - ) -> anyhow::Result<()> { + bytes_per_block: usize, + ) -> anyhow::Result { let drt = config.drt.clone(); let worker_id = drt @@ -266,30 +326,63 @@ impl KvbmWorker { ))? .id() as usize; + let barrier_id_worker_to_leader = + format!("{}{}", config.barrier_id_prefix, "-worker-to-leader"); tracing::info!( "Worker {} waiting on barrier {}", worker_id, - config.barrier_id + barrier_id_worker_to_leader ); - let worker_barrier = WorkerBarrier::::new( - config.barrier_id, + let worker_to_leader_barrier = WorkerBarrier::::new( + barrier_id_worker_to_leader, worker_id.to_string(), ); let worker_data = KvbmWorkerData { num_device_blocks: config.num_device_blocks, + bytes_per_block, }; + // leader_data is not important in the worker to leader phase + let _leader_data = tokio::select! { + _ = cancel_token.cancelled() => { + return Err(anyhow::anyhow!("Cancelled")) + } + _leader_data = worker_to_leader_barrier.sync(&drt, &worker_data) => { + _leader_data + } + } + .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?; + + tracing::debug!( + "Worker {} received leader data: {:?} in worker to leader phase", + worker_id, + _leader_data + ); + + let barrier_id_leader_to_worker = + format!("{}{}", config.barrier_id_prefix, "-leader-to-worker"); + tracing::info!( + "Worker {} waiting on barrier {}", + worker_id, + barrier_id_leader_to_worker + ); + + let leader_to_worker_barrier = WorkerBarrier::::new( + barrier_id_leader_to_worker, + worker_id.to_string(), + ); + let leader_data = tokio::select! { _ = cancel_token.cancelled() => { - return Ok(()) + return Err(anyhow::anyhow!("Cancelled")) } - leader_data = worker_barrier.sync(&drt, &worker_data) => { + leader_data = leader_to_worker_barrier.sync(&drt, &worker_data) => { leader_data } } - .map_err(|e| anyhow::anyhow!("Failed to sync worker barrier: {:?}", e))?; + .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?; tracing::info!( "Worker {} received leader data: {:?}", @@ -297,6 +390,67 @@ impl KvbmWorker { leader_data ); + Ok(leader_data) + } + + async fn leader_readiness_sync( + config: KvbmWorkerConfig, + cancel_token: CancellationToken, + ) -> anyhow::Result<()> { + let drt = config.drt.clone(); + + let worker_id = drt + .primary_lease() + .ok_or(anyhow::anyhow!( + "unable to get primary lease; check that drt is not static" + ))? + .id() as usize; + + let barrier_id_leader_readiness = + format!("{}{}", config.barrier_id_prefix, "-leader-ready"); + tracing::info!( + "Worker {} waiting on barrier {}", + worker_id, + barrier_id_leader_readiness + ); + + let leader_readiness_barrier = + WorkerBarrier::<(), ()>::new(barrier_id_leader_readiness, worker_id.to_string()); + + // leader_data is not important in the leader readiness case + tokio::select! { + _ = cancel_token.cancelled() => { + return Err(anyhow::anyhow!("Cancelled")) + } + _leader_data = leader_readiness_barrier.sync(&drt, &()) => { + _leader_data + } + } + .map_err(|e| anyhow::anyhow!("Failed to sync leader readiness barrier: {:?}", e))?; + + Ok(()) + } + + async fn worker_task( + device_layout: Box>, + mut layout_builder: LayoutConfigBuilder, + leader_data: KvbmLeaderData, + layout_type: LayoutType, + config: KvbmWorkerConfig, + cancel_token: CancellationToken, + handler_tx: oneshot::Sender, + layout_ready_tx: oneshot::Sender, + scheduler_client: Option, + ) -> anyhow::Result<()> { + let drt = config.drt.clone(); + + let worker_id = drt + .primary_lease() + .ok_or(anyhow::anyhow!( + "unable to get primary lease; check that drt is not static" + ))? + .id() as usize; + let agent = build_agent(worker_id, leader_data.num_disk_blocks > 0)?; let transfer_context = Arc::new(TransferContext::new( @@ -380,6 +534,10 @@ impl KvbmWorker { cancel_token.clone(), )?; + if layout_ready_tx.send("finished".to_string()).is_err() { + tracing::error!("worker receiver dropped before result was sent"); + } + // TODO: Some sort of fancy loop here. // For now, just wait for cancellation. cancel_token.cancelled().await; diff --git a/lib/llm/src/block_manager/pool/managed.rs b/lib/llm/src/block_manager/pool/managed.rs index ab6dce3856..aca3c5e3a2 100644 --- a/lib/llm/src/block_manager/pool/managed.rs +++ b/lib/llm/src/block_manager/pool/managed.rs @@ -456,6 +456,7 @@ impl BlockPool &self, sequence_hashes: &[SequenceHash], ) -> BlockPoolResult> { + tracing::debug!("find matching for sequence_hashes: {:?}", sequence_hashes); self._match_sequence_hashes(sequence_hashes)? .blocking_recv() .map_err(|_| BlockPoolError::ProgressEngineShutdown)? From 1bdfd374b303b70fd047d108562727a8ae38b009 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 21 Aug 2025 23:35:34 -0700 Subject: [PATCH 02/17] fix computed tokens fix fix position fix fix fix fix fix fix Signed-off-by: richardhuo-nv --- .../rust/llm/block_manager/vllm/connector/leader/slot.rs | 8 ++++++-- .../llm/block_manager/vllm/connector/trtllm_leader.rs | 8 ++++---- .../trtllm_integration/connector/kvbm_connector_leader.py | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index ccca6f94d7..56e18b0d58 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -593,13 +593,17 @@ impl Slot for VllmConnectorSlot { // in onborading case if computed_position < self.current_position { tracing::debug!( - "computed_position={} <= current_position={}, so we are onboarding during prefilling phase", + "computed_position={} < current_position={}, so we are onboarding during prefilling phase", computed_position, self.current_position ); return Ok(()); } // now we decide what we should do for the new computed tokens + tracing::debug!( + "applying scheduler output, computed_position={}, sequence_total_tokens={}", + computed_position, self.sequence.total_tokens() + ); if computed_position < self.sequence.total_tokens() { // no need to apply new tokens, since it's applied when created the slot during prefilling @@ -621,7 +625,7 @@ impl Slot for VllmConnectorSlot { } let num_candidate_blocks = - ((computed_position + 1) / self.block_size) - self.evaluated_blocks; + (computed_position / self.block_size) - self.evaluated_blocks; if num_candidate_blocks != 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index bb8d15cccd..012a1691a5 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -210,7 +210,7 @@ impl Leader for KvConnectorLeader { block_ids: Vec, context_current_position: usize, ) -> anyhow::Result<()> { - tracing::debug!(request_id, "num_device_blocks: {}", block_ids.len(),); + tracing::debug!(request_id, "num_device_blocks: {}, context_current_position: {}", block_ids.len(), context_current_position); let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot @@ -227,7 +227,7 @@ impl Leader for KvConnectorLeader { .get(&request_id) { if num_external_tokens > 0 { - let num_computed_tokens = (context_current_position + 1) - num_external_tokens; + let num_computed_tokens = context_current_position - num_external_tokens; slot.record_cached_device_tokens(num_computed_tokens); slot.advance_computed_position(num_computed_tokens)?; @@ -317,7 +317,7 @@ impl Leader for KvConnectorLeader { slot.apply_scheduler_output_with_computed_position( &new_req.prompt_token_ids, &new_req.block_ids, - new_req.num_computed_tokens - 1, + new_req.num_computed_tokens, )?; if let Some(pending_ops) = slot.take_pending_operations() { @@ -347,7 +347,7 @@ impl Leader for KvConnectorLeader { slot.apply_scheduler_output_with_computed_position( &cached_req.new_token_ids, &cached_req.new_block_ids, - cached_req.num_computed_tokens - 1, + cached_req.num_computed_tokens, )?; if let Some(pending_ops) = slot.take_pending_operations() { diff --git a/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py index 1ecd8406d5..336925af2e 100644 --- a/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py +++ b/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_leader.py @@ -54,7 +54,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: str(req.request_id), req.new_tokens, req.new_block_ids, - req.computed_position + 1, + req.computed_position, ) resumed_from_preemption = False @@ -64,7 +64,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: resumed_from_preemption, req.new_tokens, req.new_block_ids, - req.computed_position + 1, + req.computed_position, ) return self._connector.build_connector_metadata(output) From 85a75638fcb677fc08f1f7b6bfbea164169c77ce Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Sat, 23 Aug 2025 14:10:45 -0700 Subject: [PATCH 03/17] fix layout Signed-off-by: richardhuo-nv --- .../llm/block_manager/vllm/connector/leader/slot.rs | 10 +++++----- lib/llm/src/block_manager/layout.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 56e18b0d58..1c8fbdb816 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -625,7 +625,7 @@ impl Slot for VllmConnectorSlot { } let num_candidate_blocks = - (computed_position / self.block_size) - self.evaluated_blocks; + ((computed_position + 1) / self.block_size) - self.evaluated_blocks; if num_candidate_blocks != 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. @@ -1334,10 +1334,10 @@ async fn process_offload_request( // 4. Wait for the offload request to complete match notify_receiver.await { Ok(_) => { - tracing::debug!("Transfer completed successfully"); + tracing::debug!("Offloading transfer completed successfully"); } Err(_) => { - return Err(anyhow::anyhow!("Transfer completion notification failed")); + return Err(anyhow::anyhow!("Offloading transfer completion notification failed")); } } tracing::debug!( @@ -1408,10 +1408,10 @@ async fn process_onboard_request( match notify_receiver.await { Ok(_) => { - tracing::debug!("Transfer completed successfully"); + tracing::debug!("Onboarding transfer completed successfully"); } Err(_) => { - return Err(anyhow::anyhow!("Transfer completion notification failed")); + return Err(anyhow::anyhow!("Onboarding transfer completion notification failed")); } } diff --git a/lib/llm/src/block_manager/layout.rs b/lib/llm/src/block_manager/layout.rs index 6032c415c2..53f347b9d0 100644 --- a/lib/llm/src/block_manager/layout.rs +++ b/lib/llm/src/block_manager/layout.rs @@ -334,9 +334,9 @@ impl FullyContiguousConfig { config.validate()?; let alignment = config.alignment; - let memory_region_size = config.page_size * config.inner_dim * config.dtype_width_bytes; - let outer_dim_stride_in_bytes = memory_region_size; + let outer_dim_stride_in_bytes = config.page_size * config.inner_dim * config.dtype_width_bytes; let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim; + let memory_region_size = layer_stride_in_bytes; let natural_block_stride = config.num_layers * layer_stride_in_bytes; let block_stride_in_bytes = if alignment > 1 { From af20c6f1a2bde6228a47984278b73a03adfa1914 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 25 Aug 2025 14:54:33 -0700 Subject: [PATCH 04/17] fmt and rebase Signed-off-by: richardhuo-nv --- .../vllm/connector/leader/slot.rs | 15 ++++++++---- .../vllm/connector/trtllm_leader.rs | 23 ++++++++++++++++++- .../src/block_manager/distributed/leader.rs | 4 ++-- lib/llm/src/block_manager/layout.rs | 3 ++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 1c8fbdb816..90163e4604 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -601,9 +601,10 @@ impl Slot for VllmConnectorSlot { // now we decide what we should do for the new computed tokens tracing::debug!( - "applying scheduler output, computed_position={}, sequence_total_tokens={}", - computed_position, self.sequence.total_tokens() - ); + "applying scheduler output, computed_position={}, sequence_total_tokens={}", + computed_position, + self.sequence.total_tokens() + ); if computed_position < self.sequence.total_tokens() { // no need to apply new tokens, since it's applied when created the slot during prefilling @@ -1337,7 +1338,9 @@ async fn process_offload_request( tracing::debug!("Offloading transfer completed successfully"); } Err(_) => { - return Err(anyhow::anyhow!("Offloading transfer completion notification failed")); + return Err(anyhow::anyhow!( + "Offloading transfer completion notification failed" + )); } } tracing::debug!( @@ -1411,7 +1414,9 @@ async fn process_onboard_request( tracing::debug!("Onboarding transfer completed successfully"); } Err(_) => { - return Err(anyhow::anyhow!("Onboarding transfer completion notification failed")); + return Err(anyhow::anyhow!( + "Onboarding transfer completion notification failed" + )); } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index 012a1691a5..43305b9d62 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -10,6 +10,8 @@ use crate::llm::block_manager::BlockManagerBuilder; use crate::llm::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRequest}; use crate::DistributedRuntime as PyDistributedRuntime; use anyhow; +use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics; +use dynamo_runtime::metrics::prometheus_names::kvbm_connector; use std::collections::HashSet; use std::sync::{Arc, OnceLock}; use tokio::runtime::Handle; @@ -55,6 +57,7 @@ pub struct KvConnectorLeader { onboarding_slots: HashSet, iteration_counter: u64, inflight_request_to_num_external_tokens: HashMap, + kvbm_metrics: KvbmMetrics, } impl KvConnectorLeader { @@ -73,6 +76,13 @@ impl KvConnectorLeader { let drt = drt.inner().clone(); let handle: Handle = drt.runtime().primary(); + let ns = drt + .namespace(kvbm_connector::KVBM_CONNECTOR_LEADER) + .unwrap(); + + let kvbm_metrics = KvbmMetrics::new(&ns); + let kvbm_metrics_clone = kvbm_metrics.clone(); + let slot_manager_cell = Arc::new(OnceLock::new()); { @@ -107,6 +117,7 @@ impl KvConnectorLeader { block_manager.get_block_manager().clone(), leader.clone(), drt.clone(), + kvbm_metrics_clone.clone(), ); let _ = slot_manager_cell.set(sm); @@ -125,6 +136,7 @@ impl KvConnectorLeader { onboarding_slots: HashSet::new(), iteration_counter: 0, inflight_request_to_num_external_tokens: HashMap::new(), + kvbm_metrics, } } } @@ -195,6 +207,10 @@ impl Leader for KvConnectorLeader { // Add to the map so that onboarding can be triggered in update_state_after_alloc. self.inflight_request_to_num_external_tokens .insert(request_id, num_external_tokens); + + self.kvbm_metrics + .matched_tokens + .inc_by(num_external_tokens as u64); Ok((num_external_tokens, true)) } else { Ok((0, false)) @@ -210,7 +226,12 @@ impl Leader for KvConnectorLeader { block_ids: Vec, context_current_position: usize, ) -> anyhow::Result<()> { - tracing::debug!(request_id, "num_device_blocks: {}, context_current_position: {}", block_ids.len(), context_current_position); + tracing::debug!( + request_id, + "num_device_blocks: {}, context_current_position: {}", + block_ids.len(), + context_current_position + ); let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index 5ec0ad6097..e1e6368973 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -11,12 +11,12 @@ use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier; use derive_builder::Builder; use serde::{Deserialize, Serialize}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::time::Duration; -use tokio::sync::oneshot; use tokio::sync::Notify; use tokio::sync::OnceCell; +use tokio::sync::oneshot; use tokio::time::sleep; /// Data that is sent to workers over ETCD to establish a ZMQ connection. diff --git a/lib/llm/src/block_manager/layout.rs b/lib/llm/src/block_manager/layout.rs index 53f347b9d0..e137a15173 100644 --- a/lib/llm/src/block_manager/layout.rs +++ b/lib/llm/src/block_manager/layout.rs @@ -334,7 +334,8 @@ impl FullyContiguousConfig { config.validate()?; let alignment = config.alignment; - let outer_dim_stride_in_bytes = config.page_size * config.inner_dim * config.dtype_width_bytes; + let outer_dim_stride_in_bytes = + config.page_size * config.inner_dim * config.dtype_width_bytes; let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim; let memory_region_size = layer_stride_in_bytes; let natural_block_stride = config.num_layers * layer_stride_in_bytes; From cf771cec058f1467503419b3aa5eeb93901e1644 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Mon, 25 Aug 2025 15:19:37 -0700 Subject: [PATCH 05/17] fix fmt Signed-off-by: richardhuo-nv --- lib/llm/src/block_manager/distributed/worker.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index 3ce729f4ab..be7b651e39 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -431,6 +431,7 @@ impl KvbmWorker { Ok(()) } + #[allow(clippy::too_many_arguments)] async fn worker_task( device_layout: Box>, mut layout_builder: LayoutConfigBuilder, From cba8a85ab212ec67560b67173d6151f56f4516b8 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 21:11:54 -0700 Subject: [PATCH 06/17] fix tests Signed-off-by: richardhuo-nv --- .github/workflows/container-validation-dynamo.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/container-validation-dynamo.yml b/.github/workflows/container-validation-dynamo.yml index 1f9d593352..508ba43853 100644 --- a/.github/workflows/container-validation-dynamo.yml +++ b/.github/workflows/container-validation-dynamo.yml @@ -66,7 +66,7 @@ jobs: docker run -v ${{ github.workspace }}:/workspace -w /workspace \ --name ${{ env.CONTAINER_ID }}_pytest \ ${{ steps.define_image_tag.outputs.image_tag }} \ - bash -c "pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m \"${{ env.PYTEST_MARKS }}\"" + bash -c "pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m \"${{ env.PYTEST_MARKS }}\" --ignore /workspace/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector " - name: Copy test report from test Container if: always() run: | @@ -89,4 +89,4 @@ jobs: uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: Event File - path: ${{ github.event_path }} \ No newline at end of file + path: ${{ github.event_path }} From 63aa8f3ebb6c0e6d6860f93ff051af4f600cbe4d Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 21:47:19 -0700 Subject: [PATCH 07/17] resolve comments Signed-off-by: richardhuo-nv --- .../vllm/connector/trtllm_leader.rs | 9 ++--- .../src/block_manager/distributed/leader.rs | 36 +++++++++---------- .../src/block_manager/distributed/worker.rs | 14 ++++---- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index 43305b9d62..3728013b81 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -99,7 +99,7 @@ impl KvConnectorLeader { let block_manager = match BlockManagerBuilder::new() .worker_id(worker_id) - .leader(leader_py) // your distributed::KvbmLeader + .leader(leader_py) .page_size(page_size) .disable_device_pool(false) .build() @@ -177,11 +177,6 @@ impl Leader for KvConnectorLeader { .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; - if slot.state() == SlotState::Prefilling { - tracing::warn!("slot is in the Prefilled state; this seems like we need to reset the slot and start over"); - slot.reset(); - } - // early exit if we cannot match full block if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { let total_tokens = slot.sequence().total_tokens(); @@ -412,7 +407,7 @@ impl Leader for KvConnectorLeader { .remove(&request_id); // if the slot has finished, we can return false to trtllm, indicating all gpu blocks are free to be reused - // otherwise, we return false, which means there are still outstanding operations on gpu blocks which + // otherwise, we return true, which means there are still outstanding operations on gpu blocks which // must be awaited before the gpu blocks can be reused. if we return true, then it is the worker side // of the connector api which will be used to inform trtllm that the request is finished. if let SlotState::Finished = slot.state() { diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index e1e6368973..12fd7a63b7 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -123,11 +123,17 @@ impl KvbmLeader { }; let cancel_token = tokio_util::sync::CancellationToken::new(); - leader.spawn_barrier_task( - drt, + + // The leader_sockets struct cannot be cloned, + // so we use a tuple to "struct" the two urls + let leader_urls = ( leader_sockets.pub_url.clone(), leader_sockets.ack_url.clone(), ); + leader.spawn_barrier_task( + drt, + leader_urls + ); leader.spawn_zmq_task(leader_sockets, cancel_token); Ok(leader) @@ -136,8 +142,7 @@ impl KvbmLeader { fn spawn_barrier_task( &self, drt: DistributedRuntime, - leader_sockets_pub_url: String, - leader_sockets_ack_url: String, + leader_urls: (String, String), ) { let state = self.state.clone(); let leader_config = self.config.clone(); @@ -148,8 +153,7 @@ impl KvbmLeader { tokio::spawn(async move { match KvbmLeader::run_barrier_sync( drt, - leader_sockets_pub_url, - leader_sockets_ack_url, + leader_urls, leader_config, ) .await @@ -180,8 +184,7 @@ impl KvbmLeader { async fn run_barrier_sync( drt: DistributedRuntime, - leader_sockets_pub_url: String, - leader_sockets_ack_url: String, + leader_urls: (String, String), leader_config: KvbmLeaderConfig, ) -> anyhow::Result<(usize, usize, usize)> { let barrier_id_worker_to_leader = @@ -191,16 +194,10 @@ impl KvbmLeader { leader_config.world_size, barrier_id_worker_to_leader ); - let zmq_data_worker_to_leader: Arc = Arc::new(KvbmLeaderData { - pub_url: leader_sockets_pub_url.clone(), - ack_url: leader_sockets_ack_url.clone(), - num_host_blocks: 0, // doesn't matter for worker to leader sync - num_disk_blocks: 0, // doesn't matter for worker to leader sync - }); // Build our leader barrier and publish the data. // TODO: Use a separate timeout parameter from the ZMQ connection timeout - let worker_to_leader_barrier: LeaderBarrier = + let worker_to_leader_barrier: LeaderBarrier<(), worker::KvbmWorkerData> = LeaderBarrier::new( barrier_id_worker_to_leader.clone(), leader_config.world_size, @@ -208,7 +205,7 @@ impl KvbmLeader { ); let worker_data = worker_to_leader_barrier - .sync(&drt, zmq_data_worker_to_leader.as_ref()) + .sync(&drt, &()) .await .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?; @@ -245,14 +242,15 @@ impl KvbmLeader { barrier_id_leader_to_worker ); + let (leader_pub_url, leader_ack_url) = leader_urls; let zmq_data_leader_to_worker = Arc::new(KvbmLeaderData { - pub_url: leader_sockets_pub_url.clone(), - ack_url: leader_sockets_ack_url.clone(), + pub_url: leader_pub_url, + ack_url: leader_ack_url, num_host_blocks, num_disk_blocks, }); - let leader_to_worker_barrier: LeaderBarrier = + let leader_to_worker_barrier: LeaderBarrier = LeaderBarrier::new( barrier_id_leader_to_worker.clone(), leader_config.world_size, diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index be7b651e39..ca93b83ba9 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -334,7 +334,7 @@ impl KvbmWorker { barrier_id_worker_to_leader ); - let worker_to_leader_barrier = WorkerBarrier::::new( + let worker_to_leader_barrier = WorkerBarrier::<(), KvbmWorkerData>::new( barrier_id_worker_to_leader, worker_id.to_string(), ); @@ -344,8 +344,7 @@ impl KvbmWorker { bytes_per_block, }; - // leader_data is not important in the worker to leader phase - let _leader_data = tokio::select! { + tokio::select! { _ = cancel_token.cancelled() => { return Err(anyhow::anyhow!("Cancelled")) } @@ -356,9 +355,8 @@ impl KvbmWorker { .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?; tracing::debug!( - "Worker {} received leader data: {:?} in worker to leader phase", - worker_id, - _leader_data + "Worker {} sent the worker data in worker to leader phase", + worker_id ); let barrier_id_leader_to_worker = @@ -369,7 +367,7 @@ impl KvbmWorker { barrier_id_leader_to_worker ); - let leader_to_worker_barrier = WorkerBarrier::::new( + let leader_to_worker_barrier = WorkerBarrier::::new( barrier_id_leader_to_worker, worker_id.to_string(), ); @@ -378,7 +376,7 @@ impl KvbmWorker { _ = cancel_token.cancelled() => { return Err(anyhow::anyhow!("Cancelled")) } - leader_data = leader_to_worker_barrier.sync(&drt, &worker_data) => { + leader_data = leader_to_worker_barrier.sync(&drt, &()) => { leader_data } } From 7b5aed3d94eee2ce3e1892990bc5bd88f9336268 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Tue, 26 Aug 2025 21:56:30 -0700 Subject: [PATCH 08/17] fix fmt Signed-off-by: richardhuo-nv --- .../src/block_manager/distributed/leader.rs | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index 12fd7a63b7..5630b14fc5 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -130,20 +130,13 @@ impl KvbmLeader { leader_sockets.pub_url.clone(), leader_sockets.ack_url.clone(), ); - leader.spawn_barrier_task( - drt, - leader_urls - ); + leader.spawn_barrier_task(drt, leader_urls); leader.spawn_zmq_task(leader_sockets, cancel_token); Ok(leader) } - fn spawn_barrier_task( - &self, - drt: DistributedRuntime, - leader_urls: (String, String), - ) { + fn spawn_barrier_task(&self, drt: DistributedRuntime, leader_urls: (String, String)) { let state = self.state.clone(); let leader_config = self.config.clone(); let ready = Arc::clone(&self.workers_sync_ready); @@ -151,13 +144,7 @@ impl KvbmLeader { let done = Arc::clone(&self.workers_sync_done); tokio::spawn(async move { - match KvbmLeader::run_barrier_sync( - drt, - leader_urls, - leader_config, - ) - .await - { + match KvbmLeader::run_barrier_sync(drt, leader_urls, leader_config).await { Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) => { // write back results state @@ -250,12 +237,11 @@ impl KvbmLeader { num_disk_blocks, }); - let leader_to_worker_barrier: LeaderBarrier = - LeaderBarrier::new( - barrier_id_leader_to_worker.clone(), - leader_config.world_size, - Some(Duration::from_secs(leader_config.leader_init_timeout_secs)), - ); + let leader_to_worker_barrier: LeaderBarrier = LeaderBarrier::new( + barrier_id_leader_to_worker.clone(), + leader_config.world_size, + Some(Duration::from_secs(leader_config.leader_init_timeout_secs)), + ); let _worker_data = leader_to_worker_barrier .sync(&drt, zmq_data_leader_to_worker.as_ref()) From 48ae78ed09e0c1f4eb976f69ac12af0b2a726707 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Wed, 27 Aug 2025 21:55:27 -0700 Subject: [PATCH 09/17] integrate vllm integrate vllm add doc fix blocking fix fix fix fix fix fix fmt fix Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 115 +++++++++++ .../llm/block_manager/distributed/worker.rs | 5 +- .../block_manager/vllm/connector/leader.rs | 138 ++++++++----- .../vllm/connector/leader/recorder.rs | 107 ++++++++--- .../vllm/connector/trtllm_leader.rs | 2 +- .../vllm/connector/trtllm_worker.rs | 2 +- .../block_manager/vllm/connector/worker.rs | 2 +- .../llm/vllm_integration/connector_leader.py | 20 +- lib/llm/src/block_manager/distributed.rs | 2 +- .../src/block_manager/distributed/leader.rs | 17 ++ .../src/block_manager/distributed/worker.rs | 174 ++++++++++++++--- tests/kvbm/README.md | 14 +- tests/kvbm/test_determinism.py | 181 ++++++++++++++---- 13 files changed, 611 insertions(+), 168 deletions(-) create mode 100644 docs/guides/run_kvbm_in_trtllm.md diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md new file mode 100644 index 0000000000..4d31e35778 --- /dev/null +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -0,0 +1,115 @@ + + +# Running KVBM in TensorRT-LLM + +This guide explains how to leverage KVBM (KV Block Manager) to mange KV cache and do KV offloading in TensorRT-LLM (trtllm). + +To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest/architecture/kvbm_intro.html) + +> [!Note] +> - Ensure that `etcd` is running before starting. +> - KVBM does not currently support CUDA graphs in TensorRT-LLM. +> - KVBM only supports TensorRT-LLM’s PyTorch backend. + +## Quick Start + +To use KVBM in TensorRT-LLM, you can follow the steps below: + +```bash +# start up etcd for KVBM leader/worker registration and discovery +docker compose -f deploy/docker-compose.yml up -d + +# build a container containing trtllm and kvbm, note that KVBM integration is only availiable on TensorRT-LLM commit: TBD +./container/build.sh --framework trtllm --tensorrtllm-commit TBD --enable-kvbm + +# launch the container +./container/run.sh --framework trtllm -it --mount-workspace --use-nixl-gds + +# enable kv offloading to CPU memory +# 4 means 4GB of pinned CPU memory would be used +export DYN_KVBM_CPU_CACHE_GB=60 + +# enable kv offloading to disk +# 8 means 8GB of disk would be used +export DYN_KVBM_DISK_CACHE_GB=20 +``` + +```bash +# write an example LLM API config +cat > "/tmp/kvbm_llm_api_config.yaml" < "/tmp/kvbm_llm_api_config.yaml" <, + layout_blocking: bool, ) -> PyResult { let py_drt = drt.ok_or_else(|| { pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided") @@ -146,7 +147,7 @@ impl KvbmWorker { let worker = rt .block_on(async move { - let kvbm_worker = KvbmWorkerImpl::new(config).await?; + let kvbm_worker = KvbmWorkerImpl::new(config, layout_blocking).await?; anyhow::Ok(kvbm_worker) }) .map_err(to_pyerr)?; diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs index 3449847016..6a86caa5be 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs @@ -9,7 +9,7 @@ use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics; use dynamo_runtime::DistributedRuntime; use slot::{ConnectorSlotManager, SlotError, SlotManager, SlotState}; -use crate::llm::block_manager::BlockManager as PyBlockManager; +use crate::llm::block_manager::BlockManagerBuilder; use crate::llm::block_manager::{ distributed::KvbmLeader as PyKvbmLeader, vllm::connector::leader::slot::VllmConnectorSlot, vllm::KvbmRequest, VllmBlockManager, @@ -26,10 +26,12 @@ use dynamo_llm::block_manager::{ BasicMetadata, DiskStorage, ImmutableBlock, PinnedStorage, }; use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens}; - +use std::sync::{Arc, OnceLock}; use std::{collections::HashSet, sync::Mutex}; use tokio; +use tokio::runtime::Handle; use tokio::sync::mpsc; +use tokio::sync::oneshot; type VllmLocality = Logical; @@ -71,11 +73,13 @@ pub trait Leader: Send + Sync + std::fmt::Debug { fn has_slot(&self, request_id: String) -> bool; fn create_slot(&mut self, request: KvbmRequest, tokens: Vec) -> anyhow::Result<()>; + + fn slot_manager(&self) -> &ConnectorSlotManager; } #[derive(Debug)] pub struct KvConnectorLeader { - slot_manager: ConnectorSlotManager, + slot_manager: Arc>>, block_size: usize, inflight_requests: HashSet, onboarding_slots: HashSet, @@ -87,37 +91,86 @@ impl KvConnectorLeader { fn new( worker_id: String, drt: PyDistributedRuntime, - block_manager: PyBlockManager, - leader: PyKvbmLeader, + page_size: usize, + leader_py: PyKvbmLeader, ) -> Self { tracing::info!( "KvConnectorLeader initialized with worker_id: {}", worker_id ); - // if drt is none, then we must construct a runtime and distributed runtime - let block_manager = block_manager.get_block_manager().clone(); - let block_size = block_manager.block_size(); - - let leader = leader.get_inner(); - - // if we need a drt, get it from here + let leader = leader_py.get_inner().clone(); let drt = drt.inner().clone(); + let handle: Handle = drt.runtime().primary(); let ns = drt .namespace(kvbm_connector::KVBM_CONNECTOR_LEADER) .unwrap(); let kvbm_metrics = KvbmMetrics::new(&ns); + let kvbm_metrics_clone = kvbm_metrics.clone(); + + let slot_manager_cell = Arc::new(OnceLock::new()); + let (leader_ready_tx, leader_ready_rx) = oneshot::channel::(); + + { + let slot_manager_cell = slot_manager_cell.clone(); + + handle.spawn(async move { + let ready = leader.wait_worker_sync_ready().await; + if !ready { + tracing::error!( + "KvConnectorLeader init aborted: leader worker barrier not ready!", + ); + return; + } + + let block_manager = match BlockManagerBuilder::new() + .worker_id(0) + .leader(leader_py) + .page_size(page_size) + .disable_device_pool(false) + .build() + .await + { + Ok(bm) => bm, + Err(e) => { + tracing::error!("Failed to build BlockManager: {}", e); + return; + } + }; + + // Create the slot manager now that everything is ready + let sm = ConnectorSlotManager::new( + block_manager.get_block_manager().clone(), + leader.clone(), + drt.clone(), + kvbm_metrics_clone.clone(), + ); + + let _ = slot_manager_cell.set(sm); + + // another barrier sync to make sure worker init won't return before leader is ready + let _ = leader.run_leader_readiness_barrier_blocking(drt); + + if leader_ready_tx.send("finished".to_string()).is_err() { + tracing::error!("main routine receiver dropped before result was sent"); + } + }); + } + + tokio::task::block_in_place(|| { + handle.block_on(async { + match leader_ready_rx.await { + Ok(_) => tracing::info!("KvConnectorLeader init complete."), + Err(_) => tracing::warn!("KvConnectorLeader init channel dropped"), + } + }); + }); Self { - slot_manager: ConnectorSlotManager::new( - block_manager.clone(), - leader, - drt.clone(), - kvbm_metrics.clone(), - ), - block_size, + slot_manager: slot_manager_cell, + block_size: page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, @@ -127,6 +180,13 @@ impl KvConnectorLeader { } impl Leader for KvConnectorLeader { + #[inline] + fn slot_manager(&self) -> &ConnectorSlotManager { + self.slot_manager + .get() + .expect("slot_manager not initialized") + } + /// Match the tokens in the request with the available block pools. /// Note: the necessary details of the request are captured prior to this call. For vllm, /// we make a create slot call prior to this call, so a slot is guaranteed to exist. @@ -147,7 +207,7 @@ impl Leader for KvConnectorLeader { // the number of device matched tokens should be less than or equal to the number of tokens in the request debug_assert!(num_computed_tokens % self.block_size == 0); - let shared_slot = self.slot_manager.get_slot(&request_id)?; + let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -215,7 +275,7 @@ impl Leader for KvConnectorLeader { num_external_tokens ); - let shared_slot = self.slot_manager.get_slot(&request_id)?; + let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -271,7 +331,7 @@ impl Leader for KvConnectorLeader { // This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot // once for onboarding (this loop), then again for prefill/decode (new_requests loop). for request_id in onboarding_slots.iter() { - let shared_slot = self.slot_manager.get_slot(request_id)?; + let shared_slot = self.slot_manager().get_slot(request_id)?; let mut slot = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -300,7 +360,7 @@ impl Leader for KvConnectorLeader { "request_id {request_id} not found in inflight_requests: " ); - let shared_slot = self.slot_manager.get_slot(request_id)?; + let shared_slot = self.slot_manager().get_slot(request_id)?; let mut slot = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -343,7 +403,7 @@ impl Leader for KvConnectorLeader { // we really do not know what to expect here: // first let's try to get the slot, it might fail because maybe preemption put us thru // a finished cycle -- who knows - let shared_slot = self.slot_manager.get_slot(request_id); + let shared_slot = self.slot_manager().get_slot(request_id); match &shared_slot { Ok(_) => { tracing::info!("after preemption, slot is still alive"); @@ -371,7 +431,7 @@ impl Leader for KvConnectorLeader { "request_id {request_id} not found in inflight_requests: " ); - let shared_slot = self.slot_manager.get_slot(request_id)?; + let shared_slot = self.slot_manager().get_slot(request_id)?; let mut slot = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -399,7 +459,7 @@ impl Leader for KvConnectorLeader { } for unscheduled_req in inflight_requests.iter() { - let shared_slot = self.slot_manager.get_slot(unscheduled_req)?; + let shared_slot = self.slot_manager().get_slot(unscheduled_req)?; let mut slot_guard = shared_slot .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; @@ -424,7 +484,7 @@ impl Leader for KvConnectorLeader { ) -> anyhow::Result { tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}"); - if !self.slot_manager.has_slot(&request_id) { + if !self.slot_manager().has_slot(&request_id) { tracing::warn!( "request_finished called for request_id: {request_id} but slot is not found" ); @@ -433,7 +493,7 @@ impl Leader for KvConnectorLeader { } // grab the slot - let shared_slot = self.slot_manager.get_slot(&request_id)?; + let shared_slot = self.slot_manager().get_slot(&request_id)?; // mark the slot as finished let mut slot = shared_slot @@ -450,7 +510,7 @@ impl Leader for KvConnectorLeader { self.inflight_requests.remove(&request_id); // remove it from the manager as we will never use it again - self.slot_manager.remove_slot(&request_id)?; + self.slot_manager().remove_slot(&request_id)?; // if the slot has finished, we can return false to vllm, indicating all gpu blocks are free to be reused // otherwise, we return true, which means there are still outstanding operations on gpu blocks which @@ -465,13 +525,13 @@ impl Leader for KvConnectorLeader { } fn has_slot(&self, request_id: String) -> bool { - self.slot_manager.has_slot(&request_id) + self.slot_manager().has_slot(&request_id) } /// Create a new slot for the given request ID. /// This is used to create a new slot for the request. fn create_slot(&mut self, request: KvbmRequest, tokens: Vec) -> anyhow::Result<()> { - self.slot_manager + self.slot_manager() .create_slot(&request.request_id, tokens, request.salt_hash)?; self.inflight_requests.insert(request.request_id); @@ -488,11 +548,11 @@ pub struct PyKvConnectorLeader { #[pymethods] impl PyKvConnectorLeader { #[new] - #[pyo3(signature = (worker_id, drt, block_manager, leader))] + #[pyo3(signature = (worker_id, drt, page_size, leader))] pub fn new( worker_id: String, drt: PyDistributedRuntime, - block_manager: PyBlockManager, + page_size: usize, leader: PyKvbmLeader, ) -> Self { let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD") @@ -501,18 +561,10 @@ impl PyKvConnectorLeader { let connector_leader: Box = if enable_kvbm_record { Box::new(recorder::KvConnectorLeaderRecorder::new( - worker_id, - drt, - block_manager, - leader, + worker_id, drt, page_size, leader, )) } else { - Box::new(KvConnectorLeader::new( - worker_id, - drt, - block_manager, - leader, - )) + Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)) }; Self { connector_leader } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs index 29f62c1ceb..a7647df35c 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs @@ -88,51 +88,35 @@ impl KvConnectorLeaderRecorder { pub fn new( worker_id: String, drt: PyDistributedRuntime, - block_manager: PyBlockManager, - leader: PyKvbmLeader, + page_size: usize, + leader_py: PyKvbmLeader, ) -> Self { tracing::info!( "KvConnectorLeaderRecorder initialized with worker_id: {}", worker_id ); - // if drt is none, then we must construct a runtime and distributed runtime - let block_manager = block_manager.get_block_manager().clone(); - let block_size = block_manager.block_size(); + let leader = leader_py.get_inner().clone(); + let drt = drt.inner().clone(); + let handle: Handle = drt.runtime().primary(); - let leader = leader.get_inner(); + let ns = drt + .namespace(kvbm_connector::KVBM_CONNECTOR_LEADER) + .unwrap(); - // if we need a drt, get it from here - let drt = drt.inner().clone(); + let kvbm_metrics = KvbmMetrics::new(&ns); + let kvbm_metrics_clone = kvbm_metrics.clone(); let token = CancellationToken::new(); let output_path = "/tmp/records.jsonl"; tracing::info!("recording events to {}", output_path); - let ns = drt.namespace("kvbm_connector_leader").unwrap(); - - let kvbm_metrics = KvbmMetrics::new(&ns); - let recorder = drt .runtime() .primary() .block_on(async { Recorder::new(token, &output_path, None, None, None).await }) .unwrap(); - let connector_leader = KvConnectorLeader { - slot_manager: ConnectorSlotManager::new( - block_manager.clone(), - leader, - drt.clone(), - kvbm_metrics.clone(), - ), - block_size, - inflight_requests: HashSet::new(), - onboarding_slots: HashSet::new(), - iteration_counter: 0, - kvbm_metrics, - }; - let (unbounded_tx, unbounded_rx) = mpsc::unbounded_channel(); let recorder_tx = recorder.event_sender(); @@ -141,6 +125,73 @@ impl KvConnectorLeaderRecorder { .primary() .spawn(Self::forward_unbounded_to_sender(unbounded_rx, recorder_tx)); + let slot_manager_cell = Arc::new(OnceLock::new()); + let (leader_ready_tx, leader_ready_rx) = oneshot::channel::(); + + { + let slot_manager_cell = slot_manager_cell.clone(); + + handle.spawn(async move { + let ready = leader.wait_worker_sync_ready().await; + if !ready { + tracing::error!( + "KvConnectorLeader init aborted: leader worker barrier not ready!", + ); + return; + } + + let block_manager = match BlockManagerBuilder::new() + .worker_id(0) + .leader(leader_py) + .page_size(page_size) + .disable_device_pool(false) + .build() + .await + { + Ok(bm) => bm, + Err(e) => { + tracing::error!("Failed to build BlockManager: {}", e); + return; + } + }; + + // Create the slot manager now that everything is ready + let sm = ConnectorSlotManager::new( + block_manager.get_block_manager().clone(), + leader.clone(), + drt.clone(), + kvbm_metrics_clone.clone(), + ); + + let _ = slot_manager_cell.set(sm); + + // another barrier sync to make sure worker init won't return before leader is ready + leader.spawn_leader_readiness_barrier(drt); + + if leader_ready_tx.send("finished".to_string()).is_err() { + tracing::error!("main routine receiver dropped before result was sent"); + } + }); + } + + tokio::task::block_in_place(|| { + handle.block_on(async { + match leader_ready_rx.await { + Ok(_) => tracing::info!("KvConnectorLeader init complete."), + Err(_) => tracing::warn!("KvConnectorLeader init channel dropped"), + } + }); + }); + + let connector_leader = KvConnectorLeader { + slot_manager: slot_manager_cell, + block_size: page_size, + inflight_requests: HashSet::new(), + onboarding_slots: HashSet::new(), + iteration_counter: 0, + kvbm_metrics, + }; + Self { _recorder: recorder, unbounded_tx, @@ -161,6 +212,10 @@ impl KvConnectorLeaderRecorder { } impl Leader for KvConnectorLeaderRecorder { + #[inline] + fn slot_manager(&self) -> &ConnectorSlotManager { + self.connector_leader.slot_manager() + } /// Match the tokens in the request with the available block pools. /// Note: the necessary details of the request are captured prior to this call. For vllm, /// we make a create slot call prior to this call, so a slot is guaranteed to exist. diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index 3728013b81..2890e1673c 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -98,7 +98,7 @@ impl KvConnectorLeader { } let block_manager = match BlockManagerBuilder::new() - .worker_id(worker_id) + .worker_id(0) .leader(leader_py) .page_size(page_size) .disable_device_pool(false) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs index b967c2b653..e7095e9e57 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs @@ -142,7 +142,7 @@ impl Worker for KvConnectorWorker { self.layer_events = raw_event_handles; let worker = self.drt.runtime().primary().block_on(async move { - let worker = KvbmWorker::new(config).await?; + let worker = KvbmWorker::new(config, true).await?; anyhow::Ok(worker) })?; diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index cbf27a5ce9..94bc37dc49 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -171,7 +171,7 @@ impl Worker for KvConnectorWorker { .build()?; let worker = self.drt.runtime().primary().block_on(async move { - let worker = KvbmWorker::new(config).await?; + let worker = KvbmWorker::new(config, false).await?; anyhow::Ok(worker) })?; diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py index e730acafc0..55fcf653e9 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py @@ -14,7 +14,6 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request -from vllm.worker.cache_engine import CacheEngine if TYPE_CHECKING: from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -29,7 +28,7 @@ # ) # from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput -from dynamo.llm import BlockManager, KvbmLeader +from dynamo.llm import KvbmLeader from dynamo.llm.vllm_integration.kv_cache_utils import ( find_and_set_available_port_from_env, ) @@ -64,25 +63,12 @@ def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): self.vllm_config = vllm_config world_size = vllm_config.parallel_config.world_size - bytes_per_block = CacheEngine.get_cache_block_size( - vllm_config.cache_config, - vllm_config.model_config, - vllm_config.parallel_config, - ) - total_bytes = bytes_per_block * world_size - - leader = KvbmLeader(total_bytes, world_size, drt=self.drt) - block_manager = BlockManager( - 0, - leader, - vllm_config.cache_config.block_size, - disable_device_pool=True, - ) + leader = KvbmLeader(world_size, drt=self.drt) print(f"KvConnectorLeader initialized with engine_id: {engine_id}") self._connector = RustKvConnectorLeader( - engine_id, self.drt, block_manager, leader + engine_id, self.drt, vllm_config.cache_config.block_size, leader ) # KV Connector diff --git a/lib/llm/src/block_manager/distributed.rs b/lib/llm/src/block_manager/distributed.rs index 9aa72c7b8f..0f8759f12e 100644 --- a/lib/llm/src/block_manager/distributed.rs +++ b/lib/llm/src/block_manager/distributed.rs @@ -136,7 +136,7 @@ mod tests { .device_id(i) .build()?; - let worker = KvbmWorker::new(config).await?; + let worker = KvbmWorker::new(config, false).await?; workers.push(worker); } diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index 5630b14fc5..2346e4b2be 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -9,6 +9,7 @@ use zmq::*; use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier; +use anyhow::Context; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -289,6 +290,7 @@ impl KvbmLeader { }); } + // This is supposed to be used in non-blocking leader initialization pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) { let leader_config = self.config.clone(); let handle = drt.runtime().primary(); @@ -304,6 +306,21 @@ impl KvbmLeader { }); } + // This is supposed to be used in blocking leader initialization + pub fn run_leader_readiness_barrier_blocking( + &self, + drt: DistributedRuntime, + ) -> anyhow::Result<()> { + let leader_config = self.config.clone(); + let fut = KvbmLeader::run_leader_readiness(drt, leader_config); + + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(fut) + .context("leader readiness barrier failed") + }) + } + async fn run_leader_readiness( drt: DistributedRuntime, leader_config: KvbmLeaderConfig, diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index ca93b83ba9..717f8b6e48 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -136,7 +136,7 @@ pub struct KvbmWorker { } impl KvbmWorker { - pub async fn new(config: KvbmWorkerConfig) -> anyhow::Result { + pub async fn new(config: KvbmWorkerConfig, layout_blocking: bool) -> anyhow::Result { tracing::info!( "Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}", config.num_device_blocks, @@ -157,11 +157,8 @@ impl KvbmWorker { ))); } - let layout_type: LayoutType; - let mut outer_dim = 1; - let num_layers; - let inner_dim; - if !config.is_fully_contiguous_layout { + let (layout_type, num_layers, outer_dim, inner_dim) = if !config.is_fully_contiguous_layout + { let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks { (false, shape[1]) } else if shape[1] >= config.num_device_blocks { @@ -172,9 +169,8 @@ impl KvbmWorker { shape ))); }; - layout_type = LayoutType::LayerSeparate { outer_contiguous }; - num_layers = device_tensors.len(); - inner_dim = shape[2..].iter().product::() / config.page_size; + let num_layers = device_tensors.len(); + let inner_dim = shape[2..].iter().product::() / config.page_size; tracing::info!( "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", @@ -183,11 +179,17 @@ impl KvbmWorker { config.page_size, inner_dim ); + + ( + LayoutType::LayerSeparate { outer_contiguous }, + num_layers, + outer_dim, + inner_dim, + ) } else { - layout_type = LayoutType::FullyContiguous; - num_layers = shape[1]; - outer_dim = shape[2]; - inner_dim = shape[3..].iter().product::() / config.page_size; + let num_layers = shape[1]; + let outer_dim = shape[2]; + let inner_dim = shape[3..].iter().product::() / config.page_size; tracing::info!( "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", num_layers, @@ -195,7 +197,14 @@ impl KvbmWorker { config.page_size, inner_dim ); - } + + ( + LayoutType::FullyContiguous, + num_layers, + outer_dim, + inner_dim, + ) + }; let bytes_per_block = num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes; @@ -213,10 +222,44 @@ impl KvbmWorker { .build()? .create_layout(layout_type, device_tensors)?; - let layout_builder_clone = layout_builder.clone(); + let layout_builder = layout_builder.clone(); + + let (task, handler_rx) = if layout_blocking { + Self::run_blocking_layout_initialization( + config, + bytes_per_block, + device_layout, + layout_builder, + layout_type, + ) + .await? + } else { + Self::run_non_blocking_layout_initialization( + config, + bytes_per_block, + device_layout, + layout_builder, + layout_type, + ) + .await? + }; - // add worker-connector scheduler here - // let scheduler = KvbmWorkerScheduler::new(config.scheduler.clone()); + Ok(Self { + task: Some(task), + block_transfer_handler_rx: Some(handler_rx), + }) + } + + async fn run_blocking_layout_initialization( + config: KvbmWorkerConfig, + bytes_per_block: usize, + device_layout: Box>, + layout_builder: LayoutConfigBuilder, + layout_type: LayoutType, + ) -> anyhow::Result<( + CriticalTaskExecutionHandle, + oneshot::Receiver, + )> { let cancel_token = config.drt.primary_token().clone(); // barrier sync with leader to get the leader data @@ -248,7 +291,7 @@ impl KvbmWorker { move |cancel_token| { KvbmWorker::worker_task( device_layout, - layout_builder_clone, + layout_builder, leader_data, layout_type, worker_config, @@ -262,6 +305,12 @@ impl KvbmWorker { "kvbm-worker-task", )?; + // waiting for the worker layout allocation ready + match layout_ready_rx.await { + Ok(_) => tracing::info!("worker layout allocation finished."), + Err(_) => tracing::error!("Worker layout dropped without sending"), + } + let worker_config = config.clone(); let cancel_for_barrier = cancel_token.clone(); // wait until the leader finished the initialization of all components @@ -274,18 +323,87 @@ impl KvbmWorker { }) })?; - // waiting for the worker layout allocation ready - match layout_ready_rx.await { - Ok(_) => tracing::info!("worker layout allocation finished."), - Err(_) => tracing::error!("Worker layout dropped without sending"), - } - - Ok(Self { - task: Some(task), - block_transfer_handler_rx: Some(handler_rx), - }) + Ok((task, handler_rx)) } + async fn run_non_blocking_layout_initialization( + config: KvbmWorkerConfig, + bytes_per_block: usize, + device_layout: Box + Send + 'static>, + layout_builder: LayoutConfigBuilder, + layout_type: LayoutType, + ) -> anyhow::Result<( + CriticalTaskExecutionHandle, + oneshot::Receiver, + )> { + let cancel_token = config.drt.primary_token().clone(); + let scheduler_client = config.scheduler_client.clone(); + + // channel to get BlockTransferHandler back to the caller + let (handler_tx, handler_rx) = oneshot::channel::(); + + // channel that the worker will use to signal layout readiness + let (layout_ready_tx, layout_ready_rx) = oneshot::channel::(); + + // clone what we need inside the orchestrator + let worker_config = config.clone(); + let cancel_token_for_task = cancel_token.clone(); + + // Single task that orchestrates everything in-order. + let task = CriticalTaskExecutionHandle::new( + move |ct| { + let cfg = worker_config.clone(); + let scheduler = scheduler_client.clone(); + + async move { + // 1) barrier (must finish before worker_task starts) + let leader_data = + KvbmWorker::leader_barrier_sync(cfg.clone(), ct.clone(), bytes_per_block) + .await?; + + // 2) start the long-running worker (after barrier) + // Spawn it so the orchestrator can continue with readiness + waiting. + let dev_layout = device_layout; // moved in + let lb = layout_builder; // moved in + let lt = layout_type; // moved in + + let worker_fut = KvbmWorker::worker_task( + dev_layout, + lb, + leader_data, + lt, + cfg.clone(), + ct.clone(), + handler_tx, + layout_ready_tx, + scheduler, + ); + + // If worker_task returns Result, handle/log it inside the spawned task. + tokio::spawn(async move { + if let Err(e) = worker_fut.await { + tracing::error!("worker_task exited with error: {e:#}"); + } + }); + + // 3) wait for the worker’s layout allocation readiness + match layout_ready_rx.await { + Ok(_) => tracing::info!("worker layout allocation finished."), + Err(_) => tracing::warn!("worker layout readiness channel dropped"), + } + + // 4) wait for leader to finish its side of initialization + KvbmWorker::leader_readiness_sync(cfg.clone(), ct.clone()).await?; + + Ok::<(), anyhow::Error>(()) + } + }, + cancel_token_for_task, + "kvbm-worker-task", + )?; + + Ok((task, handler_rx)) + } /// One-time use method to extract the block transfer handler from the worker. /// /// This is a bit of a hack. Improve the API design around this in the future. diff --git a/tests/kvbm/README.md b/tests/kvbm/README.md index 5a97ca07dd..887bf36d34 100644 --- a/tests/kvbm/README.md +++ b/tests/kvbm/README.md @@ -2,11 +2,11 @@ ## Overview -This suite validates determinism properties of the API-backed LLM under fixed sampling parameters and optionally across prefix cache resets. The tests can automatically start a local vLLM server, warm it up, and compare responses for identical prompts over multiple iterations. +This suite validates the determinism properties of the API-backed LLM under fixed sampling parameters and, optionally, across prefix cache resets. The tests can automatically start a local LLM server—either a vLLM server or a TensorRT-LLM server—warm it up, and compare responses for identical prompts over multiple iterations. The suite also automatically detects whether the vLLM or TensorRT-LLM wheel is installed and starts the corresponding server. ## Files -- `test_determinism.py` — comprehensive determinism tests with automatic vLLM server lifecycle and warmup. +- `test_determinism.py` — comprehensive determinism tests with automatic LLM server lifecycle and warmup. - `test_determinism_with_cache_reset` — run test with warmup, reset cache, then run again without warmup to test determinism across cache reset boundary - `test_concurrent_determinism_with_ifeval` — send parametrized number of IFEval prompts (default: 120) with controlled concurrency, with warmup, then reset cache and test again without warmup to validate determinism across cache reset @@ -19,7 +19,7 @@ This suite validates determinism properties of the API-backed LLM under fixed sa ## How It Works -- A `VLLMServerManager` fixture (`vllm_server`) launches `vllm serve` with the Dynamo connector and optional cache block overrides. +- A `LLMServerManager` fixture (`llm_server`) launches `vllm serve` or `trtllm-serve` with the Dynamo connector and optional cache block overrides. - A `tester` fixture binds the test client to the running server's base URL. - The test performs a comprehensive warmup across prompts, then executes repeated requests and checks that responses are identical (deterministic). An optional cache reset phase re-validates determinism across the reset boundary. @@ -43,8 +43,8 @@ Environment variables control server settings and test load: - Server/model - `KVBM_MODEL_ID` (default: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B`) - - `KVBM_VLLM_PORT` (default: `8000`) - - `KVBM_VLLM_START_TIMEOUT` (default: `300` seconds) + - `KVBM_SERVER_PORT` (default: `8000`) + - `KVBM_SERVER_START_TIMEOUT` (default: `300` seconds) - Cache size overrides - `KVBM_CPU_BLOCKS` (used via test parametrization; default: `10000`) @@ -90,5 +90,5 @@ pytest -v -m "kvbm" -s - Warmup is critical to avoid initialization effects impacting determinism. - For faster local iteration, reduce `KVBM_MAX_ITERATIONS` and/or increase intervals. -- Logs are written under the per-test directory created by `tests/conftest.py` and include the vLLM server stdout/stderr. -- Tests use the static port defined by `KVBM_VLLM_PORT` for vLLM server communication. \ No newline at end of file +- Logs are written under the per-test directory created by `tests/conftest.py` and include the LLM server stdout/stderr. +- Tests use the static port defined by `KVBM_SERVER_PORT` for LLM server communication. diff --git a/tests/kvbm/test_determinism.py b/tests/kvbm/test_determinism.py index 75dd454195..18bdf50e88 100755 --- a/tests/kvbm/test_determinism.py +++ b/tests/kvbm/test_determinism.py @@ -13,6 +13,7 @@ impact determinism measurements. """ +import importlib.util import logging import os import signal @@ -21,6 +22,7 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime +from enum import Enum from pathlib import Path from typing import Dict, List, Optional, TextIO, Tuple @@ -38,8 +40,13 @@ ] -class VLLMServerManager: - """Manages vLLM server lifecycle for determinism testing.""" +class ServerType(str, Enum): + vllm = "vllm" + trtllm = "trtllm" + + +class LLMServerManager: + """Manages LLM server lifecycle for determinism testing.""" def __init__( self, @@ -48,8 +55,10 @@ def __init__( cpu_cache_blocks: Optional[int] = None, gpu_cache_blocks: Optional[int] = None, log_dir: Optional[Path] = None, + server_type: Optional[str] = ServerType.vllm, ): - self.port = port or int(os.environ.get("KVBM_VLLM_PORT", "8000")) + self.server_type = server_type + self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000")) self.base_url = base_url or f"http://localhost:{self.port}" self.process: Optional[subprocess.Popen] = None self.cpu_cache_blocks = cpu_cache_blocks @@ -63,11 +72,41 @@ def __init__( f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}" ) self.server_log_file = ( - self.log_dir / f"vllm_server_{config_str}_{timestamp}.log" + self.log_dir / f"{self.server_type}_server_{config_str}_{timestamp}.log" ) self.server_stdout_file: Optional[TextIO] = None self.server_stderr_file: Optional[TextIO] = None + # Environment for the process + self.env = os.environ.copy() + self.env.update( + { + "RUST_BACKTRACE": "1", + "DYN_LOG": os.environ.get( + "DYN_LOG", "debug,dynamo_llm::block_manager::layout=error" + ), + # DynamoConnector connection settings + "NATS_SERVER": "nats://localhost:4222", + "ETCD_ENDPOINTS": "http://localhost:2379", + } + ) + + # CPU cache blocks override via env + if cpu_cache_blocks is not None: + self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks) + + if self.server_type == ServerType.vllm: + self._set_up_vllm_config(gpu_cache_blocks) + elif self.server_type == ServerType.trtllm: + self._set_up_trtllm_config(gpu_cache_blocks) + else: + raise ValueError( + f"{self.server_type} is not supported yet in the KVBM test suite" + ) + + def _set_up_vllm_config(self, gpu_cache_blocks): + self.env["VLLM_SERVER_DEV_MODE"] = "1" + # Construct serve command self.server_cmd = [ "vllm", @@ -85,27 +124,52 @@ def __init__( if gpu_cache_blocks is not None: self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)]) - # Environment for the process - self.env = os.environ.copy() - self.env.update( - { - "RUST_BACKTRACE": "1", - "DYN_LOG": os.environ.get( - "DYN_LOG", "debug,dynamo_llm::block_manager::layout=error" - ), - "VLLM_SERVER_DEV_MODE": "1", - # DynamoConnector connection settings - "NATS_SERVER": "nats://localhost:4222", - "ETCD_ENDPOINTS": "http://localhost:2379", - } + def _set_up_trtllm_config(self, gpu_cache_blocks): + config_path = os.environ.get( + "KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml" ) + llm_api_config = {} + llm_api_config[ + "cuda_graph_config" + ] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM + llm_api_config["kv_cache_config"] = { + "enable_partial_reuse": False, + "free_gpu_memory_fraction": 0.10, # Set a small GPU fraction so that we can evict/reset the on-device kv cache faster + } + llm_api_config["kv_connector_config"] = { + "connector_module": "dynamo.llm.trtllm_integration.connector", + "connector_scheduler_class": "DynamoKVBMConnectorLeader", + "connector_worker_class": "DynamoKVBMConnectorWorker", + } - # CPU cache blocks override via env - if cpu_cache_blocks is not None: - self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks) + # GPU blocks override + if gpu_cache_blocks is not None: + del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"] + llm_api_config["kv_cache_config"]["max_tokens"] = ( + gpu_cache_blocks * 32 + ) # TRTLLM defaults 32 tokens per block + + # Construct serve command + self.server_cmd = [ + "trtllm-serve", + os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"), + "--host", + "localhost", + "--port", + str(self.port), + "--backend", + "pytorch", + "--extra_llm_api_options", + config_path, + ] + + import yaml + + with open(config_path, "w") as f: + yaml.dump(llm_api_config, f, default_flow_style=False, sort_keys=False) def start_server(self, timeout: int = 300) -> bool: - """Start vLLM server and wait for readiness.""" + """Start LLM server and wait for readiness.""" if self.is_server_running(): self.stop_server() time.sleep(2) @@ -119,7 +183,7 @@ def start_server(self, timeout: int = 300) -> bool: ) if self.server_stdout_file is not None: self.server_stdout_file.write( - f"=== vLLM Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n" + f"=== {self.server_type} Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n" ) self.server_stdout_file.flush() @@ -147,7 +211,7 @@ def start_server(self, timeout: int = 300) -> bool: return False def stop_server(self): - """Stop vLLM server and close logs.""" + """Stop LLM server and close logs.""" if self.process: try: os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) @@ -205,7 +269,12 @@ def is_server_running(self) -> bool: class DeterminismTester: """Test class for model determinism validation.""" - def __init__(self, base_url: Optional[str] = None, model_id: Optional[str] = None): + def __init__( + self, + base_url: Optional[str] = None, + model_id: Optional[str] = None, + server_type: Optional[str] = ServerType.vllm, + ): # Allow environment override for flexibility in CI/local runs self.base_url = ( base_url or os.environ.get("DYNAMO_API_BASE_URL") or "http://localhost:8000" @@ -215,6 +284,7 @@ def __init__(self, base_url: Optional[str] = None, model_id: Optional[str] = Non or os.environ.get("KVBM_MODEL_ID") or "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ) + self.server_type = server_type self.shakespeare_file = Path("t8.shakespeare.txt") self.max_iterations = int(os.environ.get("KVBM_MAX_ITERATIONS", "500")) @@ -298,11 +368,28 @@ def make_request(self, content: str) -> str: def reset_prefix_cache(self): """Reset the prefix cache.""" print("Resetting prefix cache...") - response = requests.post( - f"{self.base_url}/reset_prefix_cache", - timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")), - ) - response.raise_for_status() + if self.server_type == ServerType.trtllm: + # TRTLLM doesn't support reset_prefix_cache endpoint API + # 300 shakespeare content could evict the 0.1 x 80G (~1700 blocks) on-device cache + shakespeare_count = 300 + for seq_idx in range(1, shakespeare_count + 1): + start_word = (seq_idx - 1) * self.word_count + content = self.get_shakespeare_content(start_word) + + if content: + print( + f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..." + ) + try: + self.make_request(content) + except Exception as e: + print(f"Resetting request failed: {e}") + else: + response = requests.post( + f"{self.base_url}/reset_prefix_cache", + timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")), + ) + response.raise_for_status() print("Cache reset done") def warmup_server(self): @@ -623,11 +710,11 @@ def make_concurrent_request(task): @pytest.fixture(scope="function") -def vllm_server(request, runtime_services): - """Start and stop vLLM server for each test with optional cache block overrides. +def llm_server(request, runtime_services): + """Start and stop a LLM server for each test with optional cache block overrides. To parametrize, use: - @pytest.mark.parametrize("vllm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True) + @pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True) """ logger = logging.getLogger("pytest") logger.setLevel(logging.INFO) @@ -639,17 +726,27 @@ def vllm_server(request, runtime_services): # Put logs in the per-test directory set up by tests/conftest.py log_dir = Path(request.node.name) - server_manager = VLLMServerManager( + if importlib.util.find_spec("vllm") is not None: + server_type = ServerType.vllm + elif importlib.util.find_spec("tensorrt_llm") is not None: + server_type = ServerType.trtllm + else: + raise Exception( + "Neither the vllm nor the tensorrt_llm module is available in the current environment." + ) + + server_manager = LLMServerManager( port=port, cpu_cache_blocks=cpu_blocks, gpu_cache_blocks=gpu_blocks, log_dir=log_dir, + server_type=server_type, ) - start_timeout = int(os.environ.get("KVBM_VLLM_START_TIMEOUT", "300")) + start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300")) if not server_manager.start_server(timeout=start_timeout): pytest.fail( - f"Failed to start vLLM server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" + f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" ) yield server_manager @@ -658,9 +755,11 @@ def vllm_server(request, runtime_services): @pytest.fixture(scope="function") -def tester(vllm_server): +def tester(llm_server): """Create determinism tester bound to the running server's base URL.""" - t = DeterminismTester(base_url=vllm_server.base_url) + t = DeterminismTester( + base_url=llm_server.base_url, server_type=llm_server.server_type + ) t.download_shakespeare_text() return t @@ -669,13 +768,13 @@ class TestDeterminism: """Test class for determinism validation.""" @pytest.mark.parametrize( - "vllm_server", + "llm_server", [ {"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))}, ], indirect=True, ) - def test_determinism_with_cache_reset(self, tester, vllm_server, runtime_services): + def test_determinism_with_cache_reset(self, tester, llm_server, runtime_services): """Test determinism across cache reset: run test with warmup, reset cache, run again without warmup.""" print("\n" + "=" * 70) print("STARTING DETERMINISM TEST (WITH CACHE RESET)") @@ -797,7 +896,7 @@ def test_determinism_with_cache_reset(self, tester, vllm_server, runtime_service ), f"Model is not deterministic across cache reset: {total_failed} comparisons failed" @pytest.mark.parametrize( - "vllm_server", + "llm_server", [ {"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))}, ], @@ -818,7 +917,7 @@ def test_determinism_with_cache_reset(self, tester, vllm_server, runtime_service def test_concurrent_determinism_with_ifeval( self, tester, - vllm_server, + llm_server, runtime_services, num_concurrent, max_tokens, From e9651ee11a8a74929602cf6f3b673ed1675b8ddc Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 28 Aug 2025 17:29:46 -0700 Subject: [PATCH 10/17] fix tests Signed-off-by: richardhuo-nv --- tests/kvbm/test_determinism.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kvbm/test_determinism.py b/tests/kvbm/test_determinism.py index 18bdf50e88..6903e1ad2f 100755 --- a/tests/kvbm/test_determinism.py +++ b/tests/kvbm/test_determinism.py @@ -24,7 +24,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, TextIO, Tuple +from typing import Any, Dict, List, Optional, TextIO, Tuple import pytest import requests @@ -128,7 +128,7 @@ def _set_up_trtllm_config(self, gpu_cache_blocks): config_path = os.environ.get( "KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml" ) - llm_api_config = {} + llm_api_config: dict[str, Any] = {} llm_api_config[ "cuda_graph_config" ] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM @@ -146,7 +146,7 @@ def _set_up_trtllm_config(self, gpu_cache_blocks): if gpu_cache_blocks is not None: del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"] llm_api_config["kv_cache_config"]["max_tokens"] = ( - gpu_cache_blocks * 32 + int(gpu_cache_blocks) * 32 ) # TRTLLM defaults 32 tokens per block # Construct serve command From 5bf7c50fcb334eaea363af0cf9a3cdf0466531d7 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 28 Aug 2025 20:30:38 -0700 Subject: [PATCH 11/17] resolve comments Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 18 ++++++++++++------ lib/bindings/python/rust/llm/block_manager.rs | 2 +- .../llm/block_manager/distributed/utils.rs | 5 ++++- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index 4d31e35778..ff5b8dc7fd 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -25,6 +25,7 @@ To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest > - Ensure that `etcd` is running before starting. > - KVBM does not currently support CUDA graphs in TensorRT-LLM. > - KVBM only supports TensorRT-LLM’s PyTorch backend. +> - KVBM requires TensorRT-LLM at commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer. ## Quick Start @@ -34,19 +35,24 @@ To use KVBM in TensorRT-LLM, you can follow the steps below: # start up etcd for KVBM leader/worker registration and discovery docker compose -f deploy/docker-compose.yml up -d -# build a container containing trtllm and kvbm, note that KVBM integration is only availiable on TensorRT-LLM commit: TBD -./container/build.sh --framework trtllm --tensorrtllm-commit TBD --enable-kvbm +# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6. +./container/build.sh --framework trtllm --tensorrtllm-commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 --enable-kvbm # launch the container ./container/run.sh --framework trtllm -it --mount-workspace --use-nixl-gds # enable kv offloading to CPU memory -# 4 means 4GB of pinned CPU memory would be used +# 60 means 60GB of pinned CPU memory would be used export DYN_KVBM_CPU_CACHE_GB=60 # enable kv offloading to disk -# 8 means 8GB of disk would be used +# 20 means 20GB of disk would be used export DYN_KVBM_DISK_CACHE_GB=20 + +# Allocating memory and disk storage can take some time. +# We recommend setting a higher timeout for leader–worker initialization. +# 1200 means 1200 seconds timeout +export DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS=1200 ``` ```bash @@ -103,8 +109,8 @@ EOF # serve an example LLM model trtllm-serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml -# start vllm with DYN_SYSTEM_ENABLED set to true and DYN_SYSTEM_PORT port to 6880. -# NOTE: Make sure port 6880 (for KVBM worker metrics) and port 6881 (for KVBM leader metrics) are available. +# start trtllm-serve with DYN_SYSTEM_ENABLED set to true and DYN_SYSTEM_PORT set to 6880 +# NOTE: Ensure ports 6880 (KVBM worker metrics) and 6881 (KVBM leader metrics) are available. DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=6880 trtllm-serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml # optional if firewall blocks KVBM metrics ports to send prometheus metrics diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index 346b1bcce3..d62dd1700c 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -233,7 +233,7 @@ pub struct BlockManagerBuilder { impl BlockManagerBuilder { pub fn new() -> Self { Self { - page_size: 0, + page_size: 32, // default consistent with BlockManager::new ..Default::default() } } diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs b/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs index 2777260fb4..84288c6f94 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/utils.rs @@ -2,5 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 pub fn get_barrier_id_prefix() -> String { - std::env::var("DYN_KVBM_BARRIER_ID_PREFIX").unwrap_or("kvbm".to_string()) + std::env::var("DYN_KVBM_BARRIER_ID_PREFIX") + .ok() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "kvbm".to_string()) } From 8727bbce22661713bb08bb1b8d9b64416c95a81c Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 28 Aug 2025 20:51:07 -0700 Subject: [PATCH 12/17] fix repo checkout Signed-off-by: richardhuo-nv --- .github/workflows/docs-link-check.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/docs-link-check.yml b/.github/workflows/docs-link-check.yml index 0702a660a9..1bb96be6fc 100644 --- a/.github/workflows/docs-link-check.yml +++ b/.github/workflows/docs-link-check.yml @@ -15,6 +15,11 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v4 + with: + # For pull_request events, use the PR head (commit from the contributor's branch/repo) + repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} + ref: ${{ github.event.pull_request.head.sha || github.sha }} + fetch-depth: 0 # Cache lychee results (e.g. to avoid hitting rate limits) # https://lychee.cli.rs/github_action_recipes/caching/ From 8c7dd172660c32e98c772e9a2178eb9f814f57a0 Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Thu, 28 Aug 2025 21:30:22 -0700 Subject: [PATCH 13/17] fix doc Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index ff5b8dc7fd..707d31d8a3 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -35,7 +35,7 @@ To use KVBM in TensorRT-LLM, you can follow the steps below: # start up etcd for KVBM leader/worker registration and discovery docker compose -f deploy/docker-compose.yml up -d -# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6. +# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer. ./container/build.sh --framework trtllm --tensorrtllm-commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 --enable-kvbm # launch the container From 04c3cdb96bee65624d0ee3c8447b1d810c2c13cb Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Fri, 29 Aug 2025 14:57:17 -0700 Subject: [PATCH 14/17] fix Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 2 +- .../src/block_manager/distributed/leader.rs | 51 +++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index 707d31d8a3..7cf7abd3ee 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -22,7 +22,7 @@ This guide explains how to leverage KVBM (KV Block Manager) to mange KV cache an To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest/architecture/kvbm_intro.html) > [!Note] -> - Ensure that `etcd` is running before starting. +> - Ensure that `etcd` and 'nats' are running before starting. > - KVBM does not currently support CUDA graphs in TensorRT-LLM. > - KVBM only supports TensorRT-LLM’s PyTorch backend. > - KVBM requires TensorRT-LLM at commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer. diff --git a/lib/llm/src/block_manager/distributed/leader.rs b/lib/llm/src/block_manager/distributed/leader.rs index 2346e4b2be..7280ee87d4 100644 --- a/lib/llm/src/block_manager/distributed/leader.rs +++ b/lib/llm/src/block_manager/distributed/leader.rs @@ -9,7 +9,7 @@ use zmq::*; use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier; -use anyhow::Context; +use anyhow::{Context, anyhow}; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -85,6 +85,7 @@ pub struct KvbmLeaderState { pub num_host_blocks: Arc, pub num_disk_blocks: Arc, pub workers_allocation_ready: Arc, + pub workers_ready_notify: Arc, } /// The leader of the KVBM. @@ -203,6 +204,7 @@ impl KvbmLeader { .min() .unwrap(); + // TODO: this works for TP, need to redefine bytes_per_block when we enable the DP/PP let bytes_per_block: usize = worker_data.values().map(|d| d.bytes_per_block).sum(); assert!( @@ -282,6 +284,7 @@ impl KvbmLeader { state .workers_allocation_ready .store(true, Ordering::Release); + state.workers_ready_notify.notify_waiters(); } Err(e) => { tracing::error!("ZMQ init failed: {e:?}"); @@ -292,9 +295,31 @@ impl KvbmLeader { // This is supposed to be used in non-blocking leader initialization pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) { + let timeout_secs = self.config.leader_init_timeout_secs; + let state = self.state.clone(); let leader_config = self.config.clone(); let handle = drt.runtime().primary(); handle.spawn(async move { + if !state.workers_allocation_ready.load(Ordering::Acquire) { + // Wait until ZMQ marks ready or we time out. + let waited = tokio::time::timeout( + Duration::from_secs(timeout_secs), + state.workers_ready_notify.notified(), + ) + .await; + if waited.is_err() { + tracing::error!( + "leader readiness barrier wait timed out after {timeout_secs} seconds" + ); + return; + } + // Double-check the flag (Acquire) after wakeup. + if !state.workers_allocation_ready.load(Ordering::Acquire) { + tracing::error!("leader readiness notify fired but flag not set; aborting"); + return; + } + } + match KvbmLeader::run_leader_readiness(drt, leader_config).await { Ok(()) => { tracing::info!("leader readiness barrier synced!"); @@ -311,12 +336,32 @@ impl KvbmLeader { &self, drt: DistributedRuntime, ) -> anyhow::Result<()> { + let state = self.state.clone(); + let timeout_secs = self.config.leader_init_timeout_secs; let leader_config = self.config.clone(); - let fut = KvbmLeader::run_leader_readiness(drt, leader_config); tokio::task::block_in_place(|| { tokio::runtime::Handle::current() - .block_on(fut) + .block_on(async move { + // Create the future *before* checking the flag to avoid a lost-notify race. + let notified = state.workers_ready_notify.notified(); + + if !state.workers_allocation_ready.load(Ordering::Acquire) { + // Wait (with timeout) until ZMQ task marks ready. + tokio::time::timeout(Duration::from_secs(timeout_secs), notified) + .await + .map_err(|_| anyhow!("timed out waiting for workers_allocation_ready after {timeout_secs} seconds"))?; + + // Double-check after wake to ensure the flag is actually set. + if !state.workers_allocation_ready.load(Ordering::Acquire) { + return Err(anyhow!( + "notified but workers_allocation_ready is still false" + )); + } + } + + KvbmLeader::run_leader_readiness(drt, leader_config).await + }) .context("leader readiness barrier failed") }) } From c68ecc8dc9133d1948a4a6f4a3d10ca7f1686ade Mon Sep 17 00:00:00 2001 From: richardhuo-nv Date: Fri, 29 Aug 2025 15:27:30 -0700 Subject: [PATCH 15/17] fix doc Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index 7cf7abd3ee..cf3ccd5eff 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -22,9 +22,11 @@ This guide explains how to leverage KVBM (KV Block Manager) to mange KV cache an To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest/architecture/kvbm_intro.html) > [!Note] -> - Ensure that `etcd` and 'nats' are running before starting. +> - Ensure that `etcd` and `nats` are running before starting. > - KVBM does not currently support CUDA graphs in TensorRT-LLM. > - KVBM only supports TensorRT-LLM’s PyTorch backend. +> - To enable disk cache offloading, you must first enable a CPU memory cache offloading. +> - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits. > - KVBM requires TensorRT-LLM at commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer. ## Quick Start @@ -45,7 +47,7 @@ docker compose -f deploy/docker-compose.yml up -d # 60 means 60GB of pinned CPU memory would be used export DYN_KVBM_CPU_CACHE_GB=60 -# enable kv offloading to disk +# enable kv offloading to disk. Note: To enable disk cache offloading, you must first enable a CPU memory cache offloading. # 20 means 20GB of disk would be used export DYN_KVBM_DISK_CACHE_GB=20 @@ -57,6 +59,7 @@ export DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS=1200 ```bash # write an example LLM API config +# Note: Disable partial reuse "enable_partial_reuse: false" in the LLM API config’s "kv_connector_config" to increase offloading cache hits. cat > "/tmp/kvbm_llm_api_config.yaml" < Date: Fri, 29 Aug 2025 16:23:40 -0700 Subject: [PATCH 16/17] remove metrics in readme Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 35 +------------------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index cf3ccd5eff..40a6114cfb 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -28,6 +28,7 @@ To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest > - To enable disk cache offloading, you must first enable a CPU memory cache offloading. > - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits. > - KVBM requires TensorRT-LLM at commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer. +> - Enabling KVBM metrics with TensorRT-LLM is still a work in progress. ## Quick Start @@ -88,37 +89,3 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" "max_tokens": 30 }' ``` - -## Enable and View KVBM Metrics - -Follow below steps to enable metrics collection and view via Grafana dashboard: -```bash -# Start the basic services (etcd & natsd), along with Prometheus and Grafana -docker compose -f deploy/docker-compose.yml --profile metrics up -d - -# write an example LLM API config -cat > "/tmp/kvbm_llm_api_config.yaml" < Date: Fri, 29 Aug 2025 17:30:45 -0700 Subject: [PATCH 17/17] use dynamo in the readme example Signed-off-by: richardhuo-nv --- docs/guides/run_kvbm_in_trtllm.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/guides/run_kvbm_in_trtllm.md b/docs/guides/run_kvbm_in_trtllm.md index 40a6114cfb..3ac865a1a2 100644 --- a/docs/guides/run_kvbm_in_trtllm.md +++ b/docs/guides/run_kvbm_in_trtllm.md @@ -73,8 +73,14 @@ kv_connector_config: connector_worker_class: DynamoKVBMConnectorWorker EOF -# serve an example LLM model -trtllm-serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml +# start dynamo frontend +python3 -m dynamo.frontend --http-port 8000 & + +# To serve an LLM model with dynamo +python3 -m dynamo.trtllm \ + --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --extra-engine-args /tmp/kvbm_llm_api_config.yaml & # make a call to LLM curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ @@ -88,4 +94,8 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" "stream":false, "max_tokens": 30 }' + +# Optionally, we could also serve an LLM with trtllm-serve to utilize the KVBM feature. +trtllm-serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B --host localhost --port 8001 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml + ```