diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 2c30b704..68363d9d 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import copy import os import re @@ -296,6 +297,12 @@ def __init__( f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." ) + # kv cache + static_kv_dtype = kwargs.pop("static_kv_dtype", None) + self.static_kv_dtype = static_kv_dtype + if self.static_kv_dtype is not None: + logger.warning("The static kv is experimental and currently has limited support.") + self.sampler = sampler self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap @@ -742,8 +749,13 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au kwargs.pop("inplace", None) # Perform model quantization - model, _ = self.quantize() + if self.static_kv_dtype is not None: + from auto_round.experimental.kv_cache import kvcache_quant_context + with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype): + model, _ = self.quantize() + else: + model, _ = self.quantize() # Save the quantized model in the specified format_list folders = [] for format in format_list: @@ -1334,7 +1346,6 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - add_hook_to_module(m, hook, True) else: block = block.to(self.device) - input_ids = self.get_block_outputs( block, input_ids, diff --git a/auto_round/experimental/kv_cache.py b/auto_round/experimental/kv_cache.py new file mode 100644 index 00000000..8a49f307 --- /dev/null +++ b/auto_round/experimental/kv_cache.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTICE: The design adapted from: +# https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/quantization/cache.py + + +import contextlib +from enum import Enum +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import DynamicCache + +from auto_round.utils import logger + +__all__ = [ + "initialize_quantized_kv_cache", + "prep_attention_module_for_calibration", + "freeze_module_quantization_", + "kvcache_quant_context", +] + + +def freeze_module_quantization_(module: torch.nn.Module): + """ + deletes observers when calibration is complete. + + apply to full model with `model.apply(freeze_module_quantization_)` + + :param module: module to freeze quantization for + """ + + # remove observers if needed + for name in ("input", "weight", "output"): + obs_name = f"{name}_observer" + if hasattr(module, obs_name): + delattr(module, obs_name) + + # remove quantized kv_cache + kv_cache = getattr(module, "kv_cache", None) + if isinstance(kv_cache, QuantizedKVParameterCache): + delattr(module, "kv_cache") + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +# NOTE: Using _ suffix to denote l is modified in place +def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: + """ + Append value val to list lst at index idx, right padding if necessary + Needed because user may ignore some layers in configuration, meaning + len(lst) <= idx-1 + + >>> _pad_and_append_at_idx_([0,1,2], 5, 5) + [0, 1, 2, None, None, 5] + >>> _pad_and_append_at_idx_([0,1,2], 3, 8) + [0, 1, 2, 8] + >>> _pad_and_append_at_idx_([0,1,2], 1, 5) + [0, 5, 2] + """ + num_to_pad = idx - len(lst) + 1 + if num_to_pad > 0: + lst += [None] * num_to_pad + lst[idx] = val + return lst + + +def fp8_per_tensor_qdq(tensor): + from auto_round.data_type.fp8 import quant_fp8_sym + + qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0) + return qdq_tensor, scale + + +class QuantizedKVParameterCache(DynamicCache): + """ + Quantized KV cache used in the forward call based on HF's dynamic cache. + Singleton, so that the same cache gets reused in all forward call of self_attn. + Each time forward is called, .update() is called, and ._quant_dequant() gets called appropriately. + The size of tensor is + `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + + """ + + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + """Singleton""" + if cls._instance is None: + cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) + return cls._instance + + def __init__(self, dtype: torch.dtype = torch.float8_e4m3fn): + + assert dtype == torch.float8_e4m3fn, "Only fp8_e4m3fn is supported for now." + if not self._initialized: + super().__init__() + + # each index corresponds to layer_idx of the attention layer + self.k_scales: List[torch.Tensor] = [] + self.v_scales: List[torch.Tensor] = [] + self._initialized = True + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the k_scale and v_scale and output the quant-dequant key_states and value_states + """ + qdq_key_states = self._quant_dequant(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) + qdq_value_states = self._quant_dequant(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) + + keys_to_return, values_to_return = qdq_key_states, qdq_value_states + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """ + Returns the sequence length of the cached states. + A layer index can be optionally passed. + """ + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and + # rely on `_seen_tokens` which is updated every "layer_idx" == 0, + # this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to + # verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def reset_states(self): + """reset the kv states (used in calibration)""" + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 + self._quantized_key_cache: List[torch.Tensor] = [] + self._quantized_value_cache: List[torch.Tensor] = [] + + def reset(self): + """ + Reset the instantiation, create new instance on init + """ + QuantizedKVParameterCache._instance = None + QuantizedKVParameterCache._initialized = False + + def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): + """Quantizes a key/value using a defined quantization method.""" + if kv_type == KVCacheScaleType.KEY: # key type + scales = self.k_scales + else: + assert kv_type == KVCacheScaleType.VALUE + scales = self.v_scales + + qdq_tensor, scale = fp8_per_tensor_qdq(tensor) + _pad_and_append_at_idx_(scales, layer_idx, scale) + return qdq_tensor + + +def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4m3fn): + """ + Initialize a quantized kv_cache on a module (analogous to initializing an observer) + """ + if not is_attention_module(module): + return + existing_kv_cache = getattr(module, "kv_cache", None) + + if isinstance(existing_kv_cache, QuantizedKVParameterCache): + return + + quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype) + setattr(module, "kv_cache", quantized_kv_cache) + logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + + +def is_attention_module(module: torch.nn.Module): + # FIXME: Handle this better. + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + ) + + +def calibrate_kv_cache_input_hook( + module: torch.nn.Module, args: Any, kwargs: Dict[str, Any] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + Hook to update inputs to attention layers when running + kv_cache quantization. Will update the passed in + kv_cache to singleton QuantizedKVParameterCache. + """ + logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + kv_cache = getattr(module, "kv_cache") + # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. + # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 + if "past_key_values" in kwargs: + kwargs["past_key_values"] = kv_cache + else: + kwargs["past_key_value"] = kv_cache + kwargs["use_cache"] = False + return args, kwargs + + +def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): + """ + Update the data of a parameter in a module. + If the parameter does not exist, it will be created. + """ + if hasattr(module, name): + param = getattr(module, name) + if isinstance(param, torch.nn.Parameter): + param.data = new_val + else: + module.register_parameter(name, torch.nn.Parameter(new_val)) + else: + logger.warning( + "Parameter %s not found in module %s, creating new parameter." + % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) + ) + module.register_parameter(name, torch.nn.Parameter(new_val)) + + +def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor): + """ + Hook to update k_scale and v_scale parameters when running kv_cache quantization. + """ + logger.debug( + "Calibrate kv_cache output hook for %s %s" + % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) + ) + kv_cache = getattr(module, "kv_cache") + k_scale = kv_cache.k_scales[module.layer_idx] + v_scale = kv_cache.v_scales[module.layer_idx] + update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value) + update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value) + + +def prep_attention_module_for_calibration(module: torch.nn.Module): + if is_attention_module(module): + module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True) + module.register_forward_hook(calibrate_kv_cache_output_hook) + + +def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype: + valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] + valid_torch_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8_e4m3fn": torch.float8_e4m3fn, + "float32": torch.float32, + "float": torch.float32, # Alias for float32 + } + if static_kv_dtype in valid_dtype_name_lst: + new_dtype = valid_torch_dtype[static_kv_dtype] + elif static_kv_dtype in valid_torch_dtype.values(): + new_dtype = static_kv_dtype + else: + raise ValueError( + f"Invalid static kv dtype: {static_kv_dtype}. " + f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." + ) + return new_dtype + + +@contextlib.contextmanager +def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): + """Context manager for FP8 KV cache quantization operations.""" + try: + # Setup phase: Initialize KV cache for quantization + static_kv_dtype = normalize_static_kv_dtype(static_kv_dtype) + if static_kv_dtype != torch.float8_e4m3fn: + logger.warning(f"Ignoring static kv dtype {static_kv_dtype}, only fp8_e4m3fn is supported.") + else: + initialize_fn = partial(initialize_quantized_kv_cache, dtype=static_kv_dtype) + model.apply(initialize_fn) + model.apply(prep_attention_module_for_calibration) + + # Provide the model to the with block + yield model + + finally: + # Cleanup phase: Freeze quantization parameters + model.apply(freeze_module_quantization_) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index 265ba7e9..9ce86165 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -3,3 +3,4 @@ modelscope gguf torchvision compressed-tensors +parameterized diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index b60c87f2..72a1e68c 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -2,6 +2,8 @@ import sys import unittest +from parameterized import parameterized + sys.path.insert(0, "../..") import torch from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer @@ -199,7 +201,8 @@ def test_autoround_3bit_sym_format(self): print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) shutil.rmtree(quantized_model_path, ignore_errors=True) - def test_static_afp8_export(self): + @parameterized.expand([(None,), ("fp8",), ("float16")]) + def test_static_afp8_export(self, static_kv_dtype): import os from safetensors import safe_open @@ -218,6 +221,7 @@ def test_static_afp8_export(self): act_data_type="fp8", act_dynamic=False, act_group_size=0, + static_kv_dtype=static_kv_dtype, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") @@ -226,6 +230,13 @@ def test_static_afp8_export(self): self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + + if static_kv_dtype == "fp8": + self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) + self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").dtype, torch.float32) shutil.rmtree(quantized_model_path, ignore_errors=True) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)