You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
kvcached crashes during initialization when used with vLLM's PD (prefill-decode) disaggregation via NixlConnector. Two bugs were identified. A third concern (UCX + VMM memory compatibility for the actual transfer) appears to be addressed in recent UCX versions but remains untested.
PD disaggregation works correctly without kvcached using the same setup.
curl http://localhost:9100/v1/completions \
-H "Content-Type: application/json" \
-d '{"model":"Qwen/Qwen2.5-1.5B-Instruct","prompt":"The capital of France is","max_tokens":30}'
Without kvcached: all four components work, responses are correct.
With kvcached: both prefill and decode crash during KV cache initialization before serving any requests.
Bug 1: Block count assertion mismatch
Error
AssertionError in nixl_connector.py:1667, register_kv_caches():
"All kv cache tensors must have the same number of blocks"
Crash path:
gpu_worker.py:536 → initialize_from_config()
nixl_connector.py:452 → register_kv_caches()
nixl_connector.py:1667 → assert cache.shape[0] == num_blocks ← FAILS
Root cause
NixlConnector sets self.num_blocks = kv_cache_config.num_blocks (line 1098) and
asserts every KV cache tensor's shape[0] matches this value.
kvcached's alloc_kv_cache() (interfaces.py:159) computes num_blocks_per_layer
from total GPU memory and returns tensors with that larger shape (interfaces.py:171):
kvcached over-provisions virtual blocks for elastic resize - physical pages are
mapped/unmapped on demand via CUDA VMM (cuMemAddressReserve + cuMemMap).
NixlConnector expects the exact physical block count for RDMA descriptor
pre-registration.
Why relaxing the assertion should be safe
NixlConnector pre-creates per-block transfer descriptors (block ID = array index).
During a transfer, only blocks assigned by the scheduler are copied - and the
scheduler only assigns blocks that kvcached has physically mapped. Descriptors for
unmapped blocks (backed by the shared zero page) exist but are never referenced in
transfers. Registering 187K descriptors instead of 512 wastes some memory (~252MB
for descriptor metadata) but is functionally safe.
Bug 2: KV cache layout incompatibility
This bug is hidden behind Bug 1 - it only appears after patching the assertion.
Error
RuntimeError: set_stride is not allowed on a Tensor created from .data or .detach().
Crash path:
gpu_worker.py:381 → determine_available_memory()
gpu_model_runner.py:5815 → _init_minimal_kv_cache_for_profiling()
gpu_model_runner.py:6718 → initialize_kv_cache_tensors()
→ set_stride CRASH
Root cause
NixlConnector forces HND (Head-Num-Dim) KV cache layout via get_required_kvcache_layout() returning "HND" for better transfer performance.
vLLM then tries to permute the KV cache tensors from NHD to HND, which calls set_stride().
kvcached creates tensors via torch::from_blob() over CUDA virtual memory
(cuMemAddressReserve). PyTorch does not allow set_stride() on from_blob
tensors (non-owning data pointer). kvcached explicitly only supports NHD layout
(interfaces.py:111-112).
UCX + VMM memory compatibility
Even after fixing both bugs above, the actual GPU-to-GPU KV cache transfer has not
been tested with kvcached's VMM-allocated memory.
UCX's cuda_ipc transport historically used cuIpcGetMemHandle() which does not
support cuMemCreate-allocated memory:
cuda_ipc: VMM handle type (UCT_CUDA_IPC_KEY_HANDLE_TYPE_VMM) using cuMemExportToShareableHandle with fabric handles instead of legacy cuIpcGetMemHandle (in UCX master, src/uct/cuda/cuda_ipc/cuda_ipc_md.c)
Whether the UCX version bundled with NIXL includes these changes is unknown. We have
not been able to get past the initialization crashes to attempt an actual transfer.
We implemented a NixlConnectorPatch in kvcached's monkey-patch integration layer:
Bug 2 fix (working): Override get_required_kvcache_layout() to return None
(use NHD default) when kvcached is active. Verified - set_stride crash no longer
occurs.
Bug 1 fix (in progress): Patch register_kv_caches() to adjust self.num_blocks
to match the tensor's shape[0]. The challenge is that internal method _sync_block_size_with_kernel() runs inside register_kv_caches() and can overwrite
the patched value. Still debugging the patch mechanism.
Recommended - relax assertion in NixlConnector (vLLM side): Change line 1667
from assert cache.shape[0] == num_blocks to assert cache.shape[0] >= num_blocks.
Use cache.shape[0] for descriptor creation instead of kv_cache_config.num_blocks.
This is the smallest change and is safe because transfers only touch
scheduler-assigned (physically backed) blocks.
Alternative - fix tensor shape in kvcached (interfaces.py:171): Return tensor
views with shape[0] matching vLLM's requested_num_blocks instead of the full
virtual capacity. The underlying virtual reservation remains large. Caveat: elastic
resize creates block IDs beyond requested_num_blocks, which would be out of bounds
on the smaller tensor.
Layout - add HND support to kvcached or make it optional in NixlConnector:
Either support set_stride() on kvcached's VMM tensors, or make NixlConnector's
HND layout preference overridable when the KV cache allocator doesn't support it.
Description
kvcached crashes during initialization when used with vLLM's PD (prefill-decode) disaggregation via NixlConnector. Two bugs were identified. A third concern (UCX + VMM memory compatibility for the actual transfer) appears to be addressed in recent UCX versions but remains untested.
PD disaggregation works correctly without kvcached using the same setup.
Related: #302
Environment
cuda_ipc,cuda_copy,tcp)Steps to reproduce
Start prefill and decode instances with kvcached + NixlConnector:
Prefill (GPU 0):
ENABLE_KVCACHED=true KVCACHED_AUTOPATCH=1 \ UCX_TLS=cuda_ipc,cuda_copy,tcp \ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-1.5B-Instruct \ --host 0.0.0.0 --port 8100 --max-model-len 1024 --gpu-memory-utilization 0.8 \ --kv-transfer-config \ '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_buffer_device":"cuda"}'Decode (GPU 1):
ENABLE_KVCACHED=true KVCACHED_AUTOPATCH=1 \ UCX_TLS=cuda_ipc,cuda_copy,tcp \ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \ CUDA_VISIBLE_DEVICES=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct \ --host 0.0.0.0 --port 8200 --max-model-len 1024 --gpu-memory-utilization 0.8 \ --kv-transfer-config \ '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_buffer_device":"cuda"}'Proxy (port 9100):
(from
vllm/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py)Test request:
Without kvcached: all four components work, responses are correct.
With kvcached: both prefill and decode crash during KV cache initialization before serving any requests.
Bug 1: Block count assertion mismatch
Error
Root cause
NixlConnector sets
self.num_blocks = kv_cache_config.num_blocks(line 1098) andasserts every KV cache tensor's
shape[0]matches this value.kvcached's
alloc_kv_cache()(interfaces.py:159) computesnum_blocks_per_layerfrom total GPU memory and returns tensors with that larger shape (interfaces.py:171):
cache.shape[0]self.num_blocksBoth sides are correct by design:
mapped/unmapped on demand via CUDA VMM (
cuMemAddressReserve+cuMemMap).pre-registration.
Why relaxing the assertion should be safe
NixlConnector pre-creates per-block transfer descriptors (block ID = array index).
During a transfer, only blocks assigned by the scheduler are copied - and the
scheduler only assigns blocks that kvcached has physically mapped. Descriptors for
unmapped blocks (backed by the shared zero page) exist but are never referenced in
transfers. Registering 187K descriptors instead of 512 wastes some memory (~252MB
for descriptor metadata) but is functionally safe.
Bug 2: KV cache layout incompatibility
This bug is hidden behind Bug 1 - it only appears after patching the assertion.
Error
Root cause
NixlConnector forces HND (Head-Num-Dim) KV cache layout via
get_required_kvcache_layout()returning"HND"for better transfer performance.vLLM then tries to permute the KV cache tensors from NHD to HND, which calls
set_stride().kvcached creates tensors via
torch::from_blob()over CUDA virtual memory(
cuMemAddressReserve). PyTorch does not allowset_stride()onfrom_blobtensors (non-owning data pointer). kvcached explicitly only supports NHD layout
(interfaces.py:111-112).
UCX + VMM memory compatibility
Even after fixing both bugs above, the actual GPU-to-GPU KV cache transfer has not
been tested with kvcached's VMM-allocated memory.
UCX's
cuda_ipctransport historically usedcuIpcGetMemHandle()which does notsupport
cuMemCreate-allocated memory:cuMemCreateandcudaMallocAsyncopenucx/ucx#7110However, recent UCX versions have added VMM support for both transports:
UCT_CUDA_IPC_KEY_HANDLE_TYPE_VMM) usingcuMemExportToShareableHandlewith fabric handles instead of legacycuIpcGetMemHandle(in UCX master,src/uct/cuda/cuda_ipc/cuda_ipc_md.c)Whether the UCX version bundled with NIXL includes these changes is unknown. We have
not been able to get past the initialization crashes to attempt an actual transfer.
Other connectors tested
shape[0] == num_blocks)num_block = kv_cache.shape[0])What we tried
We implemented a
NixlConnectorPatchin kvcached's monkey-patch integration layer:Bug 2 fix (working): Override
get_required_kvcache_layout()to returnNone(use NHD default) when kvcached is active. Verified -
set_stridecrash no longeroccurs.
Bug 1 fix (in progress): Patch
register_kv_caches()to adjustself.num_blocksto match the tensor's
shape[0]. The challenge is that internal method_sync_block_size_with_kernel()runs insideregister_kv_caches()and can overwritethe patched value. Still debugging the patch mechanism.
Branch: https://github.com/AAbouzeid/kvcached/tree/fix/pd-disagg-nixl-connector
Suggested fix directions
Recommended - relax assertion in NixlConnector (vLLM side): Change line 1667
from
assert cache.shape[0] == num_blockstoassert cache.shape[0] >= num_blocks.Use
cache.shape[0]for descriptor creation instead ofkv_cache_config.num_blocks.This is the smallest change and is safe because transfers only touch
scheduler-assigned (physically backed) blocks.
Alternative - fix tensor shape in kvcached (interfaces.py:171): Return tensor
views with
shape[0]matching vLLM'srequested_num_blocksinstead of the fullvirtual capacity. The underlying virtual reservation remains large. Caveat: elastic
resize creates block IDs beyond
requested_num_blocks, which would be out of boundson the smaller tensor.
Layout - add HND support to kvcached or make it optional in NixlConnector:
Either support
set_stride()on kvcached's VMM tensors, or make NixlConnector'sHND layout preference overridable when the KV cache allocator doesn't support it.
Happy to work on a fix and test it.