@@ -194,12 +194,33 @@ def get_kv_cache_torch_dtype(
194194 return torch_dtype
195195
196196
197+ def get_kv_cache_quant_algo_dtype (quant_cfg : dict [str , Any ]) -> torch .dtype | None :
198+ quant_method = quant_cfg .get ("quant_method" , "" )
199+ if quant_method .startswith ("modelopt" ):
200+ quantization_inner = quant_cfg .get ("quantization" , quant_cfg )
201+ # Check if quant config is specified and use kv cache quant algo
202+ kv_algo = quantization_inner .get ("kv_cache_quant_algo" ) or quant_cfg .get (
203+ "kv_cache_quant_algo"
204+ )
205+ if isinstance (kv_algo , str ):
206+ return STR_DTYPE_TO_TORCH_DTYPE [kv_algo .lower ()]
207+ return None
208+
209+
197210def kv_cache_dtype_str_to_dtype (
198211 kv_cache_dtype : str , model_config : ModelConfig
199212) -> torch .dtype :
213+ # Model config may not be specified for unit tests, default to float16
214+ dtype = model_config .dtype if model_config else torch .half
200215 if kv_cache_dtype == "auto" :
201- # Model config may not be specified for unit tests, default to float16
202- return model_config .dtype if model_config else torch .half
216+ hf_cfg = getattr (model_config , "hf_config" , None )
217+ if hf_cfg is not None :
218+ quant_cfg = getattr (hf_cfg , "quantization_config" , None )
219+ if quant_cfg is not None :
220+ kv_algo_dtype = get_kv_cache_quant_algo_dtype (quant_cfg )
221+ return kv_algo_dtype if kv_algo_dtype is not None else dtype
222+ return dtype
223+
203224 return STR_DTYPE_TO_TORCH_DTYPE [kv_cache_dtype ]
204225
205226
0 commit comments