Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion kvcached/integration/vllm/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from wrapt.importer import when_imported

from kvcached.integration.patch_base import PatchManager, log_patch_results
from kvcached.integration.patch_base import PatchManager, enable_kvcached, log_patch_results
from kvcached.integration.vllm.patches import (
VLLM_ALL_RANGE,
VLLM_V8_RANGE,
Expand Down Expand Up @@ -50,5 +50,64 @@ def _patch_vllm(_vllm: types.ModuleType) -> None:
# Apply all patches
results = patch_manager.apply_all_patches()

# Patch NixlConnector for kvcached compatibility (two bugs).
# Done eagerly here because get_required_kvcache_layout() is a classmethod
# called during create_engine_config(), before deferred module patches fire.
_patch_nixl_connector()

# Log results
log_patch_results("vllm", results)


def _patch_nixl_connector() -> None:
"""Patch NixlConnector for kvcached PD disaggregation compatibility.

Bug 1: NixlConnector forces HND layout, but kvcached's from_blob tensors
don't support set_stride (needed for NHD->HND permutation).
Fix: override get_required_kvcache_layout() to return None (use NHD).

Bug 2: NixlConnectorWorker.register_kv_caches asserts
tensor.shape[blocks_dim] == self.num_blocks. kvcached allocates a larger
virtual blocks dimension (_num_blocks_per_layer) than vLLM's physical
budget (self.num_blocks), so the assertion fires.
Fix: rewrite self.num_blocks to match _num_blocks_per_layer before the
original register_kv_caches runs.
"""
try:
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnector,
NixlConnectorWorker,
)
except ImportError:
return # NIXL not installed

# Bug 1: force NHD layout
@classmethod # type: ignore[misc]
def _kvcached_layout(cls, *args, **kwargs):
if not enable_kvcached():
return NixlConnector._original_get_layout(*args, **kwargs)
logger.info("kvcached: NixlConnector layout overridden to NHD")
return None

NixlConnector._original_get_layout = NixlConnector.get_required_kvcache_layout
NixlConnector.get_required_kvcache_layout = _kvcached_layout

# Bug 2: relax block count assertion
_original_register = NixlConnectorWorker.register_kv_caches

def _patched_register(self, kv_caches, *args, **kwargs):
if not enable_kvcached():
return _original_register(self, kv_caches, *args, **kwargs)

from kvcached.integration.vllm.interfaces import _num_blocks_per_layer
if _num_blocks_per_layer > 0 and _num_blocks_per_layer != self.num_blocks:
logger.info(
"kvcached: NixlConnector num_blocks %d -> %d",
self.num_blocks, _num_blocks_per_layer,
)
self.num_blocks = _num_blocks_per_layer

return _original_register(self, kv_caches, *args, **kwargs)

NixlConnectorWorker.register_kv_caches = _patched_register
logger.info("Patched NixlConnector for kvcached PD disagg compatibility")
6 changes: 6 additions & 0 deletions kvcached/integration/vllm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
_pp_rank: int = 0
_contiguous_layout: bool = CONTIGUOUS_LAYOUT
_is_worker: bool = False
# Physical blocks per layer kvcached allocates for the GPU's KV region.
# Read by the NixlConnector patch in autopatch.py to satisfy the connector's
# shape[blocks_dim] == self.num_blocks assertion when kvcached over-allocates.
_num_blocks_per_layer: int = 0


def should_use_worker_ipc() -> bool:
Expand Down Expand Up @@ -204,6 +208,8 @@ def alloc_kv_cache(
gpu_mem_bytes_per_layer_k_or_v = (gpu_mem_bytes_per_layer_k_or_v // PAGE_SIZE) * PAGE_SIZE

num_blocks_per_layer = gpu_mem_bytes_per_layer_k_or_v // block_mem_bytes
global _num_blocks_per_layer
_num_blocks_per_layer = num_blocks_per_layer
Comment on lines +211 to +212
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a module-level global variable _num_blocks_per_layer to communicate the physical block count to the NixlConnector patch is a bit brittle. If alloc_kv_cache is called multiple times with different block counts (e.g., for different KV cache groups or in a multi-engine scenario within the same process), the global will only store the last value. While vLLM currently tends to have a uniform block count across layers, consider if this state should be managed more explicitly (e.g., as an attribute on a shared configuration object) to avoid potential issues in more complex multi-tenant or hybrid model scenarios.

if requested_num_blocks > num_blocks_per_layer:
logger.warning(
f"Requested {requested_num_blocks} blocks, but only {num_blocks_per_layer} blocks are available."
Expand Down
Loading