From d2eeaacb1e07572a401a3957145d88eaf9f3edea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 26 Jun 2024 09:18:31 +0200 Subject: [PATCH] Use symmetric quantization in the `quantize` subcommand Packing of asymmetric quantization is broken, all (q)zeros values of `0` get reset to `1`, resulting in a loss of accuracy. So instead use symmetric quantization. To be able to distinguish models with symmetric and asymmetric quantization, a new config tensor `gptq_sym` is added. If this tensor is not present, we assume `sym=False`. --- server/text_generation_server/cli.py | 1 + .../layers/gptq/quantize.py | 3 ++ .../text_generation_server/utils/weights.py | 34 ++++++++++++------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd7e0..71ad18f7920 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817a39..ff840f36c8d 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -869,6 +869,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -943,6 +944,7 @@ def _unload(): percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -954,6 +956,7 @@ def _unload(): state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 348d215cdbc..c2558c722af 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -79,6 +79,13 @@ def _get_slice(self, tensor_name: str): slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() @@ -749,23 +756,26 @@ def get_multi_weights_row(self, prefix: str, quantize: str): return weight def _get_gptq_params(self) -> _GPTQParams: - try: + if self._has_tensor("gptq_bits") and self._has_tensor("gptq_groupsize"): bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = False - sym = True + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + sym = ( + self.get_tensor("gptq_sym").item() + if self._has_tensor("gptq_sym") + else False + ) quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e + else: + bits = self.gptq_bits + groupsize = self.gptq_groupsize + checkpoint_format = getattr(self, "gptq_checkpoint_format", None) + desc_act = getattr(self, "gptq_desc_act", False) + quant_method = getattr(self, "quant_method", "gptq") + sym = getattr(self, "sym", True) return _GPTQParams( bits=bits,