|
18 | 18 | from vllm.lora.request import LoRARequest |
19 | 19 | from vllm.tasks import SupportedTask |
20 | 20 | from vllm.v1 import utils as vllm_utils |
21 | | -from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size |
| 21 | +from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks, |
| 22 | + get_uniform_page_size) |
22 | 23 | from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput |
23 | 24 | from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec |
24 | 25 | from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput |
@@ -382,33 +383,37 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: |
382 | 383 | # responsible for this translation. When vLLM can be modified, this |
383 | 384 | # method should be changed to return `dict[str, AbstractKVCacheSpec]`, |
384 | 385 | # and the vLLM side should be updated to handle the translation. |
385 | | - kv_cache_specs = self.model_runner.get_kv_cache_spec() |
| 386 | + kv_cache_spec = self.model_runner.get_kv_cache_spec() |
386 | 387 |
|
387 | | - if len(kv_cache_specs) == 0: |
388 | | - return kv_cache_specs |
| 388 | + if len(kv_cache_spec) == 0: |
| 389 | + return kv_cache_spec |
389 | 390 |
|
390 | 391 | # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce |
391 | 392 | # feature that allows overriding page_size_bytes of KVCacheSpec. |
392 | 393 | vllm_page_size_bytes = get_uniform_page_size( |
393 | | - list(kv_cache_specs.values())) |
| 394 | + list(kv_cache_spec.values())) |
394 | 395 | attention_page_size_bytes = get_attention_page_size_bytes( |
395 | | - self.model_runner.mesh, kv_cache_specs) |
| 396 | + self.model_runner.mesh, kv_cache_spec) |
396 | 397 |
|
397 | 398 | if vllm_page_size_bytes != attention_page_size_bytes: |
398 | 399 | logger.info( |
399 | | - f"KV cache page size calculated by vLLM " |
400 | | - f"({vllm_page_size_bytes} Bytes) does not match with actual " |
401 | | - f"page size used by Attention kernel ({attention_page_size_bytes} Bytes). " |
402 | | - f"Recalculating number of KV blocks using actual page size.") |
403 | | - |
| 400 | + f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) " |
| 401 | + f"does not match with actual page size used by the kernel " |
| 402 | + f"({attention_page_size_bytes} Bytes). Recalculating number of " |
| 403 | + f"KV blocks using actual page size.") |
| 404 | + |
| 405 | + kv_cache_groups = get_kv_cache_groups(self.vllm_config, |
| 406 | + kv_cache_spec) |
| 407 | + group_size = max( |
| 408 | + len(group.layer_names) for group in kv_cache_groups) |
404 | 409 | available_memory = self.determine_available_memory() |
405 | | - num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs), |
| 410 | + num_blocks = get_num_blocks(self.vllm_config, group_size, |
406 | 411 | available_memory, |
407 | 412 | attention_page_size_bytes) |
408 | 413 | cache_config = self.vllm_config.cache_config |
409 | 414 | cache_config.num_gpu_blocks_override = num_blocks |
410 | 415 |
|
411 | | - return kv_cache_specs |
| 416 | + return kv_cache_spec |
412 | 417 |
|
413 | 418 | def initialize_from_config( |
414 | 419 | self, |
|
0 commit comments