Skip to content

Patch NixlConnector for kvcached PD disaggregation (closes #302)#313

Open
AAbouzeid wants to merge 3 commits into
ovg-project:mainfrom
AAbouzeid:fix/pd-disagg-nixl-connector-minimal
Open

Patch NixlConnector for kvcached PD disaggregation (closes #302)#313
AAbouzeid wants to merge 3 commits into
ovg-project:mainfrom
AAbouzeid:fix/pd-disagg-nixl-connector-minimal

Conversation

@AAbouzeid
Copy link
Copy Markdown

@AAbouzeid AAbouzeid commented Apr 21, 2026

TL;DR

vLLM's NixlConnector has two incompatibilities with kvcached under prefill/decode disaggregation:

  1. Layout: NixlConnector.get_required_kvcache_layout returns "HND", which triggers set_stride() on the KV tensor. kvcached's torch.from_blob tensors on a CUDA VMM region don't support set_stride, so engine init crashes.
  2. Assertion: NixlConnectorWorker.register_kv_caches asserts tensor.shape[blocks_dim] == self.num_blocks, but kvcached over-allocates virtual capacity so the tensor's blocks dim (_num_blocks_per_layer) exceeds vLLM's physical budget (self.num_blocks), and registration refuses.

The Fix

Fix, in kvcached/integration/vllm/autopatch.py:_patch_nixl_connector:

  • Override get_required_kvcache_layout to return None (use NHD). Eagerly applied because it is a classmethod consulted during create_engine_config, before deferred module patches fire.
  • Wrap register_kv_caches to set self.num_blocks equal to _num_blocks_per_layer before the original runs. Only active when ENABLE_KVCACHED=true. Passthrough otherwise.

kvcached/integration/vllm/interfaces.py exposes _num_blocks_per_layer as a module global so the NixlConnector patch can read it without touching vLLM internals. The value is set inside alloc_kv_cache whenever kvcached sizes its per-layer block budget.

Validations

Validated on RunPod 2xA100 80GB, Qwen2.5-1.5B-Instruct, vLLM 0.19.0. Greedy outputs are byte-identical with and without kvcached for four fixed prompts.

Experiment Branch

Full equivalence harness, layout dumps, and SHA fingerprint checks (including experiments/09_equivalence.sh, experiments/06_test_fix1.sh, and kvcached/integration/vllm/debug_layout.py) Can be found in this PR: #312

…t#302)

vLLM's NixlConnector has two incompatibilities with kvcached under
prefill/decode disaggregation:

1. Layout: NixlConnector.get_required_kvcache_layout returns "HND",
   which triggers set_stride() on the KV tensor. kvcached's
   torch.from_blob tensors on a CUDA VMM region don't support
   set_stride, so engine init crashes.
2. Assertion: NixlConnectorWorker.register_kv_caches asserts
   tensor.shape[blocks_dim] == self.num_blocks, but kvcached
   over-allocates virtual capacity so the tensor's blocks dim
   (_num_blocks_per_layer) exceeds vLLM's physical budget
   (self.num_blocks), and registration refuses.

Fix, in kvcached/integration/vllm/autopatch.py:_patch_nixl_connector:
  - Override get_required_kvcache_layout to return None (use NHD).
    Eagerly applied because it is a classmethod consulted during
    create_engine_config, before deferred module patches fire.
  - Wrap register_kv_caches to set self.num_blocks equal to
    _num_blocks_per_layer before the original runs. Only active
    when ENABLE_KVCACHED=true. Passthrough otherwise.

kvcached/integration/vllm/interfaces.py exposes _num_blocks_per_layer
as a module global so the NixlConnector patch can read it without
touching vLLM internals. The value is set inside alloc_kv_cache
whenever kvcached sizes its per-layer block budget.

Validated on RunPod 2xA100 80GB, Qwen2.5-1.5B-Instruct, vLLM 0.19.0.
Greedy outputs are byte-identical with and without kvcached for four
fixed prompts.

Full equivalence harness, layout dumps, and SHA fingerprint checks
(including experiments/09_equivalence.sh, experiments/06_test_fix1.sh,
and kvcached/integration/vllm/debug_layout.py) are on the companion
branch AAbouzeid/kvcached:fix/pd-disagg-nixl-connector, which stacks
those additions on top of this fix.
@cui36
Copy link
Copy Markdown
Collaborator

cui36 commented Apr 25, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements monkey-patches for vLLM's NixlConnector to ensure compatibility with kvcached during PD disaggregation. The changes force an NHD layout and synchronize the physical block count in NixlConnectorWorker with kvcached's internal allocation. A review comment identifies that the current block count synchronization only handles over-allocation and suggests updating the logic to handle any mismatch to avoid assertion failures.

Comment thread kvcached/integration/vllm/autopatch.py Outdated
When GPU memory is not enough

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@AAbouzeid
Copy link
Copy Markdown
Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces patches for vLLM's NixlConnector to ensure compatibility with kvcached PD disaggregation. It addresses layout incompatibilities by forcing NHD and resolves block count assertion failures by dynamically updating the connector's block count using a new tracking variable. Review feedback suggests that using a module-level global variable for this state is brittle and may cause issues in multi-engine or multi-tenant scenarios, recommending a more explicit configuration management approach instead.

Comment on lines +211 to +212
global _num_blocks_per_layer
_num_blocks_per_layer = num_blocks_per_layer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a module-level global variable _num_blocks_per_layer to communicate the physical block count to the NixlConnector patch is a bit brittle. If alloc_kv_cache is called multiple times with different block counts (e.g., for different KV cache groups or in a multi-engine scenario within the same process), the global will only store the last value. While vLLM currently tends to have a uniform block count across layers, consider if this state should be managed more explicitly (e.g., as an attribute on a shared configuration object) to avoid potential issues in more complex multi-tenant or hybrid model scenarios.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants