diff --git a/Cargo.lock b/Cargo.lock index 42094da04f..b149e33ed5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2178,6 +2178,7 @@ dependencies = [ "serde", "serde_json", "serial_test", + "smallvec", "strum", "temp-env", "tempfile", @@ -7622,6 +7623,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "socket2" diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index bea47ac6e3..78d42ded5d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1510,6 +1510,7 @@ dependencies = [ "rustls", "serde", "serde_json", + "smallvec", "strum", "tempfile", "thiserror 2.0.16", @@ -1579,6 +1580,7 @@ dependencies = [ "rstest", "serde", "serde_json", + "smallvec", "socket2 0.6.0", "thiserror 2.0.16", "tokio", @@ -5686,6 +5688,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "socket2" diff --git a/lib/bindings/python/Cargo.toml b/lib/bindings/python/Cargo.toml index 0c85134531..04246c1ab7 100644 --- a/lib/bindings/python/Cargo.toml +++ b/lib/bindings/python/Cargo.toml @@ -44,6 +44,7 @@ rand = { version = "0.9" } socket2 = { version = "0.6" } serde = { version = "1" } serde_json = { version = "1.0.138" } +smallvec = { version = "1.15.1", features = ["serde"] } thiserror = { version = "2.0" } tokio = { version = "1.46.0", features = ["full"] } tokio-stream = { version = "0" } diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index f19a17cd70..32a3800e7e 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -161,8 +161,12 @@ impl BlockManager { }) } - fn block_size(&self) -> usize { - self.inner.block_size() + fn engine_block_size(&self) -> usize { + self.inner.engine_block_size() + } + + fn offload_block_size(&self) -> usize { + self.inner.offload_block_size() } fn init_controller(&mut self, component: Component) -> PyResult<()> { @@ -214,14 +218,16 @@ impl BlockManager { pub struct BlockManagerBuilder { worker_id: u64, leader: Option, - page_size: usize, + offload_page_size: usize, + engine_page_size: usize, disable_device_pool: bool, } impl BlockManagerBuilder { pub fn new() -> Self { Self { - page_size: 32, // default consistent with BlockManager::new + engine_page_size: 32, // default consistent with BlockManager::new + offload_page_size: 1024, // default consistent with BlockManager::new ..Default::default() } } @@ -230,8 +236,12 @@ impl BlockManagerBuilder { self.worker_id = id; self } - pub fn page_size(mut self, ps: usize) -> Self { - self.page_size = ps; + pub fn engine_page_size(mut self, ps: usize) -> Self { + self.engine_page_size = ps; + self + } + pub fn offload_page_size(mut self, ps: usize) -> Self { + self.offload_page_size = ps; self } pub fn leader(mut self, l: distributed::KvbmLeader) -> Self { @@ -267,7 +277,7 @@ impl BlockManagerBuilder { let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() .num_layers(1) .outer_dim(1) - .page_size(self.page_size) + .page_size(self.engine_page_size) .inner_dim(1) .build()?; diff --git a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs index 1cf58185bf..b3c3e2940b 100644 --- a/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/distributed/worker.rs @@ -176,7 +176,8 @@ impl KvbmWorker { let config = KvbmWorkerConfig::builder() .drt(drt) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size) + .engine_page_size(page_size) .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm.rs b/lib/bindings/python/rust/llm/block_manager/vllm.rs index 524af640ae..9fd5118a6d 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm.rs @@ -83,7 +83,10 @@ impl KvbmCacheManager { #[new] #[pyo3(signature = (block_manager))] pub fn new(block_manager: PyBlockManager) -> PyResult { - let slot_manager = Mutex::new(SlotManager::new(block_manager.block_size())); + let slot_manager = Mutex::new(SlotManager::new( + block_manager.engine_block_size(), + block_manager.offload_block_size(), + )); Ok(Self { block_manager, slot_manager, @@ -286,7 +289,7 @@ pub struct GenericSlotUpdate { pub num_new_tokens: usize, /// The number of new computed tokens in the request. - /// The `num_new_tokens / block_size` should be equal to the length of the `new_computed_blocks`, + /// The `num_new_tokens / engine_block_size` should be equal to the length of the `new_computed_blocks`, /// it may have a remainder for the partial block state. /// Note: this field is solely tied to the `new_computed_blocks` field and not used when `tokens_to_append` is provided. /// The name might be confusing, but the name matched the vLLM implementation. @@ -401,15 +404,17 @@ impl SlotError { pub struct SlotManager { slots: HashMap>>, - block_size: usize, + engine_block_size: usize, + offload_block_size: usize, } impl SlotManager { /// Creates a new slot manager. - pub fn new(block_size: usize) -> Self { + pub fn new(engine_block_size: usize, offload_block_size: usize) -> Self { Self { slots: HashMap::new(), - block_size, + engine_block_size, + offload_block_size, } } @@ -436,7 +441,7 @@ impl SlotManager { if !self.slots.contains_key(request_id) { self.slots.insert( request_id.clone(), - Slot::new(tokens.into(), self.block_size, salt_hash), + Slot::new(tokens.into(), self.engine_block_size, salt_hash), ); tracing::debug!( request_id, @@ -498,7 +503,7 @@ impl SlotManager { tracing::debug!( request_id, "applying {} cache-hit tokens", - blocks.len() * self.block_size + blocks.len() * self.engine_block_size ); slot.initialize_with_device_matches(blocks)?; } @@ -566,9 +571,9 @@ impl SlotManager { match self.slots.remove(request_id) { Some(slot) => { let isl = slot.num_tokens(SlotPosition::Prefill); - let isl_device = slot.num_blocks_cached_from_device() * self.block_size; - let isl_host = slot.num_blocks_cached_from_host() * self.block_size; - let isl_disk = slot.num_blocks_cached_from_disk() * self.block_size; + let isl_device = slot.num_blocks_cached_from_device() * self.engine_block_size; + let isl_host = slot.num_blocks_cached_from_host() * self.offload_block_size; + let isl_disk = slot.num_blocks_cached_from_disk() * self.offload_block_size; tracing::info!( request_id, "request complete isl: {} - cache hits: device: {}, host: {}, disk: {} - prefilled: {}", @@ -603,14 +608,14 @@ impl SlotManager { assert!(num_computed_tokens <= request_num_tokens); // early exit if we cannot match full block - if (request_num_tokens - num_computed_tokens) < self.block_size { + if (request_num_tokens - num_computed_tokens) < self.engine_block_size { return Ok((0, false)); } // num_computed_tokens represents the number of tokens already on the device // this much be a multiple of the block size - let num_device_blocks = num_computed_tokens / self.block_size; - debug_assert_eq!(num_computed_tokens % self.block_size, 0); + let num_device_blocks = num_computed_tokens / self.engine_block_size; + debug_assert_eq!(num_computed_tokens % self.engine_block_size, 0); // get the sequence hashes for the device matched tokens let sequence_hashes = slot.sequence_hashes(SlotPosition::All); @@ -661,7 +666,7 @@ impl SlotManager { return Ok((0, false)); } - let mut num_new_matched_tokens = num_matched_blocks * self.block_size; + let mut num_new_matched_tokens = num_matched_blocks * self.engine_block_size; // we are on a block boundary, so we need to throw away the last block if num_computed_tokens + num_new_matched_tokens == request_num_tokens { @@ -681,7 +686,7 @@ impl SlotManager { } // decrement the number of new matched tokens by the block size - num_new_matched_tokens -= self.block_size; + num_new_matched_tokens -= self.engine_block_size; } slot.store_onboard_blocks(host_blocks, disk_blocks); diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs index 9d7f10c840..3318044c34 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs @@ -79,7 +79,8 @@ pub trait Leader: Send + Sync + std::fmt::Debug { #[derive(Debug)] pub struct KvConnectorLeader { slot_manager: Arc>>, - block_size: usize, + engine_page_size: usize, + offload_page_size: usize, inflight_requests: HashSet, onboarding_slots: HashSet, iteration_counter: u64, @@ -90,7 +91,8 @@ impl KvConnectorLeader { fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -127,7 +129,8 @@ impl KvConnectorLeader { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -169,7 +172,8 @@ impl KvConnectorLeader { Self { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, @@ -204,7 +208,7 @@ impl Leader for KvConnectorLeader { ); // the number of device matched tokens should be less than or equal to the number of tokens in the request - debug_assert!(num_computed_tokens % self.block_size == 0); + debug_assert!(num_computed_tokens % self.engine_page_size == 0); let shared_slot = self.slot_manager().get_slot(&request_id)?; let mut slot = shared_slot @@ -234,7 +238,7 @@ impl Leader for KvConnectorLeader { } // early exit if we cannot match full block - if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { + if (slot.sequence().total_tokens() - num_computed_tokens) < self.offload_page_size { return Ok((0, false)); } @@ -245,7 +249,9 @@ impl Leader for KvConnectorLeader { // return the number of external tokens that are ready for onboarding // we always return true here as we always asynchronously onboard matched blocks if let SlotState::OnboardStaged(num_external_tokens) = slot.state() { - debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0); + debug_assert!( + (num_computed_tokens + num_external_tokens) % self.offload_page_size == 0 + ); tracing::debug!( request_id = request_id, "scheduling onboarding for {} external tokens", @@ -289,7 +295,7 @@ impl Leader for KvConnectorLeader { // the second call will show num_external_tokens == 0 // this call is just letting us know the other blocks that are being used for the remainder of the prefill if num_external_tokens > 0 { - let num_computed_tokens = block_ids.len() * self.block_size - num_external_tokens; + let num_computed_tokens = block_ids.len() * self.engine_page_size - num_external_tokens; slot.record_cached_device_tokens(num_computed_tokens); slot.advance_computed_position(num_computed_tokens)?; @@ -549,11 +555,12 @@ pub struct PyKvConnectorLeader { #[pymethods] impl PyKvConnectorLeader { #[new] - #[pyo3(signature = (worker_id, drt, page_size, leader))] + #[pyo3(signature = (worker_id, drt, engine_page_size, offload_page_size, leader))] pub fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader: PyKvbmLeader, ) -> Self { let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD") @@ -562,10 +569,20 @@ impl PyKvConnectorLeader { let connector_leader: Box = if enable_kvbm_record { Box::new(recorder::KvConnectorLeaderRecorder::new( - worker_id, drt, page_size, leader, + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, )) } else { - Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)) + Box::new(KvConnectorLeader::new( + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, + )) }; Self { connector_leader } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs index 9c267a2f95..c19b332cac 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/recorder.rs @@ -88,7 +88,8 @@ impl KvConnectorLeaderRecorder { pub fn new( worker_id: String, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -143,7 +144,8 @@ impl KvConnectorLeaderRecorder { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -185,7 +187,8 @@ impl KvConnectorLeaderRecorder { let connector_leader = KvConnectorLeader { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 2397112267..d35b4b4a0a 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -13,6 +13,7 @@ use dynamo_llm::{ tokens::TokenBlock, }; use dynamo_runtime::utils::task::CriticalTaskExecutionHandle; +use smallvec::SmallVec; use tokio_util::sync::CancellationToken; use super::*; @@ -194,8 +195,9 @@ impl ConnectorSlotManager { kvbm_metrics: KvbmMetrics, ) -> Self { tracing::debug!( - "creating slot manager with block size: {}", - block_manager.block_size() + "creating slot manager with engine block size: {}, offload block size: {}", + block_manager.engine_block_size(), + block_manager.offload_block_size(), ); let (xfer_tx, xfer_rx) = mpsc::unbounded_channel(); @@ -317,7 +319,9 @@ pub struct VllmConnectorSlot { /// Phantom data to ensure the storage type is correct. block_manager: VllmBlockManager, - block_size: usize, + offload_block_size: usize, + engine_block_size: usize, + offload_block_size_ratio: usize, iteration_first_scheduled: Option, @@ -345,15 +349,20 @@ impl VllmConnectorSlot { xfer_tx: mpsc::UnboundedSender, ) -> Self { assert!(!tokens.is_empty(), "tokens must be non-empty"); - let block_size = block_manager.block_size(); - debug_assert!(block_size.is_power_of_two() && block_size <= 1024); - let sequence = TokenBlockSequence::new(tokens, block_size as u32, Some(salt_hash)); + let offload_block_size = block_manager.offload_block_size(); + let engine_block_size = block_manager.engine_block_size(); + let offload_block_size_ratio = block_manager.offload_block_size_ratio(); + debug_assert!(offload_block_size.is_power_of_two() && offload_block_size <= 1024); + debug_assert!(engine_block_size.is_power_of_two() && engine_block_size <= 1024); + let sequence = TokenBlockSequence::new(tokens, offload_block_size as u32, Some(salt_hash)); Self { request_id, sequence, block_manager, - block_size, + engine_block_size, + offload_block_size, + offload_block_size_ratio, xfer_tx, // default values state: SlotState::Initialized, @@ -501,12 +510,13 @@ impl Slot for VllmConnectorSlot { // we should have enough device blocks to cover the newly scheduled tokens let next_position = self.current_position + num_scheduled_tokens; + assert!( - next_position <= self.device_blocks.len() * self.block_size, - "next_position: {} > device_blocks.len() {} * block_size {}", + next_position <= self.device_blocks.len() * self.engine_block_size, + "next_position: {} > device_blocks.len() {} * engine_block_size {}", next_position, self.device_blocks.len(), - self.block_size + self.engine_block_size ); if next_position > self.sequence.total_tokens() { @@ -529,9 +539,9 @@ impl Slot for VllmConnectorSlot { // TODO(ryan) - apply policy let next_position = self.current_position + num_scheduled_tokens; - debug_assert!(next_position / self.block_size >= self.evaluated_blocks); + debug_assert!(next_position / self.engine_block_size >= self.evaluated_blocks); - let num_candidate_blocks = (next_position / self.block_size) - self.evaluated_blocks; + let num_candidate_blocks = (next_position / self.engine_block_size) - self.evaluated_blocks; tracing::debug!( "evaluating policy with the following parameters: state: {:?}; current_position: {}; num_candidate_blocks: {}; num_scheduled_tokens: {}", @@ -541,20 +551,22 @@ impl Slot for VllmConnectorSlot { num_scheduled_tokens ); - if num_candidate_blocks != 0 { + if num_candidate_blocks / self.offload_block_size_ratio != 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. // for now, offload all the blocks to the host + let aligned_candidates = (num_candidate_blocks / self.offload_block_size_ratio) + * self.offload_block_size_ratio; let offload_block_ids: Vec = self .device_blocks .iter() .skip(self.evaluated_blocks) - .take(num_candidate_blocks) + .take(aligned_candidates) .copied() .collect::>(); assert_eq!( offload_block_ids.len(), - num_candidate_blocks, + aligned_candidates, "device block overflow - candidate blocks exceed block count at offset {}", self.evaluated_blocks ); @@ -563,15 +575,15 @@ impl Slot for VllmConnectorSlot { .sequence .blocks() .iter() - .skip(self.evaluated_blocks) - .take(num_candidate_blocks) + .skip(self.evaluated_blocks / self.offload_block_size_ratio) + .take(aligned_candidates / self.offload_block_size_ratio) .cloned() .collect::>(); self.offload_blocks(&offload_block_ids, &offload_token_blocks) .expect("failed to offload blocks"); - self.evaluated_blocks += num_candidate_blocks; + self.evaluated_blocks += aligned_candidates; } // done applying policy @@ -640,11 +652,11 @@ impl Slot for VllmConnectorSlot { is_new_request && computed_position > 0 && self.evaluated_blocks == 0; if maybe_have_device_matched_blocks { - self.evaluated_blocks = (computed_position + 1) / self.block_size; + self.evaluated_blocks = (computed_position + 1) / self.offload_block_size; } - let num_candidate_blocks = - ((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks); + let num_candidate_blocks = ((computed_position + 1) / self.offload_block_size) + .saturating_sub(self.evaluated_blocks); if num_candidate_blocks > 0 { // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. @@ -745,7 +757,7 @@ impl Slot for VllmConnectorSlot { tracing::info!("slot is in the Preempted state; we get another chance to match"); } - let block_size = self.block_manager.block_size(); + let block_size = self.block_manager.offload_block_size(); let num_computed_blocks = num_computed_tokens / block_size; debug_assert!(num_computed_tokens % block_size == 0); @@ -863,28 +875,40 @@ impl Slot for VllmConnectorSlot { } debug_assert_eq!(self.evaluated_blocks, 0); - debug_assert_eq!(self.current_position % self.block_size, 0); - debug_assert_eq!(num_external_tokens % self.block_size, 0); + debug_assert_eq!(self.current_position % self.engine_block_size, 0); + debug_assert_eq!(num_external_tokens % self.engine_block_size, 0); - let num_computed_blocks = self.current_position / self.block_size; + let num_computed_blocks = self.current_position / self.offload_block_size; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks = num_computed_blocks; + self.evaluated_blocks = num_computed_blocks * self.offload_block_size_ratio; + + tracing::debug!( + "trigger_onboarding: self.device_blocks.len()={:?}", + self.device_blocks.len() + ); // match the host / disk blocks to the newly assigned mutable device blocks if let Some(host_blocks) = self.staging_from_host.take() { let num_host_blocks = host_blocks.len(); + tracing::debug!( + "trigger_onboarding: host_blocks.len()={:?}", + num_host_blocks + ); // get device block ids let dst_block_ids = self .device_blocks .iter() .skip(self.evaluated_blocks) - .take(num_host_blocks) + .take(num_host_blocks * self.offload_block_size_ratio) .copied() .collect::>(); - debug_assert_eq!(dst_block_ids.len(), num_host_blocks); + debug_assert_eq!( + dst_block_ids.len(), + num_host_blocks * self.offload_block_size_ratio + ); // construct offload requests - transfer engine + worker let src_blocks = Box::new(AnyImmutableBlocks::::new(host_blocks)); @@ -892,22 +916,30 @@ impl Slot for VllmConnectorSlot { self.onboard_blocks(src_blocks, dst_block_ids)?; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks += num_host_blocks; + self.evaluated_blocks += num_host_blocks * self.offload_block_size_ratio; } if let Some(disk_blocks) = self.staging_from_disk.take() { let num_disk_blocks = disk_blocks.len(); + tracing::debug!( + "trigger_onboarding: dist_blocks.len()={:?}", + num_disk_blocks + ); + // get device block ids let dst_block_ids = self .device_blocks .iter() - .skip(self.evaluated_blocks) + .skip(self.evaluated_blocks * self.offload_block_size_ratio) .take(num_disk_blocks) .copied() .collect::>(); - debug_assert_eq!(dst_block_ids.len(), num_disk_blocks); + debug_assert_eq!( + dst_block_ids.len(), + num_disk_blocks * self.offload_block_size_ratio + ); // construct offload requests - transfer engine + worker let src_blocks = Box::new(AnyImmutableBlocks::::new(disk_blocks)); @@ -915,7 +947,7 @@ impl Slot for VllmConnectorSlot { self.onboard_blocks(src_blocks, dst_block_ids)?; // shift the evaluated blocks position to the end of the computed/cached blocks - self.evaluated_blocks += num_disk_blocks; + self.evaluated_blocks += num_disk_blocks * self.offload_block_size_ratio; } self.state = SlotState::Onboarding(num_external_tokens); @@ -977,7 +1009,7 @@ impl VllmConnectorSlot { block_ids: &[BlockId], token_blocks: &[TokenBlock], ) -> Result<(), SlotError> { - assert!(block_ids.len() == token_blocks.len()); + assert!(block_ids.len() == token_blocks.len() * self.offload_block_size_ratio); let operation_id = uuid::Uuid::new_v4(); let xfer_req = LocalTransferRequest::Offload(LocalOffloadRequest::new( @@ -1007,7 +1039,7 @@ impl VllmConnectorSlot { tracing::debug!( request_id = self.request_id, operation_id = %operation_id, - "offloading {} blocks to host", + "offloading {} device blocks to host", block_ids.len() ); @@ -1166,6 +1198,7 @@ impl LocalTransferEngine { // Clone resources needed for tasks let block_manager_offload = self.block_manager.clone(); + let block_manager_onboard = self.block_manager.clone(); let leader_offload = Arc::clone(&self.leader); let leader_onboard = Arc::clone(&self.leader); @@ -1179,9 +1212,13 @@ impl LocalTransferEngine { tracing::debug!("LocalOnboardTask: received cancellation signal"); break; } - if let Err(e) = - process_onboard_request(req, &leader_onboard, kvbm_metrics_onboard.clone()) - .await + if let Err(e) = process_onboard_request( + req, + &block_manager_onboard, + &leader_onboard, + kvbm_metrics_onboard.clone(), + ) + .await { tracing::error!("LocalOnboardTask: error processing request: {:?}", e); } @@ -1285,24 +1322,30 @@ async fn process_offload_request( let request_id = &offload_req.request_id; let operation_id = &offload_req.operation_id; + let engine_blocks_per_offload_block = + block_manager.offload_block_size() / block_manager.engine_block_size(); tracing::debug!( - "Processing offload request for {} blocks", - offload_req.block_ids.len() + "Processing offload request for {} blocks, engine_blocks_per_offload_block = {} ({}/{})", + offload_req.block_ids.len(), + engine_blocks_per_offload_block, + block_manager.offload_block_size(), + block_manager.engine_block_size(), ); // 1. Acquire mutable host blocks let host_blocks = block_manager .host() .unwrap() - .allocate_blocks(offload_req.block_ids.len()) + .allocate_blocks(offload_req.block_ids.len() / engine_blocks_per_offload_block) .await?; let token_blocks = offload_req.token_blocks; let host_block_ids: Vec = host_blocks.iter().map(|b| b.block_id()).collect(); - let block_pairs: Vec<(usize, usize)> = offload_req + let block_pairs: Vec<_> = offload_req .block_ids - .into_iter() + .chunks(engine_blocks_per_offload_block) .zip(host_block_ids.into_iter()) + .map(|(src, dst)| (SmallVec::from(src), SmallVec::from([dst]))) .collect(); tracing::debug!( @@ -1386,6 +1429,7 @@ async fn process_offload_request( async fn process_onboard_request( onboard_req: LocalOnboardRequest, + block_manager: &VllmBlockManager, leader: &Arc, kvbm_metrics: KvbmMetrics, ) -> anyhow::Result<()> { @@ -1402,16 +1446,22 @@ async fn process_onboard_request( let request_id = &onboard_req.request_id; let operation_id = &onboard_req.operation_id; + let engine_blocks_per_offload_block = + block_manager.offload_block_size() / block_manager.engine_block_size(); // extract source block ids let src_block_ids = onboard_req.src_blocks.block_ids(); // create block pairs - let block_pairs = src_block_ids + let block_pairs: Vec<_> = src_block_ids .iter() - .zip(onboard_req.dst_block_ids.iter()) - .map(|(src, dst)| (*src, *dst)) - .collect::>(); + .zip( + onboard_req + .dst_block_ids + .chunks(engine_blocks_per_offload_block), + ) + .map(|(src, dst)| (SmallVec::from([*src]), SmallVec::from(dst))) + .collect(); // create transfer request let block_xfer_req = BlockTransferRequest { diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs index 7f3b51d6dd..3f9803aecc 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_leader.rs @@ -54,7 +54,9 @@ pub trait Leader: Send + Sync + std::fmt::Debug { #[derive(Debug)] pub struct KvConnectorLeader { slot_manager: Arc>>, - block_size: usize, + engine_page_size: usize, + #[allow(unused)] + offload_page_size: usize, inflight_requests: HashSet, onboarding_slots: HashSet, iteration_counter: u64, @@ -66,7 +68,8 @@ impl KvConnectorLeader { fn new( worker_id: u64, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader_py: PyKvbmLeader, ) -> Self { tracing::info!( @@ -103,7 +106,8 @@ impl KvConnectorLeader { let block_manager = match BlockManagerBuilder::new() .worker_id(0) .leader(leader_py) - .page_size(page_size) + .engine_page_size(engine_page_size) + .offload_page_size(offload_page_size) .disable_device_pool(false) .build() .await @@ -134,7 +138,8 @@ impl KvConnectorLeader { Self { slot_manager: slot_manager_cell, - block_size: page_size, + engine_page_size, + offload_page_size, inflight_requests: HashSet::new(), onboarding_slots: HashSet::new(), iteration_counter: 0, @@ -171,7 +176,7 @@ impl Leader for KvConnectorLeader { // TRTLLM could match partial blocks if enable_partial_reuse = True, // immediately return 0 to simplify things. - if num_computed_tokens % self.block_size != 0 { + if num_computed_tokens % self.engine_page_size != 0 { return Ok((0, false)); } @@ -181,7 +186,7 @@ impl Leader for KvConnectorLeader { .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; // early exit if we cannot match full block - if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { + if (slot.sequence().total_tokens() - num_computed_tokens) < self.engine_page_size { let total_tokens = slot.sequence().total_tokens(); tracing::debug!( "total_tokens in sequence: {total_tokens}; num_computed_tokens: {num_computed_tokens}; can not match full block." @@ -196,7 +201,7 @@ impl Leader for KvConnectorLeader { // return the number of external tokens that are ready for onboarding // we always return true here as we always asynchronously onboard matched blocks if let SlotState::OnboardStaged(num_external_tokens) = slot.state() { - debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0); + debug_assert!((num_computed_tokens + num_external_tokens) % self.engine_page_size == 0); tracing::debug!( request_id = request_id, "scheduling onboarding for {} external tokens", @@ -447,15 +452,21 @@ pub struct PyTrtllmKvConnectorLeader { #[pymethods] impl PyTrtllmKvConnectorLeader { #[new] - #[pyo3(signature = (worker_id, drt, page_size, leader))] + #[pyo3(signature = (worker_id, drt, engine_page_size, offload_page_size, leader))] pub fn new( worker_id: u64, drt: PyDistributedRuntime, - page_size: usize, + engine_page_size: usize, + offload_page_size: usize, leader: PyKvbmLeader, ) -> Self { - let connector_leader: Box = - Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader)); + let connector_leader: Box = Box::new(KvConnectorLeader::new( + worker_id, + drt, + engine_page_size, + offload_page_size, + leader, + )); Self { connector_leader } } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs index 5b017db93a..0a7ad81033 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs @@ -131,7 +131,8 @@ impl Worker for KvConnectorWorker { let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size) + .engine_page_size(page_size) .tensors(kv_cache_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index 3c498a3d8e..60f5046329 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -73,6 +73,9 @@ pub struct KvConnectorWorker { /// cuda events created by the python side layer_events: Vec, + + /// Ratio between offload block size and engine block size + offload_block_size_ratio: usize, } impl KvConnectorWorker { @@ -111,6 +114,7 @@ impl KvConnectorWorker { layers_complete: 0, kv_cache_layers: Vec::new(), layer_events: Vec::new(), + offload_block_size_ratio: 32, // Default value }) } } @@ -196,7 +200,8 @@ impl Worker for KvConnectorWorker { let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) .num_device_blocks(num_device_blocks) - .page_size(page_size) + .offload_page_size(page_size * self.offload_block_size_ratio) + .engine_page_size(page_size) .tensors(vllm_tensors) .device_id(device_id) .dtype_width_bytes(dtype_width_bytes) diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py index 8c69909c48..19546a3856 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_leader.py @@ -64,7 +64,7 @@ def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs): print(f"KvConnectorLeader initialized with engine_id: {engine_id}") self._connector = RustKvConnectorLeader( - engine_id, self.drt, vllm_config.cache_config.block_size, leader + engine_id, self.drt, vllm_config.cache_config.block_size, 1024, leader ) # KV Connector diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index d81cd42bf8..86e8a74a15 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -141,6 +141,7 @@ json-five = { version = "0.3" } zeromq = "0.4.1" rmp-serde = "1.3" ahash = "0.8.12" +smallvec = { version = "1.15.1", features = ["serde"] } [dev-dependencies] approx = "0.5" diff --git a/lib/llm/src/block_manager.rs b/lib/llm/src/block_manager.rs index 1878d68e30..d3c2f165cc 100644 --- a/lib/llm/src/block_manager.rs +++ b/lib/llm/src/block_manager.rs @@ -99,7 +99,9 @@ impl Drop for CancelOnLastDrop { pub struct KvBlockManager { state: Arc>, _cancellation_token: Arc, - block_size: usize, + offload_block_size: usize, + engine_block_size: usize, + offload_block_size_ratio: usize, } impl Clone @@ -109,15 +111,27 @@ impl Clone Self { state: self.state.clone(), _cancellation_token: self._cancellation_token.clone(), - block_size: self.block_size, + offload_block_size: self.offload_block_size, + engine_block_size: self.engine_block_size, + offload_block_size_ratio: self.offload_block_size_ratio, } } } impl KvBlockManager { /// Get the block size - pub fn block_size(&self) -> usize { - self.block_size + pub fn engine_block_size(&self) -> usize { + self.engine_block_size + } + + /// Get the block size + pub fn offload_block_size(&self) -> usize { + self.offload_block_size + } + + /// Get the offload block size ratio + pub fn offload_block_size_ratio(&self) -> usize { + self.offload_block_size_ratio } /// Get a reference to the disk block pool @@ -171,6 +185,7 @@ impl KvBlockManager { let _cancellation_token = build_cancel_token(&mut config); let block_size = config.model.page_size; + let offload_block_size_ratio = config.offload_block_size_ratio; // Create the internal state let state = state::KvBlockManagerState::::new(config).await?; @@ -178,7 +193,9 @@ impl KvBlockManager { Ok(Self { state, _cancellation_token, - block_size, + engine_block_size: block_size, + offload_block_size: block_size * offload_block_size_ratio, + offload_block_size_ratio, }) } @@ -215,6 +232,7 @@ impl KvBlockManager { impl KvBlockManager, Metadata> { pub async fn new(mut config: KvBlockManagerConfig, logical_resources: R) -> Result { let block_size = config.model.page_size; + let offload_block_size_ratio = config.offload_block_size_ratio; let _cancellation_token = build_cancel_token(&mut config); @@ -227,7 +245,9 @@ impl KvBlockManager: Send + Sync + 'static + std::fmt::Debug { /// The index of the block in the block set fn block_id(&self) -> BlockId; + /// The index of the block in the block set + fn fragment_block_id(&self, idx: usize) -> BlockId; + /// The identifier of the block set within the worker fn block_set_id(&self) -> usize; @@ -28,6 +31,12 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { /// Whether the block is fully contiguous fn is_fully_contiguous(&self) -> bool; + /// Is the block fragmented + fn is_fragmented(&self) -> bool; + + /// Returns the number of fragments + fn num_fragments(&self) -> usize; + /// Returns the number of layers in the block fn num_layers(&self) -> usize; @@ -57,6 +66,19 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { } } + /// Get a read-only view of this block's storage for a layer + fn layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + match self.is_local() { + Some(views) => views.local_layer_view_fragment(fragment_idx, layer_idx, outer_idx), + None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks), + } + } + /// Get a mutable view of this block's storage for a layer fn layer_view_mut( &mut self, @@ -69,6 +91,19 @@ pub trait BlockDataExt: Send + Sync + 'static + std::fmt::Debug { } } + /// Get a mutable view of this block's storage for a layer + fn layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + match self.is_local_mut() { + Some(views) => views.local_layer_view_fragment_mut(fragment_idx, layer_idx, outer_idx), + None => Err(BlockError::ViewsNotAvailableOnLogicalBlocks), + } + } + /// Get a read-only view of this block's storage fn block_view(&self) -> BlockResult> { match self.is_local() { @@ -94,6 +129,14 @@ pub trait BlockDataViews { outer_idx: usize, ) -> BlockResult>; + /// Get a read-only view of this block's fragment storage for a layer + fn local_layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult>; + /// Get a mutable view of this block's storage for a layer fn local_layer_view_mut( &mut self, @@ -101,6 +144,14 @@ pub trait BlockDataViews { outer_idx: usize, ) -> BlockResult>; + /// Get a mutable view of this block's fragment storage for a layer + fn local_layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult>; + /// Get a read-only view of this block's storage fn local_block_view(&self) -> BlockResult>; diff --git a/lib/llm/src/block_manager/block/data/local.rs b/lib/llm/src/block_manager/block/data/local.rs index f1679f5eac..f88ec57d99 100644 --- a/lib/llm/src/block_manager/block/data/local.rs +++ b/lib/llm/src/block_manager/block/data/local.rs @@ -2,12 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; +use smallvec::SmallVec; /// Individual block storage #[derive(Debug)] pub struct LocalBlockData { layout: Arc>, - block_idx: usize, + // The special case of multiple block IDx will only happen for small KV cache blocks on device. + // offloaded blocks will be large and `block_idxs.len() == 1` will hold for them. + block_idxs: SmallVec<[usize; 1]>, block_set_idx: usize, worker_id: WorkerID, } @@ -16,7 +19,34 @@ impl Clone for LocalBlockData { fn clone(&self) -> Self { Self { layout: self.layout.clone(), - block_idx: self.block_idx, + block_idxs: self.block_idxs.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id, + } + } +} + +pub struct LocalBlockDataBase { + layout: Arc>, + block_set_idx: usize, + worker_id: WorkerID, +} + +impl Clone for LocalBlockDataBase { + fn clone(&self) -> Self { + Self { + layout: self.layout.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id, + } + } +} + +impl LocalBlockDataBase { + pub(crate) fn get_data(&self, block_idxs: SmallVec<[usize; 1]>) -> LocalBlockData { + LocalBlockData { + layout: self.layout.clone(), + block_idxs, block_set_idx: self.block_set_idx, worker_id: self.worker_id, } @@ -36,11 +66,19 @@ where ) -> Self { Self { layout, - block_idx, + block_idxs: [block_idx].iter().map(|x| *x).collect(), block_set_idx, worker_id, } } + + pub(crate) fn base(&self) -> LocalBlockDataBase { + LocalBlockDataBase { + layout: self.layout.clone(), + block_set_idx: self.block_set_idx, + worker_id: self.worker_id.clone(), + } + } } impl BlockDataExt for LocalBlockData @@ -49,7 +87,22 @@ where { #[inline(always)] fn block_id(&self) -> BlockId { - self.block_idx + if self.block_idxs.len() == 1 { + self.block_idxs[0] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::block_id() for fragmented block"); + } + } + + #[inline(always)] + fn fragment_block_id(&self, idx: usize) -> BlockId { + if self.block_idxs.len() != 1 { + self.block_idxs[idx] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::fragment_block_id() for non--fragmented block"); + } } #[inline(always)] @@ -71,6 +124,14 @@ where self.layout.layout_type() == LayoutType::FullyContiguous } + fn is_fragmented(&self) -> bool { + self.block_idxs.len() != 1 + } + + fn num_fragments(&self) -> usize { + self.block_idxs.len() + } + fn num_layers(&self) -> usize { self.layout.num_layers() } @@ -104,7 +165,22 @@ impl BlockDataViews for LocalBlockData { ) -> BlockResult> { let mr = self .layout - .memory_region(self.block_idx, layer_idx, outer_idx)?; + .memory_region(self.block_id(), layer_idx, outer_idx)?; + let storage_type = mr.storage_type(); + unsafe { view::LayerView::new(self, mr.addr(), mr.size(), storage_type) } + } + + fn local_layer_view_fragment( + &self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + let mr = self.layout.memory_region( + self.fragment_block_id(fragment_idx), + layer_idx, + outer_idx, + )?; let storage_type = mr.storage_type(); unsafe { view::LayerView::new(self, mr.addr(), mr.size(), storage_type) } } @@ -116,13 +192,27 @@ impl BlockDataViews for LocalBlockData { ) -> BlockResult> { let mr = self .layout - .memory_region(self.block_idx, layer_idx, outer_idx)?; + .memory_region(self.block_id(), layer_idx, outer_idx)?; + unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) } + } + + fn local_layer_view_fragment_mut( + &mut self, + fragment_idx: usize, + layer_idx: usize, + outer_idx: usize, + ) -> BlockResult> { + let mr = self.layout.memory_region( + self.fragment_block_id(fragment_idx), + layer_idx, + outer_idx, + )?; unsafe { view::LayerViewMut::new(self, mr.addr(), mr.size(), mr.storage_type()) } } fn local_block_view(&self) -> BlockResult> { if self.is_fully_contiguous() { - let mr = self.layout.memory_region(self.block_idx, 0, 0)?; + let mr = self.layout.memory_region(self.block_id(), 0, 0)?; let offset = mr.addr(); let size = mr.size() .checked_mul(self.num_layers()) @@ -144,7 +234,7 @@ impl BlockDataViews for LocalBlockData { fn local_block_view_mut(&mut self) -> BlockResult> { if self.is_fully_contiguous() { - let mr = self.layout.memory_region(self.block_idx, 0, 0)?; + let mr = self.layout.memory_region(self.block_id(), 0, 0)?; let offset = mr.addr(); let size = mr.size() .checked_mul(self.num_layers()) diff --git a/lib/llm/src/block_manager/block/data/logical.rs b/lib/llm/src/block_manager/block/data/logical.rs index 3aee97bc3e..81d9357953 100644 --- a/lib/llm/src/block_manager/block/data/logical.rs +++ b/lib/llm/src/block_manager/block/data/logical.rs @@ -5,13 +5,13 @@ use super::*; pub mod distributed_leader_worker; pub mod null; - use crate::block_manager::block::{ BlockDataProvider, ReadableBlock, WritableBlock, transfer::{TransferContext, TransferError, WriteToStrategy}, }; use crate::block_manager::locality::Logical; use crate::block_manager::storage::{self, nixl::NixlDescriptor}; +use smallvec::SmallVec; use tokio::sync::oneshot; pub enum LogicalKinds { @@ -37,7 +37,7 @@ pub trait LogicalResources: Clone + Send + Sync + 'static + std::fmt::Debug { /// Individual block storage - cannot be cloned to ensure uniqueness #[derive(Debug)] pub struct LogicalBlockData { - block_id: BlockId, + block_ids: SmallVec<[BlockId; 1]>, block_set_id: usize, worker_id: WorkerID, resources: Arc, @@ -56,7 +56,7 @@ impl LogicalBlockData { page_size: usize, ) -> Self { Self { - block_id, + block_ids: [block_id].iter().map(|x| *x).collect(), block_set_id, worker_id, resources, @@ -73,7 +73,21 @@ impl LogicalBlockData { impl BlockDataExt for LogicalBlockData { fn block_id(&self) -> BlockId { - self.block_id + if self.block_ids.len() == 1 { + self.block_ids[0] + } else { + panic!("used LocalBlockData::block_id() for fragmented block"); + } + } + + #[inline(always)] + fn fragment_block_id(&self, idx: usize) -> BlockId { + if self.block_ids.len() != 1 { + self.block_ids[idx] + } else { + tracing::error!("Backtrace: {}", std::backtrace::Backtrace::force_capture()); + panic!("used LocalBlockData::fragment_block_id() for non-fragmented block"); + } } fn block_set_id(&self) -> usize { @@ -92,6 +106,14 @@ impl BlockDataExt for LogicalBlockData unimplemented!() } + fn is_fragmented(&self) -> bool { + unimplemented!() + } + + fn num_fragments(&self) -> usize { + unimplemented!() + } + fn num_layers(&self) -> usize { unimplemented!() } diff --git a/lib/llm/src/block_manager/block/transfer/context.rs b/lib/llm/src/block_manager/block/transfer/context.rs index 36ad83a4c0..7327cf9615 100644 --- a/lib/llm/src/block_manager/block/transfer/context.rs +++ b/lib/llm/src/block_manager/block/transfer/context.rs @@ -161,6 +161,7 @@ pub struct PoolConfig { pub max_concurrent_transfers: usize, pub max_transfer_batch_size: usize, pub num_outer_components: usize, + pub offload_block_size_ratio: usize, pub num_layers: usize, } @@ -200,6 +201,7 @@ impl TransferContext { let buffer_size = max_blocks_per_transfer * config.num_outer_components * config.num_layers + * config.offload_block_size_ratio * std::mem::size_of::(); tracing::info!( diff --git a/lib/llm/src/block_manager/block/transfer/cuda.rs b/lib/llm/src/block_manager/block/transfer/cuda.rs index fdc345c2ff..c8bff9980a 100644 --- a/lib/llm/src/block_manager/block/transfer/cuda.rs +++ b/lib/llm/src/block_manager/block/transfer/cuda.rs @@ -75,6 +75,7 @@ fn collect_kv_addresses( destinations: &[Destination], num_layers: usize, num_outer_dims: usize, + num_fragments: usize, ) -> Result<(Vec, Vec), TransferError> where Source: BlockDataProvider, @@ -86,7 +87,7 @@ where )); } - let total_address_pairs = sources.len() * num_layers * num_outer_dims; + let total_address_pairs = sources.len() * num_layers * num_outer_dims * num_fragments; let mut src_addresses = Vec::with_capacity(total_address_pairs); let mut dst_addresses = Vec::with_capacity(total_address_pairs); @@ -99,12 +100,46 @@ where for (src_data, dst_data) in src_block_data.iter().zip(dst_block_data.iter()) { for layer_idx in 0..num_layers { for outer_idx in 0..num_outer_dims { - let src_view = src_data.layer_view(layer_idx, outer_idx)?; - let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + if src_data.is_fragmented() { + let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + let mut dst_ptr = unsafe { dst_view.as_ptr() }; + let n = src_data.num_fragments(); - unsafe { - src_addresses.push(src_view.as_ptr() as u64); - dst_addresses.push(dst_view.as_ptr() as u64); + for i in 0..n { + let src_view = src_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size() * n, dst_view.size()); + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + + dst_ptr = dst_ptr.add(src_view.size()); + } + } + } else if dst_data.is_fragmented() { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut src_ptr = unsafe { src_view.as_ptr() }; + let n = dst_data.num_fragments(); + + for i in 0..n { + let dst_view = dst_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size(), dst_view.size() * n); + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + + src_ptr = src_ptr.add(dst_view.size()); + } + } + } else { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let dst_view = dst_data.layer_view(layer_idx, outer_idx)?; + + unsafe { + src_addresses.push(src_view.as_ptr() as u64); + dst_addresses.push(dst_view.as_ptr() as u64); + } } } } @@ -183,31 +218,48 @@ struct CachedBlockDimensions { num_layers: usize, num_outer_dims: usize, layer_size: usize, + num_fragments: usize, } static BLOCK_DIMENSIONS_CACHE: OnceLock = OnceLock::new(); -fn get_cached_block_dimensions( - block: &T, +fn get_cached_block_dimensions( + src_block: &T, + dst_block: &S, ) -> Result { Ok(*BLOCK_DIMENSIONS_CACHE - .get_or_init(|| calculate_block_dimensions_from_layout(block).unwrap())) + .get_or_init(|| calculate_block_dimensions_from_layout(src_block, dst_block).unwrap())) } -fn calculate_block_dimensions_from_layout( - block: &T, +fn calculate_block_dimensions_from_layout( + src_block: &T, + dst_block: &S, ) -> Result { - let block_data = block.block_data(); + let src_block_data = src_block.block_data(); + let dst_block_data = dst_block.block_data(); // Get dimensions directly from layout (pre-computed values) - let num_layers = block_data.num_layers(); - let num_outer_dims = block_data.num_outer_dims(); - let layer_size = block_data.layer_view(0, 0).map(|v| v.size()).unwrap_or(0); + let num_layers = src_block_data.num_layers(); + let num_outer_dims = src_block_data.num_outer_dims(); + let num_fragments = src_block_data.num_fragments(); + let layer_size = if dst_block_data.is_fragmented() { + src_block_data + .layer_view(0, 0) + .map(|v| v.size() / dst_block_data.num_fragments()) + } else if src_block_data.is_fragmented() { + src_block_data + .layer_view_fragment(0, 0, 0) + .map(|v| v.size()) + } else { + src_block_data.layer_view(0, 0).map(|v| v.size()) + } + .unwrap_or(0); Ok(CachedBlockDimensions { num_layers, num_outer_dims, layer_size, + num_fragments, }) } @@ -223,18 +275,24 @@ where { let _context_guard = stream.context().bind_to_thread(); // Get cached dimensions (calculated once per program lifetime!) - let dims = get_cached_block_dimensions(&sources[0])?; + let dims = get_cached_block_dimensions(&sources[0], &destinations[0])?; // Use cached dimensions - let (src_addresses, dst_addresses) = - collect_kv_addresses(sources, destinations, dims.num_layers, dims.num_outer_dims)?; + let (src_addresses, dst_addresses) = collect_kv_addresses( + sources, + destinations, + dims.num_layers, + dims.num_outer_dims, + dims.num_fragments, + )?; tracing::debug!( - "Using vectorized_copy for {} blocks [{}L×{}O×{}B], {} address pairs", + "Using vectorized_copy for {} blocks [{}L×{}O×{}Bx{}F], {} address pairs", sources.len(), dims.num_layers, dims.num_outer_dims, dims.layer_size, + dims.num_fragments, src_addresses.len() ); @@ -343,18 +401,50 @@ where for layer_idx in layer_range { for outer_idx in 0..src_data.num_outer_dims() { - let src_view = src_data.layer_view(layer_idx, outer_idx)?; - let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + if src_data.is_fragmented() { + let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + let mut dst_ptr = unsafe { dst_view.as_mut_ptr() }; + let n = src_data.num_fragments(); - debug_assert_eq!(src_view.size(), dst_view.size()); + for i in 0..n { + let src_view = src_data.layer_view_fragment(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size() * n, dst_view.size()); - unsafe { - memcpy_fn( - src_view.as_ptr(), - dst_view.as_mut_ptr(), - src_view.size(), - stream, - )?; + unsafe { + memcpy_fn(src_view.as_ptr(), dst_ptr, src_view.size(), stream)?; + + dst_ptr = dst_ptr.add(src_view.size()); + } + } + } else if dst_data.is_fragmented() { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut src_ptr = unsafe { src_view.as_ptr() }; + let n = dst_data.num_fragments(); + + for i in 0..n { + let mut dst_view = dst_data.layer_view_fragment_mut(i, layer_idx, outer_idx)?; + debug_assert_eq!(src_view.size(), dst_view.size() * n); + + unsafe { + memcpy_fn(src_ptr, dst_view.as_mut_ptr(), dst_view.size(), stream)?; + + src_ptr = src_ptr.add(dst_view.size()); + } + } + } else { + let src_view = src_data.layer_view(layer_idx, outer_idx)?; + let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?; + + debug_assert_eq!(src_view.size(), dst_view.size()); + + unsafe { + memcpy_fn( + src_view.as_ptr(), + dst_view.as_mut_ptr(), + src_view.size(), + stream, + )?; + } } } } diff --git a/lib/llm/src/block_manager/config.rs b/lib/llm/src/block_manager/config.rs index d9aa5e70be..4c344d4079 100644 --- a/lib/llm/src/block_manager/config.rs +++ b/lib/llm/src/block_manager/config.rs @@ -194,6 +194,12 @@ pub struct KvBlockManagerConfig { /// Channel to reset the block manager to a specific cache level #[builder(default)] pub block_reset_channel: Option, + + /// Ratio between offload block size and engine block size + /// Offload blocks are larger to enable more efficient I/O operations + #[builder(default = "32")] + #[validate(range(min = 1))] + pub offload_block_size_ratio: usize, } impl KvBlockManagerConfig { diff --git a/lib/llm/src/block_manager/distributed/transfer.rs b/lib/llm/src/block_manager/distributed/transfer.rs index fb1c7f452c..d566304e3a 100644 --- a/lib/llm/src/block_manager/distributed/transfer.rs +++ b/lib/llm/src/block_manager/distributed/transfer.rs @@ -15,6 +15,7 @@ use crate::block_manager::{ block::{ Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock, data::local::LocalBlockData, + data::local::LocalBlockDataBase, locality, transfer::{TransferContext, WriteTo, WriteToStrategy}, }, @@ -28,7 +29,6 @@ use async_trait::async_trait; use std::{any::Any, sync::Arc}; type LocalBlock = Block; -type LocalBlockDataList = Vec>; /// A batching wrapper for connector transfers to prevent resource exhaustion. /// Splits large transfers into smaller batches that can be handled by the resource pools. @@ -86,9 +86,9 @@ impl ConnectorTransferBatcher { /// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s. #[derive(Clone)] pub struct BlockTransferHandler { - device: Option>, - host: Option>, - disk: Option>, + device: Option>, + host: Option>, + disk: Option>, context: Arc, scheduler_client: Option, batcher: ConnectorTransferBatcher, @@ -105,38 +105,36 @@ impl BlockTransferHandler { // add worker-connector scheduler client here ) -> Result { Ok(Self { - device: Self::get_local_data(device_blocks), - host: Self::get_local_data(host_blocks), - disk: Self::get_local_data(disk_blocks), + device: Self::get_local_data_base(device_blocks), + host: Self::get_local_data_base(host_blocks), + disk: Self::get_local_data_base(disk_blocks), context, scheduler_client, batcher: ConnectorTransferBatcher::new(), }) } - fn get_local_data( + fn get_local_data_base( blocks: Option>>, - ) -> Option> { - blocks.map(|blocks| { - blocks - .into_iter() - .map(|b| { - let block_data = b.block_data() as &dyn Any; - - block_data - .downcast_ref::>() - .unwrap() - .clone() - }) - .collect() - }) + ) -> Option> { + let Some(vec) = blocks else { return None }; + if let Some(b) = vec.first() { + let block_data = b.block_data() as &dyn Any; + let block_data = block_data + .downcast_ref::>() + .unwrap() + .clone(); + Some(block_data.base()) + } else { + None + } } /// Initiate a transfer between two pools. async fn begin_transfer( &self, - source_pool_list: &Option>, - target_pool_list: &Option>, + source_pool_base: &Option>, + target_pool_base: &Option>, request: BlockTransferRequest, ) -> Result> where @@ -150,23 +148,23 @@ impl BlockTransferHandler { LocalBlockData: BlockDataProvider, LocalBlockData: BlockDataProviderMut, { - let Some(source_pool_list) = source_pool_list else { + let Some(source_pool_base) = source_pool_base else { return Err(anyhow::anyhow!("Source pool manager not initialized")); }; - let Some(target_pool_list) = target_pool_list else { + let Some(target_pool_base) = target_pool_base else { return Err(anyhow::anyhow!("Target pool manager not initialized")); }; // Extract the `from` and `to` indices from the request. - let source_idxs = request.blocks().iter().map(|(from, _)| *from); - let target_idxs = request.blocks().iter().map(|(_, to)| *to); + let source_idxs = request.blocks().iter().map(|(from, _)| from); + let target_idxs = request.blocks().iter().map(|(_, to)| to); // Get the blocks corresponding to the indices. let sources: Vec> = source_idxs - .map(|idx| source_pool_list[idx].clone()) + .map(|idx| source_pool_base.get_data(idx.clone())) .collect(); let mut targets: Vec> = target_idxs - .map(|idx| target_pool_list[idx].clone()) + .map(|idx| target_pool_base.get_data(idx.clone())) .collect(); // Perform the transfer, and return the notifying channel. diff --git a/lib/llm/src/block_manager/distributed/utils.rs b/lib/llm/src/block_manager/distributed/utils.rs index 5798a87fbb..573de49ed5 100644 --- a/lib/llm/src/block_manager/distributed/utils.rs +++ b/lib/llm/src/block_manager/distributed/utils.rs @@ -3,6 +3,7 @@ use derive_getters::Getters; use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; use crate::block_manager::connector::protocol::LeaderTransferRequest; @@ -33,7 +34,7 @@ pub struct ConnectorRequestLeader { pub struct BlockTransferRequest { pub from_pool: BlockTransferPool, pub to_pool: BlockTransferPool, - pub blocks: Vec<(usize, usize)>, + pub blocks: Vec<(SmallVec<[usize; 1]>, SmallVec<[usize; 1]>)>, #[serde(skip_serializing_if = "Option::is_none")] pub connector_req: Option, @@ -49,7 +50,15 @@ impl BlockTransferRequest { Self { from_pool, to_pool, - blocks, + blocks: blocks + .into_iter() + .map(|(src, dst)| { + ( + [src].iter().map(|x| *x).collect(), + [dst].iter().map(|x| *x).collect(), + ) + }) + .collect(), connector_req: None, } } @@ -63,7 +72,15 @@ impl BlockTransferRequest { Self { from_pool, to_pool, - blocks, + blocks: blocks + .into_iter() + .map(|(src, dst)| { + ( + [src].iter().map(|x| *x).collect(), + [dst].iter().map(|x| *x).collect(), + ) + }) + .collect(), connector_req: Some(connector_req), } } diff --git a/lib/llm/src/block_manager/distributed/worker.rs b/lib/llm/src/block_manager/distributed/worker.rs index 578461a700..1da7e57f89 100644 --- a/lib/llm/src/block_manager/distributed/worker.rs +++ b/lib/llm/src/block_manager/distributed/worker.rs @@ -90,7 +90,10 @@ pub struct KvbmWorkerConfig { num_device_blocks: usize, #[builder(default = "32")] - page_size: usize, + engine_page_size: usize, + + #[builder(default = "32")] + offload_page_size: usize, #[builder(default = "Vec::new()")] tensors: Vec>, @@ -143,9 +146,10 @@ pub struct KvbmWorker { impl KvbmWorker { pub async fn new(config: KvbmWorkerConfig, layout_blocking: bool) -> anyhow::Result { tracing::info!( - "Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}", + "Initializing KvbmWorker with params: num_device_blocks={}, engine_page_size={}, offload_page_size={}, dtype_width_bytes={}", config.num_device_blocks, - config.page_size, + config.engine_page_size, + config.offload_page_size, config.dtype_width_bytes ); @@ -166,12 +170,13 @@ impl KvbmWorker { LayoutType::FullyContiguous => { let num_layers = shape[1]; let outer_dim = shape[2]; - let inner_dim = shape[3..].iter().product::() / config.page_size; + let inner_dim = shape[3..].iter().product::() / config.engine_page_size; tracing::info!( - "Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}", + "Inferred layout: num_layers={}, outer_dim={}, page_size={}/{}, inner_dim={}", num_layers, outer_dim, - config.page_size, + config.engine_page_size, + config.offload_page_size, inner_dim ); @@ -194,14 +199,15 @@ impl KvbmWorker { }; let num_layers = device_tensors.len(); - let inner_dim = shape[2..].iter().product::() / config.page_size; + let inner_dim = shape[2..].iter().product::() / config.engine_page_size; tracing::info!( - "Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}, inner_dim={}", + "Inferred layout: num_layers={}, outer_dim={}, outer_contiguous={}, page_size={}/{}, inner_dim={}", num_layers, outer_dim, outer_contiguous, - config.page_size, + config.engine_page_size, + config.offload_page_size, inner_dim ); @@ -209,14 +215,17 @@ impl KvbmWorker { } }; - let bytes_per_block = - num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes; + let offload_bytes_per_block = num_layers + * outer_dim + * config.offload_page_size + * inner_dim + * config.dtype_width_bytes; let mut layout_builder_instance = LayoutConfigBuilder::default(); let layout_builder = layout_builder_instance .num_layers(num_layers) .outer_dim(outer_dim) - .page_size(config.page_size) + .page_size(config.engine_page_size) .inner_dim(inner_dim) .dtype_width_bytes(config.dtype_width_bytes); @@ -225,12 +234,12 @@ impl KvbmWorker { .build()? .create_layout(layout_type, device_tensors)?; - let layout_builder = layout_builder.clone(); + let layout_builder = layout_builder.page_size(config.offload_page_size).clone(); let (task, handler_rx) = if layout_blocking { Self::run_blocking_layout_initialization( config, - bytes_per_block, + offload_bytes_per_block, device_layout, layout_builder, layout_type, @@ -239,7 +248,7 @@ impl KvbmWorker { } else { Self::run_non_blocking_layout_initialization( config, - bytes_per_block, + offload_bytes_per_block, device_layout, layout_builder, layout_type, @@ -579,6 +588,7 @@ impl KvbmWorker { max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, num_outer_components: device_layout.config().outer_dim, num_layers: device_layout.config().num_layers, + offload_block_size_ratio: config.offload_page_size / config.engine_page_size, }; let transfer_context = Arc::new(TransferContext::new( diff --git a/lib/llm/src/block_manager/offload.rs b/lib/llm/src/block_manager/offload.rs index 94e51751bb..445411f461 100644 --- a/lib/llm/src/block_manager/offload.rs +++ b/lib/llm/src/block_manager/offload.rs @@ -74,6 +74,7 @@ pub struct OffloadManagerConfig { pub metrics: Arc, pub cancellation_token: CancellationToken, pub model_config: KvManagerModelConfig, + pub offload_block_size_ratio: usize, } /// The offload manager handles all block transfers between different cache levels. @@ -131,6 +132,7 @@ impl max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, num_outer_components: config.model_config.outer_dim, num_layers: config.model_config.num_layers, + offload_block_size_ratio: config.offload_block_size_ratio, }; // We want cuda offloads to happen in parallel with host onboards, so we need to use a different stream. @@ -765,6 +767,7 @@ mod tests { metrics: BlockManagerMetrics::new(&Arc::new(Registry::new()))?, cancellation_token: CancellationToken::new(), model_config: minimal_config, + offload_block_size_ratio: 1, }; let manager = OffloadManager::new( diff --git a/lib/llm/src/block_manager/state.rs b/lib/llm/src/block_manager/state.rs index 5cf62bff4d..1f63c92dff 100644 --- a/lib/llm/src/block_manager/state.rs +++ b/lib/llm/src/block_manager/state.rs @@ -104,6 +104,7 @@ impl { pub async fn new(config: KvBlockManagerConfig, logical_resources: R) -> Result> { let model_config = config.model.clone(); + let offload_block_size_ratio = config.offload_block_size_ratio; let mut resources = Resources::new(config)?; let block_data_factories = logical::LogicalBlockFactories::new(&mut resources, logical_resources)?; @@ -152,6 +153,7 @@ impl metrics: resources.metrics.clone(), cancellation_token: resources.cancellation_token.clone(), model_config, + offload_block_size_ratio, }; let offload_manager = OffloadManager::new( @@ -213,6 +215,7 @@ impl impl KvBlockManagerState { pub async fn new(config: KvBlockManagerConfig) -> Result> { let model_config = config.model.clone(); + let offload_block_size_ratio = config.offload_block_size_ratio; let mut resources = Resources::new(config)?; let block_data_factories = local::LocalBlockDataFactories::new(&mut resources)?; @@ -267,6 +270,7 @@ impl KvBlockManagerState { metrics: resources.metrics.clone(), cancellation_token: resources.cancellation_token.clone(), model_config, + offload_block_size_ratio, }; let offload_manager = OffloadManager::new( diff --git a/lib/llm/src/block_manager/state/resources.rs b/lib/llm/src/block_manager/state/resources.rs index 1a17228b41..391d66a3a0 100644 --- a/lib/llm/src/block_manager/state/resources.rs +++ b/lib/llm/src/block_manager/state/resources.rs @@ -89,7 +89,7 @@ impl Resources { layout_builder .num_layers(model.num_layers) .outer_dim(model.outer_dim) - .page_size(model.page_size) + .page_size(model.page_size * self.config.offload_block_size_ratio) .inner_dim(model.inner_dim) .dtype_width_bytes(model.dtype_width_bytes);