diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd7e0..71ad18f7920 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817a39..ff840f36c8d 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -869,6 +869,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -943,6 +944,7 @@ def _unload(): percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -954,6 +956,7 @@ def _unload(): state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 348d215cdbc..c2558c722af 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -79,6 +79,13 @@ def _get_slice(self, tensor_name: str): slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() @@ -749,23 +756,26 @@ def get_multi_weights_row(self, prefix: str, quantize: str): return weight def _get_gptq_params(self) -> _GPTQParams: - try: + if self._has_tensor("gptq_bits") and self._has_tensor("gptq_groupsize"): bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() checkpoint_format = getattr(self, "gptq_checkpoint_format", None) desc_act = False - sym = True + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + sym = ( + self.get_tensor("gptq_sym").item() + if self._has_tensor("gptq_sym") + else False + ) quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e + else: + bits = self.gptq_bits + groupsize = self.gptq_groupsize + checkpoint_format = getattr(self, "gptq_checkpoint_format", None) + desc_act = getattr(self, "gptq_desc_act", False) + quant_method = getattr(self, "quant_method", "gptq") + sym = getattr(self, "sym", True) return _GPTQParams( bits=bits,