From 75da6b616c8c67a6f65d3c078c319d1c97fbeb09 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 14 Aug 2025 22:43:10 -0400 Subject: [PATCH 01/15] add fp8 kv Signed-off-by: yiliu30 --- auto_round/autoround.py | 1 - .../experimental/auto_round_fp8_kv_example.py | 108 ++++ auto_round/experimental/fp8_kv_cache.py | 540 ++++++++++++++++++ 3 files changed, 648 insertions(+), 1 deletion(-) create mode 100644 auto_round/experimental/auto_round_fp8_kv_example.py create mode 100644 auto_round/experimental/fp8_kv_cache.py diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 3c0b0724..69942334 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1306,7 +1306,6 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - add_hook_to_module(m, hook, True) else: block = block.to(self.device) - input_ids = self.get_block_outputs( block, input_ids, diff --git a/auto_round/experimental/auto_round_fp8_kv_example.py b/auto_round/experimental/auto_round_fp8_kv_example.py new file mode 100644 index 00000000..66cbfb52 --- /dev/null +++ b/auto_round/experimental/auto_round_fp8_kv_example.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import torch +from fp8_kv_cache import ( + freeze_module_quantization_, + initialize_quantized_kv_cache, + prep_attention_module_for_calibration, +) +from loguru import logger + +logger.add(sys.stderr, level="TRACE") + + +# Example +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + + +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +model.eval() + +model.apply(initialize_quantized_kv_cache) +model.apply(prep_attention_module_for_calibration) + + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound + +autoround = AutoRound( + model, + tokenizer, + bits=8, + group_size=-1, + iters=0, + act_bits=8, + nsamples=2, + data_type="fp8", + act_data_type="fp8", + act_dynamic=False, +) +model, qconfig = autoround.quantize() +assert model is not None, "Expected q_model to be not None" + + +model.apply(freeze_module_quantization_) + +for name, param in model.named_parameters(): + if "k_scale" in name or "v_scale" in name: + print(f"{name}: {param.shape}, {param.dtype}, {param.item()}") + + +################### + +# # Example +# from transformers import AutoModelForCausalLM, AutoTokenizer + +# # Select model and load it. +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + + +# os.environ["LLM_COMPRESSOR_LOG_LEVEL"] = "DEBUG" + + +# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" +# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# model.eval() + +# model.apply(initialize_quantized_kv_cache) +# model.apply(prep_attention_module_for_calibration) + +# sample = { +# name: torch.ones((1, 32)).long() +# for name in ["input_ids", "attention_mask", "labels"] +# } + +# with torch.no_grad(): +# _ = model(**sample) + +# breakpoint() +# model.apply(freeze_module_quantization_) + +# for name, param in model.named_parameters(): +# if "k_scale" in name or "v_scale" in name: +# print(f"{name}: {param.shape}, {param.dtype}, {param.item()}"): diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py new file mode 100644 index 00000000..61795f1f --- /dev/null +++ b/auto_round/experimental/fp8_kv_cache.py @@ -0,0 +1,540 @@ +# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTICE: The design adapted from: +# https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/quantization/cache.py + + +import os +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from loguru import logger +from torch import FloatTensor, IntTensor, Tensor +from torch.nn import Module +from transformers.cache_utils import DynamicCache + +__all__ = [ + "initialize_quantized_kv_cache", + "prep_attention_module_for_calibration", + "freeze_module_quantization_", +] + +import functools +import sys + +logger.add(sys.stderr, level="TRACE") + +import packaging + + +def is_greater_or_equal_version(cur_version, deprecated_version_str): + deprecated_version = packaging.version.parse(deprecated_version_str) + current_version = packaging.version.parse(cur_version) + return current_version >= deprecated_version + + +def freeze_module_quantization_(module: Module): + """ + deletes observers when calibration is complete. + + apply to full model with `model.apply(freeze_module_quantization_)` + + :param module: module to freeze quantization for + """ + + # remove observers + for name in ("input", "weight", "output"): + obs_name = f"{name}_observer" + if hasattr(module, obs_name): + delattr(module, obs_name) + + # remove quantized kv_cache + kv_cache = getattr(module, "kv_cache", None) + if isinstance(kv_cache, QuantizedKVParameterCache): + delattr(module, "kv_cache") + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +def calculate_qparams( + min_vals: Tensor, + max_vals: Tensor, +) -> Tuple[FloatTensor, IntTensor]: + """ + :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) + from + :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) + from + :param quantization_args: settings to quantization + :param global_scale: additional global scale to scale the locally generated scale + currently only applied/supported for Fp4 + + :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated + scale is of dtype FP8 + """ + # based on the implementations for consuming quantized values, + # 0.0 must always be representable within the quantized range + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + + device = min_vals.device + + # bit_min, bit_max = calculate_range(quantization_args, device) + fp8_info = torch.finfo(torch.float8_e4m3fn) + bit_min, bit_max = fp8_info.min, fp8_info.max + + bit_range = bit_max - bit_min + zp_dtype = min_vals.dtype + + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped + + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) + + scales = (max_vals - min_vals) / float(bit_range) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) + + return scales, zero_points + + +class MinMaxObserver(Module): + def __init__( + self, + name: str = "_observer", + averaging_constant: float = 0.01, + ): + super().__init__() + self.name = name + self._scale = None + self._zero_point = None + self.min_val = {} + self.max_val = {} + self.averaging_constant = averaging_constant + + @torch.no_grad() + def forward( + self, + observed: Tensor, + global_scale: Optional[Tensor] = None, + ) -> Tuple[FloatTensor, IntTensor]: + """ + maps directly to get_qparams + :param observed: optional observed tensor from which to calculate + quantization parameters + :param global_scale: optional scale to further scale local quantization scales + :return: tuple of scale and zero point based on last observed value + """ + return self.get_qparams( + observed=observed, + global_scale=global_scale, + ) + + def calculate_updated_min_max( + self, + observed: torch.Tensor, + reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + ): + """ + Updates the observed min and max using a moving average smoothed by the + averaging_constant. Set the averaging_constant to 1.0 to disable averaging. + + :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :return: updated min and max values + """ + tensor_id = tensor_id or "default" + + if not reduce_dims: + min_val, max_val = torch.aminmax(observed) + else: + min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) + max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + + # early stopping, save some computation and memory + # if self.averaging_constant == 1.0: + # return min_val, max_val + + running_min_val = self.min_val.get(tensor_id, None) + running_max_val = self.max_val.get(tensor_id, None) + + if running_min_val is None or running_max_val is None: + updated_min_val = min_val + updated_max_val = max_val + else: + updated_min_val = running_min_val + self.averaging_constant * (min_val - running_min_val) + updated_max_val = running_max_val + self.averaging_constant * (max_val - running_max_val) + + self.min_val[tensor_id] = updated_min_val + self.max_val[tensor_id] = updated_max_val + return updated_min_val, updated_max_val + + def calculate_qparams( + self, + observed: torch.Tensor, + reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + global_scale: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, torch.IntTensor]: + """ + Generate a scale and zero-point using the observed min and max. + + :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :param global_scale: optional scale to further scale local quantization scales + :return: tuple of scale and zero point derived from the observed tensor + """ + + updated_min_val, updated_max_val = self.calculate_updated_min_max( + observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims + ) + return calculate_qparams( + min_vals=updated_min_val, + max_vals=updated_max_val, + ) + + def post_calculate_qparams(self) -> None: + """ + Run any logic specific to its observers after running calculate_qparams + """ + + def get_qparams( + self, + observed: Optional[Tensor] = None, + g_idx: Optional[Tensor] = None, + global_scale: Optional[Tensor] = None, + ) -> Tuple[FloatTensor, IntTensor]: + """ + Convenience function to wrap overwritten calculate_qparams + adds support to make observed tensor optional and support for tracking latest + calculated scale and zero point + + :param observed: optional observed tensor to calculate quantization parameters + from + :param g_idx: optional mapping from column index to group index + :param global_scale: optional scale to further scale local quantization scales + :return: tuple of scale and zero point based on last observed value + """ + self._scale, self._zero_point = self.calculate_qparams(observed) + return self._scale, self._zero_point + + def get_qparams_along_dim( + self, + observed, + dim: Union[int, Iterable[int]], + tensor_id: Optional[Any] = None, + global_scale: Optional[Tensor] = None, + ): + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) + return self.calculate_qparams( + observed, + reduce_dims=reduce_dims, + tensor_id=tensor_id, + global_scale=global_scale, + ) + + def reset(self): + """ + Reset the state of the observer + """ + self._scale = None + self._zero_point = None + + +# NOTE: Using _ suffix to denote l is modified in place +def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: + """ + Append value val to list lst at index idx, right padding if necessary + Needed because user may ignore some layers in configuration, meaning + len(lst) <= idx-1 + + >>> _pad_and_append_at_idx_([0,1,2], 5, 5) + [0, 1, 2, None, None, 5] + >>> _pad_and_append_at_idx_([0,1,2], 3, 8) + [0, 1, 2, 8] + >>> _pad_and_append_at_idx_([0,1,2], 1, 5) + [0, 5, 2] + """ + num_to_pad = idx - len(lst) + 1 + if num_to_pad > 0: + lst += [None] * num_to_pad + lst[idx] = val + return lst + + +class QuantizedKVParameterCache(DynamicCache): + """ + Quantized KV cache used in the forward call based on HF's dynamic cache. + Quantization strategy (tensor, group, channel) set from Quantization arg's strategy + Singleton, so that the same cache gets reused in all forward call of self_attn. + Each time forward is called, .update() is called, and ._quantize(), ._dequantize() + gets called appropriately. + The size of tensor is + `[batch_size, num_heads, seq_len - residual_length, head_dim]`. + + + # TODO: Triggered by adding kv_cache_scheme in ... + + """ + + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + """Singleton""" + if cls._instance is None: + cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) + return cls._instance + + def __init__(self, quantization_args=None): + if not self._initialized: + super().__init__() + + self.quantization_args = quantization_args + + self.k_observers: List[MinMaxObserver] = [] + self.v_observers: List[MinMaxObserver] = [] + + # each index corresponds to layer_idx of the attention layer + self.k_scales: List[Tensor] = [] + self.v_scales: List[Tensor] = [] + + self.k_zps: List[Tensor] = [] + self.v_zps: List[Tensor] = [] + + self._initialized = True + + def update( + self, + key_states: Tensor, + value_states: Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Get the k_scale and v_scale and output the + fakequant-ed key_states and value_states + """ + + if len(self.k_observers) <= layer_idx: + k_observer = MinMaxObserver(f"k_observer_{layer_idx}") + v_observer = MinMaxObserver(f"v_observer_{layer_idx}") + + # NOTE: User may ignore some layers in configuration, + # meaning len(self.k_observers) <= layer_idx-1 + # Must account for that case by padding list so that + # index of lists corresponds to layer_idx + _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) + _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) + # FIXME: Should we append the key_states/value_states to the cache? + q_key_states = self._quantize(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) + q_value_states = self._quantize(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) + + qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) + qdq_value_states = self._dequantize(q_value_states, KVCacheScaleType.VALUE, layer_idx) + + keys_to_return, values_to_return = qdq_key_states, qdq_value_states + + return keys_to_return, values_to_return + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """ + Returns the sequence length of the cached states. + A layer index can be optionally passed. + """ + if len(self.key_cache) <= layer_idx: + return 0 + # since we cannot get the seq_length of each layer directly and + # rely on `_seen_tokens` which is updated every "layer_idx" == 0, + # this is a hack to get the actual seq_length for the given layer_idx + # this part of code otherwise fails when used to + # verify attn_weight shape in some models + return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 + + def reset_states(self): + """reset the kv states (used in calibration)""" + self.key_cache: List[Tensor] = [] + self.value_cache: List[Tensor] = [] + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 + self._quantized_key_cache: List[Tensor] = [] + self._quantized_value_cache: List[Tensor] = [] + + def reset(self): + """ + Reset the instantiation, create new instance on init + """ + QuantizedKVParameterCache._instance = None + QuantizedKVParameterCache._initialized = False + + def _quantize(self, tensor, kv_type, layer_idx): + """Quantizes a key/value using a defined quantization method.""" + # from compressed_tensors.quantization.lifecycle.forward import quantize + if kv_type == KVCacheScaleType.KEY: # key type + observer = self.k_observers[layer_idx] + scales = self.k_scales + zps = self.k_zps + else: + assert kv_type == KVCacheScaleType.VALUE + observer = self.v_observers[layer_idx] + scales = self.v_scales + zps = self.v_zps + + scale, zp = observer(tensor) + _pad_and_append_at_idx_(scales, layer_idx, scale) + _pad_and_append_at_idx_(zps, layer_idx, zp) + + def quantize_fp8_tensor(x, scale, zero_point): + fp8_info = torch.finfo(torch.float8_e4m3fn) + q_min, q_max = fp8_info.min, fp8_info.max + + scaled = x / scale + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, + q_min, + q_max, + ) + + # round + quantized_value = clamped_value.to(torch.float8_e4m3fn) + return quantized_value + + q_tensor = quantize_fp8_tensor(x=tensor, scale=scale, zero_point=zp) + return q_tensor + + def _dequantize(self, qtensor, kv_type, layer_idx): + """Dequantizes back the tensor that was quantized by `self._quantize()`""" + from compressed_tensors.quantization.lifecycle.forward import dequantize + + if kv_type == KVCacheScaleType.KEY: + scale = self.k_scales[layer_idx] + zp = self.k_zps[layer_idx] + else: + assert kv_type == KVCacheScaleType.VALUE + scale = self.v_scales[layer_idx] + zp = self.v_zps[layer_idx] + + qdq_tensor = dequantize( + x_q=qtensor, + scale=scale, + zero_point=zp, + args=self.quantization_args, + ) + return qdq_tensor + + +def initialize_quantized_kv_cache(module: Module): + """ + Initialize a quantized kv_cache on a module (analogous to initializing an observer) + """ + if not is_attention_module(module): + return + existing_kv_cache = getattr(module, "kv_cache", None) + + if isinstance(existing_kv_cache, QuantizedKVParameterCache): + return + + quantized_kv_cache = QuantizedKVParameterCache() + setattr(module, "kv_cache", quantized_kv_cache) + logger.trace(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + + +def is_attention_module(module: Module): + # FIXME: Handle this better. + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + ) + + +def calibrate_kv_cache_input_hook( + module: Module, args: Any, kwargs: Dict[str, Any] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """ + Hook to update inputs to attention layers when running + kv_cache quantization. Will update the passed in + kv_cache to singleton QuantizedKVParameterCache. + """ + logger.trace(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + # breakpoint() + kv_cache = getattr(module, "kv_cache") + # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. + # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 + if "past_key_values" in kwargs: + kwargs["past_key_values"] = kv_cache + else: + kwargs["past_key_value"] = kv_cache + kwargs["use_cache"] = False + return args, kwargs + + +def update_parameter_data(module, new_val, name: str): + """ + Update the data of a parameter in a module. + If the parameter does not exist, it will be created. + """ + if hasattr(module, name): + param = getattr(module, name) + if isinstance(param, torch.nn.Parameter): + param.data = new_val + else: + module.register_parameter(name, torch.nn.Parameter(new_val)) + else: + logger.warning( + "Parameter %s not found in module %s, creating new parameter." + % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) + ) + module.register_parameter(name, torch.nn.Parameter(new_val)) + + +def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): + """ + Hook to update k_scale and v_scale parameters when running kv_cache quantization. + """ + logger.trace( + "Calibrate kv_cache output hook for %s %s" + % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) + ) + kv_cache = getattr(module, "kv_cache") + k_scale = kv_cache.k_scales[module.layer_idx] + v_scale = kv_cache.v_scales[module.layer_idx] + update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value) + update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value) + + +def prep_attention_module_for_calibration(module: torch.nn.Module): + if is_attention_module(module): + module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True) + module.register_forward_hook(calibrate_kv_cache_output_hook) From 4b3f36ade008b071bd285cb496b5437f9860e88f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 14 Aug 2025 23:58:49 -0400 Subject: [PATCH 02/15] refactor Signed-off-by: yiliu30 --- auto_round/autoround.py | 16 +- auto_round/experimental/fp8_kv_cache.py | 323 +++--------------------- test/test_cpu/test_export.py | 13 +- 3 files changed, 65 insertions(+), 287 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 69942334..51135588 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import copy import os import re @@ -170,6 +171,7 @@ def __init__( act_sym: bool = None, act_data_type: str = None, act_dynamic: bool = True, + enable_fp8_kv: bool = False, enable_torch_compile: bool = False, device_map: Union[str, dict] = None, disable_opt_rtn: bool = False, @@ -293,6 +295,10 @@ def __init__( f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." ) + # kv cache + self.enable_fp8_kv = enable_fp8_kv + logger.warning("The `enable_fp8_kv` feature is experimental and currently has limited support.") + self.sampler = sampler self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap @@ -732,7 +738,15 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au kwargs.pop("inplace", None) # Perform model quantization - model, _ = self.quantize() + if self.enable_fp8_kv: + from auto_round.experimental.fp8_kv_cache import fp8_kv_context + + quant_ctx = fp8_kv_context + else: + quant_ctx = contextlib.nullcontext + + with quant_ctx(self.model): + model, _ = self.quantize() # Save the quantized model in the specified format_list folders = [] diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 61795f1f..eb9fbb01 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -17,6 +17,7 @@ import os +import sys from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -26,25 +27,15 @@ from torch.nn import Module from transformers.cache_utils import DynamicCache +logger.add(sys.stderr, level="TRACE") + __all__ = [ "initialize_quantized_kv_cache", "prep_attention_module_for_calibration", "freeze_module_quantization_", + "fp8_kv_context", ] -import functools -import sys - -logger.add(sys.stderr, level="TRACE") - -import packaging - - -def is_greater_or_equal_version(cur_version, deprecated_version_str): - deprecated_version = packaging.version.parse(deprecated_version_str) - current_version = packaging.version.parse(cur_version) - return current_version >= deprecated_version - def freeze_module_quantization_(module: Module): """ @@ -55,7 +46,7 @@ def freeze_module_quantization_(module: Module): :param module: module to freeze quantization for """ - # remove observers + # remove observers if needed for name in ("input", "weight", "output"): obs_name = f"{name}_observer" if hasattr(module, obs_name): @@ -72,207 +63,6 @@ class KVCacheScaleType(Enum): VALUE = "v_scale" -def calculate_qparams( - min_vals: Tensor, - max_vals: Tensor, -) -> Tuple[FloatTensor, IntTensor]: - """ - :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s) - from - :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s) - from - :param quantization_args: settings to quantization - :param global_scale: additional global scale to scale the locally generated scale - currently only applied/supported for Fp4 - - :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated - scale is of dtype FP8 - """ - # based on the implementations for consuming quantized values, - # 0.0 must always be representable within the quantized range - min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) - max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) - - device = min_vals.device - - # bit_min, bit_max = calculate_range(quantization_args, device) - fp8_info = torch.finfo(torch.float8_e4m3fn) - bit_min, bit_max = fp8_info.min, fp8_info.max - - bit_range = bit_max - bit_min - zp_dtype = min_vals.dtype - - max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped - - zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) - - scales = (max_vals - min_vals) / float(bit_range) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max) - - return scales, zero_points - - -class MinMaxObserver(Module): - def __init__( - self, - name: str = "_observer", - averaging_constant: float = 0.01, - ): - super().__init__() - self.name = name - self._scale = None - self._zero_point = None - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant - - @torch.no_grad() - def forward( - self, - observed: Tensor, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - maps directly to get_qparams - :param observed: optional observed tensor from which to calculate - quantization parameters - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value - """ - return self.get_qparams( - observed=observed, - global_scale=global_scale, - ) - - def calculate_updated_min_max( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ): - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: updated min and max values - """ - tensor_id = tensor_id or "default" - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - # early stopping, save some computation and memory - # if self.averaging_constant == 1.0: - # return min_val, max_val - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * (min_val - running_min_val) - updated_max_val = running_max_val + self.averaging_constant * (max_val - running_max_val) - - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - return updated_min_val, updated_max_val - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Generate a scale and zero-point using the observed min and max. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims - ) - return calculate_qparams( - min_vals=updated_min_val, - max_vals=updated_max_val, - ) - - def post_calculate_qparams(self) -> None: - """ - Run any logic specific to its observers after running calculate_qparams - """ - - def get_qparams( - self, - observed: Optional[Tensor] = None, - g_idx: Optional[Tensor] = None, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point - - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value - """ - self._scale, self._zero_point = self.calculate_qparams(observed) - return self._scale, self._zero_point - - def get_qparams_along_dim( - self, - observed, - dim: Union[int, Iterable[int]], - tensor_id: Optional[Any] = None, - global_scale: Optional[Tensor] = None, - ): - if isinstance(dim, int): - dim = [dim] - dim = set(dim) - - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, - global_scale=global_scale, - ) - - def reset(self): - """ - Reset the state of the observer - """ - self._scale = None - self._zero_point = None - - # NOTE: Using _ suffix to denote l is modified in place def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: """ @@ -294,6 +84,13 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: return lst +def fp8_per_tensor_qdq(tensor): + from auto_round.data_type.fp8 import quant_fp8_sym + + qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0) + return qdq_tensor, scale + + class QuantizedKVParameterCache(DynamicCache): """ Quantized KV cache used in the forward call based on HF's dynamic cache. @@ -324,16 +121,9 @@ def __init__(self, quantization_args=None): self.quantization_args = quantization_args - self.k_observers: List[MinMaxObserver] = [] - self.v_observers: List[MinMaxObserver] = [] - # each index corresponds to layer_idx of the attention layer self.k_scales: List[Tensor] = [] self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - self._initialized = True def update( @@ -344,26 +134,15 @@ def update( cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Tensor, Tensor]: """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states + Get the k_scale and v_scale and output the quant-dequant key_states and value_states """ - if len(self.k_observers) <= layer_idx: - k_observer = MinMaxObserver(f"k_observer_{layer_idx}") - v_observer = MinMaxObserver(f"v_observer_{layer_idx}") - - # NOTE: User may ignore some layers in configuration, - # meaning len(self.k_observers) <= layer_idx-1 - # Must account for that case by padding list so that - # index of lists corresponds to layer_idx - _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) - _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) # FIXME: Should we append the key_states/value_states to the cache? - q_key_states = self._quantize(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) - q_value_states = self._quantize(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) + # q_key_states = self._quantize(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) + # q_value_states = self._quantize(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize(q_value_states, KVCacheScaleType.VALUE, layer_idx) + qdq_key_states = self._quant_dequant(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) + qdq_value_states = self._quant_dequant(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) keys_to_return, values_to_return = qdq_key_states, qdq_value_states @@ -399,60 +178,16 @@ def reset(self): QuantizedKVParameterCache._instance = None QuantizedKVParameterCache._initialized = False - def _quantize(self, tensor, kv_type, layer_idx): + def _quant_dequant(self, tensor, kv_type, layer_idx): """Quantizes a key/value using a defined quantization method.""" - # from compressed_tensors.quantization.lifecycle.forward import quantize if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] scales = self.k_scales - zps = self.k_zps else: assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] scales = self.v_scales - zps = self.v_zps - scale, zp = observer(tensor) + qdq_tensor, scale = fp8_per_tensor_qdq(tensor) _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) - - def quantize_fp8_tensor(x, scale, zero_point): - fp8_info = torch.finfo(torch.float8_e4m3fn) - q_min, q_max = fp8_info.min, fp8_info.max - - scaled = x / scale - # clamp first because cast isn't guaranteed to be saturated (ie for fp8) - clamped_value = torch.clamp( - scaled, - q_min, - q_max, - ) - - # round - quantized_value = clamped_value.to(torch.float8_e4m3fn) - return quantized_value - - q_tensor = quantize_fp8_tensor(x=tensor, scale=scale, zero_point=zp) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) return qdq_tensor @@ -488,7 +223,6 @@ def calibrate_kv_cache_input_hook( kv_cache to singleton QuantizedKVParameterCache. """ logger.trace(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") - # breakpoint() kv_cache = getattr(module, "kv_cache") # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 @@ -538,3 +272,22 @@ def prep_attention_module_for_calibration(module: torch.nn.Module): if is_attention_module(module): module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True) module.register_forward_hook(calibrate_kv_cache_output_hook) + + +import contextlib + + +@contextlib.contextmanager +def fp8_kv_context(model): + """Context manager for FP8 KV cache quantization operations.""" + try: + # Setup phase: Initialize KV cache for quantization + model.apply(initialize_quantized_kv_cache) + model.apply(prep_attention_module_for_calibration) + + # Provide the model to the with block + yield model + + finally: + # Cleanup phase: Freeze quantization parameters + model.apply(freeze_module_quantization_) diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index bbce4036..2e889dd0 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -2,6 +2,8 @@ import sys import unittest +from parameterized import parameterized + sys.path.insert(0, "../..") import torch from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer @@ -199,7 +201,8 @@ def test_autoround_3bit_sym_format(self): print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) shutil.rmtree(quantized_model_path, ignore_errors=True) - def test_static_afp8_export(self): + @parameterized.expand([(True,), (False,)]) + def test_static_afp8_export(self, enable_fp8_kv): import os from safetensors import safe_open @@ -218,6 +221,7 @@ def test_static_afp8_export(self): act_data_type="fp8", act_dynamic=False, act_group_size=0, + enable_fp8_kv=enable_fp8_kv, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") @@ -226,6 +230,13 @@ def test_static_afp8_export(self): self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + + if enable_fp8_kv: + self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) + self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").dtype, torch.float32) shutil.rmtree(quantized_model_path, ignore_errors=True) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) From eff6a346a51e6b9b6abd24518b721b0d7e9a8b01 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 14 Aug 2025 23:59:25 -0400 Subject: [PATCH 03/15] clean code Signed-off-by: yiliu30 --- .../experimental/auto_round_fp8_kv_example.py | 108 ------------------ 1 file changed, 108 deletions(-) delete mode 100644 auto_round/experimental/auto_round_fp8_kv_example.py diff --git a/auto_round/experimental/auto_round_fp8_kv_example.py b/auto_round/experimental/auto_round_fp8_kv_example.py deleted file mode 100644 index 66cbfb52..00000000 --- a/auto_round/experimental/auto_round_fp8_kv_example.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import torch -from fp8_kv_cache import ( - freeze_module_quantization_, - initialize_quantized_kv_cache, - prep_attention_module_for_calibration, -) -from loguru import logger - -logger.add(sys.stderr, level="TRACE") - - -# Example -from transformers import AutoModelForCausalLM, AutoTokenizer - -# Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" - - -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -model.eval() - -model.apply(initialize_quantized_kv_cache) -model.apply(prep_attention_module_for_calibration) - - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from auto_round import AutoRound - -autoround = AutoRound( - model, - tokenizer, - bits=8, - group_size=-1, - iters=0, - act_bits=8, - nsamples=2, - data_type="fp8", - act_data_type="fp8", - act_dynamic=False, -) -model, qconfig = autoround.quantize() -assert model is not None, "Expected q_model to be not None" - - -model.apply(freeze_module_quantization_) - -for name, param in model.named_parameters(): - if "k_scale" in name or "v_scale" in name: - print(f"{name}: {param.shape}, {param.dtype}, {param.item()}") - - -################### - -# # Example -# from transformers import AutoModelForCausalLM, AutoTokenizer - -# # Select model and load it. -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" - - -# os.environ["LLM_COMPRESSOR_LOG_LEVEL"] = "DEBUG" - - -# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" -# MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" -# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") -# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# model.eval() - -# model.apply(initialize_quantized_kv_cache) -# model.apply(prep_attention_module_for_calibration) - -# sample = { -# name: torch.ones((1, 32)).long() -# for name in ["input_ids", "attention_mask", "labels"] -# } - -# with torch.no_grad(): -# _ = model(**sample) - -# breakpoint() -# model.apply(freeze_module_quantization_) - -# for name, param in model.named_parameters(): -# if "k_scale" in name or "v_scale" in name: -# print(f"{name}: {param.shape}, {param.dtype}, {param.item()}"): From b13c1cc6456b71c642d2a817609f2ba6a2bb4a55 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 15 Aug 2025 04:21:50 -0400 Subject: [PATCH 04/15] update name Signed-off-by: yiliu30 --- auto_round/autoround.py | 8 ++++---- test/test_cpu/test_export.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 51135588..1edb54c5 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -171,7 +171,7 @@ def __init__( act_sym: bool = None, act_data_type: str = None, act_dynamic: bool = True, - enable_fp8_kv: bool = False, + enable_static_fp8_kv: bool = False, enable_torch_compile: bool = False, device_map: Union[str, dict] = None, disable_opt_rtn: bool = False, @@ -296,8 +296,8 @@ def __init__( ) # kv cache - self.enable_fp8_kv = enable_fp8_kv - logger.warning("The `enable_fp8_kv` feature is experimental and currently has limited support.") + self.enable_static_fp8_kv = enable_static_fp8_kv + logger.warning("The `enable_static_fp8_kv` feature is experimental and currently has limited support.") self.sampler = sampler self.not_use_best_mse = not_use_best_mse @@ -738,7 +738,7 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au kwargs.pop("inplace", None) # Perform model quantization - if self.enable_fp8_kv: + if self.enable_static_fp8_kv: from auto_round.experimental.fp8_kv_cache import fp8_kv_context quant_ctx = fp8_kv_context diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 2e889dd0..36a7fe90 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -202,7 +202,7 @@ def test_autoround_3bit_sym_format(self): shutil.rmtree(quantized_model_path, ignore_errors=True) @parameterized.expand([(True,), (False,)]) - def test_static_afp8_export(self, enable_fp8_kv): + def test_static_afp8_export(self, enable_static_fp8_kv): import os from safetensors import safe_open @@ -221,7 +221,7 @@ def test_static_afp8_export(self, enable_fp8_kv): act_data_type="fp8", act_dynamic=False, act_group_size=0, - enable_fp8_kv=enable_fp8_kv, + enable_static_fp8_kv=enable_static_fp8_kv, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") @@ -231,7 +231,7 @@ def test_static_afp8_export(self, enable_fp8_kv): self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) - if enable_fp8_kv: + if enable_static_fp8_kv: self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) From 0cf5807766bba7cf0f2ec627e9ceac1cfff37a35 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 01:42:40 -0400 Subject: [PATCH 05/15] format code Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 43 +++++++++++-------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index eb9fbb01..74c9bda5 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -16,18 +16,13 @@ # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/quantization/cache.py -import os -import sys from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import torch -from loguru import logger -from torch import FloatTensor, IntTensor, Tensor -from torch.nn import Module from transformers.cache_utils import DynamicCache -logger.add(sys.stderr, level="TRACE") +from auto_round.utils import logger __all__ = [ "initialize_quantized_kv_cache", @@ -37,7 +32,7 @@ ] -def freeze_module_quantization_(module: Module): +def freeze_module_quantization_(module: torch.nn.Module): """ deletes observers when calibration is complete. @@ -122,17 +117,17 @@ def __init__(self, quantization_args=None): self.quantization_args = quantization_args # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] + self.k_scales: List[torch.Tensor] = [] + self.v_scales: List[torch.Tensor] = [] self._initialized = True def update( self, - key_states: Tensor, - value_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the k_scale and v_scale and output the quant-dequant key_states and value_states """ @@ -164,12 +159,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def reset_states(self): """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] + self._quantized_key_cache: List[torch.Tensor] = [] + self._quantized_value_cache: List[torch.Tensor] = [] def reset(self): """ @@ -191,7 +186,7 @@ def _quant_dequant(self, tensor, kv_type, layer_idx): return qdq_tensor -def initialize_quantized_kv_cache(module: Module): +def initialize_quantized_kv_cache(module: torch.nn.Module): """ Initialize a quantized kv_cache on a module (analogous to initializing an observer) """ @@ -204,10 +199,10 @@ def initialize_quantized_kv_cache(module: Module): quantized_kv_cache = QuantizedKVParameterCache() setattr(module, "kv_cache", quantized_kv_cache) - logger.trace(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") -def is_attention_module(module: Module): +def is_attention_module(module: torch.nn.Module): # FIXME: Handle this better. return "attention" in module.__class__.__name__.lower() and ( hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") @@ -215,14 +210,14 @@ def is_attention_module(module: Module): def calibrate_kv_cache_input_hook( - module: Module, args: Any, kwargs: Dict[str, Any] + module: torch.nn.Module, args: Any, kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Hook to update inputs to attention layers when running kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache. """ - logger.trace(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") + logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") kv_cache = getattr(module, "kv_cache") # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 @@ -253,11 +248,11 @@ def update_parameter_data(module, new_val, name: str): module.register_parameter(name, torch.nn.Parameter(new_val)) -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): +def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor): """ Hook to update k_scale and v_scale parameters when running kv_cache quantization. """ - logger.trace( + logger.debug( "Calibrate kv_cache output hook for %s %s" % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) ) From f6b058a315981eb5b393d96d64d85a640db85fc6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 01:44:05 -0400 Subject: [PATCH 06/15] update Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 74c9bda5..656ba612 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -96,9 +96,6 @@ class QuantizedKVParameterCache(DynamicCache): The size of tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - # TODO: Triggered by adding kv_cache_scheme in ... - """ _instance = None From 3c9749a7a206d45ea7d48dbe08a99300933435db Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 01:45:17 -0400 Subject: [PATCH 07/15] correct docs Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 656ba612..8940e029 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -91,8 +91,7 @@ class QuantizedKVParameterCache(DynamicCache): Quantized KV cache used in the forward call based on HF's dynamic cache. Quantization strategy (tensor, group, channel) set from Quantization arg's strategy Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. + Each time forward is called, .update() is called, and ._quant_dequant(), gets called appropriately. The size of tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. @@ -128,11 +127,6 @@ def update( """ Get the k_scale and v_scale and output the quant-dequant key_states and value_states """ - - # FIXME: Should we append the key_states/value_states to the cache? - # q_key_states = self._quantize(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) - # q_value_states = self._quantize(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) - qdq_key_states = self._quant_dequant(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) qdq_value_states = self._quant_dequant(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) From 6e905ddb63c9d61a37214d9b9526c5fc25bbfc7d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 01:50:52 -0400 Subject: [PATCH 08/15] add type hints Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 8940e029..12f5f4a9 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -91,7 +91,7 @@ class QuantizedKVParameterCache(DynamicCache): Quantized KV cache used in the forward call based on HF's dynamic cache. Quantization strategy (tensor, group, channel) set from Quantization arg's strategy Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quant_dequant(), gets called appropriately. + Each time forward is called, .update() is called, and ._quant_dequant() gets called appropriately. The size of tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. @@ -106,12 +106,10 @@ def __new__(cls, *args, **kwargs): cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) return cls._instance - def __init__(self, quantization_args=None): + def __init__(self): if not self._initialized: super().__init__() - self.quantization_args = quantization_args - # each index corresponds to layer_idx of the attention layer self.k_scales: List[torch.Tensor] = [] self.v_scales: List[torch.Tensor] = [] @@ -164,7 +162,7 @@ def reset(self): QuantizedKVParameterCache._instance = None QuantizedKVParameterCache._initialized = False - def _quant_dequant(self, tensor, kv_type, layer_idx): + def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): """Quantizes a key/value using a defined quantization method.""" if kv_type == KVCacheScaleType.KEY: # key type scales = self.k_scales @@ -220,7 +218,7 @@ def calibrate_kv_cache_input_hook( return args, kwargs -def update_parameter_data(module, new_val, name: str): +def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): """ Update the data of a parameter in a module. If the parameter does not exist, it will be created. @@ -264,7 +262,7 @@ def prep_attention_module_for_calibration(module: torch.nn.Module): @contextlib.contextmanager -def fp8_kv_context(model): +def fp8_kv_context(model: torch.nn.Module): """Context manager for FP8 KV cache quantization operations.""" try: # Setup phase: Initialize KV cache for quantization From 7e44306c1467d7efcaff312debe7f2bd57e5a5f7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 01:52:29 -0400 Subject: [PATCH 09/15] correct docs Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 12f5f4a9..27ded472 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -89,7 +89,6 @@ def fp8_per_tensor_qdq(tensor): class QuantizedKVParameterCache(DynamicCache): """ Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy Singleton, so that the same cache gets reused in all forward call of self_attn. Each time forward is called, .update() is called, and ._quant_dequant() gets called appropriately. The size of tensor is From 866f32b8c0f4dbca73b7259888883eec3ce3b4a0 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 02:05:09 -0400 Subject: [PATCH 10/15] updated Signed-off-by: yiliu30 --- auto_round/autoround.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index e7731121..97bee3b9 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -174,7 +174,6 @@ def __init__( act_sym: bool = None, act_data_type: str = None, act_dynamic: bool = True, - enable_static_fp8_kv: bool = False, enable_torch_compile: bool = False, device_map: Union[str, dict] = None, disable_opt_rtn: bool = False, @@ -299,6 +298,7 @@ def __init__( ) # kv cache + enable_static_fp8_kv = kwargs.pop("enable_static_fp8_kv", False) self.enable_static_fp8_kv = enable_static_fp8_kv logger.warning("The `enable_static_fp8_kv` feature is experimental and currently has limited support.") From 4a8e05affd544c9a925f00e403c21603c0872338 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 02:06:18 -0400 Subject: [PATCH 11/15] hide enable_static_fp8_kv Signed-off-by: yiliu30 --- auto_round/autoround.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 97bee3b9..6d692d56 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -300,7 +300,8 @@ def __init__( # kv cache enable_static_fp8_kv = kwargs.pop("enable_static_fp8_kv", False) self.enable_static_fp8_kv = enable_static_fp8_kv - logger.warning("The `enable_static_fp8_kv` feature is experimental and currently has limited support.") + if self.enable_static_fp8_kv: + logger.warning("The `enable_static_fp8_kv` feature is experimental and currently has limited support.") self.sampler = sampler self.not_use_best_mse = not_use_best_mse From e421585add49304b6b1ec46314a215eb88060e27 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 04:34:46 -0400 Subject: [PATCH 12/15] rename arg name Signed-off-by: yiliu30 --- auto_round/autoround.py | 17 ++++------ auto_round/experimental/fp8_kv_cache.py | 43 ++++++++++++++++++++----- test/test_cpu/test_export.py | 8 ++--- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 6d692d56..5c3f3a9f 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -298,10 +298,10 @@ def __init__( ) # kv cache - enable_static_fp8_kv = kwargs.pop("enable_static_fp8_kv", False) - self.enable_static_fp8_kv = enable_static_fp8_kv - if self.enable_static_fp8_kv: - logger.warning("The `enable_static_fp8_kv` feature is experimental and currently has limited support.") + static_kv_dtype = kwargs.pop("static_kv_dtype", None) + self.static_kv_dtype = static_kv_dtype + if self.static_kv_dtype is not None: + logger.warning("The static kv is experimental and currently has limited support.") self.sampler = sampler self.not_use_best_mse = not_use_best_mse @@ -749,16 +749,13 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au kwargs.pop("inplace", None) # Perform model quantization - if self.enable_static_fp8_kv: + if self.static_kv_dtype is not None: from auto_round.experimental.fp8_kv_cache import fp8_kv_context - quant_ctx = fp8_kv_context + with fp8_kv_context(self.model, static_kv_dtype=self.static_kv_dtype): + model, _ = self.quantize() else: - quant_ctx = contextlib.nullcontext - - with quant_ctx(self.model): model, _ = self.quantize() - # Save the quantized model in the specified format_list folders = [] for format in format_list: diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 27ded472..497acd6e 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -16,8 +16,10 @@ # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/quantization/cache.py +import contextlib from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.cache_utils import DynamicCache @@ -105,7 +107,9 @@ def __new__(cls, *args, **kwargs): cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) return cls._instance - def __init__(self): + def __init__(self, dtype: torch.dtype = torch.float8_e4m3fn): + + assert dtype == torch.float8_e4m3fn, "Only fp8_e4m3fn is supported for now." if not self._initialized: super().__init__() @@ -174,7 +178,7 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_ return qdq_tensor -def initialize_quantized_kv_cache(module: torch.nn.Module): +def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4m3fn): """ Initialize a quantized kv_cache on a module (analogous to initializing an observer) """ @@ -185,7 +189,7 @@ def initialize_quantized_kv_cache(module: torch.nn.Module): if isinstance(existing_kv_cache, QuantizedKVParameterCache): return - quantized_kv_cache = QuantizedKVParameterCache() + quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype) setattr(module, "kv_cache", quantized_kv_cache) logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") @@ -257,16 +261,39 @@ def prep_attention_module_for_calibration(module: torch.nn.Module): module.register_forward_hook(calibrate_kv_cache_output_hook) -import contextlib +def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype: + valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] + valid_torch_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float32": torch.float32, + "float": torch.float32, # Alias for float32 + } + if static_kv_dtype in valid_dtype_name_lst: + new_dtype = valid_torch_dtype[static_kv_dtype] + elif static_kv_dtype in valid_torch_dtype.values(): + new_dtype = static_kv_dtype + else: + raise ValueError( + f"Invalid static kv dtype: {static_kv_dtype}. " + f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." + ) + return new_dtype @contextlib.contextmanager -def fp8_kv_context(model: torch.nn.Module): +def fp8_kv_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): """Context manager for FP8 KV cache quantization operations.""" try: # Setup phase: Initialize KV cache for quantization - model.apply(initialize_quantized_kv_cache) - model.apply(prep_attention_module_for_calibration) + static_kv_dtype = normalize_static_kv_dtype(static_kv_dtype) + if static_kv_dtype != torch.float8_e4m3fn: + logger.warning(f"Ignoring static kv dtype {static_kv_dtype}, only fp8_e4m3fn is supported.") + else: + initialize_fn = partial(initialize_quantized_kv_cache, dtype=static_kv_dtype) + model.apply(initialize_fn) + model.apply(prep_attention_module_for_calibration) # Provide the model to the with block yield model diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index a67ad0da..9b4593e4 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -201,8 +201,8 @@ def test_autoround_3bit_sym_format(self): print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) shutil.rmtree(quantized_model_path, ignore_errors=True) - @parameterized.expand([(True,), (False,)]) - def test_static_afp8_export(self, enable_static_fp8_kv): + @parameterized.expand([(None,), ("fp8",), ("float16")]) + def test_static_afp8_export(self, static_kv_dtype): import os from safetensors import safe_open @@ -221,7 +221,7 @@ def test_static_afp8_export(self, enable_static_fp8_kv): act_data_type="fp8", act_dynamic=False, act_group_size=0, - enable_static_fp8_kv=enable_static_fp8_kv, + static_kv_dtype=static_kv_dtype, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") @@ -231,7 +231,7 @@ def test_static_afp8_export(self, enable_static_fp8_kv): self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) - if enable_static_fp8_kv: + if static_kv_dtype: self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) From 30c62552f707db398f7bcef99c9041bdba02e040 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 04:44:54 -0400 Subject: [PATCH 13/15] fix ut Signed-off-by: yiliu30 --- auto_round/experimental/fp8_kv_cache.py | 1 + test/test_cpu/test_export.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/fp8_kv_cache.py index 497acd6e..2536407e 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/fp8_kv_cache.py @@ -267,6 +267,7 @@ def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch "float16": torch.float16, "bfloat16": torch.bfloat16, "fp8": torch.float8_e4m3fn, + "float8_e4m3fn": torch.float8_e4m3fn, "float32": torch.float32, "float": torch.float32, # Alias for float32 } diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 9b4593e4..72a1e68c 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -231,7 +231,7 @@ def test_static_afp8_export(self, static_kv_dtype): self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) - if static_kv_dtype: + if static_kv_dtype == "fp8": self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) From 9309c74b136e6b21b767fc0d29aa941c601f8ca5 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 20 Aug 2025 04:47:50 -0400 Subject: [PATCH 14/15] rename file Signed-off-by: yiliu30 --- auto_round/autoround.py | 4 ++-- auto_round/experimental/{fp8_kv_cache.py => kv_cache.py} | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename auto_round/experimental/{fp8_kv_cache.py => kv_cache.py} (98%) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 5c3f3a9f..68363d9d 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -750,9 +750,9 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au # Perform model quantization if self.static_kv_dtype is not None: - from auto_round.experimental.fp8_kv_cache import fp8_kv_context + from auto_round.experimental.kv_cache import kvcache_quant_context - with fp8_kv_context(self.model, static_kv_dtype=self.static_kv_dtype): + with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype): model, _ = self.quantize() else: model, _ = self.quantize() diff --git a/auto_round/experimental/fp8_kv_cache.py b/auto_round/experimental/kv_cache.py similarity index 98% rename from auto_round/experimental/fp8_kv_cache.py rename to auto_round/experimental/kv_cache.py index 2536407e..8a49f307 100644 --- a/auto_round/experimental/fp8_kv_cache.py +++ b/auto_round/experimental/kv_cache.py @@ -30,7 +30,7 @@ "initialize_quantized_kv_cache", "prep_attention_module_for_calibration", "freeze_module_quantization_", - "fp8_kv_context", + "kvcache_quant_context", ] @@ -284,7 +284,7 @@ def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch @contextlib.contextmanager -def fp8_kv_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): +def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): """Context manager for FP8 KV cache quantization operations.""" try: # Setup phase: Initialize KV cache for quantization From 90edf9a117a7441530a9c0fedf499db1ddfea07e Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 20 Aug 2025 19:19:39 +0800 Subject: [PATCH 15/15] Update requirements.txt --- test/test_cpu/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index 265ba7e9..9ce86165 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -3,3 +3,4 @@ modelscope gguf torchvision compressed-tensors +parameterized