Skip to content

Commit 337afb2

Browse files
fix(server): fix bnb quantization for CausalLM models (#385)
1 parent 87dc034 commit 337afb2

File tree

5 files changed

+15
-3
lines changed

5 files changed

+15
-3
lines changed

server/text_generation_server/models/bloom.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def linear(input, weight, bias):
245245
return linear
246246

247247
module.linear = replace_linear(state)
248+
else:
249+
tensor = tensor.to(device)
248250
elif quantize == "gptq":
249251
raise NotImplementedError("`gptq` is not implemented for now")
250252
elif quantize is None:

server/text_generation_server/models/galactica.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def linear(input, weight, bias):
364364
return linear
365365

366366
module.linear = replace_linear(state)
367+
else:
368+
tensor = tensor.to(device)
367369
elif quantize == "gptq":
368370
raise NotImplementedError("`gptq` is not implemented for now")
369371
elif quantize is None:

server/text_generation_server/models/gpt_neox.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def linear(input, weight, bias):
210210
return linear
211211

212212
module.linear = replace_linear(state)
213+
else:
214+
tensor = tensor.to(device)
213215
elif quantize == "gptq":
214216
raise NotImplementedError("`gptq` is not implemented for now")
215217
elif quantize is None:

server/text_generation_server/models/opt.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def load_weights(
166166

167167
tensor = tensor.contiguous().to(dtype)
168168

169-
if quantize:
169+
if quantize == "bitsandbytes":
170170
if not HAS_BITS_AND_BYTES:
171171
raise ImportError(
172172
"bitsandbytes is not available on your machine either because it is not installed "
@@ -216,9 +216,14 @@ def linear(input, weight, bias):
216216
return linear
217217

218218
module.linear = replace_linear(state)
219-
220219
else:
221220
tensor = tensor.to(device)
221+
elif quantize == "gptq":
222+
raise NotImplementedError("`gptq` is not implemented for now")
223+
elif quantize is None:
224+
tensor = tensor.to(device)
225+
else:
226+
raise ValueError(f"Unexpected quantize `{quantize}`")
222227

223228
module._parameters[param_name] = tensor
224229
if name == "model.decoder.embed_tokens.weight":

server/text_generation_server/models/t5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def linear(input, weight, bias):
222222
return linear
223223

224224
module.linear = replace_linear(state)
225-
225+
else:
226+
tensor = tensor.to(device)
226227
elif quantize == "gptq" and not module_name.endswith("wo"):
227228
raise NotImplementedError("`gptq` is not implemented for now")
228229
elif quantize is None or module_name.endswith("wo"):

0 commit comments

Comments
 (0)