Skip to content

Proposal: Integrate kvcached with Ollama #326

@ztang2370

Description

@ztang2370

Background

kvcached provides elastic GPU KV cache management by decoupling virtual address reservation from physical memory commitment, using CUDA's Virtual Memory Management API (cuMemAddressReserve, cuMemCreate, cuMemMap). Today it integrates with vLLM and SGLang, where it has demonstrated significant VRAM efficiency gains for multi-LLM serving on shared GPUs.

Ollama currently uses static, worst-case KV cache reservation. When a model loads, the runner pre-allocates KV memory sized for n_layers × n_kv_heads × head_dim × NumCtx × NumParallel regardless of actual usage. This memory is held via cudaMalloc and is invisible to other processes on the GPU even when the model is idle.

This proposal integrates kvcached with Ollama's new ollamarunner (the native ggml engine) so KV cache is allocated lazily via CUDA VMM. Idle models hold less physical VRAM; multiple Ollama instances on the same GPU can coordinate through kvcached's existing cooperative-shrink protocol — the same mechanism used in production with vLLM/SGLang today.

Motivation

Two concrete user-visible improvements:

1. Idle models free up VRAM for new loads.
Today: GPU has 24 GB, instance A loads Llama-3-8B (reserves ~14 GB statically). User starts instance B for Qwen-2-7B; B sees only 10 GB free via cudaMemGetInfo, refuses to load.

With kvcached: instance A reserves 14 GB of virtual address space but only commits ~6 GB of physical memory while idle. Instance B sees ~18 GB free, loads successfully.

2. Higher NumParallel without OOM.
Today: doubling NumParallel doubles upfront KV reservation, often pushing past available VRAM. With kvcached: KV is committed only as concurrent requests actually fill it; peak commitment scales with workload, not configuration.

How Ollama allocates today

Relevant code paths for context:

The integration point is one level below ml.Backend/ml.Context — at ggml's ggml_backend_buffer_type_i interface (ggml-backend-impl.h:17-33). This is a clean BYO-allocator interface (~7 function pointers) with existing precedent (ggml_backend_cuda_buffer_type in ggml-cuda.cu:821-838).

Architecture

┌──────────────────────────────────────────────────────────────┐
│  Ollama runner (Go)                                          │
│  - kvcache/causal.go: pre-Compute commit hooks               │
│  - ml/backend/ggml/ggml.go: route KV layers to kvcached buft │
└──────────────────────────────────────────────────────────────┘
                              │ CGo
                              ▼
┌──────────────────────────────────────────────────────────────┐
│  ggml backend (C/CUDA, in Ollama's vendored ggml fork)       │
│  - new file: ggml-kvcached.cu                                │
│    implements ggml_backend_buffer_type_i + buffer_i          │
└──────────────────────────────────────────────────────────────┘
                              │ C ABI
                              ▼
┌──────────────────────────────────────────────────────────────┐
│  kvcached C ABI (new in kvcached repo)                       │
│  - csrc/capi.{h,cpp}                                         │
│  - libkvcached_capi.so (no Python/torch dependency)          │
│  - wraps existing FTensorAllocator/FTensor/Page              │
└──────────────────────────────────────────────────────────────┘
                              │
                              ▼
                    CUDA VMM driver API
            (cuMemAddressReserve, cuMemMap, cuMemUnmap)


           Cross-instance coordination (Phase 2):

  ┌────────────────┐                        ┌────────────────┐
  │  Ollama-1      │ ◄── shm (advisory) ──► │  Ollama-2      │
  │  kvcached embed│                        │  kvcached embed│
  └────────────────┘                        └────────────────┘
                              ▲
                              │ writes limits
                              │
                       kvctl (existing CLI)

Cross-instance coordination uses kvcached's existing protocol: each instance creates a POSIX shared-memory segment holding a MemInfoStruct. An external operator (or policy script) writes new limits via kvctl; instances poll their own shm and call resize() to unmap pages, returning physical memory to the CUDA driver's free pool. This is the same mechanism used in production with vLLM/SGLang today — no daemon required.

Relevant kvcached code:

Phases

Phase 0 — De-risking spikes (~2 days, ~340 LOC throwaway)

# Spike Pass criteria
0.1 cuMemAddressReserve + cuMemMap + trivial CUDA kernel write — verify on target driver kernel runs, no CUDA_ERROR_INVALID_ADDRESS
0.2 Same flow callable from Go via CGo go test passes
0.3 Build kvcached's csrc/ standalone, without pybind11/torch libkvcached_core.so builds

0.4 — ggml third-party buft smoke test (~1 day, ~150 LOC throwaway)

Build a stand-in buffer type that is functionally identical to ggml's existing CUDA backend, register it as a separate buft, and verify Ollama produces bit-exact output when the KV cache layer is routed through it.

What it does:

  1. Copy the buffer-type and buffer interface code from ggml-cuda.cu:821-838 into a new file ggml-myalloc.cu. Rename all symbols (ggml_backend_myalloc_*).
  2. Internally still call cudaMalloc — same behavior as the standard CUDA buft. The point is not to do anything different; the point is to register a separately named buft and verify ggml routes through it correctly.
  3. Add a myallocBufferType in Ollama's ml/backend/ggml/ggml.go, route only the KV cache layer to it under an env flag (OLLAMA_GGML_MYALLOC=1).
  4. Run inference twice with the same prompt, seed, and low-temperature decoding:
    • Run A: flag off → tokens via standard CUDA buft
    • Run B: flag on → tokens via the new buft
  5. Compare token sequences and nvidia-smi snapshots.

Pass criteria: runs A and B produce bit-exact identical token output. VRAM behavior matches.

What this proves:

  • ggml's per-layer buft routing actually works for third-party bufts (not just internal ones)
  • The graph scheduler correctly handles tensors split across two different bufts (KV in our buft, weights/activations in the standard CUDA buft)
  • set_tensor/get_tensor/init_tensor callbacks are wired correctly
  • Alignment and size-padding behave correctly

What this doesn't prove (intentionally):

  • That kvcached's VMM-backed memory works as a tensor backing store (deferred to Phase 1.2)
  • That commit-before-kernel works — no demand-paging is tested here, since cudaMalloc'd memory is always physically present
  • That uncommit (page release) works

This is a deliberate decoupling: the spike tests the integration interface, not the integration content. If 0.4 passes, the only remaining risk in Phase 1.2 is "does VMM-backed memory specifically work" — a much narrower question than "does any of this work." If 0.4 fails, kvcached is irrelevant to the failure; ggml's buft interface itself has the problem, and we discover that on day 2 instead of week 2.

Decision gate. Any spike fails → reassess or stop.


Phase 1 — Single-instance integration (~1 week, ~1,400 LOC)

Goal: one Ollama runner uses kvcached for KV cache. Bit-equivalent inference output, measurable VRAM reduction with NumParallel > 1. Independently shippable.

1.1 — kvcached C ABI wrapper (~1-2 days, ~400 LOC)

New files in kvcached/: csrc/capi.{h,cpp}, build target for libkvcached_capi.so.

// csrc/inc/kvcached_capi.h (excerpt)
typedef struct kvc_pool   kvc_pool_t;
typedef struct kvc_buffer kvc_buffer_t;

int   kvc_pool_create(size_t va_bytes, int device, kvc_pool_t** out);
int   kvc_pool_destroy(kvc_pool_t* pool);
int   kvc_pool_resize(kvc_pool_t* pool, size_t new_limit_bytes);

int   kvc_buffer_alloc(kvc_pool_t* pool, size_t size, kvc_buffer_t** out);
void* kvc_buffer_get_base(kvc_buffer_t* buf);
int   kvc_buffer_commit(kvc_buffer_t* buf, size_t offset, size_t size);
int   kvc_buffer_uncommit(kvc_buffer_t* buf, size_t offset, size_t size);
int   kvc_buffer_free(kvc_buffer_t* buf);

const char* kvc_last_error(void);  // thread-local

Wraps existing FTensorAllocator/FTensor/Page C++ classes in extern "C" opaque-handle functions. No CUDA logic changes. Errors via int return + thread-local string.

The existing pybind11 module (vmm_ops.cpython-*.so) keeps working for vLLM/SGLang; the new C-ABI .so is additive.

Acceptance: unit tests linking from C and from Go (CGo) cover create/alloc/commit/uncommit/destroy.

1.2 — ggml buffer type (~2 days, ~500 LOC)

New file in Ollama's vendored ggml: ml/backend/ggml/ggml/src/ggml-kvcached.cu. Implements ggml_backend_buffer_type_i and ggml_backend_buffer_i over the C ABI. Pattern after ggml_backend_cuda_buffer_type_alloc_buffer.

Key choices:

  • get_alignment → 2 MiB (kvcached's CUDA VMM page granularity)
  • noalloc_buffer → "VA-only reserve," matches kvcached's lazy-commit semantics
  • set_tensor/get_tensor/memset_tensor/clear auto-commit pages they touch (host↔device transfers, called rarely)
  • cpy_tensor returns false → ggml falls back to its own copy path

Acceptance: ggml unit test allocates a tensor via the kvcached buft, fills via set_tensor, runs mulmat, reads back via get_tensor. Output matches the same graph on the CUDA buft.

1.3 — Go-side wiring (~1 day, ~250 LOC)

Modify ml/backend/ggml/ggml.go (per-layer buft assignment, lines 158-230) and add ml/backend/ggml/kvcached.go. Add a kvcachedBufferType parallel to gpuDeviceBufferTypes. Routing rule:

if envconfig.KVCachedEnabled() && layer.kind == "cache" {
    return kvcachedBufferType
}
return gpuDeviceBufferTypes[i]

Today's per-layer buft is selected per-layer but not per-tensor-role. Need to split the Layer(i) context so cache and activations can use different bufts (ggml.go:787-800). Weights and activations stay on the standard CUDA buft; only KV cache layers route to kvcached.

Acceptance: with OLLAMA_KVCACHED=1, model loads, runs inference, output matches without the flag. nvidia-smi shows expected VRAM behavior.

1.4 — Cache lifecycle hooks (~1 day, ~150 LOC)

Modify kvcache/causal.go, kvcache/recurrent.go, runner/ollamarunner/cache.go. Insert commit/uncommit calls — analogous to vLLM's existing block-free hooks in kvcached/integration/vllm/patches.py:511-522:

Event Operation Location
Pre-Compute in Put commit pages for curLoc..curLoc+batchSize causal.go:478
Pre-Compute in Get/SDPA commit pages for [0, cachedSize) causal.go:430-446
Remove(seq, begin, end) uncommit pages fully covered by removed range causal.go Remove impl
Close() free buffers existing path

Helper to translate (layer, token_range)(buffer, byte_range).

Acceptance: end-to-end multi-turn conversation. Bit-equivalent generation. With NumParallel=4 and 4K context, peak resident VRAM measurably lower than baseline.

1.5 — Bench + smoke tests (~0.5 day, ~100 LOC tests)

  • Long-context single-turn (8K, 32K)
  • Multi-turn conversation (10 turns, growing context)
  • Concurrent requests with NumParallel=8
  • Compare wall-clock + peak VRAM against baseline ggml-cuda

Phase 1 LOC: ~1,400 implementation + ~100 tests.


Phase 2 — Cross-instance coordination via shm (~3-4 days, ~600 LOC)

Reuse kvcached's existing shm + cooperative shrink protocol — same mechanism deployed with vLLM/SGLang today. No daemon to build.

2.1 — MemInfoTracker analog in C ABI / runner (~2 days, ~350 LOC)

  • C ABI extension: add kvc_pool_attach_shm(pool, ipc_name) to wire kvcached's existing MemInfoTracker into the C-callable surface. The polling loop and resize() logic exist in Python today; port the minimum needed to C++ so it works without a Python interpreter (~250 LOC).
  • Go-side init: each Ollama runner creates a shm with a deterministic name (/dev/shm/kvcached_<pid>_<gpu>), passes it to kvc_pool_attach_shm at startup. Existing kvctl and kvtop work against this format unmodified (~100 LOC).

Acceptance: running kvctl limit <runner-shm> 2G on a live Ollama runner causes it to unmap pages and shrink to ~2 GB. Verified via nvidia-smi.

2.2 — Two-instance demo + bench (~1-2 days, ~250 LOC tooling/tests)

  • Two Ollamas on one GPU: Llama-3-8B + Qwen-2-7B
  • Workload: alternating bursts
  • Demo: when load shifts, run kvctl (or a small auto-shrink script) to rebalance — same operator pattern as multi-vLLM deployments
  • Measure peak concurrent active requests, p50/p99 latency, OOM count vs. static partitioning

Acceptance: measurable improvement on at least one of {throughput, peak concurrency, OOM count} versus today's two-Ollama static deployment.

Phase 2 LOC: ~350 implementation + ~250 tools/tests.


Phase 3 — Hardening (~2-3 days, ~400 LOC) [stretch]

Step Work LOC
3.1 Multi-GPU support — per-device pool, per-device shm ~150
3.2 Recurrent + SWA cache types — verify commit hooks; add for non-causal patterns ~150
3.3 Docs, env vars, ops runbook ~100

Total

Phase Calendar New code Tests/tools
0 — Spikes ~2 days ~340 throwaway
1 — Single-instance ~1 week ~1,400 ~100
2 — Cross-instance shm ~3-4 days ~350 ~250
3 — Hardening (stretch) ~2-3 days ~400
Total ~2-2.5 weeks ~2,150 ~350

Single engineer familiar with kvcached internals.

Decision gates

  • End of Phase 0: all spikes pass → continue. Otherwise → reassess or stop.
  • End of Phase 1: measurable VRAM reduction with bit-equivalent output → continue, or ship Phase 1 alone as a standalone single-instance feature.
  • End of Phase 2: measurable multi-instance improvement on the two-Ollama bench → continue to hardening.

Scope

In scope (matches kvcached's current vLLM/SGLang support):

  • Linux + NVIDIA CUDA (compute capability ≥ 7.0 for VMM API)
  • New ollamarunner (native ggml engine)
  • Causal, Recurrent, Wrapper cache types
  • Multi-GPU within a single Ollama instance

Out of scope:

  • macOS / Metal and AMD / ROCm. kvcached itself only supports NVIDIA CUDA today. The flag will be a no-op on non-CUDA builds; users on those platforms see no change. Matches kvcached's current platform support.
  • Old llamarunner (llama.cpp-based). Too entangled with upstream llama.cpp; would roughly double the C-side work.
  • Radix-style cross-conversation prefix sharing. Ollama's slot-based prefix cache is coarser than vLLM/SGLang's; closing that gap is structural to Ollama's slot model and out of scope here.

Risks

  1. ggml's tensor allocator may not cooperate with paged VA-backed buffers. Front-loaded as Spike 0.4. If it fails, the project scope expands materially (patching ggml internals vs. vendoring a single new file).

  2. 2 MiB page alignment vs. small KV tensors. kvcached pages at 2 MiB granularity; small allocations waste tail space. Mitigated by routing only KV-cache layers (which are large) to the kvcached buft; weights and activations stay on the standard CUDA buft. Same constraint kvcached operates under in vLLM/SGLang today.


Drafted for collaborator review. Phase 0 spikes are the first concrete commitment ask; everything after is contingent on Phase 0 results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions