From 60e41ef3a679ce972de78e7598509b16a8b23698 Mon Sep 17 00:00:00 2001 From: Dan Aloni Date: Mon, 8 Sep 2025 18:24:19 +0300 Subject: [PATCH] Draft: kvbm - bigger page size for offload block pools only This change allows the host and disk pool to have larger blocks compared to the device block pool, with a fixed block size ratio between these two pool classes. During offload and onboarding of blocks, we need to unite small device blocks to large offload blocks or to split large offloaded blocks back to small device blocks. The advantage conferred by this change: more efficient management of offloaded blocks, especifically when transferring back and forth from disk, larger IOs can always be used. To do: - Fix onboarding from disk to device - Fix support for TRTLLM - Add unit tests Signed-off-by: Dan Aloni --- Cargo.lock | 4 + lib/bindings/python/Cargo.lock | 5 + lib/bindings/python/Cargo.toml | 1 + lib/bindings/python/rust/llm/block_manager.rs | 24 ++- .../llm/block_manager/distributed/worker.rs | 3 +- .../python/rust/llm/block_manager/vllm.rs | 35 +++-- .../block_manager/vllm/connector/leader.rs | 41 +++-- .../vllm/connector/leader/recorder.rs | 9 +- .../vllm/connector/leader/slot.rs | 142 +++++++++++------ .../vllm/connector/trtllm_leader.rs | 33 ++-- .../vllm/connector/trtllm_worker.rs | 3 +- .../block_manager/vllm/connector/worker.rs | 7 +- .../llm/vllm_integration/connector_leader.py | 2 +- lib/llm/Cargo.toml | 1 + lib/llm/src/block_manager.rs | 32 +++- lib/llm/src/block_manager/block/data.rs | 51 ++++++ lib/llm/src/block_manager/block/data/local.rs | 106 ++++++++++++- .../src/block_manager/block/data/logical.rs | 30 +++- .../block_manager/block/transfer/context.rs | 2 + .../src/block_manager/block/transfer/cuda.rs | 148 ++++++++++++++---- lib/llm/src/block_manager/config.rs | 6 + .../src/block_manager/distributed/transfer.rs | 58 ++++--- .../src/block_manager/distributed/utils.rs | 23 ++- .../src/block_manager/distributed/worker.rs | 40 +++-- lib/llm/src/block_manager/offload.rs | 3 + lib/llm/src/block_manager/state.rs | 4 + lib/llm/src/block_manager/state/resources.rs | 2 +- 27 files changed, 621 insertions(+), 194 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 42094da04f..b149e33ed5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2178,6 +2178,7 @@ dependencies = [ "serde", "serde_json", "serial_test", + "smallvec", "strum", "temp-env", "tempfile", @@ -7622,6 +7623,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "socket2" diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index bea47ac6e3..78d42ded5d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1510,6 +1510,7 @@ dependencies = [ "rustls", "serde", "serde_json", + "smallvec", "strum", "tempfile", "thiserror 2.0.16", @@ -1579,6 +1580,7 @@ dependencies = [ "rstest", "serde", "serde_json", + "smallvec", "socket2 0.6.0", "thiserror 2.0.16", "tokio", @@ -5686,6 +5688,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "socket2" diff --git a/lib/bindings/python/Cargo.toml b/lib/bindings/python/Cargo.toml index 0c85134531..04246c1ab7 100644 --- a/lib/bindings/python/Cargo.toml +++ b/lib/bindings/python/Cargo.toml @@ -44,6 +44,7 @@ rand = { version = "0.9" } socket2 = { version = "0.6" } serde = { version = "1" } serde_json = { version = "1.0.138" } +smallvec = { version = "1.15.1", features = ["serde"] } thiserror = { version = "2.0" } tokio = { version = "1.46.0", features = ["full"] } tokio-stream = { version = "0" } diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index f19a17cd70..32a3800e7e 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -161,8 +161,12 @@ impl BlockManager { }) } - fn block_size(&self) -> usize { - self.inner.block_size() + fn engine_block_size(&self) -> usize { + self.inner.engine_block_size() + } + + fn offload_block_size(&self) -> usize { + self.inner.offload_block_size() } fn init_controller(&mut self, component: Component) -> PyResult<()> { @@ -214,14 +218,16 @@ impl BlockManager { pub struct BlockManagerBuilder { worker_id: u64, leader: Option, - page_size: usize, + offload_page_size: usize, + engine_page_size: usize, disable_device_pool: bool, } impl BlockManagerBuilder { pub fn new() -> Self { Self { - page_size: 32, // default consistent with BlockManager::new + engine_page_size: 32, // default consistent with BlockManager::new + offload_page_size: 1024, // default consistent with BlockManager::new ..Default::default() } } @@ -230,8 +236,12 @@ impl BlockManagerBuilder { self.worker_id = id; self } - pub fn page_size(mut self, ps: usize) -> Self { - self.page_size = ps; + pub fn engine_page_size(mut self, ps: usize) -> Self { + self.engine_page_size = ps; + self + } + pub fn offload_page_size(mut self, ps: usize) -> Self { + self.offload_page_size = ps; self } pub fn leader(mut self, l: distributed::KvbmLeader) -> Self { @@ -267,7 +277,7 @@ impl BlockManagerBuilder { let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() .num_layers(1) .outer_dim(1) - .page_size(self.page_size) + .page_size(self.engine_page_size) .inner_dim(1) .build()?; diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs index 1cf58185bf..b3c3e2940b 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs @@ -176,7 +176,8 @@ impl KvbmWorker { let config = KvbmWorkerConfig::builder() .drt(drt) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size) + .engine_page_size(page_size) .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm.rs b/lib/bindings/python/rust/llm/block_manager/vllm.rs index 524af640ae..9fd5118a6d 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm.rs @@ -83,7 +83,10 @@ impl KvbmCacheManager { #[new] #[pyo3(signature = (block_manager))] pub fn new(block_manager: PyBlockManager) -> PyResult { - let slot_manager = Mutex::new(SlotManager::new(block_manager.block_size())); + let slot_manager = Mutex::new(SlotManager::new( + block_manager.engine_block_size(), + block_manager.offload_block_size(), + )); Ok(Self { block_manager, slot_manager, @@ -286,7 +289,7 @@ pub struct GenericSlotUpdate { pub num_new_tokens: usize, /// The number of new computed tokens in the request. - /// The `num_new_tokens / block_size` should be equal to the length of the `new_computed_blocks`, + /// The `num_new_tokens / engine_block_size` should be equal to the length of the `new_computed_blocks`, /// it may have a remainder for the partial block state. /// Note: this field is solely tied to the `new_computed_blocks` field and not used when `tokens_to_append` is provided. /// The name might be confusing, but the name matched the vLLM implementation. @@ -401,15 +404,17 @@ impl SlotError { pub struct SlotManager { slots: HashMap>>, - block_size: usize, + engine_block_size: usize, + offload_block_size: usize, } impl SlotManager { /// Creates a new slot manager. - pub fn new(block_size: usize) -> Self { + pub fn new(engine_block_size: usize, offload_block_size: usize) -> Self { Self { slots: HashMap::new(), - block_size, + engine_block_size, + offload_block_size, } } @@ -436,7 +441,7 @@ impl SlotManager { if !self.slots.contains_key(request_id) { self.slots.insert( request_id.clone(), - Slot::new(tokens.into(), self.block_size, salt_hash), + Slot::new(tokens.into(), self.engine_block_size, salt_hash), ); tracing::debug!( request_id, @@ -498,7 +503,7 @@ impl SlotManager { tracing::debug!( request_id, "applying {} cache-hit tokens", - blocks.len() * self.block_size + blocks.len() * self.engine_block_size ); slot.initialize_with_device_matches(blocks)?; } @@ -566,9 +571,9 @@ impl SlotManager { match self.slots.remove(request_id) { Some(slot) => { let isl = slot.num_tokens(SlotPosition::Prefill); - let isl_device = slot.num_blocks_cached_from_device() * self.block_size; - let isl_host = slot.num_blocks_cached_from_host() * self.block_size; - let isl_disk = slot.num_blocks_cached_from_disk() * self.block_size; + let isl_device = slot.num_blocks_cached_from_device() * self.engine_block_size; + let isl_host = slot.num_blocks_cached_from_host() * self.offload_block_size; + let isl_disk = slot.num_blocks_cached_from_disk() * self.offload_block_size; tracing::info!( request_id, "request complete isl: {} - cache hits: device: {}, host: {}, disk: {} - prefilled: {}", @@ -603,14 +608,14 @@ impl SlotManager { assert!(num_computed_tokens <= request_num_tokens); // early exit if we cannot match full block - if (request_num_tokens - num_computed_tokens) < self.block_size { + if (request_num_tokens - num_computed_tokens) < self.engine_block_size { return Ok((0, false)); } // num_computed_tokens represents the number of tokens already on the device // this much be a multiple of the block size - let num_device_blocks = num_computed_tokens / self.block_size; - debug_assert_eq!(num_computed_tokens % self.block_size, 0); + let num_device_blocks = num_computed_tokens / self.engine_block_size; + debug_assert_eq!(num_computed_tokens % self.engine_block_size, 0); // get the sequence hashes for the device matched tokens let sequence_hashes = slot.sequence_hashes(SlotPosition::All); @@ -661,7 +666,7 @@ impl SlotManager { return Ok((0, false)); } - let mut num_new_matched_tokens = num_matched_blocks * self.block_size; + let mut num_new_matched_tokens = num_matched_blocks * self.engine_block_size; // we are on a block boundary, so we need to throw away the last block if num_computed_tokens + num_new_matched_tokens == request_num_tokens { @@ -681,7 +686,7 @@ impl SlotManager { } // decrement the number of new matched tokens by the block size - num_new_matched_tokens -= self.block_size; + num_new_matched_tokens -= self.engine_block_size; } slot.store_onboard_blocks(host_blocks, disk_blocks); diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs index 9d7f10c840..3318044c34 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs @@ -79,7 +79,8 @@ pub trait Leader: Send + Sync + std::fmt::Debug { #[derive(Debug)] pub struct KvConnectorLeader { slot_manager: Arc>>, - block_size: usize, + engine_page_size: usize, + offload_page_size: usize, inflight_requests: HashSet, onboarding_slots: HashSet, iteration_counter: u64, @@ -90,7 +91,8 @@ impl KvConnectorLeader { fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -127,7 +129,8 @@ impl KvConnectorLeader { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -169,7 +172,8 @@ impl KvConnectorLeader { Self { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, @@ -204,7 +208,7 @@ impl Leader for KvConnectorLeader { ); // the number of device matched tokens should be less than or equal to the number of tokens in the request - debug_assert!(num_computed_tokens % self.block_size == 0); + debug_assert!(num_computed_tokens % self.engine_page_size == 0); let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot @@ -234,7 +238,7 @@ impl Leader for KvConnectorLeader { } // early exit if we cannot match full block - if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { + if (slot.sequence().total_tokens() - num_computed_tokens) < self.offload_page_size { return Ok((0, false)); } @@ -245,7 +249,9 @@ impl Leader for KvConnectorLeader { // return the number of external tokens that are ready for onboarding // we always return true here as we always asynchronously onboard matched blocks if let SlotState::OnboardStaged(num_external_tokens) = slot.state() { - debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0); + debug_assert!( + (num_computed_tokens + num_external_tokens) % self.offload_page_size == 0 + ); tracing::debug!( request_id = request_id, "scheduling onboarding for {} external tokens", @@ -289,7 +295,7 @@ impl Leader for KvConnectorLeader { // the second call will show num_external_tokens == 0 // this call is just letting us know the other blocks that are being used for the remainder of the prefill if num_external_tokens > 0 { - let num_computed_tokens = block_ids.len() * self.block_size - num_external_tokens; + let num_computed_tokens = block_ids.len() * self.engine_page_size - num_external_tokens; slot.record_cached_device_tokens(num_computed_tokens); slot.advance_computed_position(num_computed_tokens)?; @@ -549,11 +555,12 @@ pub struct PyKvConnectorLeader { #[pymethods] impl PyKvConnectorLeader { #[new] - #[pyo3(signature = (worker_id, drt, page_size, leader))] + #[pyo3(signature = (worker_id, drt, engine_page_size, offload_page_size, leader))] pub fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader: PyKvbmLeader, ) -> Self { let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD") @@ -562,10 +569,20 @@ impl PyKvConnectorLeader { let connector_leader: Box = if enable_kvbm_record { Box::new(recorder::KvConnectorLeaderRecorder::new( - worker_id, drt, page_size, leader, + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, )) } else { - Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)) + Box::new(KvConnectorLeader::new( + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, + )) }; Self { connector_leader } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs index 9c267a2f95..c19b332cac 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs @@ -88,7 +88,8 @@ impl KvConnectorLeaderRecorder { pub fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -143,7 +144,8 @@ impl KvConnectorLeaderRecorder { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -185,7 +187,8 @@ impl KvConnectorLeaderRecorder { let connector_leader = KvConnectorLeader { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 2397112267..d35b4b4a0a 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -13,6 +13,7 @@ use dynamo_llm::{ tokens::TokenBlock, }; use dynamo_runtime::utils::task::CriticalTaskExecutionHandle; +use smallvec::SmallVec; use tokio_util::sync::CancellationToken; use super::*; @@ -194,8 +195,9 @@ impl ConnectorSlotManager { kvbm_metrics: KvbmMetrics, ) -> Self { tracing::debug!( - "creating slot manager with block size: {}", - block_manager.block_size() + "creating slot manager with engine block size: {}, offload block size: {}", + block_manager.engine_block_size(), + block_manager.offload_block_size(), ); let (xfer_tx, xfer_rx) = mpsc::unbounded_channel(); @@ -317,7 +319,9 @@ pub struct VllmConnectorSlot { /// Phantom data to ensure the storage type is correct. block_manager: VllmBlockManager, - block_size: usize, + offload_block_size: usize, + engine_block_size: usize, + offload_block_size_ratio: usize, iteration_first_scheduled: Option, @@ -345,15 +349,20 @@ impl VllmConnectorSlot { xfer_tx: mpsc::UnboundedSender, ) -> Self { assert!(!tokens.is_empty(), "tokens must be non-empty"); - let block_size = block_manager.block_size(); - debug_assert!(block_size.is_power_of_two() && block_size <= 1024); - let sequence = TokenBlockSequence::new(tokens, block_size as u32, Some(salt_hash)); + let offload_block_size = block_manager.offload_block_size(); + let engine_block_size = block_manager.engine_block_size(); + let offload_block_size_ratio = block_manager.offload_block_size_ratio(); + debug_assert!(offload_block_size.is_power_of_two() && offload_block_size <= 1024); + debug_assert!(engine_block_size.is_power_of_two() && engine_block_size <= 1024); + let sequence = TokenBlockSequence::new(tokens, offload_block_size as u32, Some(salt_hash)); Self { request_id, sequence, block_manager, - block_size, + engine_block_size, + offload_block_size, + offload_block_size_ratio, xfer_tx, // default values state: SlotState::Initialized, @@ -501,12 +510,13 @@ impl Slot for VllmConnectorSlot { // we should have enough device blocks to cover the newly scheduled tokens let next_position = self.current_position + num_scheduled_tokens; + assert!( - next_position <= self.device_blocks.len() * self.block_size, - "next_position: {} > device_blocks.len() {} * block_size {}", + next_position <= self.device_blocks.len() * self.engine_block_size, + "next_position: {} > device_blocks.len() {} * engine_block_size {}", next_position, self.device_blocks.len(), - self.block_size + self.engine_block_size ); if next_position > self.sequence.total_tokens() { @@ -529,9 +539,9 @@ impl Slot for VllmConnectorSlot { // TODO(ryan) - apply policy let next_position = self.current_position + num_scheduled_tokens; - debug_assert!(next_position / self.block_size >= self.evaluated_blocks); + debug_assert!(next_position / self.engine_block_size >= self.evaluated_blocks); - let num_candidate_blocks = (next_position / self.block_size) - self.evaluated_blocks; + let num_candidate_blocks = (next_position / self.engine_block_size) - self.evaluated_blocks; tracing::debug!( "evaluating policy with the following parameters: state: {:?}; current_position: {}; num_candidate_blocks: {}; num_scheduled_tokens: {}", @@ -541,20 +551,22 @@ impl Slot for VllmConnectorSlot { num_scheduled_tokens ); - if num_candidate_blocks != 0 { + if num_candidate_blocks / self.offload_block_size_ratio != 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. // for now, offload all the blocks to the host + let aligned_candidates = (num_candidate_blocks / self.offload_block_size_ratio) + * self.offload_block_size_ratio; let offload_block_ids: Vec = self .device_blocks .iter() .skip(self.evaluated_blocks) - .take(num_candidate_blocks) + .take(aligned_candidates) .copied() .collect::>(); assert_eq!( offload_block_ids.len(), - num_candidate_blocks, + aligned_candidates, "device block overflow - candidate blocks exceed block count at offset {}", self.evaluated_blocks ); @@ -563,15 +575,15 @@ impl Slot for VllmConnectorSlot { .sequence .blocks() .iter() - .skip(self.evaluated_blocks) - .take(num_candidate_blocks) + .skip(self.evaluated_blocks / self.offload_block_size_ratio) + .take(aligned_candidates / self.offload_block_size_ratio) .cloned() .collect::>(); self.offload_blocks(&offload_block_ids, &offload_token_blocks) .expect("failed to offload blocks"); - self.evaluated_blocks += num_candidate_blocks; + self.evaluated_blocks += aligned_candidates; } // done applying policy @@ -640,11 +652,11 @@ impl Slot for VllmConnectorSlot { is_new_request && computed_position > 0 && self.evaluated_blocks == 0; if maybe_have_device_matched_blocks { - self.evaluated_blocks = (computed_position + 1) / self.block_size; + self.evaluated_blocks = (computed_position + 1) / self.offload_block_size; } - let num_candidate_blocks = - ((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks); + let num_candidate_blocks = ((computed_position + 1) / self.offload_block_size) + .saturating_sub(self.evaluated_blocks); if num_candidate_blocks > 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. @@ -745,7 +757,7 @@ impl Slot for VllmConnectorSlot { tracing::info!("slot is in the Preempted state; we get another chance to match"); } - let block_size = self.block_manager.block_size(); + let block_size = self.block_manager.offload_block_size(); let num_computed_blocks = num_computed_tokens / block_size; debug_assert!(num_computed_tokens % block_size == 0); @@ -863,28 +875,40 @@ impl Slot for VllmConnectorSlot { } debug_assert_eq!(self.evaluated_blocks, 0); - debug_assert_eq!(self.current_position % self.block_size, 0); - debug_assert_eq!(num_external_tokens % self.block_size, 0); + debug_assert_eq!(self.current_position % self.engine_block_size, 0); + debug_assert_eq!(num_external_tokens % self.engine_block_size, 0); - let num_computed_blocks = self.current_position / self.block_size; + let num_computed_blocks = self.current_position / self.offload_block_size; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks = num_computed_blocks; + self.evaluated_blocks = num_computed_blocks * self.offload_block_size_ratio; + + tracing::debug!( + "trigger_onboarding: self.device_blocks.len()={:?}", + self.device_blocks.len() + ); // match the host / disk blocks to the newly assigned mutable device blocks if let Some(host_blocks) = self.staging_from_host.take() { let num_host_blocks = host_blocks.len(); + tracing::debug!( + "trigger_onboarding: host_blocks.len()={:?}", + num_host_blocks + ); // get device block ids let dst_block_ids = self .device_blocks .iter() .skip(self.evaluated_blocks) - .take(num_host_blocks) + .take(num_host_blocks * self.offload_block_size_ratio) .copied() .collect::>(); - debug_assert_eq!(dst_block_ids.len(), num_host_blocks); + debug_assert_eq!( + dst_block_ids.len(), + num_host_blocks * self.offload_block_size_ratio + ); // construct offload requests - transfer engine + worker let src_blocks = Box::new(AnyImmutableBlocks::::new(host_blocks)); @@ -892,22 +916,30 @@ impl Slot for VllmConnectorSlot { self.onboard_blocks(src_blocks, dst_block_ids)?; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks += num_host_blocks; + self.evaluated_blocks += num_host_blocks * self.offload_block_size_ratio; } if let Some(disk_blocks) = self.staging_from_disk.take() { let num_disk_blocks = disk_blocks.len(); + tracing::debug!( + "trigger_onboarding: dist_blocks.len()={:?}", + num_disk_blocks + ); + // get device block ids let dst_block_ids = self .device_blocks .iter() - .skip(self.evaluated_blocks) + .skip(self.evaluated_blocks * self.offload_block_size_ratio) .take(num_disk_blocks) .copied() .collect::>(); - debug_assert_eq!(dst_block_ids.len(), num_disk_blocks); + debug_assert_eq!( + dst_block_ids.len(), + num_disk_blocks * self.offload_block_size_ratio + ); // construct offload requests - transfer engine + worker let src_blocks = Box::new(AnyImmutableBlocks::::new(disk_blocks)); @@ -915,7 +947,7 @@ impl Slot for VllmConnectorSlot { self.onboard_blocks(src_blocks, dst_block_ids)?; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks += num_disk_blocks; + self.evaluated_blocks += num_disk_blocks * self.offload_block_size_ratio; } self.state = SlotState::Onboarding(num_external_tokens); @@ -977,7 +1009,7 @@ impl VllmConnectorSlot { block_ids: &[BlockId], token_blocks: &[TokenBlock], ) -> Result<(), SlotError> { - assert!(block_ids.len() == token_blocks.len()); + assert!(block_ids.len() == token_blocks.len() * self.offload_block_size_ratio); let operation_id = uuid::Uuid::new_v4(); let xfer_req = LocalTransferRequest::Offload(LocalOffloadRequest::new( @@ -1007,7 +1039,7 @@ impl VllmConnectorSlot { tracing::debug!( request_id = self.request_id, operation_id = %operation_id, - "offloading {} blocks to host", + "offloading {} device blocks to host", block_ids.len() ); @@ -1166,6 +1198,7 @@ impl LocalTransferEngine { // Clone resources needed for tasks let block_manager_offload = self.block_manager.clone(); + let block_manager_onboard = self.block_manager.clone(); let leader_offload = Arc::clone(&self.leader); let leader_onboard = Arc::clone(&self.leader); @@ -1179,9 +1212,13 @@ impl LocalTransferEngine { tracing::debug!("LocalOnboardTask: received cancellation signal"); break; } - if let Err(e) = - process_onboard_request(req, &leader_onboard, kvbm_metrics_onboard.clone()) - .await + if let Err(e) = process_onboard_request( + req, + &block_manager_onboard, + &leader_onboard, + kvbm_metrics_onboard.clone(), + ) + .await { tracing::error!("LocalOnboardTask: error processing request: {:?}", e); } @@ -1285,24 +1322,30 @@ async fn process_offload_request( let request_id = &offload_req.request_id; let operation_id = &offload_req.operation_id; + let engine_blocks_per_offload_block = + block_manager.offload_block_size() / block_manager.engine_block_size(); tracing::debug!( - "Processing offload request for {} blocks", - offload_req.block_ids.len() + "Processing offload request for {} blocks, engine_blocks_per_offload_block = {} ({}/{})", + offload_req.block_ids.len(), + engine_blocks_per_offload_block, + block_manager.offload_block_size(), + block_manager.engine_block_size(), ); // 1. Acquire mutable host blocks let host_blocks = block_manager .host() .unwrap() - .allocate_blocks(offload_req.block_ids.len()) + .allocate_blocks(offload_req.block_ids.len() / engine_blocks_per_offload_block) .await?; let token_blocks = offload_req.token_blocks; let host_block_ids: Vec = host_blocks.iter().map(|b| b.block_id()).collect(); - let block_pairs: Vec<(usize, usize)> = offload_req + let block_pairs: Vec<_> = offload_req .block_ids - .into_iter() + .chunks(engine_blocks_per_offload_block) .zip(host_block_ids.into_iter()) + .map(|(src, dst)| (SmallVec::from(src), SmallVec::from([dst]))) .collect(); tracing::debug!( @@ -1386,6 +1429,7 @@ async fn process_offload_request( async fn process_onboard_request( onboard_req: LocalOnboardRequest, + block_manager: &VllmBlockManager, leader: &Arc, kvbm_metrics: KvbmMetrics, ) -> anyhow::Result<()> { @@ -1402,16 +1446,22 @@ async fn process_onboard_request( let request_id = &onboard_req.request_id; let operation_id = &onboard_req.operation_id; + let engine_blocks_per_offload_block = + block_manager.offload_block_size() / block_manager.engine_block_size(); // extract source block ids let src_block_ids = onboard_req.src_blocks.block_ids(); // create block pairs - let block_pairs = src_block_ids + let block_pairs: Vec<_> = src_block_ids .iter() - .zip(onboard_req.dst_block_ids.iter()) - .map(|(src, dst)| (*src, *dst)) - .collect::>(); + .zip( + onboard_req + .dst_block_ids + .chunks(engine_blocks_per_offload_block), + ) + .map(|(src, dst)| (SmallVec::from([*src]), SmallVec::from(dst))) + .collect(); // create transfer request let block_xfer_req = BlockTransferRequest { diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index 7f3b51d6dd..3f9803aecc 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -54,7 +54,9 @@ pub trait Leader: Send + Sync + std::fmt::Debug { #[derive(Debug)] pub struct KvConnectorLeader { slot_manager: Arc>>, - block_size: usize, + engine_page_size: usize, + #[allow(unused)] + offload_page_size: usize, inflight_requests: HashSet, onboarding_slots: HashSet, iteration_counter: u64, @@ -66,7 +68,8 @@ impl KvConnectorLeader { fn new( worker_id: u64, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -103,7 +106,8 @@ impl KvConnectorLeader { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -134,7 +138,8 @@ impl KvConnectorLeader { Self { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, @@ -171,7 +176,7 @@ impl Leader for KvConnectorLeader { // TRTLLM could match partial blocks if enable_partial_reuse = True, // immediately return 0 to simplify things. - if num_computed_tokens % self.block_size != 0 { + if num_computed_tokens % self.engine_page_size != 0 { return Ok((0, false)); } @@ -181,7 +186,7 @@ impl Leader for KvConnectorLeader { .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; // early exit if we cannot match full block - if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { + if (slot.sequence().total_tokens() - num_computed_tokens) < self.engine_page_size { let total_tokens = slot.sequence().total_tokens(); tracing::debug!( "total_tokens in sequence: {total_tokens}; num_computed_tokens: {num_computed_tokens}; can not match full block." @@ -196,7 +201,7 @@ impl Leader for KvConnectorLeader { // return the number of external tokens that are ready for onboarding // we always return true here as we always asynchronously onboard matched blocks if let SlotState::OnboardStaged(num_external_tokens) = slot.state() { - debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0); + debug_assert!((num_computed_tokens + num_external_tokens) % self.engine_page_size == 0); tracing::debug!( request_id = request_id, "scheduling onboarding for {} external tokens", @@ -447,15 +452,21 @@ pub struct PyTrtllmKvConnectorLeader { #[pymethods] impl PyTrtllmKvConnectorLeader { #[new] - #[pyo3(signature = (worker_id, drt, page_size, leader))] + #[pyo3(signature = (worker_id, drt, engine_page_size, offload_page_size, leader))] pub fn new( worker_id: u64, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader: PyKvbmLeader, ) -> Self { - let connector_leader: Box = - Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)); + let connector_leader: Box = Box::new(KvConnectorLeader::new( + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, + )); Self { connector_leader } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs index 5b017db93a..0a7ad81033 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs @@ -131,7 +131,8 @@ impl Worker for KvConnectorWorker { let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size) + .engine_page_size(page_size) .tensors(kv_cache_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index 3c498a3d8e..60f5046329 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -73,6 +73,9 @@ pub struct KvConnectorWorker { /// cuda events created by the python side layer_events: Vec, + + /// Ratio between offload block size and engine block size + offload_block_size_ratio: usize, } impl KvConnectorWorker { @@ -111,6 +114,7 @@ impl KvConnectorWorker { layers_complete: 0, kv_cache_layers: Vec::new(), layer_events: Vec::new(), + offload_block_size_ratio: 32, // Default value }) } } @@ -196,7 +200,8 @@ impl Worker for KvConnectorWorker { let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size * self.offload_block_size_ratio) + .engine_page_size(page_size) .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py index 8c69909c48..19546a3856 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py @@ -64,7 +64,7 @@ def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): print(f"KvConnectorLeader initialized with engine_id: {engine_id}") self._connector = RustKvConnectorLeader( - engine_id, self.drt, vllm_config.cache_config.block_size, leader + engine_id, self.drt, vllm_config.cache_config.block_size, 1024, leader ) # KV Connector diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index d81cd42bf8..86e8a74a15 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -141,6 +141,7 @@ json-five = { version = "0.3" } zeromq = "0.4.1" rmp-serde = "1.3" ahash = "0.8.12" +smallvec = { version = "1.15.1", features = ["serde"] } [dev-dependencies] approx = "0.5" diff --git a/lib/llm/src/block_manager.rs b/lib/llm/src/block_manager.rs index 1878d68e30..d3c2f165cc 100644 --- a/lib/llm/src/block_manager.rs +++ b/lib/llm/src/block_manager.rs @@ -99,7 +99,9 @@ impl Drop for CancelOnLastDrop { pub struct KvBlockManager { state: Arc>, _cancellation_token: Arc, - block_size: usize, + offload_block_size: usize, + engine_block_size: usize, + offload_block_size_ratio: usize, } impl Clone @@ -109,15 +111,27 @@ impl Clone Self { state: self.state.clone(), _cancellation_token: self._cancellation_token.clone(), - block_size: self.block_size, + offload_block_size: self.offload_block_size, + engine_block_size: self.engine_block_size, + offload_block_size_ratio: self.offload_block_size_ratio, } } } impl KvBlockManager { /// Get the block size - pub fn block_size(&self) -> usize { - self.block_size + pub fn engine_block_size(&self) -> usize { + self.engine_block_size + } + + /// Get the block size + pub fn offload_block_size(&self) -> usize { + self.offload_block_size + } + + /// Get the offload block size ratio + pub fn offload_block_size_ratio(&self) -> usize { + self.offload_block_size_ratio } /// Get a reference to the disk block pool @@ -171,6 +185,7 @@ impl KvBlockManager { let _cancellation_token = build_cancel_token(&mut config); let block_size = config.model.page_size; + let offload_block_size_ratio = config.offload_block_size_ratio; // Create the internal state let state = state::KvBlockManagerState::::new(config).await?; @@ -178,7 +193,9 @@ impl KvBlockManager { Ok(Self { state, _cancellation_token, - block_size, + engine_block_size: block_size, + offload_block_size: block_size * offload_block_size_ratio, + offload_block_size_ratio, }) } @@ -215,6 +232,7 @@ impl KvBlockManager { impl KvBlockManager, Metadata> { pub async fn new(mut config: KvBlockManagerConfig, logical_resources: R) -> Result { let block_size = config.model.page_size; + let offload_block_size_ratio = config.offload_block_size_ratio; let _cancellation_token = build_cancel_token(&mut config); @@ -227,7 +245,9 @@ impl KvBlockManager: Send + Sync + 'static + std::fmt::Debug { /// The index of the block in the block set fn block_id(&self) -> BlockId; + /// The index of the block in the block set + fn fragment_block_id(&self, idx: usize) -> BlockId; + /// The identifier of the block set within the worker fn block_set_id(&self) -> usize; @@ -28,6 +31,12 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { /// Whether the block is fully contiguous fn is_fully_contiguous(&self) -> bool; + /// Is the block fragmented + fn is_fragmented(&self) -> bool; + + /// Returns the number of fragments + fn num_fragments(&self) -> usize; + /// Returns the number of layers in the block fn num_layers(&self) -> usize; @@ -57,6 +66,19 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { } } + /// Get a read-only view of this block's storage for a layer + fn layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + match self.is_local() { + Some(views) => views.local_layer_view_fragment(fragment_idx, layer_idx, outer_idx), + None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks), + } + } + /// Get a mutable view of this block's storage for a layer fn layer_view_mut( &mut self, @@ -69,6 +91,19 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { } } + /// Get a mutable view of this block's storage for a layer + fn layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + match self.is_local_mut() { + Some(views) => views.local_layer_view_fragment_mut(fragment_idx, layer_idx, outer_idx), + None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks), + } + } + /// Get a read-only view of this block's storage fn block_view(&self) -> BlockResult> { match self.is_local() { @@ -94,6 +129,14 @@ pub trait BlockDataViews { outer_idx: usize, ) -> BlockResult>; + /// Get a read-only view of this block's fragment storage for a layer + fn local_layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult>; + /// Get a mutable view of this block's storage for a layer fn local_layer_view_mut( &mut self, @@ -101,6 +144,14 @@ pub trait BlockDataViews { outer_idx: usize, ) -> BlockResult>; + /// Get a mutable view of this block's fragment storage for a layer + fn local_layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult>; + /// Get a read-only view of this block's storage fn local_block_view(&self) -> BlockResult>; diff --git a/lib/llm/src/block_manager/block/data/local.rs b/lib/llm/src/block_manager/block/data/local.rs index f1679f5eac..f88ec57d99 100644 --- a/lib/llm/src/block_manager/block/data/local.rs +++ b/lib/llm/src/block_manager/block/data/local.rs @@ -2,12 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; +use smallvec::SmallVec; /// Individual block storage #[derive(Debug)] pub struct LocalBlockData { layout: Arc>, - block_idx: usize, + // The special case of multiple block IDx will only happen for small KV cache blocks on device. + // offloaded blocks will be large and `block_idxs.len() == 1` will hold for them. + block_idxs: SmallVec<[usize; 1]>, block_set_idx: usize, worker_id: WorkerID, } @@ -16,7 +19,34 @@ impl Clone for LocalBlockData { fn clone(&self) -> Self { Self { layout: self.layout.clone(), - block_idx: self.block_idx, + block_idxs: self.block_idxs.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id, + } + } +} + +pub struct LocalBlockDataBase { + layout: Arc>, + block_set_idx: usize, + worker_id: WorkerID, +} + +impl Clone for LocalBlockDataBase { + fn clone(&self) -> Self { + Self { + layout: self.layout.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id, + } + } +} + +impl LocalBlockDataBase { + pub(crate) fn get_data(&self, block_idxs: SmallVec<[usize; 1]>) -> LocalBlockData { + LocalBlockData { + layout: self.layout.clone(), + block_idxs, block_set_idx: self.block_set_idx, worker_id: self.worker_id, } @@ -36,11 +66,19 @@ where ) -> Self { Self { layout, - block_idx, + block_idxs: [block_idx].iter().map(|x| *x).collect(), block_set_idx, worker_id, } } + + pub(crate) fn base(&self) -> LocalBlockDataBase { + LocalBlockDataBase { + layout: self.layout.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id.clone(), + } + } } impl BlockDataExt for LocalBlockData @@ -49,7 +87,22 @@ where { #[inline(always)] fn block_id(&self) -> BlockId { - self.block_idx + if self.block_idxs.len() == 1 { + self.block_idxs[0] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::block_id() for fragmented block"); + } + } + + #[inline(always)] + fn fragment_block_id(&self, idx: usize) -> BlockId { + if self.block_idxs.len() != 1 { + self.block_idxs[idx] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::fragment_block_id() for non--fragmented block"); + } } #[inline(always)] @@ -71,6 +124,14 @@ where self.layout.layout_type() == LayoutType::FullyContiguous } + fn is_fragmented(&self) -> bool { + self.block_idxs.len() != 1 + } + + fn num_fragments(&self) -> usize { + self.block_idxs.len() + } + fn num_layers(&self) -> usize { self.layout.num_layers() } @@ -104,7 +165,22 @@ impl BlockDataViews for LocalBlockData { ) -> BlockResult> { let mr = self .layout - .memory_region(self.block_idx, layer_idx, outer_idx)?; + .memory_region(self.block_id(), layer_idx, outer_idx)?; + let storage_type = mr.storage_type(); + unsafe { view::LayerView::new(self, mr.addr(), mr.size(), storage_type) } + } + + fn local_layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + let mr = self.layout.memory_region( + self.fragment_block_id(fragment_idx), + layer_idx, + outer_idx, + )?; let storage_type = mr.storage_type(); unsafe { view::LayerView::new(self, mr.addr(), mr.size(), storage_type) } } @@ -116,13 +192,27 @@ impl BlockDataViews for LocalBlockData { ) -> BlockResult> { let mr = self .layout - .memory_region(self.block_idx, layer_idx, outer_idx)?; + .memory_region(self.block_id(), layer_idx, outer_idx)?; + unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) } + } + + fn local_layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + let mr = self.layout.memory_region( + self.fragment_block_id(fragment_idx), + layer_idx, + outer_idx, + )?; unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) } } fn local_block_view(&self) -> BlockResult> { if self.is_fully_contiguous() { - let mr = self.layout.memory_region(self.block_idx, 0, 0)?; + let mr = self.layout.memory_region(self.block_id(), 0, 0)?; let offset = mr.addr(); let size = mr.size() .checked_mul(self.num_layers()) @@ -144,7 +234,7 @@ impl BlockDataViews for LocalBlockData { fn local_block_view_mut(&mut self) -> BlockResult> { if self.is_fully_contiguous() { - let mr = self.layout.memory_region(self.block_idx, 0, 0)?; + let mr = self.layout.memory_region(self.block_id(), 0, 0)?; let offset = mr.addr(); let size = mr.size() .checked_mul(self.num_layers()) diff --git a/lib/llm/src/block_manager/block/data/logical.rs b/lib/llm/src/block_manager/block/data/logical.rs index 3aee97bc3e..81d9357953 100644 --- a/lib/llm/src/block_manager/block/data/logical.rs +++ b/lib/llm/src/block_manager/block/data/logical.rs @@ -5,13 +5,13 @@ use super::*; pub mod distributed_leader_worker; pub mod null; - use crate::block_manager::block::{ BlockDataProvider, ReadableBlock, WritableBlock, transfer::{TransferContext, TransferError, WriteToStrategy}, }; use crate::block_manager::locality::Logical; use crate::block_manager::storage::{self, nixl::NixlDescriptor}; +use smallvec::SmallVec; use tokio::sync::oneshot; pub enum LogicalKinds { @@ -37,7 +37,7 @@ pub trait LogicalResources: Clone + Send + Sync + 'static + std::fmt::Debug { /// Individual block storage - cannot be cloned to ensure uniqueness #[derive(Debug)] pub struct LogicalBlockData { - block_id: BlockId, + block_ids: SmallVec<[BlockId; 1]>, block_set_id: usize, worker_id: WorkerID, resources: Arc, @@ -56,7 +56,7 @@ impl LogicalBlockData { page_size: usize, ) -> Self { Self { - block_id, + block_ids: [block_id].iter().map(|x| *x).collect(), block_set_id, worker_id, resources, @@ -73,7 +73,21 @@ impl LogicalBlockData { impl BlockDataExt for LogicalBlockData { fn block_id(&self) -> BlockId { - self.block_id + if self.block_ids.len() == 1 { + self.block_ids[0] + } else { + panic!("used LocalBlockData::block_id() for fragmented block"); + } + } + + #[inline(always)] + fn fragment_block_id(&self, idx: usize) -> BlockId { + if self.block_ids.len() != 1 { + self.block_ids[idx] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::fragment_block_id() for non-fragmented block"); + } } fn block_set_id(&self) -> usize { @@ -92,6 +106,14 @@ impl BlockDataExt for LogicalBlockData unimplemented!() } + fn is_fragmented(&self) -> bool { + unimplemented!() + } + + fn num_fragments(&self) -> usize { + unimplemented!() + } + fn num_layers(&self) -> usize { unimplemented!() } diff --git a/lib/llm/src/block_manager/block/transfer/context.rs b/lib/llm/src/block_manager/block/transfer/context.rs index 36ad83a4c0..7327cf9615 100644 --- a/lib/llm/src/block_manager/block/transfer/context.rs +++ b/lib/llm/src/block_manager/block/transfer/context.rs @@ -161,6 +161,7 @@ pub struct PoolConfig { pub max_concurrent_transfers: usize, pub max_transfer_batch_size: usize, pub num_outer_components: usize, + pub offload_block_size_ratio: usize, pub num_layers: usize, } @@ -200,6 +201,7 @@ impl TransferContext { let buffer_size = max_blocks_per_transfer * config.num_outer_components * config.num_layers + * config.offload_block_size_ratio * std::mem::size_of::(); tracing::info!( diff --git a/lib/llm/src/block_manager/block/transfer/cuda.rs b/lib/llm/src/block_manager/block/transfer/cuda.rs index fdc345c2ff..c8bff9980a 100644 --- a/lib/llm/src/block_manager/block/transfer/cuda.rs +++ b/lib/llm/src/block_manager/block/transfer/cuda.rs @@ -75,6 +75,7 @@ fn collect_kv_addresses( destinations: &[Destination], num_layers: usize, num_outer_dims: usize, + num_fragments: usize, ) -> Result<(Vec, Vec), TransferError> where Source: BlockDataProvider, @@ -86,7 +87,7 @@ where )); } - let total_address_pairs = sources.len() * num_layers * num_outer_dims; + let total_address_pairs = sources.len() * num_layers * num_outer_dims * num_fragments; let mut src_addresses = Vec::with_capacity(total_address_pairs); let mut dst_addresses = Vec::with_capacity(total_address_pairs); @@ -99,12 +100,46 @@ where for (src_data, dst_data) in src_block_data.iter().zip(dst_block_data.iter()) { for layer_idx in 0..num_layers { for outer_idx in 0..num_outer_dims { - let src_view = src_data.layer_view(layer_idx, outer_idx)?; - let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + if src_data.is_fragmented() { + let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + let mut dst_ptr = unsafe { dst_view.as_ptr() }; + let n = src_data.num_fragments(); - unsafe { - src_addresses.push(src_view.as_ptr() as u64); - dst_addresses.push(dst_view.as_ptr() as u64); + for i in 0..n { + let src_view = src_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size() * n, dst_view.size()); + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + + dst_ptr = dst_ptr.add(src_view.size()); + } + } + } else if dst_data.is_fragmented() { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut src_ptr = unsafe { src_view.as_ptr() }; + let n = dst_data.num_fragments(); + + for i in 0..n { + let dst_view = dst_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size(), dst_view.size() * n); + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + + src_ptr = src_ptr.add(dst_view.size()); + } + } + } else { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + } } } } @@ -183,31 +218,48 @@ struct CachedBlockDimensions { num_layers: usize, num_outer_dims: usize, layer_size: usize, + num_fragments: usize, } static BLOCK_DIMENSIONS_CACHE: OnceLock = OnceLock::new(); -fn get_cached_block_dimensions( - block: &T, +fn get_cached_block_dimensions( + src_block: &T, + dst_block: &S, ) -> Result { Ok(*BLOCK_DIMENSIONS_CACHE - .get_or_init(|| calculate_block_dimensions_from_layout(block).unwrap())) + .get_or_init(|| calculate_block_dimensions_from_layout(src_block, dst_block).unwrap())) } -fn calculate_block_dimensions_from_layout( - block: &T, +fn calculate_block_dimensions_from_layout( + src_block: &T, + dst_block: &S, ) -> Result { - let block_data = block.block_data(); + let src_block_data = src_block.block_data(); + let dst_block_data = dst_block.block_data(); // Get dimensions directly from layout (pre-computed values) - let num_layers = block_data.num_layers(); - let num_outer_dims = block_data.num_outer_dims(); - let layer_size = block_data.layer_view(0, 0).map(|v| v.size()).unwrap_or(0); + let num_layers = src_block_data.num_layers(); + let num_outer_dims = src_block_data.num_outer_dims(); + let num_fragments = src_block_data.num_fragments(); + let layer_size = if dst_block_data.is_fragmented() { + src_block_data + .layer_view(0, 0) + .map(|v| v.size() / dst_block_data.num_fragments()) + } else if src_block_data.is_fragmented() { + src_block_data + .layer_view_fragment(0, 0, 0) + .map(|v| v.size()) + } else { + src_block_data.layer_view(0, 0).map(|v| v.size()) + } + .unwrap_or(0); Ok(CachedBlockDimensions { num_layers, num_outer_dims, layer_size, + num_fragments, }) } @@ -223,18 +275,24 @@ where { let _context_guard = stream.context().bind_to_thread(); // Get cached dimensions (calculated once per program lifetime!) - let dims = get_cached_block_dimensions(&sources[0])?; + let dims = get_cached_block_dimensions(&sources[0], &destinations[0])?; // Use cached dimensions - let (src_addresses, dst_addresses) = - collect_kv_addresses(sources, destinations, dims.num_layers, dims.num_outer_dims)?; + let (src_addresses, dst_addresses) = collect_kv_addresses( + sources, + destinations, + dims.num_layers, + dims.num_outer_dims, + dims.num_fragments, + )?; tracing::debug!( - "Using vectorized_copy for {} blocks [{}L×{}O×{}B], {} address pairs", + "Using vectorized_copy for {} blocks [{}L×{}O×{}Bx{}F], {} address pairs", sources.len(), dims.num_layers, dims.num_outer_dims, dims.layer_size, + dims.num_fragments, src_addresses.len() ); @@ -343,18 +401,50 @@ where for layer_idx in layer_range { for outer_idx in 0..src_data.num_outer_dims() { - let src_view = src_data.layer_view(layer_idx, outer_idx)?; - let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + if src_data.is_fragmented() { + let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + let mut dst_ptr = unsafe { dst_view.as_mut_ptr() }; + let n = src_data.num_fragments(); - debug_assert_eq!(src_view.size(), dst_view.size()); + for i in 0..n { + let src_view = src_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size() * n, dst_view.size()); - unsafe { - memcpy_fn( - src_view.as_ptr(), - dst_view.as_mut_ptr(), - src_view.size(), - stream, - )?; + unsafe { + memcpy_fn(src_view.as_ptr(), dst_ptr, src_view.size(), stream)?; + + dst_ptr = dst_ptr.add(src_view.size()); + } + } + } else if dst_data.is_fragmented() { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut src_ptr = unsafe { src_view.as_ptr() }; + let n = dst_data.num_fragments(); + + for i in 0..n { + let mut dst_view = dst_data.layer_view_fragment_mut(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size(), dst_view.size() * n); + + unsafe { + memcpy_fn(src_ptr, dst_view.as_mut_ptr(), dst_view.size(), stream)?; + + src_ptr = src_ptr.add(dst_view.size()); + } + } + } else { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + + debug_assert_eq!(src_view.size(), dst_view.size()); + + unsafe { + memcpy_fn( + src_view.as_ptr(), + dst_view.as_mut_ptr(), + src_view.size(), + stream, + )?; + } } } } diff --git a/lib/llm/src/block_manager/config.rs b/lib/llm/src/block_manager/config.rs index d9aa5e70be..4c344d4079 100644 --- a/lib/llm/src/block_manager/config.rs +++ b/lib/llm/src/block_manager/config.rs @@ -194,6 +194,12 @@ pub struct KvBlockManagerConfig { /// Channel to reset the block manager to a specific cache level #[builder(default)] pub block_reset_channel: Option, + + /// Ratio between offload block size and engine block size + /// Offload blocks are larger to enable more efficient I/O operations + #[builder(default = "32")] + #[validate(range(min = 1))] + pub offload_block_size_ratio: usize, } impl KvBlockManagerConfig { diff --git a/lib/llm/src/block_manager/distributed/transfer.rs b/lib/llm/src/block_manager/distributed/transfer.rs index fb1c7f452c..d566304e3a 100644 --- a/lib/llm/src/block_manager/distributed/transfer.rs +++ b/lib/llm/src/block_manager/distributed/transfer.rs @@ -15,6 +15,7 @@ use crate::block_manager::{ block::{ Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock, data::local::LocalBlockData, + data::local::LocalBlockDataBase, locality, transfer::{TransferContext, WriteTo, WriteToStrategy}, }, @@ -28,7 +29,6 @@ use async_trait::async_trait; use std::{any::Any, sync::Arc}; type LocalBlock = Block; -type LocalBlockDataList = Vec>; /// A batching wrapper for connector transfers to prevent resource exhaustion. /// Splits large transfers into smaller batches that can be handled by the resource pools. @@ -86,9 +86,9 @@ impl ConnectorTransferBatcher { /// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s. #[derive(Clone)] pub struct BlockTransferHandler { - device: Option>, - host: Option>, - disk: Option>, + device: Option>, + host: Option>, + disk: Option>, context: Arc, scheduler_client: Option, batcher: ConnectorTransferBatcher, @@ -105,38 +105,36 @@ impl BlockTransferHandler { // add worker-connector scheduler client here ) -> Result { Ok(Self { - device: Self::get_local_data(device_blocks), - host: Self::get_local_data(host_blocks), - disk: Self::get_local_data(disk_blocks), + device: Self::get_local_data_base(device_blocks), + host: Self::get_local_data_base(host_blocks), + disk: Self::get_local_data_base(disk_blocks), context, scheduler_client, batcher: ConnectorTransferBatcher::new(), }) } - fn get_local_data( + fn get_local_data_base( blocks: Option>>, - ) -> Option> { - blocks.map(|blocks| { - blocks - .into_iter() - .map(|b| { - let block_data = b.block_data() as &dyn Any; - - block_data - .downcast_ref::>() - .unwrap() - .clone() - }) - .collect() - }) + ) -> Option> { + let Some(vec) = blocks else { return None }; + if let Some(b) = vec.first() { + let block_data = b.block_data() as &dyn Any; + let block_data = block_data + .downcast_ref::>() + .unwrap() + .clone(); + Some(block_data.base()) + } else { + None + } } /// Initiate a transfer between two pools. async fn begin_transfer( &self, - source_pool_list: &Option>, - target_pool_list: &Option>, + source_pool_base: &Option>, + target_pool_base: &Option>, request: BlockTransferRequest, ) -> Result> where @@ -150,23 +148,23 @@ impl BlockTransferHandler { LocalBlockData: BlockDataProvider, LocalBlockData: BlockDataProviderMut, { - let Some(source_pool_list) = source_pool_list else { + let Some(source_pool_base) = source_pool_base else { return Err(anyhow::anyhow!("Source pool manager not initialized")); }; - let Some(target_pool_list) = target_pool_list else { + let Some(target_pool_base) = target_pool_base else { return Err(anyhow::anyhow!("Target pool manager not initialized")); }; // Extract the `from` and `to` indices from the request. - let source_idxs = request.blocks().iter().map(|(from, _)| *from); - let target_idxs = request.blocks().iter().map(|(_, to)| *to); + let source_idxs = request.blocks().iter().map(|(from, _)| from); + let target_idxs = request.blocks().iter().map(|(_, to)| to); // Get the blocks corresponding to the indices. let sources: Vec> = source_idxs - .map(|idx| source_pool_list[idx].clone()) + .map(|idx| source_pool_base.get_data(idx.clone())) .collect(); let mut targets: Vec> = target_idxs - .map(|idx| target_pool_list[idx].clone()) + .map(|idx| target_pool_base.get_data(idx.clone())) .collect(); // Perform the transfer, and return the notifying channel. diff --git a/lib/llm/src/block_manager/distributed/utils.rs b/lib/llm/src/block_manager/distributed/utils.rs index 5798a87fbb..573de49ed5 100644 --- a/lib/llm/src/block_manager/distributed/utils.rs +++ b/lib/llm/src/block_manager/distributed/utils.rs @@ -3,6 +3,7 @@ use derive_getters::Getters; use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; use crate::block_manager::connector::protocol::LeaderTransferRequest; @@ -33,7 +34,7 @@ pub struct ConnectorRequestLeader { pub struct BlockTransferRequest { pub from_pool: BlockTransferPool, pub to_pool: BlockTransferPool, - pub blocks: Vec<(usize, usize)>, + pub blocks: Vec<(SmallVec<[usize; 1]>, SmallVec<[usize; 1]>)>, #[serde(skip_serializing_if = "Option::is_none")] pub connector_req: Option, @@ -49,7 +50,15 @@ impl BlockTransferRequest { Self { from_pool, to_pool, - blocks, + blocks: blocks + .into_iter() + .map(|(src, dst)| { + ( + [src].iter().map(|x| *x).collect(), + [dst].iter().map(|x| *x).collect(), + ) + }) + .collect(), connector_req: None, } } @@ -63,7 +72,15 @@ impl BlockTransferRequest { Self { from_pool, to_pool, - blocks, + blocks: blocks + .into_iter() + .map(|(src, dst)| { + ( + [src].iter().map(|x| *x).collect(), + [dst].iter().map(|x| *x).collect(), + ) + }) + .collect(), connector_req: Some(connector_req), } } diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index 578461a700..1da7e57f89 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -90,7 +90,10 @@ pub struct KvbmWorkerConfig { num_device_blocks: usize, #[builder(default = "32")] - page_size: usize, + engine_page_size: usize, + + #[builder(default = "32")] + offload_page_size: usize, #[builder(default = "Vec::new()")] tensors: Vec>, @@ -143,9 +146,10 @@ pub struct KvbmWorker { impl KvbmWorker { pub async fn new(config: KvbmWorkerConfig, layout_blocking: bool) -> anyhow::Result { tracing::info!( - "Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}", + "Initializing KvbmWorker with params: num_device_blocks={}, engine_page_size={}, offload_page_size={}, dtype_width_bytes={}", config.num_device_blocks, - config.page_size, + config.engine_page_size, + config.offload_page_size, config.dtype_width_bytes ); @@ -166,12 +170,13 @@ impl KvbmWorker { LayoutType::FullyContiguous => { let num_layers = shape[1]; let outer_dim = shape[2]; - let inner_dim = shape[3..].iter().product::() / config.page_size; + let inner_dim = shape[3..].iter().product::() / config.engine_page_size; tracing::info!( - "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", + "Inferred layout: num_layers={}, outer_dim={}, page_size={}/{}, inner_dim={}", num_layers, outer_dim, - config.page_size, + config.engine_page_size, + config.offload_page_size, inner_dim ); @@ -194,14 +199,15 @@ impl KvbmWorker { }; let num_layers = device_tensors.len(); - let inner_dim = shape[2..].iter().product::() / config.page_size; + let inner_dim = shape[2..].iter().product::() / config.engine_page_size; tracing::info!( - "Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}, inner_dim={}", + "Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}/{}, inner_dim={}", num_layers, outer_dim, outer_contiguous, - config.page_size, + config.engine_page_size, + config.offload_page_size, inner_dim ); @@ -209,14 +215,17 @@ impl KvbmWorker { } }; - let bytes_per_block = - num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes; + let offload_bytes_per_block = num_layers + * outer_dim + * config.offload_page_size + * inner_dim + * config.dtype_width_bytes; let mut layout_builder_instance = LayoutConfigBuilder::default(); let layout_builder = layout_builder_instance .num_layers(num_layers) .outer_dim(outer_dim) - .page_size(config.page_size) + .page_size(config.engine_page_size) .inner_dim(inner_dim) .dtype_width_bytes(config.dtype_width_bytes); @@ -225,12 +234,12 @@ impl KvbmWorker { .build()? .create_layout(layout_type, device_tensors)?; - let layout_builder = layout_builder.clone(); + let layout_builder = layout_builder.page_size(config.offload_page_size).clone(); let (task, handler_rx) = if layout_blocking { Self::run_blocking_layout_initialization( config, - bytes_per_block, + offload_bytes_per_block, device_layout, layout_builder, layout_type, @@ -239,7 +248,7 @@ impl KvbmWorker { } else { Self::run_non_blocking_layout_initialization( config, - bytes_per_block, + offload_bytes_per_block, device_layout, layout_builder, layout_type, @@ -579,6 +588,7 @@ impl KvbmWorker { max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, num_outer_components: device_layout.config().outer_dim, num_layers: device_layout.config().num_layers, + offload_block_size_ratio: config.offload_page_size / config.engine_page_size, }; let transfer_context = Arc::new(TransferContext::new( diff --git a/lib/llm/src/block_manager/offload.rs b/lib/llm/src/block_manager/offload.rs index 94e51751bb..445411f461 100644 --- a/lib/llm/src/block_manager/offload.rs +++ b/lib/llm/src/block_manager/offload.rs @@ -74,6 +74,7 @@ pub struct OffloadManagerConfig { pub metrics: Arc, pub cancellation_token: CancellationToken, pub model_config: KvManagerModelConfig, + pub offload_block_size_ratio: usize, } /// The offload manager handles all block transfers between different cache levels. @@ -131,6 +132,7 @@ impl max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, num_outer_components: config.model_config.outer_dim, num_layers: config.model_config.num_layers, + offload_block_size_ratio: config.offload_block_size_ratio, }; // We want cuda offloads to happen in parallel with host onboards, so we need to use a different stream. @@ -765,6 +767,7 @@ mod tests { metrics: BlockManagerMetrics::new(&Arc::new(Registry::new()))?, cancellation_token: CancellationToken::new(), model_config: minimal_config, + offload_block_size_ratio: 1, }; let manager = OffloadManager::new( diff --git a/lib/llm/src/block_manager/state.rs b/lib/llm/src/block_manager/state.rs index 5cf62bff4d..1f63c92dff 100644 --- a/lib/llm/src/block_manager/state.rs +++ b/lib/llm/src/block_manager/state.rs @@ -104,6 +104,7 @@ impl { pub async fn new(config: KvBlockManagerConfig, logical_resources: R) -> Result> { let model_config = config.model.clone(); + let offload_block_size_ratio = config.offload_block_size_ratio; let mut resources = Resources::new(config)?; let block_data_factories = logical::LogicalBlockFactories::new(&mut resources, logical_resources)?; @@ -152,6 +153,7 @@ impl metrics: resources.metrics.clone(), cancellation_token: resources.cancellation_token.clone(), model_config, + offload_block_size_ratio, }; let offload_manager = OffloadManager::new( @@ -213,6 +215,7 @@ impl impl KvBlockManagerState { pub async fn new(config: KvBlockManagerConfig) -> Result> { let model_config = config.model.clone(); + let offload_block_size_ratio = config.offload_block_size_ratio; let mut resources = Resources::new(config)?; let block_data_factories = local::LocalBlockDataFactories::new(&mut resources)?; @@ -267,6 +270,7 @@ impl KvBlockManagerState { metrics: resources.metrics.clone(), cancellation_token: resources.cancellation_token.clone(), model_config, + offload_block_size_ratio, }; let offload_manager = OffloadManager::new( diff --git a/lib/llm/src/block_manager/state/resources.rs b/lib/llm/src/block_manager/state/resources.rs index 1a17228b41..391d66a3a0 100644 --- a/lib/llm/src/block_manager/state/resources.rs +++ b/lib/llm/src/block_manager/state/resources.rs @@ -89,7 +89,7 @@ impl Resources { layout_builder .num_layers(model.num_layers) .outer_dim(model.outer_dim) - .page_size(model.page_size) + .page_size(model.page_size * self.config.offload_block_size_ratio) .inner_dim(model.inner_dim) .dtype_width_bytes(model.dtype_width_bytes);