Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.v1 import utils as vllm_utils
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
get_uniform_page_size)
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
Expand Down Expand Up @@ -382,33 +383,37 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
# responsible for this translation. When vLLM can be modified, this
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
# and the vLLM side should be updated to handle the translation.
kv_cache_specs = self.model_runner.get_kv_cache_spec()
kv_cache_spec = self.model_runner.get_kv_cache_spec()

if len(kv_cache_specs) == 0:
return kv_cache_specs
if len(kv_cache_spec) == 0:
return kv_cache_spec

# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
# feature that allows overriding page_size_bytes of KVCacheSpec.
vllm_page_size_bytes = get_uniform_page_size(
list(kv_cache_specs.values()))
list(kv_cache_spec.values()))
attention_page_size_bytes = get_attention_page_size_bytes(
self.model_runner.mesh, kv_cache_specs)
self.model_runner.mesh, kv_cache_spec)

if vllm_page_size_bytes != attention_page_size_bytes:
logger.info(
f"KV cache page size calculated by vLLM "
f"({vllm_page_size_bytes} Bytes) does not match with actual "
f"page size used by Attention kernel ({attention_page_size_bytes} Bytes). "
f"Recalculating number of KV blocks using actual page size.")

f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
f"does not match with actual page size used by the kernel "
f"({attention_page_size_bytes} Bytes). Recalculating number of "
f"KV blocks using actual page size.")

kv_cache_groups = get_kv_cache_groups(self.vllm_config,
kv_cache_spec)
group_size = max(
len(group.layer_names) for group in kv_cache_groups)
available_memory = self.determine_available_memory()
num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
num_blocks = get_num_blocks(self.vllm_config, group_size,
available_memory,
attention_page_size_bytes)
cache_config = self.vllm_config.cache_config
cache_config.num_gpu_blocks_override = num_blocks

return kv_cache_specs
return kv_cache_spec

def initialize_from_config(
self,
Expand Down