Skip to content

fix(engine): implement late merge to deduplicate blocks for same-step prefix sharing#243

Open
fishAndShrimp wants to merge 1 commit into
GeeeekExplorer:mainfrom
fishAndShrimp:fix/same-step-cache-dup
Open

fix(engine): implement late merge to deduplicate blocks for same-step prefix sharing#243
fishAndShrimp wants to merge 1 commit into
GeeeekExplorer:mainfrom
fishAndShrimp:fix/same-step-cache-dup

Conversation

@fishAndShrimp

Copy link
Copy Markdown

🔗 Related Issue

Closes #219

🧭 Context & Problem

As discussed in #219, when multiple identical prompts are scheduled in the same prefill step, they experience a cache miss and are allocated separate physical blocks. During postprocess(), they overwrite each other's hash mapping in hash_to_block_id, leading to redundant VRAM usage.

Furthermore, due to the engine's lazy deletion mechanism, merely finding a hash in hash_to_block_id does not guarantee a valid cache hit. A block might have been freed (ref_count == 0) and returned to the free queue, but its hash mapping is left dangling. Blindly reusing a dangling hash without checking its active status would corrupt the free block queue.


🛠️ The Fix (Late Merge)

This PR implements the "Late Merge" approach proposed in the issue. It deduplicates blocks during hash_blocks by strictly validating cache hits against used_block_ids and token contents.

Specifically, the logic now:

  1. Validates the Hit: Checks if the computed hash exists in hash_to_block_id AND verifies the corresponding block is actively alive (exists in used_block_ids).
  2. Executes Late Merge: If a valid hit is found and the tokens exactly match (existing_block.token_ids == token_ids), it deallocates the redundant block, updates the sequence's block table to the shared block, and increments the ref_count.
  3. Graceful Fallback on Hash Collisions: If a hash collision occurs (hash matches but token_ids differ), the logic naturally falls through to the original behavior. It simply overwrites the hash_to_block_id mapping with the new block. The previous block remains perfectly safe and active in its sequence; it merely loses its global mapping (graceful degradation).
  4. Safely Handles Dangling Mappings: Similarly, dangling hashes (freed blocks) are treated as misses and safely overwritten in the mapping without mutating the free block queue.

🧪 Testing & Verification

I have verified this fix using an end-to-end Minimal Reproducible Example (MRE) that sends concurrent identical requests and monitors the BlockManager state step-by-step.

Here is a quick summary of the engine state during the Prefill Phase (Step 1) before and after the fix:

Metric Before Fix After Fix
Total Used Blocks 8 5
Unique Valid Hashes 3 3
Result ❌ Redundant blocks allocated ✅ Deduplicated (ref_count: 2)

MRE Code & Execution Logs

Below are the exact script and logs used to verify the correct behavior of the Late Merge deduplication and reference counting.

💻 Click to expand: MRE Script
from pathlib import Path
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer
import collections

MODEL_PATH = str(
    # Matches the default download path in Nano-vLLM's README
    Path("~/huggingface/Qwen3-0.6B/").expanduser().resolve()
)

def main():
    print("🚀 Initializing engine...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    # Enforce eager mode for deterministic debugging
    llm = LLM(MODEL_PATH, enforce_eager=True, tensor_parallel_size=1)

    # 1. Construct a prompt long enough to occupy multiple blocks
    dummy_text = "This is a concurrent cache test prompt. " * 100
    token_ids = tokenizer.encode(dummy_text)

    # Set max_tokens to 3 to verify E2E decode stability without long logs
    sampling_params = SamplingParams(temperature=1.0, max_tokens=3)

    # 2. Trigger: Send two identical requests concurrently
    print("\n📦 Sending 2 identical concurrent requests...")
    llm.add_request(token_ids, sampling_params)
    llm.add_request(token_ids, sampling_params)

    # 3. Execute steps and monitor BlockManager state
    print("\n🏃‍♂️ Executing steps and monitoring KV Cache...")
    step_count = 1
    while not llm.is_finished():
        llm.step()
        
        bm = llm.scheduler.block_manager
        used_blocks = [bm.blocks[x] for x in bm.used_block_ids]
        
        # Filter out unhashed blocks (assuming None or -1)
        valid_hashes = [b.hash for b in used_blocks if b.hash not in (None, -1)]
        counter = collections.Counter(valid_hashes)
        duplicates = {h: count for h, count in counter.items() if count > 1}
        
        phase = "Prefill" if step_count == 1 else "Decode"
        print(f"\n[Step {step_count}] Phase: {phase}")
        print(f"   Total used blocks: {len(used_blocks)}")
        print(f"   Unique valid hashes: {len(counter)}")
        
        if duplicates:
            print(f"   ❌ BUG DETECTED: {len(duplicates)} hashes have redundant physical blocks!")
            # Print just the first duplicate as an example
            sample_dup_h = list(duplicates.keys())[0]
            print(f"      e.g., Hash {str(sample_dup_h)[:10]}... occupies {duplicates[sample_dup_h]} blocks")
        else:
            # Clear, bulleted explanation for the PASS state
            print("   ✅ PASS: No redundant blocks detected. This is due to one of two reasons:")
            print("      1. Late Merge successfully deduplicated the blocks.")
            print("      2. The generation has finished, so all blocks have been freed.")
            
            if valid_hashes:
                sample_h = valid_hashes[0]
                block_id = bm.hash_to_block_id.get(sample_h)
                if block_id is not None:
                    ref_count = bm.blocks[block_id].ref_count
                    print(f"   🔍 Sample Hash {str(sample_h)[:10]}... -> ref_count: {ref_count}")

        step_count += 1

    print(f"\n🎉 Generation completed successfully in {step_count - 1} steps without crashing!")

if __name__ == "__main__":
    main()
❌ Click to expand: Log BEFORE Fix (Bug Present)
🚀 Initializing engine...

📦 Sending 2 identical concurrent requests...

🏃‍♂️ Executing steps and monitoring KV Cache...

[Step 1] Phase: Prefill
   Total used blocks: 8
   Unique valid hashes: 3
   ❌ BUG DETECTED: 3 hashes have redundant physical blocks!
      e.g., Hash 1191164273... occupies 2 blocks

[Step 2] Phase: Decode
   Total used blocks: 8
   Unique valid hashes: 3
   ❌ BUG DETECTED: 3 hashes have redundant physical blocks!
      e.g., Hash 1191164273... occupies 2 blocks

[Step 3] Phase: Decode
   Total used blocks: 0
   Unique valid hashes: 0
   ✅ PASS: No redundant blocks detected. This is due to one of two reasons:
      1. Late Merge successfully deduplicated the blocks.
      2. The generation has finished, so all blocks have been freed.

🎉 Generation completed successfully in 3 steps without crashing!

✅ Click to expand: Log AFTER Fix (Bug Resolved)
🚀 Initializing engine...

📦 Sending 2 identical concurrent requests...

🏃‍♂️ Executing steps and monitoring KV Cache...

[Step 1] Phase: Prefill
   Total used blocks: 5
   Unique valid hashes: 3
   ✅ PASS: No redundant blocks detected. This is due to one of two reasons:
      1. Late Merge successfully deduplicated the blocks.
      2. The generation has finished, so all blocks have been freed.
   🔍 Sample Hash 1191164273... -> ref_count: 2

[Step 2] Phase: Decode
   Total used blocks: 5
   Unique valid hashes: 3
   ✅ PASS: No redundant blocks detected. This is due to one of two reasons:
      1. Late Merge successfully deduplicated the blocks.
      2. The generation has finished, so all blocks have been freed.
   🔍 Sample Hash 1191164273... -> ref_count: 2

[Step 3] Phase: Decode
   Total used blocks: 0
   Unique valid hashes: 0
   ✅ PASS: No redundant blocks detected. This is due to one of two reasons:
      1. Late Merge successfully deduplicated the blocks.
      2. The generation has finished, so all blocks have been freed.

🎉 Generation completed successfully in 3 steps without crashing!

Issue: When multiple identical prompts are scheduled in the same prefill step, they overwrite each other's hash mapping, leading to duplicate physical blocks and memory waste.

Furthermore, due to the engine's lazy deletion mechanism, merely finding a hash in `hash_to_block_id` does not guarantee a valid cache hit. A block might have been freed (ref_count == 0) and returned to the free queue, but its hash mapping is left dangling for potential future reuse. Blindly reusing a dangling hash without checking its active status would corrupt the free block queue.

Fix: Deduplicate blocks during `hash_blocks` by strictly validating cache hits against `used_block_ids`. This safely handles dangling hash mappings without mutating the free block queue.

Specifically, this commit introduces robust cache hit handling:
- Checks if the computed hash exists in `hash_to_block_id` AND verifies the corresponding block is actively alive (exists in `used_block_ids`).
- If it is a valid hit, safely deallocates the redundant block newly assigned to the sequence.
- Updates the sequence's block table to point to the existing shared block.
- Increments the reference count of the shared block.
- If the hash exists but the block is not in `used_block_ids` (a dangling mapping), it treats it as a miss and safely updates the mapping to the newly allocated block.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Discussion] Duplicate blocks for same-step prefix sharing

1 participant