diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index d02178d6cd3..557656e786e 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -513,7 +513,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str): "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - s = self.get_sharded(f"{prefix}.s", dim=0) + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when group_size == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) else: