Skip to content

Support loading for static quant weight fp8 act fp8 #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions auto_round/export/export_to_autoround/export_to_fp8_woq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Union

import threadpoolctl as tctl
import torch
Expand Down Expand Up @@ -83,6 +84,95 @@ def __init__(
self.register_buffer("input_scale", input_scale.to(dtype))


def quant_tensor_with_scale(tensor, scale):
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
qtensor = tensor / scale
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE)
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn)
return scale, clipped_qtensor_fp8


class WeightFP8ActFP8StaticQuantLinear(torch.nn.Module):
hp_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn

def __init__(
self,
in_features,
out_features,
weight: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
bias: Union[torch.Tensor, bool, None] = None,
weight_zp: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
dtype=torch.bfloat16,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_weight = torch.empty((out_features, in_features), dtype=dtype) if weight is None else weight
self.weight = torch.nn.Parameter(init_weight, requires_grad=False)
self.dtype = dtype
if bias is not None:
if isinstance(bias, bool):
bias = torch.zeros((out_features,), dtype=dtype)
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.register_parameter("bias", None)
init_weight_scale = torch.empty((out_features, 1), dtype=dtype) if weight_scale is None else weight_scale
self.register_buffer("weight_scale", init_weight_scale.to(dtype))

init_weight_zp = torch.zeros((out_features, 1), dtype=dtype) if weight_zp is None else weight_zp
if weight_zp:
self.register_buffer("weight_zp", init_weight_zp.to(dtype))

init_input_scale = torch.zeros((1, 1), dtype=dtype) if input_scale is None else input_scale
self.register_buffer("input_scale", init_input_scale.to(dtype))
self.pre_dequantized = False

@classmethod
def from_original(cls, config, original_layer):
"""
Create an WeightFP8ActFP8StaticQuantLinear layer from an original linear layer.
"""
device = original_layer.weight.device
with torch.device(device):
qdq_linear = cls(
in_features=original_layer.in_features,
out_features=original_layer.out_features,
bias=original_layer.bias,
)
return qdq_linear

def dequant_weight_online(self):
if self.pre_dequantized:
return self.weight
fp8_weight = self.weight
qdq_weight = fp8_weight.to(self.dtype) * self.weight_scale
return qdq_weight

def pre_dequantize(self):
if self.pre_dequantized:
return
dequant_weight = self.dequant_weight_online()
del self.weight
del self.weight_scale
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False)
self.pre_dequantized = True

def qdq_input(self, bf16_input: torch.Tensor):
input_scale, input_fp8 = quant_tensor_with_scale(bf16_input, self.input_scale.data)
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale
return qdq_input_bf16

@torch.no_grad()
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
return out


def pack_layer(layer_name, model, data_type, packing_device=None):
"""
Packs a model layer for quantization based on its type and configuration.
Expand Down
20 changes: 19 additions & 1 deletion auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,18 @@ def check_compatible(
return True


def dynamic_import_inference_linear(backend, bits, group_size, sym):
def is_weight_fp8_activation_static_fp8(config):
bits, group_size, sym, data_type, act_dynamic = (
config["bits"],
config["group_size"],
config["sym"],
config["data_type"],
config["act_dynamic"],
)
return bits == 8 and group_size == -1 and sym and data_type == "fp8" and not act_dynamic


def dynamic_import_inference_linear(backend, config):
"""Dynamically imports and returns the appropriate QuantLinear class based on the given backend.
This function dynamically loads the correct `QuantLinear` class based on the backend and quantization
Expand All @@ -435,6 +446,13 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
ImportError:
If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq).
"""
bits, group_size, sym = config["bits"], config["group_size"], config["sym"]

if is_weight_fp8_activation_static_fp8(config):
from auto_round.export.export_to_autoround.export_to_fp8_woq import WeightFP8ActFP8StaticQuantLinear

return WeightFP8ActFP8StaticQuantLinear

if "qbits" in backend:
try:
from intel_extension_for_transformers import qbits # pylint: disable=E0401
Expand Down
14 changes: 9 additions & 5 deletions auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
find_backend,
get_highest_priority_backend,
get_layer_backend,
is_weight_fp8_activation_static_fp8,
process_requirement,
)
from auto_round.utils import (
Expand Down Expand Up @@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf
try: # transformers new api
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True)
except:
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert)
modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert)
if modules_to_not_convert:
for layer_name in layer_names:
if any([re.search(re.compile(n), layer_name) for n in modules_to_not_convert]):
Expand Down Expand Up @@ -219,6 +220,7 @@ def get_layer_config(model, quantization_config):
- group_size (int): Group size for weight quantization.
- data_type (str, optional): Data type for quantization (default: "int").
- sym (bool): Whether to use symmetric quantization.
- act_dynamic (bool, optional): Whether to use dynamic activation quantization (default: False).
- quant_block_list (list, optional): Predefined list of blocks to quantize.
- to_quant_block_names (list or str, optional): Blocks to quantize (if quant_block_list is None).
- extra_config (dict, optional): Per-layer overrides for quantization settings.
Expand All @@ -231,13 +233,14 @@ def get_layer_config(model, quantization_config):
- "group_size" (int): Group size for quantization.
- "data_type" (str): Data type used for quantization.
- "sym" (bool): Whether symmetric quantization is applied.
- "act_dynamic" (bool): Whether dynamic activation quantization is used.
- "clip" (bool): Whether weight clipping is enabled.
"""
bits = quantization_config.bits
group_size = quantization_config.group_size
data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified
sym = quantization_config.sym

