From 08baaa7aca2bb609e38069b52e9f1c616799f061 Mon Sep 17 00:00:00 2001 From: Danny Date: Sun, 19 Jan 2025 22:42:38 +0200 Subject: [PATCH] [SW-212057] Enable QDQ (#101) --- .../quantization_config/maxabs_quant_qdq.json | 9 +++++++++ .../habana/transformers/models/llama/modeling_llama.py | 4 ++++ 2 files changed, 13 insertions(+) create mode 100644 examples/text-generation/quantization_config/maxabs_quant_qdq.json diff --git a/examples/text-generation/quantization_config/maxabs_quant_qdq.json b/examples/text-generation/quantization_config/maxabs_quant_qdq.json new file mode 100644 index 0000000000..7b87c0d8d8 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_quant_qdq.json @@ -0,0 +1,9 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "scale_format": "SCALAR", + "dump_stats_path": "./hqt_output/measure", + "use_qdq": "True" +} \ No newline at end of file diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0afcfbe05a..a443ef68ed 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -490,6 +490,10 @@ def get_k_proj_weight_dtype(self): Scales tensor gets the weight dtype.""" if hasattr(self.k_proj, "qweight"): return self.k_proj.scales.dtype + elif hasattr(self.k_proj, "use_qdq") and self.k_proj.use_qdq: + return self.k_proj.dequant_weights.hp_dtype + elif isinstance(self.k_cache, KVCache) and "float8" in str(self.k_proj.weight.dtype): + return self.k_proj.scale_weight.dtype return self.k_proj.weight.dtype def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):