diff --git a/auto_round/export/export_to_autoround/export_to_fp8_woq.py b/auto_round/export/export_to_autoround/export_to_fp8_woq.py index cca0626f..e7b47359 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8_woq.py +++ b/auto_round/export/export_to_autoround/export_to_fp8_woq.py @@ -16,6 +16,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Union import threadpoolctl as tctl import torch @@ -67,6 +68,95 @@ def __init__( self.register_buffer("input_scale", input_scale.to(dtype)) +def quant_tensor_with_scale(tensor, scale): + FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max + qtensor = tensor / scale + clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) + clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn) + return scale, clipped_qtensor_fp8 + + +class WeightFP8ActFP8StaticQuantLinear(torch.nn.Module): + hp_dtype = torch.bfloat16 + fp8_dtype = torch.float8_e4m3fn + + def __init__( + self, + in_features, + out_features, + weight: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, bool, None] = None, + weight_zp: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + dtype=torch.bfloat16, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight + self.weight = torch.nn.Parameter(init_weight, requires_grad=False) + self.dtype = dtype + if bias is not None: + if isinstance(bias, bool): + bias = torch.zeros((out_features,), dtype=dtype) + self.bias = torch.nn.Parameter(bias, requires_grad=False) + else: + self.register_parameter("bias", None) + init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale + self.register_buffer("weight_scale", init_weight_scale.to(dtype)) + + init_weight_zp = torch.zeros((out_features, 1), dtype=dtype) if weight_zp is None else weight_zp + if weight_zp: + self.register_buffer("weight_zp", init_weight_zp.to(dtype)) + + init_input_scale = torch.zeros((1,), dtype=dtype) if input_scale is None else input_scale + self.register_buffer("input_scale", init_input_scale.to(dtype)) + self.pre_dequantized = False + + @classmethod + def from_original(cls, config, original_layer): + """ + Create an WeightFP8ActFP8StaticQuantLinear layer from an original linear layer. + """ + device = original_layer.weight.device + with torch.device(device): + qdq_linear = cls( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias, + ) + return qdq_linear + + def dequant_weight_online(self): + if self.pre_dequantized: + return self.weight + fp8_weight = self.weight + qdq_weight = fp8_weight.to(self.dtype) * self.weight_scale.unsqueeze(1) + return qdq_weight + + def pre_dequantize(self): + if self.pre_dequantized: + return + dequant_weight = self.dequant_weight_online() + del self.weight + del self.weight_scale + self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) + self.pre_dequantized = True + + def qdq_input(self, bf16_input: torch.Tensor): + input_scale, input_fp8 = quant_tensor_with_scale(bf16_input, self.input_scale.data) + qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale + return qdq_input_bf16 + + @torch.no_grad() + def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: + qdq_input = self.qdq_input(bf16_input) + qdq_weight = self.dequant_weight_online() + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) + return out + + def pack_layer(layer_name, model, data_type, packing_device=None): """ Packs a model layer for quantization based on its type and configuration. diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index ea66a038..4e3f4286 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -413,7 +413,18 @@ def check_compatible( return True -def dynamic_import_inference_linear(backend, bits, group_size, sym): +def is_weight_fp8_activation_static_fp8(config): + bits, group_size, sym, data_type, act_dynamic = ( + config["bits"], + config["group_size"], + config["sym"], + config["data_type"], + config["act_dynamic"], + ) + return bits == 8 and group_size == -1 and sym and data_type == "fp8" and not act_dynamic + + +def dynamic_import_inference_linear(backend, config): """Dynamically imports and returns the appropriate QuantLinear class based on the given backend. This function dynamically loads the correct `QuantLinear` class based on the backend and quantization @@ -438,6 +449,13 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym): ImportError: If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq). """ + bits, group_size, sym = config["bits"], config["group_size"], config["sym"] + + if is_weight_fp8_activation_static_fp8(config): + from auto_round.export.export_to_autoround.export_to_fp8_woq import WeightFP8ActFP8StaticQuantLinear + + return WeightFP8ActFP8StaticQuantLinear + if "qbits" in backend: try: from intel_extension_for_transformers import qbits # pylint: disable=E0401 diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index bd6dde83..bd8b4621 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -27,6 +27,7 @@ find_backend, get_highest_priority_backend, get_layer_backend, + is_weight_fp8_activation_static_fp8, process_requirement, ) from auto_round.utils import ( @@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf try: # transformers new api modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True) except: - modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert) + modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert) if modules_to_not_convert: for layer_name in layer_names: if any([re.search(re.compile(n), layer_name) for n in modules_to_not_convert]): @@ -219,6 +220,7 @@ def get_layer_config(model, quantization_config): - group_size (int): Group size for weight quantization. - data_type (str, optional): Data type for quantization (default: "int"). - sym (bool): Whether to use symmetric quantization. + - act_dynamic (bool, optional): Whether to use dynamic activation quantization (default: False). - quant_block_list (list, optional): Predefined list of blocks to quantize. - to_quant_block_names (list or str, optional): Blocks to quantize (if quant_block_list is None). - extra_config (dict, optional): Per-layer overrides for quantization settings. @@ -231,13 +233,14 @@ def get_layer_config(model, quantization_config): - "group_size" (int): Group size for quantization. - "data_type" (str): Data type used for quantization. - "sym" (bool): Whether symmetric quantization is applied. + - "act_dynamic" (bool): Whether dynamic activation quantization is used. - "clip" (bool): Whether weight clipping is enabled. """ bits = quantization_config.bits group_size = quantization_config.group_size data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified sym = quantization_config.sym - + act_dynamic = getattr(quantization_config, "act_dynamic", False) # Determine the quantization block list quant_block_list = getattr(quantization_config, "quant_block_list", None) if quant_block_list is None: @@ -290,11 +293,11 @@ def get_layer_config(model, quantization_config): "group_size": extra_config.get(layer_name, {}).get("group_size", group_size), "data_type": extra_config.get(layer_name, {}).get("data_type", data_type), "sym": extra_config.get(layer_name, {}).get("sym", sym), + "act_dynamic": extra_config.get(layer_name, {}).get("act_dynamic", act_dynamic), "clip": extra_config.get(layer_name, {}).get("clip", False), } for layer_name in layer_names } - return layer_configs @@ -415,7 +418,7 @@ def _import_exllamav2_kernels(): def _create_quant_layer(layer, layer_backend, config, in_features, out_features): """Creates a quantized layer using the appropriate class.""" - QuantLinear = dynamic_import_inference_linear(layer_backend, config["bits"], config["group_size"], config["sym"]) + QuantLinear = dynamic_import_inference_linear(layer_backend, config) bias = layer.bias is not None # Special handling for AWQ layers @@ -437,6 +440,8 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) out_features=out_features, bias=bias, ) + elif is_weight_fp8_activation_static_fp8(config): + return QuantLinear.from_original(config, layer) # Default quantized layer creation try: return QuantLinear( @@ -588,7 +593,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"): backend = backend[len("auto_round:") :] used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, orig_backend) - if backend == "auto" or backend == "": best_backend = get_highest_priority_backend( quantization_config.bits, diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 72a1e68c..bb291a2a 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -230,6 +230,31 @@ def test_static_afp8_export(self, static_kv_dtype): self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + with torch.no_grad(): + import transformers + + model = transformers.AutoModelForCausalLM.from_pretrained( + quantized_model_path, + torch_dtype="auto", + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + model.eval() + assert ( + model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ == "WeightFP8ActFP8StaticQuantLinear" + ), f"Expected WeightFP8ActFP8StaticQuantLinear, got {model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__}" + tokenizer = transformers.AutoTokenizer.from_pretrained(quantized_model_path) + prompt = "AI is " + encode = tokenizer.encode(prompt, return_tensors="pt") + with torch.no_grad(): + output_tokens = model.generate( + encode, + max_length=10, + ) + output = tokenizer.decode(output_tokens[0], skip_special_tokens=True) + print(f"Prompt: {prompt}") + print(f"Output: {output}") + assert output is not None, "Output should not be None" if static_kv_dtype == "fp8": self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys())