Skip to content

Add Static FP8 KV Support #737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import copy
import os
import re
Expand Down Expand Up @@ -296,6 +297,12 @@ def __init__(
f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}."
)

# kv cache
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
self.static_kv_dtype = static_kv_dtype
if self.static_kv_dtype is not None:
logger.warning("The static kv is experimental and currently has limited support.")

self.sampler = sampler
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
Expand Down Expand Up @@ -742,8 +749,13 @@ def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "au
kwargs.pop("inplace", None)

# Perform model quantization
model, _ = self.quantize()
if self.static_kv_dtype is not None:
from auto_round.experimental.kv_cache import kvcache_quant_context

with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):
model, _ = self.quantize()
else:
model, _ = self.quantize()
# Save the quantized model in the specified format_list
folders = []
for format in format_list:
Expand Down Expand Up @@ -1334,7 +1346,6 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) -
add_hook_to_module(m, hook, True)
else:
block = block.to(self.device)

input_ids = self.get_block_outputs(
block,
input_ids,
Expand Down
304 changes: 304 additions & 0 deletions auto_round/experimental/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTICE: The design adapted from:
# https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/quantization/cache.py


import contextlib
from enum import Enum
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers.cache_utils import DynamicCache

from auto_round.utils import logger

__all__ = [
"initialize_quantized_kv_cache",
"prep_attention_module_for_calibration",
"freeze_module_quantization_",
"kvcache_quant_context",
]


def freeze_module_quantization_(module: torch.nn.Module):
"""
deletes observers when calibration is complete.

apply to full model with `model.apply(freeze_module_quantization_)`

:param module: module to freeze quantization for
"""

# remove observers if needed
for name in ("input", "weight", "output"):
obs_name = f"{name}_observer"
if hasattr(module, obs_name):
delattr(module, obs_name)

# remove quantized kv_cache
kv_cache = getattr(module, "kv_cache", None)
if isinstance(kv_cache, QuantizedKVParameterCache):
delattr(module, "kv_cache")


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"


# NOTE: Using _ suffix to denote l is modified in place
def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
"""
Append value val to list lst at index idx, right padding if necessary
Needed because user may ignore some layers in configuration, meaning
len(lst) <= idx-1

>>> _pad_and_append_at_idx_([0,1,2], 5, 5)
[0, 1, 2, None, None, 5]
>>> _pad_and_append_at_idx_([0,1,2], 3, 8)
[0, 1, 2, 8]
>>> _pad_and_append_at_idx_([0,1,2], 1, 5)
[0, 5, 2]
"""
num_to_pad = idx - len(lst) + 1
if num_to_pad > 0:
lst += [None] * num_to_pad
lst[idx] = val
return lst


def fp8_per_tensor_qdq(tensor):
from auto_round.data_type.fp8 import quant_fp8_sym

qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
return qdq_tensor, scale


class QuantizedKVParameterCache(DynamicCache):
"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Singleton, so that the same cache gets reused in all forward call of self_attn.
Each time forward is called, .update() is called, and ._quant_dequant() gets called appropriately.
The size of tensor is
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.

"""

_instance = None
_initialized = False

def __new__(cls, *args, **kwargs):
"""Singleton"""
if cls._instance is None:
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
return cls._instance

def __init__(self, dtype: torch.dtype = torch.float8_e4m3fn):

assert dtype == torch.float8_e4m3fn, "Only fp8_e4m3fn is supported for now."
if not self._initialized:
super().__init__()

# each index corresponds to layer_idx of the attention layer
self.k_scales: List[torch.Tensor] = []
self.v_scales: List[torch.Tensor] = []
self._initialized = True

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the k_scale and v_scale and output the quant-dequant key_states and value_states
"""
qdq_key_states = self._quant_dequant(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx)
qdq_value_states = self._quant_dequant(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx)

keys_to_return, values_to_return = qdq_key_states, qdq_value_states

return keys_to_return, values_to_return

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states.
A layer index can be optionally passed.
"""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
# this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to
# verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

def reset_states(self):
"""reset the kv states (used in calibration)"""
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []

def reset(self):
"""
Reset the instantiation, create new instance on init
"""
QuantizedKVParameterCache._instance = None
QuantizedKVParameterCache._initialized = False

def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int):
"""Quantizes a key/value using a defined quantization method."""
if kv_type == KVCacheScaleType.KEY: # key type
scales = self.k_scales
else:
assert kv_type == KVCacheScaleType.VALUE
scales = self.v_scales

qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale)
return qdq_tensor


def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4m3fn):
"""
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
"""
if not is_attention_module(module):
return
existing_kv_cache = getattr(module, "kv_cache", None)

if isinstance(existing_kv_cache, QuantizedKVParameterCache):
return

quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
setattr(module, "kv_cache", quantized_kv_cache)
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")


def is_attention_module(module: torch.nn.Module):
# FIXME: Handle this better.
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
)


def calibrate_kv_cache_input_hook(
module: torch.nn.Module, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Hook to update inputs to attention layers when running
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
kv_cache = getattr(module, "kv_cache")
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
if "past_key_values" in kwargs:
kwargs["past_key_values"] = kv_cache
else:
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs


def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
"""
Update the data of a parameter in a module.
If the parameter does not exist, it will be created.
"""
if hasattr(module, name):
param = getattr(module, name)
if isinstance(param, torch.nn.Parameter):
param.data = new_val
else:
module.register_parameter(name, torch.nn.Parameter(new_val))
else:
logger.warning(
"Parameter %s not found in module %s, creating new parameter."
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
)
module.register_parameter(name, torch.nn.Parameter(new_val))


def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
logger.debug(
"Calibrate kv_cache output hook for %s %s"
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
)
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)


def prep_attention_module_for_calibration(module: torch.nn.Module):
if is_attention_module(module):
module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True)
module.register_forward_hook(calibrate_kv_cache_output_hook)


def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
valid_torch_dtype = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fp8": torch.float8_e4m3fn,
"float8_e4m3fn": torch.float8_e4m3fn,
"float32": torch.float32,
"float": torch.float32, # Alias for float32
}
if static_kv_dtype in valid_dtype_name_lst:
new_dtype = valid_torch_dtype[static_kv_dtype]
elif static_kv_dtype in valid_torch_dtype.values():
new_dtype = static_kv_dtype
else:
raise ValueError(
f"Invalid static kv dtype: {static_kv_dtype}. "
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
)
return new_dtype


@contextlib.contextmanager
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
"""Context manager for FP8 KV cache quantization operations."""
try:
# Setup phase: Initialize KV cache for quantization
static_kv_dtype = normalize_static_kv_dtype(static_kv_dtype)
if static_kv_dtype != torch.float8_e4m3fn:
logger.warning(f"Ignoring static kv dtype {static_kv_dtype}, only fp8_e4m3fn is supported.")
else:
initialize_fn = partial(initialize_quantized_kv_cache, dtype=static_kv_dtype)
model.apply(initialize_fn)
model.apply(prep_attention_module_for_calibration)

# Provide the model to the with block
yield model

finally:
# Cleanup phase: Freeze quantization parameters
model.apply(freeze_module_quantization_)
1 change: 1 addition & 0 deletions test/test_cpu/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ modelscope
gguf
torchvision
compressed-tensors
parameterized
13 changes: 12 additions & 1 deletion test/test_cpu/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
import unittest

from parameterized import parameterized

sys.path.insert(0, "../..")
import torch
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
Expand Down Expand Up @@ -199,7 +201,8 @@ def test_autoround_3bit_sym_format(self):
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_static_afp8_export(self):
@parameterized.expand([(None,), ("fp8",), ("float16")])
def test_static_afp8_export(self, static_kv_dtype):
import os

from safetensors import safe_open
Expand All @@ -218,6 +221,7 @@ def test_static_afp8_export(self):
act_data_type="fp8",
act_dynamic=False,
act_group_size=0,
static_kv_dtype=static_kv_dtype,
)
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
Expand All @@ -226,6 +230,13 @@ def test_static_afp8_export(self):
self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys())
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1]))
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn)

if static_kv_dtype == "fp8":
self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys())
self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys())
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1]))
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1, 1]))
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").dtype, torch.float32)
shutil.rmtree(quantized_model_path, ignore_errors=True)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
Expand Down
Loading