Skip to content

Commit

Permalink
marlin: support tp>1 when group_size==-1
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk authored and Daniël de Kok committed Jun 6, 2024
1 parent 4594e6f commit 0d96468
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)

s = self.get_sharded(f"{prefix}.s", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)

else:
Expand Down

0 comments on commit 0d96468

Please sign in to comment.