diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 092dc26f0..25f48ce7c 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -18,7 +18,8 @@ from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask from vllm.v1 import utils as vllm_utils -from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size +from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks, + get_uniform_page_size) from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -382,33 +383,37 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # responsible for this translation. When vLLM can be modified, this # method should be changed to return `dict[str, AbstractKVCacheSpec]`, # and the vLLM side should be updated to handle the translation. - kv_cache_specs = self.model_runner.get_kv_cache_spec() + kv_cache_spec = self.model_runner.get_kv_cache_spec() - if len(kv_cache_specs) == 0: - return kv_cache_specs + if len(kv_cache_spec) == 0: + return kv_cache_spec # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce # feature that allows overriding page_size_bytes of KVCacheSpec. vllm_page_size_bytes = get_uniform_page_size( - list(kv_cache_specs.values())) + list(kv_cache_spec.values())) attention_page_size_bytes = get_attention_page_size_bytes( - self.model_runner.mesh, kv_cache_specs) + self.model_runner.mesh, kv_cache_spec) if vllm_page_size_bytes != attention_page_size_bytes: logger.info( - f"KV cache page size calculated by vLLM " - f"({vllm_page_size_bytes} Bytes) does not match with actual " - f"page size used by Attention kernel ({attention_page_size_bytes} Bytes). " - f"Recalculating number of KV blocks using actual page size.") - + f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) " + f"does not match with actual page size used by the kernel " + f"({attention_page_size_bytes} Bytes). Recalculating number of " + f"KV blocks using actual page size.") + + kv_cache_groups = get_kv_cache_groups(self.vllm_config, + kv_cache_spec) + group_size = max( + len(group.layer_names) for group in kv_cache_groups) available_memory = self.determine_available_memory() - num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs), + num_blocks = get_num_blocks(self.vllm_config, group_size, available_memory, attention_page_size_bytes) cache_config = self.vllm_config.cache_config cache_config.num_gpu_blocks_override = num_blocks - return kv_cache_specs + return kv_cache_spec def initialize_from_config( self,