Skip to content

SWAKVPool ignores enable_memory_saver, preventing KV cache release for SWA models #860

@euclidgame

Description

@euclidgame

Problem

When using colocated training (SGLang + Megatron on the same GPUs), SGLang's /release_memory_occupation endpoint is called before training to free KV cache memory. For models with sliding window attention (SWA), this silently does nothing, leaving the full KV cache resident on GPUs. This causes an OOM crash during Megatron training model initialization. This is with the SGLang in the latest image.

Root Cause

Two issues in the SWA KV cache path:

  1. swa_memory_pool.py:51 — SWAKVPool.init unconditionally overwrites enable_memory_saver=False in kwargs, regardless of what the caller passes:

Before (bug)

  kwargs["enable_memory_saver"] = False

After (fix)

  kwargs.setdefault("enable_memory_saver", False)
  1. model_runner_kv_cache_mixin.py:555 — The SWAKVPool(...) construction site never forwards server_args.enable_memory_saver, so even with the above fix, the flag would remain at its default of False for SWA models:
    Add this kwarg to the SWAKVPool(...) call: enable_memory_saver=self.server_args.enable_memory_saver,

Impact

  • Only affects models with SWA (e.g. gpt-oss-120b, which alternates SWA and full attention layers). Models without SWA (e.g. Qwen3-30B-A3B) are unaffected.
  • Colocated setups (training + rollout on same GPUs) fail at startup with OOM when enable_memory_saver=True.
  • Non-colocated setups are unaffected since they never call release_memory_occupation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions