Skip to content

Commit 64f32bc

Browse files
committed
Option to dequantize at load-time
1 parent b585395 commit 64f32bc

File tree

5 files changed

+169
-55
lines changed

5 files changed

+169
-55
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ slower as well over time.
125125
- [ ] Options for trading off memory usage for more performance (e.g. float32 tensors)
126126
- [ ] Provide alternative backend to allow layers on CPU
127127
- [ ] Fused QKV projection and fused MLP
128-
- [ ] Support for de-quantizing select matrices at load time
128+
- [x] Support for de-quantizing select matrices at load time
129129
- [ ] A web interface maybe?
130130
- [x] Memory-efficient beam search implementation
131131
- [ ] Optimized beam search
@@ -156,4 +156,10 @@ models. Noticeably faster now.
156156
**2023-05-21**: Added beam search implementation. It doesn't process beams in parallel which saves a lot of VRAM but
157157
does slow it down a bit. There should be ways to mitigate the slowdown. It's not clear how much better beam search
158158
performs in practice, but it's at least theoretically superior and there are other features coming which will build
159-
on it, like multi-token repetition penalties and (de-)censoring.
159+
on it, like multi-token repetition penalties and (de-)censoring.
160+
161+
**2023-05-22**: Added option to auto-split layers across multiple GPUs based on VRAM allocation.
162+
163+
**2023-05-22**: Added option to dequantize layers at load-time which _should_ speed up inference, but it turns out
164+
Torch's fp16 matmul is actually slower than the quantized matmul. Maybe bandwidth is the only bottleneck right now?
165+
Need to experiment some more.

cuda_ext.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ def _matmul_q4v2_matmul(x, w, scales, zeros, seq_g_idx, x_map):
7070
return output.reshape(outshape)
7171

7272

73-
def _matmul_q4v2_recons(x, w, scales, zeros, seq_g_idx, x_map, transpose = False):
73+
def _matmul_q4v2_recons(x, w, scales, zeros, seq_g_idx, x_map):
7474

75-
if not transpose: assert w.shape[0] * 8 == x.shape[-1]
76-
else: assert w.shape[1] == x.shape[-1]
75+
assert w.shape[0] * 8 == x.shape[-1]
7776

