Skip to content

Commit 6ec0d8d

Browse files
authored
[Fix]Load kv-cache dtype from hf_quant_config.json automatically (vllm-project#29980)
Signed-off-by: Daniel Afrimi <[email protected]>
1 parent 9693dd0 commit 6ec0d8d

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

vllm/utils/torch_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,33 @@ def get_kv_cache_torch_dtype(
194194
return torch_dtype
195195

196196

197+
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
198+
quant_method = quant_cfg.get("quant_method", "")
199+
if quant_method.startswith("modelopt"):
200+
quantization_inner = quant_cfg.get("quantization", quant_cfg)
201+
# Check if quant config is specified and use kv cache quant algo
202+
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
203+
"kv_cache_quant_algo"
204+
)
205+
if isinstance(kv_algo, str):
206+
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo.lower()]
207+
return None
208+
209+
197210
def kv_cache_dtype_str_to_dtype(
198211
kv_cache_dtype: str, model_config: ModelConfig
199212
) -> torch.dtype:
213+
# Model config may not be specified for unit tests, default to float16
214+
dtype = model_config.dtype if model_config else torch.half
200215
if kv_cache_dtype == "auto":
201-
# Model config may not be specified for unit tests, default to float16
202-
return model_config.dtype if model_config else torch.half
216+
hf_cfg = getattr(model_config, "hf_config", None)
217+
if hf_cfg is not None:
218+
quant_cfg = getattr(hf_cfg, "quantization_config", None)
219+
if quant_cfg is not None:
220+
kv_algo_dtype = get_kv_cache_quant_algo_dtype(quant_cfg)
221+
return kv_algo_dtype if kv_algo_dtype is not None else dtype
222+
return dtype
223+
203224
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
204225

205226

0 commit comments

Comments
 (0)