Skip to content
12 changes: 11 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Any, Dict, Optional, Tuple

import torch
Expand Down Expand Up @@ -247,7 +248,16 @@ def calibrate_kv_cache_input_hook(
kv_cache to singleton QuantizedKVParameterCache.
"""
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_values"] = kv_cache
if not hasattr(module, "_past_kv_name"):
# Determine which past KV parameter name to use once and cache it
# TODO: Find a better place to cache this
module._past_kv_name = (
"past_key_value" # transformers#39956
if "past_key_value" in inspect.signature(module.forward).parameters
else "past_key_values"
)

kwargs[module._past_kv_name] = kv_cache
kwargs["use_cache"] = False
return args, kwargs

Expand Down