Skip to content

Commit 97a0414

Browse files
shanjiazdsikkagemini-code-assist[bot]
committed
make transformer fix backward compatible (#1794)
SUMMARY: Use similar logic in the code base to determine which kv parameter name to use. TEST PLAN: Tested locally, kv cache tests pass : ) --------- Signed-off-by: shanjiaz <[email protected]> Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 505ae88 commit 97a0414

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import Any, Dict, Optional, Tuple
23

34
import torch
@@ -247,7 +248,16 @@ def calibrate_kv_cache_input_hook(
247248
kv_cache to singleton QuantizedKVParameterCache.
248249
"""
249250
kv_cache = getattr(module, "kv_cache")
250-
kwargs["past_key_values"] = kv_cache
251+
if not hasattr(module, "_past_kv_name"):
252+
# Determine which past KV parameter name to use once and cache it
253+
# TODO: Find a better place to cache this
254+
module._past_kv_name = (
255+
"past_key_value" # transformers#39956
256+
if "past_key_value" in inspect.signature(module.forward).parameters
257+
else "past_key_values"
258+
)
259+
260+
kwargs[module._past_kv_name] = kv_cache
251261
kwargs["use_cache"] = False
252262
return args, kwargs
253263

0 commit comments

Comments
 (0)