Fix/pr319 restore resize#323
Open
cui36 wants to merge 15 commits into
Open
Conversation
- Bind PageAllocator::check_and_get_resize_target so kv_cache_manager can poll - Pin KVCACHED_IPC_NAME from Python so C++ MemInfoTracker uses the same shm - Include pybind11/functional.h and pybind11/stl.h in torch_bindings - Update test_kvcache_manager.py to use C++ public methods
Times alloc(k) + free(handles) cycles at varying k. Used to compare main vs C++ migration vs C++ + restored elastic resize.
Apply diff from 98d9bb3 -> 65a7d0a (lianghao208/kvcached:lianghao_c++): - csrc/page_allocator.cpp: refactor free_page/free_pages/resize/trim to use scoped lock_guard blocks instead of manual lock/unlock, making the slow-path unmap exception-safe. Also drop the max_reserved_pages_ auto-expansion in alloc_page(). - kvcached/kv_cache_manager.py: cap available_size() by physical free pages (avail_physical + reserved) in addition to virtual free pages, so capacity reported under memory pressure stays honest. Functionally equivalent to PR #319 head; the local override commits (restore-resize, bench scripts, overhead notes) sit on top.
…yout overhead Restructure: - bench_alloc/ now contains only the alloc/free microbench (bench_alloc.py) with a scoped bench_result.md (sections A, B, C, D, F). - bench_layout/ (new) gathers the e2e vLLM sweep scripts (run_sweep.sh, run_kvcached_configs.sh, run_layout_retest.sh, parse_results.py) and the layout-overhead investigation: run_nsys_layout.sh, diff_nsys_kernels.py, run_ncu_attn.sh. Intermediate outputs (sweep_results/, sweep_logs/, nsys_runs/, nsys_logs/, ncu_runs/, ncu_logs/) are reproducible from the scripts and live-excluded via .git/info/exclude. New evidence in bench_layout/bench_result.md G.3.2: - nsys per-kernel diff on the Section E workload (500 prompts, rate=inf) isolates the entire +8 GB-ms (+34.8%) GPU-time gap to one kernel: flash::flash_fwd_splitkv_kernel goes from 14,666 ms to 22,879 ms (+8,213 ms, +56%, 3948 calls). - Per-call attention slowdown scales with working set: +37% at 100 prompts vs +56% at 500 prompts, matching the prediction from the per-block stride (1.75 MB under contiguous=true vs 64 KB under false; VMM page = 2 MB). - KV-write kernel (reshape_and_cache_flash_kernel) is unaffected, confirming the regression is in the multi-block-read path, not the write path. ncu memory-counter run (L2 hit / DRAM throughput on flash_fwd_splitkv_kernel) is gated on NVreg_RestrictProfilingToAdminUsers=0 and will be added once counter access is enabled.
Match the bench_map_parallelism / bench_tp_ipc convention (README.md, with Quick start). Rewrite the content to be readable after the branch is merged: - Replace internal labels (main/PR/fix) with descriptive variant names: Python allocator / C++ allocator / C++ + restored resize. - Drop section-by-section detail that just restated the same speedup; keep headline tables plus selected per-N detail where the pattern itself is informative (multi-thread shape, pool scaling with N). - Reframe "open question" as "recommendation" so the takeaway about flipping CONTIGUOUS_LAYOUT default reads as a concrete proposal. Net diff: bench_alloc/README.md 121 -> 61 lines; bench_layout/README.md 168 -> 81 lines. Content unchanged for the load-bearing measurements (headline e2e tables, stride math, nsys per-kernel attribution, scaling table).
- Replace the mixed-unit "headline numbers" table with five clearly scoped subsections, each with: one-line context, table with explicit units (per-call μs vs aggregate Kops/s) and direction (lower/higher is better), and a bold speedup line. - Rewrite the summary to lead with what the C++ allocator delivers (5 bullets of concrete gains). The e2e-amortization note moves to a footer instead of the headline, so the alloc work is no longer framed primarily by its dilution.
Three measured advantages of the contiguous layout, all small relative to
the attention-read cost shown in Sections 1-2:
- 3.1 Hybrid linear / mamba: interfaces.py:138 forces it; no choice.
- 3.2 Init time: alloc_kv_cache costs ~635 ms (contiguous) vs ~2055 ms
(non-contiguous). Gap is a flat ~1.4 s across num_layers in {8..80},
not linear in layer count. Breakdown via temporary timing prints in
csrc/allocator.cpp and csrc/ftensor.cpp showed ~99% of the gap lives
in FTensor::init_with_zero_(): contiguous uses a compound page of
page_size * num_layers * num_kv_buffers and makes 1947 cuMemMap
calls at ~325 us each (= 632 ms); non-contiguous uses 2 MB pages
and makes 62,304 cuMemMap calls at ~33 us each (= 2,054 ms). Driver
per-call overhead dominates -- fewer larger calls win.
- 3.3 Alloc/free hot path: contiguous ~2x faster on cold mapping
(RESERVED=0) and ~2x at small k under default reserve, collapsing to
~1x at k=256.
3.4 then frames the trade-off as startup vs steady-state with an explicit
break-even calculation: at the Section 1 workload, LAYOUT=true's 1.4 s
startup advantage is paid off after about 45 requests, so LAYOUT=false
strictly wins for any non-trivial serving workload. Concrete cases where
LAYOUT=true still pays: smoke tests / single-shot inference (<~45
requests), frequent vLLM restarts, boot-time SLA, hybrid linear / mamba.
Recommendation section tightened to acknowledge the startup cost
explicitly rather than dismissing it as "doesn't show up at e2e".
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR #319 alloc/free microbench
bench_alloc.py: tight loop ofmanager.alloc(k)+manager.free(h)after 100-iter warmup.GPU: NVIDIA GB10. NUM_LAYERS=16, NUM_BLOCKS=65536, BLOCK_SIZE=16, page_size=2 MB.
Results (μs per alloc+free pair)
Python alloc + Python poll
incl. must-fix
98d9bb3(C++ alloc, no resize poll)
(C++ alloc + restored resize poll)