act_dynamic = getattr(quantization_config, "act_dynamic", False)
# Determine the quantization block list
quant_block_list = getattr(quantization_config, "quant_block_list", None)
if quant_block_list is None:
Expand Down Expand Up @@ -290,11 +293,11 @@ def get_layer_config(model, quantization_config):
"group_size": extra_config.get(layer_name, {}).get("group_size", group_size),
"data_type": extra_config.get(layer_name, {}).get("data_type", data_type),
"sym": extra_config.get(layer_name, {}).get("sym", sym),
"act_dynamic": extra_config.get(layer_name, {}).get("act_dynamic", act_dynamic),
"clip": extra_config.get(layer_name, {}).get("clip", False),
}
for layer_name in layer_names
}

return layer_configs


Expand Down Expand Up @@ -415,7 +418,7 @@ def _import_exllamav2_kernels():

def _create_quant_layer(layer, layer_backend, config, in_features, out_features):
"""Creates a quantized layer using the appropriate class."""
QuantLinear = dynamic_import_inference_linear(layer_backend, config["bits"], config["group_size"], config["sym"])
QuantLinear = dynamic_import_inference_linear(layer_backend, config)
bias = layer.bias is not None

# Special handling for AWQ layers
Expand All @@ -437,6 +440,8 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features)
out_features=out_features,
bias=bias,
)
elif is_weight_fp8_activation_static_fp8(config):
return QuantLinear.from_original(config, layer)
# Default quantized layer creation
try:
return QuantLinear(
Expand Down Expand Up @@ -588,7 +593,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"):
backend = backend[len("auto_round:") :]

used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, orig_backend)

if backend == "auto" or backend == "":
best_backend = get_highest_priority_backend(
quantization_config.bits,
Expand Down
28 changes: 27 additions & 1 deletion test/test_cpu/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_autoround_3bit_sym_format(self):
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_static_afp8_export(self):
def test_static_afp8_export_and_load(self):
import os

from safetensors import safe_open
Expand All @@ -226,6 +226,32 @@ def test_static_afp8_export(self):
self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys())
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1]))
self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn)
with torch.no_grad():
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained(
quantized_model_path,
torch_dtype="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model.eval()
assert (
model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ == "WeightFP8ActFP8StaticQuantLinear"
), f"Expected WeightFP8ActFP8StaticQuantLinear, got {model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__}"
tokenizer = transformers.AutoTokenizer.from_pretrained(quantized_model_path)
prompt = "AI is "
encode = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output_tokens = model.generate(
encode,
max_length=10,
)
output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print(f"Output: {output}")
assert output is not None, "Output should not be None"

shutil.rmtree(quantized_model_path, ignore_errors=True)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
Expand Down