Skip to content

Commit b690326

Browse files
committed
Support overriding logic for hybrid kv cache padding
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent 0ab84a4 commit b690326

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

tpu_inference/worker/tpu_worker.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm.lora.request import LoRARequest
1919
from vllm.tasks import SupportedTask
2020
from vllm.v1 import utils as vllm_utils
21-
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22+
get_uniform_page_size)
2223
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
2324
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2425
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@@ -382,33 +383,37 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
382383
# responsible for this translation. When vLLM can be modified, this
383384
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
384385
# and the vLLM side should be updated to handle the translation.
385-
kv_cache_specs = self.model_runner.get_kv_cache_spec()
386+
kv_cache_spec = self.model_runner.get_kv_cache_spec()
386387

387-
if len(kv_cache_specs) == 0:
388-
return kv_cache_specs
388+
if len(kv_cache_spec) == 0:
389+
return kv_cache_spec
389390

390391
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
391392
# feature that allows overriding page_size_bytes of KVCacheSpec.
392393
vllm_page_size_bytes = get_uniform_page_size(
393-
list(kv_cache_specs.values()))
394+
list(kv_cache_spec.values()))
394395
attention_page_size_bytes = get_attention_page_size_bytes(
395-
self.model_runner.mesh, kv_cache_specs)
396+
self.model_runner.mesh, kv_cache_spec)
396397

397398
if vllm_page_size_bytes != attention_page_size_bytes:
398399
logger.info(
399-
f"KV cache page size calculated by vLLM "
400-
f"({vllm_page_size_bytes} Bytes) does not match with actual "
401-
f"page size used by Attention kernel ({attention_page_size_bytes} Bytes). "
402-
f"Recalculating number of KV blocks using actual page size.")
403-
400+
f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
401+
f"does not match with actual page size used by the kernel "
402+
f"({attention_page_size_bytes} Bytes). Recalculating number of "
403+
f"KV blocks using actual page size.")
404+
405+
kv_cache_groups = get_kv_cache_groups(self.vllm_config,
406+
kv_cache_spec)
407+
group_size = max(
408+
len(group.layer_names) for group in kv_cache_groups)
404409
available_memory = self.determine_available_memory()
405-
num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
410+
num_blocks = get_num_blocks(self.vllm_config, group_size,
406411
available_memory,
407412
attention_page_size_bytes)
408413
cache_config = self.vllm_config.cache_config
409414
cache_config.num_gpu_blocks_override = num_blocks
410415

411-
return kv_cache_specs
416+
return kv_cache_spec
412417

413418
def initialize_from_config(
414419
self,

0 commit comments

Comments
 (0)