7877
qweight_recons = torch.empty((w.shape[0] * 8, w.shape[1]), dtype = torch.float16, device = w.device)
7978
q4v2_recons(w, qweight_recons, scales, zeros, seq_g_idx if seq_g_idx is not None else none_tensor)
@@ -88,11 +87,32 @@ def _matmul_q4v2_recons(x, w, scales, zeros, seq_g_idx, x_map, transpose = False
8887
column_remap(x, x_mapped, x_map)
8988
x = x_mapped.reshape(x_shape)
9089

91-
output = torch.matmul(x, qweight_recons.T if transpose else qweight_recons)
90+
output = torch.matmul(x, qweight_recons)
9291

9392
return output
9493

9594

95+
# Reconstruct fp16 matrix from 4-bit matrix
96+
97+
def dequantize_q4v2(quant_args):
98+
99+
w = quant_args["qweight"]
100+
scales = quant_args["scales"]
101+
zeros = quant_args["zeros"]
102+
seq_g_idx = quant_args["seq_g_idx"]
103+
x_map = quant_args["x_map"]
104+
105+
qweight_recons = torch.empty((w.shape[0] * 8, w.shape[1]), dtype = torch.float16, device = w.device)
106+
q4v2_recons(w, qweight_recons, scales, zeros, seq_g_idx if seq_g_idx is not None else none_tensor)
107+
108+
if x_map is not None:
109+
110+
# TODO un-unshuffle rows in qweight_recons
111+
raise ValueError("Not implemented yet.")
112+
113+
return qweight_recons
114+
115+
96116
# Matrix multiplication, returns x @ 4-bit matrix (qweight, scales, zeros, g_idx)
97117

98118
def matmul_q4v2(x, quant_args, switch):

model.py

Lines changed: 110 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(self, model_config_path):
9595
self.mlp_method = self.MLPMethod.NORMAL # Currently no benefit to fused MLP
9696
self.device_map = ExLlamaDeviceMap(self.num_hidden_layers)
9797
self.auto_map = None # List of ints with memory allocation in GB, per CUDA device, overrides device_map
98+
self.dequant = None # Number of layers (per GPU) to de-quantize at load time
9899

99100

100101
# Parse and set list of GPU VRAM allocations
@@ -105,6 +106,14 @@ def set_auto_map(self, map_string):
105106
else: self.auto_map = [float(alloc) for alloc in map_string.split(",")]
106107

107108

109+
# Parse and set number of layers to de-quantize at load, per GPU
110+
111+
def set_dequant(self, dq_string):
112+
113+
if dq_string is None: self.dequant = None
114+
else: self.dequant = [int(alloc) for alloc in dq_string.split(",")]
115+
116+
108117
def _dump_tensor(t, name):
109118

110119
if t is None:
@@ -146,11 +155,12 @@ def _mlp_switch(config, x):
146155

147156
class Ex4bitLinear(nn.Module):
148157

149-
def __init__(self, config, in_features, out_features, has_bias, tensors, key):
158+
def __init__(self, config, in_features, out_features, has_bias, tensors, key, dequant = False):
150159
super().__init__()
151160

152161
self.config = config
153162
self.key = key
163+
self.dequant = dequant
154164

155165
self.in_features = in_features
156166
self.out_features = out_features
@@ -210,6 +220,17 @@ def __init__(self, config, in_features, out_features, has_bias, tensors, key):
210220

211221
if has_bias: self.bias = tensors[key + ".bias"]
212222

223+
# Optionally dequantize layer at init time
224+
225+
if self.dequant:
226+
227+
self.qweight_dequant = cuda_ext.dequantize_q4v2(self.quant_args())
228+
self.qweight = None
229+
self.scales = None
230+
self.zeros = None
231+
self.seq_g_idx = None
232+
self.x_map = None
233+
213234

214235
def quant_args(self):
215236

@@ -268,20 +289,26 @@ def load_streaming(self):
268289

269290
def forward(self, x):
270291

271-
if torch.is_grad_enabled():
292+
if self.dequant:
272293

273-
# Untested
274-
out = cuda_ext.ExAutogradMatmul4bitCuda.apply(x, self.qweight, self.scales, self.qzeros, self.groupsize, self.bits, self.maxq)
294+
out = torch.matmul(x, self.qweight_dequant)
275295

276296
else:
277297

278-
out = cuda_ext.matmul_q4v2(x, self.quant_args(), _matmul_switch(self.config, x))
279-
if self.bias is not None: out += self.bias
298+
if torch.is_grad_enabled():
280299

281-
# if self.key == "model.layers.0.mlp.gate_proj":
282-
#
283-
# _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
284-
# sys.exit()
300+
# Untested
301+
out = cuda_ext.ExAutogradMatmul4bitCuda.apply(x, self.qweight, self.scales, self.qzeros, self.groupsize, self.bits, self.maxq)
302+
303+
else:
304+
305+
out = cuda_ext.matmul_q4v2(x, self.quant_args(), _matmul_switch(self.config, x))
306+
if self.bias is not None: out += self.bias
307+
308+
# if self.key == "model.layers.0.mlp.gate_proj":
309+
#
310+
# _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
311+
# sys.exit()
285312

286313
return out
287314

@@ -300,20 +327,22 @@ def dump(self, filename):
300327

301328
class ExLlamaMLP(nn.Module):
302329

303-
def __init__(self, config, tensors, key):
330+
def __init__(self, config, tensors, key, dequant = False):
304331
super().__init__()
305332

306333
self.config = config
334+
self.dequant = dequant
307335

308-
self.gate_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".gate_proj")
309-
self.up_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".up_proj")
310-
self.down_proj = Ex4bitLinear(config, self.config.intermediate_size, self.config.hidden_size, False, tensors, key + ".down_proj")
336+
self.gate_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".gate_proj", dequant = dequant)
337+
self.up_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".up_proj", dequant = dequant)
338+
self.down_proj = Ex4bitLinear(config, self.config.intermediate_size, self.config.hidden_size, False, tensors, key + ".down_proj", dequant = dequant)
311339

312340
self.act_fn = nn.SiLU()
313341

314342

315343
def forward_fused(self, x, rms_norm_weight, buffer):
316344

345+
assert not self.dequant
317346
x = cuda_ext.mlp_q4v2(x,
318347
buffer.x_temp,
319348
buffer.x_col_temp,
@@ -367,18 +396,18 @@ def forward(self, hidden_states, buffer):
367396

368397
class ExLlamaAttention(nn.Module):
369398

370-
def __init__(self, config, tensors, key, sin, cos, index):
399+
def __init__(self, config, tensors, key, sin, cos, index, dequant = False):
371400
super().__init__()
372401

373402
self.config = config
374403
self.sin = sin
375404
self.cos = cos
376405
self.index = index
377406

378-
self.q_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".q_proj")
379-
self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".k_proj")
380-
self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".v_proj")
381-
self.o_proj = Ex4bitLinear(config, self.config.num_attention_heads * self.config.head_dim, self.config.hidden_size, False, tensors, key + ".o_proj")
407+
self.q_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".q_proj", dequant = dequant)
408+
self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".k_proj", dequant = dequant)
409+
self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".v_proj", dequant = dequant)
410+
self.o_proj = Ex4bitLinear(config, self.config.num_attention_heads * self.config.head_dim, self.config.hidden_size, False, tensors, key + ".o_proj", dequant = dequant)
382411

383412

384413
def forward(self, hidden_states, cache, buffer):
@@ -467,14 +496,14 @@ def rotate_half(x):
467496

468497
class ExLlamaDecoderLayer(nn.Module):
469498

470-
def __init__(self, config, tensors, key, index, sin, cos):
499+
def __init__(self, config, tensors, key, index, sin, cos, dequant = False):
471500
super().__init__()
472501

473502
self.config = config
474503
self.index = index
475504

476-
self.self_attn = ExLlamaAttention(self.config, tensors, key + ".self_attn", sin, cos, self.index)
477-
self.mlp = ExLlamaMLP(self.config, tensors, key + ".mlp")
505+
self.self_attn = ExLlamaAttention(self.config, tensors, key + ".self_attn", sin, cos, self.index, dequant = dequant)
506+
self.mlp = ExLlamaMLP(self.config, tensors, key + ".mlp", dequant = dequant)
478507

479508
self.input_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".input_layernorm.weight")
480509
self.post_attention_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".post_attention_layernorm.weight")
@@ -487,7 +516,9 @@ def forward(self, hidden_states, cache, buffer):
487516
hidden_states = self.self_attn(hidden_states, cache, buffer)
488517
hidden_states = residual + hidden_states
489518

490-
if _mlp_switch(self.config, hidden_states):
519+
# TODO: Support dequantized layer in fused MLP. Also, finish implementing fused MLP
520+
521+
if self.mlp.dequant or _mlp_switch(self.config, hidden_states):
491522

492523
residual = hidden_states
493524
hidden_states = self.post_attention_layernorm(hidden_states, buffer)
@@ -741,6 +772,17 @@ def to(self, device):
741772
return new
742773

743774

775+
def _device_to_int(device):
776+
777+
return int(device[device.find(":") + 1:])
778+
779+
def _skip_key(key):
780+
781+
if key.endswith("_proj.bias"): return True
782+
if key.endswith(".rotary_emb.inv_freq"): return True
783+
return False
784+
785+
744786
class ExLlama(nn.Module):
745787

746788
def __init__(self, config):
@@ -762,8 +804,10 @@ def __init__(self, config):
762804
# Begin auto mapping if enabled
763805

764806
decoder_size = 0
807+
decoder_dq_size = 0
765808
norm_size = 0
766809
head_size = 0
810+
half_element_size = torch.tensor([], dtype = torch.float16).element_size()
767811

768812
if self.config.auto_map is not None:
769813

@@ -772,9 +816,15 @@ def __init__(self, config):
772816

773817
for key in f.keys():
774818

819+
if _skip_key(key): continue
820+
775821
if key.startswith("model.layers.0."):
776822
tensor = f.get_tensor(key)
777823
decoder_size += tensor.numel() * tensor.element_size()
824+
if key.endswith(".weight"):
825+
decoder_dq_size += tensor.numel() * tensor.element_size()
826+
if key.endswith(".qweight"):
827+
decoder_dq_size += tensor.numel() * 8 * half_element_size
778828

779829
if key.startswith("model.norm."):
780830
tensor = f.get_tensor(key)
@@ -784,37 +834,40 @@ def __init__(self, config):
784834
tensor = f.get_tensor(key)
785835
head_size += tensor.numel() * tensor.element_size()
786836

787-
# Assign layers automatically
837+
# Assign layers automatically
788838

789-
device_usage = 0
790-
device_index = 0
791-
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
839+
device_usage = 0
840+
device_index = 0
841+
layer_index_device = 0
842+
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
792843

793-
for layer in range(self.config.num_hidden_layers + 2):
844+
for layer in range(self.config.num_hidden_layers + 2):
794845

795-
this_layer_size = decoder_size
796-
if layer == self.config.num_hidden_layers + 0: this_layer_size = norm_size
797-
if layer == self.config.num_hidden_layers + 1: this_layer_size = head_size
846+
this_layer_size = decoder_size
847+
if layer == self.config.num_hidden_layers + 0: this_layer_size = norm_size
848+
elif layer == self.config.num_hidden_layers + 1: this_layer_size = head_size
849+
elif self.config.dequant is not None and layer_index_device < self.config.dequant[device_index]: this_layer_size = decoder_dq_size
798850

799-
while device_usage + this_layer_size > max_usage:
800-
device_index += 1
801-
device_usage = 0
802-
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
803-
if device_index >= len(self.config.auto_map): raise ValueError("Model too large for device allocation scheme.")
851+
while device_usage + this_layer_size > max_usage:
852+
device_index += 1
853+
device_usage = 0
854+
layer_index_device = 0
855+
max_usage = self.config.auto_map[device_index] * (1024 ** 3)
856+
if device_index >= len(self.config.auto_map): raise ValueError("Model too large for device allocation scheme.")
804857

805-
target = f"cuda:{device_index}"
806-
if layer == self.config.num_hidden_layers + 0: self.config.device_map.norm = target
807-
elif layer == self.config.num_hidden_layers + 1: self.config.device_map.lm_head = target
808-
else: self.config.device_map.layers[layer] = f"cuda:{device_index}"
858+
target = f"cuda:{device_index}"
859+
if layer == self.config.num_hidden_layers + 0: self.config.device_map.norm = target
860+
elif layer == self.config.num_hidden_layers + 1: self.config.device_map.lm_head = target
861+
else: self.config.device_map.layers[layer] = f"cuda:{device_index}"
809862

810-
device_usage += this_layer_size
863+
device_usage += this_layer_size
864+
layer_index_device += 1
811865

812-
# Load tensors to
866+
# Load tensors, move to device(s)
813867

814868
for key in f.keys():
815869

816-
if key.endswith("_proj.bias"): continue # Skip loading unused, empty bias tensors
817-
if key.endswith(".rotary_emb.inv_freq"): continue # This is always precomputed during init anyway
870+
if _skip_key(key): continue
818871

819872
device = self.config.device_map.map(key, loading = True)
820873
tensor = f.get_tensor(key)
@@ -845,8 +898,10 @@ def __init__(self, config):
845898

846899
# Prepare position embeddings for max seq length
847900

901+
devs = self.config.device_map.get_layers_devs()
902+
848903
self.sincos = {}
849-
for device in self.config.device_map.get_layers_devs():
904+
for device in devs:
850905

851906
inv_freq = 1.0 / (self.config.rotary_embedding_base ** (torch.arange(0, self.config.head_dim, 2, device = device).float() / self.config.head_dim))
852907
t = torch.arange(self.config.max_seq_len, device = device, dtype = torch.float32)
@@ -863,12 +918,21 @@ def __init__(self, config):
863918
layer_streaming = self.config.stream_layer_interval > 0
864919

865920
modules = []
921+
device_layer_index = [0] * len(devs)
922+
866923
for i in range(self.config.num_hidden_layers):
867924

868925
device = self.config.device_map.layers[i]
869926
sin, cos = self.sincos[device]
870927

871-
layer = ExLlamaDecoderLayer(self.config, tensors, f"model.layers.{i}", i, sin, cos)
928+
dequant = False
929+
if self.config.dequant is not None:
930+
device_idx = _device_to_int(device)
931+
device_layer = device_layer_index[device_idx]
932+
device_layer_index[device_idx] += 1
933+
if device_layer < self.config.dequant[device_idx]: dequant = True
934+
935+
layer = ExLlamaDecoderLayer(self.config, tensors, f"model.layers.{i}", i, sin, cos, dequant = dequant)
872936

873937
if layer_streaming and i > 0 and (i + 1) % self.config.stream_layer_interval == 0:
874938
if self.stream_buffer is None: self.stream_buffer = ExLlamaStreamer(self.config, layer) # Use first layer as prototype

0 commit comments

Comments
 (0)