Skip to content

Commit

Permalink
Use symmetric quantization in the quantize subcommand
Browse files Browse the repository at this point in the history
Packing of asymmetric quantization is broken, all (q)zeros values
of `0` get reset to `1`, resulting in a loss of accuracy. So instead
use symmetric quantization. To be able to distinguish models with
symmetric and asymmetric quantization, a new config tensor `gptq_sym` is
added. If this tensor is not present, we assume `sym=False`.
  • Loading branch information
danieldk committed Jul 4, 2024
1 parent 245d3de commit 747a3c7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
1 change: 1 addition & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def quantize(
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
sym=True,
)


Expand Down
3 changes: 3 additions & 0 deletions server/text_generation_server/layers/gptq/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def quantize(
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
sym: bool,
):
print("loading model")
config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -943,6 +944,7 @@ def _unload():
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
sym=sym,
)
print(time.time() - tick)

Expand All @@ -954,6 +956,7 @@ def _unload():
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])

max_shard_size = "10GB"
shards, index = shard_checkpoint(
Expand Down
34 changes: 22 additions & 12 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _get_slice(self, tensor_name: str):
slice_ = f.get_slice(tensor_name)
return slice_

def _has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True

def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()

Expand Down Expand Up @@ -717,23 +724,26 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
return weight

def _get_gptq_params(self) -> GPTQParams:
try:
if self._has_tensor("gptq_bits") and self._has_tensor("gptq_groupsize"):
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False
sym = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
sym = (
self.get_tensor("gptq_sym").item()
if self._has_tensor("gptq_sym")
else False
)
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
else:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)

return GPTQParams(
bits=bits,
Expand Down

0 comments on commit 747a3c7

Please sign in to comment.