From 0d96468ebb1ca0141d7a23b2fdfcef9a7ef7bb81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 6 Jun 2024 11:51:52 +0000 Subject: [PATCH] marlin: support tp>1 when group_size==-1 --- server/text_generation_server/utils/weights.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index d02178d6cd3..557656e786e 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -513,7 +513,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str): "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - s = self.get_sharded(f"{prefix}.s", dim=0) + num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when group_size == -1. share + # scales between all shards in this case. + s = self.get_tensor(f"{prefix}.s") + else: + s = self.get_sharded(f"{prefix}.s", dim=0) weight = MarlinWeight(B=B, s=s) else: