@@ -95,6 +95,7 @@ def __init__(self, model_config_path):
95
95
self .mlp_method = self .MLPMethod .NORMAL # Currently no benefit to fused MLP
96
96
self .device_map = ExLlamaDeviceMap (self .num_hidden_layers )
97
97
self .auto_map = None # List of ints with memory allocation in GB, per CUDA device, overrides device_map
98
+ self .dequant = None # Number of layers (per GPU) to de-quantize at load time
98
99
99
100
100
101
# Parse and set list of GPU VRAM allocations
@@ -105,6 +106,14 @@ def set_auto_map(self, map_string):
105
106
else : self .auto_map = [float (alloc ) for alloc in map_string .split ("," )]
106
107
107
108
109
+ # Parse and set number of layers to de-quantize at load, per GPU
110
+
111
+ def set_dequant (self , dq_string ):
112
+
113
+ if dq_string is None : self .dequant = None
114
+ else : self .dequant = [int (alloc ) for alloc in dq_string .split ("," )]
115
+
116
+
108
117
def _dump_tensor (t , name ):
109
118
110
119
if t is None :
@@ -146,11 +155,12 @@ def _mlp_switch(config, x):
146
155
147
156
class Ex4bitLinear (nn .Module ):
148
157
149
- def __init__ (self , config , in_features , out_features , has_bias , tensors , key ):
158
+ def __init__ (self , config , in_features , out_features , has_bias , tensors , key , dequant = False ):
150
159
super ().__init__ ()
151
160
152
161
self .config = config
153
162
self .key = key
163
+ self .dequant = dequant
154
164
155
165
self .in_features = in_features
156
166
self .out_features = out_features
@@ -210,6 +220,17 @@ def __init__(self, config, in_features, out_features, has_bias, tensors, key):
210
220
211
221
if has_bias : self .bias = tensors [key + ".bias" ]
212
222
223
+ # Optionally dequantize layer at init time
224
+
225
+ if self .dequant :
226
+
227
+ self .qweight_dequant = cuda_ext .dequantize_q4v2 (self .quant_args ())
228
+ self .qweight = None
229
+ self .scales = None
230
+ self .zeros = None
231
+ self .seq_g_idx = None
232
+ self .x_map = None
233
+
213
234
214
235
def quant_args (self ):
215
236
@@ -268,20 +289,26 @@ def load_streaming(self):
268
289
269
290
def forward (self , x ):
270
291
271
- if torch . is_grad_enabled () :
292
+ if self . dequant :
272
293
273
- # Untested
274
- out = cuda_ext .ExAutogradMatmul4bitCuda .apply (x , self .qweight , self .scales , self .qzeros , self .groupsize , self .bits , self .maxq )
294
+ out = torch .matmul (x , self .qweight_dequant )
275
295
276
296
else :
277
297
278
- out = cuda_ext .matmul_q4v2 (x , self .quant_args (), _matmul_switch (self .config , x ))
279
- if self .bias is not None : out += self .bias
298
+ if torch .is_grad_enabled ():
280
299
281
- # if self.key == "model.layers.0.mlp.gate_proj":
282
- #
283
- # _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
284
- # sys.exit()
300
+ # Untested
301
+ out = cuda_ext .ExAutogradMatmul4bitCuda .apply (x , self .qweight , self .scales , self .qzeros , self .groupsize , self .bits , self .maxq )
302
+
303
+ else :
304
+
305
+ out = cuda_ext .matmul_q4v2 (x , self .quant_args (), _matmul_switch (self .config , x ))
306
+ if self .bias is not None : out += self .bias
307
+
308
+ # if self.key == "model.layers.0.mlp.gate_proj":
309
+ #
310
+ # _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
311
+ # sys.exit()
285
312
286
313
return out
287
314
@@ -300,20 +327,22 @@ def dump(self, filename):
300
327
301
328
class ExLlamaMLP (nn .Module ):
302
329
303
- def __init__ (self , config , tensors , key ):
330
+ def __init__ (self , config , tensors , key , dequant = False ):
304
331
super ().__init__ ()
305
332
306
333
self .config = config
334
+ self .dequant = dequant
307
335
308
- self .gate_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".gate_proj" )
309
- self .up_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".up_proj" )
310
- self .down_proj = Ex4bitLinear (config , self .config .intermediate_size , self .config .hidden_size , False , tensors , key + ".down_proj" )
336
+ self .gate_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".gate_proj" , dequant = dequant )
337
+ self .up_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".up_proj" , dequant = dequant )
338
+ self .down_proj = Ex4bitLinear (config , self .config .intermediate_size , self .config .hidden_size , False , tensors , key + ".down_proj" , dequant = dequant )
311
339
312
340
self .act_fn = nn .SiLU ()
313
341
314
342
315
343
def forward_fused (self , x , rms_norm_weight , buffer ):
316
344
345
+ assert not self .dequant
317
346
x = cuda_ext .mlp_q4v2 (x ,
318
347
buffer .x_temp ,
319
348
buffer .x_col_temp ,
@@ -367,18 +396,18 @@ def forward(self, hidden_states, buffer):
367
396
368
397
class ExLlamaAttention (nn .Module ):
369
398
370
- def __init__ (self , config , tensors , key , sin , cos , index ):
399
+ def __init__ (self , config , tensors , key , sin , cos , index , dequant = False ):
371
400
super ().__init__ ()
372
401
373
402
self .config = config
374
403
self .sin = sin
375
404
self .cos = cos
376
405
self .index = index
377
406
378
- self .q_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".q_proj" )
379
- self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".k_proj" )
380
- self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".v_proj" )
381
- self .o_proj = Ex4bitLinear (config , self .config .num_attention_heads * self .config .head_dim , self .config .hidden_size , False , tensors , key + ".o_proj" )
407
+ self .q_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".q_proj" , dequant = dequant )
408
+ self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".k_proj" , dequant = dequant )
409
+ self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".v_proj" , dequant = dequant )
410
+ self .o_proj = Ex4bitLinear (config , self .config .num_attention_heads * self .config .head_dim , self .config .hidden_size , False , tensors , key + ".o_proj" , dequant = dequant )
382
411
383
412
384
413
def forward (self , hidden_states , cache , buffer ):
@@ -467,14 +496,14 @@ def rotate_half(x):
467
496
468
497
class ExLlamaDecoderLayer (nn .Module ):
469
498
470
- def __init__ (self , config , tensors , key , index , sin , cos ):
499
+ def __init__ (self , config , tensors , key , index , sin , cos , dequant = False ):
471
500
super ().__init__ ()
472
501
473
502
self .config = config
474
503
self .index = index
475
504
476
- self .self_attn = ExLlamaAttention (self .config , tensors , key + ".self_attn" , sin , cos , self .index )
477
- self .mlp = ExLlamaMLP (self .config , tensors , key + ".mlp" )
505
+ self .self_attn = ExLlamaAttention (self .config , tensors , key + ".self_attn" , sin , cos , self .index , dequant = dequant )
506
+ self .mlp = ExLlamaMLP (self .config , tensors , key + ".mlp" , dequant = dequant )
478
507
479
508
self .input_layernorm = ExLlamaRMSNorm (self .config , tensors , key + ".input_layernorm.weight" )
480
509
self .post_attention_layernorm = ExLlamaRMSNorm (self .config , tensors , key + ".post_attention_layernorm.weight" )
@@ -487,7 +516,9 @@ def forward(self, hidden_states, cache, buffer):
487
516
hidden_states = self .self_attn (hidden_states , cache , buffer )
488
517
hidden_states = residual + hidden_states
489
518
490
- if _mlp_switch (self .config , hidden_states ):
519
+ # TODO: Support dequantized layer in fused MLP. Also, finish implementing fused MLP
520
+
521
+ if self .mlp .dequant or _mlp_switch (self .config , hidden_states ):
491
522
492
523
residual = hidden_states
493
524
hidden_states = self .post_attention_layernorm (hidden_states , buffer )
@@ -741,6 +772,17 @@ def to(self, device):
741
772
return new
742
773
743
774
775
+ def _device_to_int (device ):
776
+
777
+ return int (device [device .find (":" ) + 1 :])
778
+
779
+ def _skip_key (key ):
780
+
781
+ if key .endswith ("_proj.bias" ): return True
782
+ if key .endswith (".rotary_emb.inv_freq" ): return True
783
+ return False
784
+
785
+
744
786
class ExLlama (nn .Module ):
745
787
746
788
def __init__ (self , config ):
@@ -762,8 +804,10 @@ def __init__(self, config):
762
804
# Begin auto mapping if enabled
763
805
764
806
decoder_size = 0
807
+ decoder_dq_size = 0
765
808
norm_size = 0
766
809
head_size = 0
810
+ half_element_size = torch .tensor ([], dtype = torch .float16 ).element_size ()
767
811
768
812
if self .config .auto_map is not None :
769
813
@@ -772,9 +816,15 @@ def __init__(self, config):
772
816
773
817
for key in f .keys ():
774
818
819
+ if _skip_key (key ): continue
820
+
775
821
if key .startswith ("model.layers.0." ):
776
822
tensor = f .get_tensor (key )
777
823
decoder_size += tensor .numel () * tensor .element_size ()
824
+ if key .endswith (".weight" ):
825
+ decoder_dq_size += tensor .numel () * tensor .element_size ()
826
+ if key .endswith (".qweight" ):
827
+ decoder_dq_size += tensor .numel () * 8 * half_element_size
778
828
779
829
if key .startswith ("model.norm." ):
780
830
tensor = f .get_tensor (key )
@@ -784,37 +834,40 @@ def __init__(self, config):
784
834
tensor = f .get_tensor (key )
785
835
head_size += tensor .numel () * tensor .element_size ()
786
836
787
- # Assign layers automatically
837
+ # Assign layers automatically
788
838
789
- device_usage = 0
790
- device_index = 0
791
- max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
839
+ device_usage = 0
840
+ device_index = 0
841
+ layer_index_device = 0
842
+ max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
792
843
793
- for layer in range (self .config .num_hidden_layers + 2 ):
844
+ for layer in range (self .config .num_hidden_layers + 2 ):
794
845
795
- this_layer_size = decoder_size
796
- if layer == self .config .num_hidden_layers + 0 : this_layer_size = norm_size
797
- if layer == self .config .num_hidden_layers + 1 : this_layer_size = head_size
846
+ this_layer_size = decoder_size
847
+ if layer == self .config .num_hidden_layers + 0 : this_layer_size = norm_size
848
+ elif layer == self .config .num_hidden_layers + 1 : this_layer_size = head_size
849
+ elif self .config .dequant is not None and layer_index_device < self .config .dequant [device_index ]: this_layer_size = decoder_dq_size
798
850
799
- while device_usage + this_layer_size > max_usage :
800
- device_index += 1
801
- device_usage = 0
802
- max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
803
- if device_index >= len (self .config .auto_map ): raise ValueError ("Model too large for device allocation scheme." )
851
+ while device_usage + this_layer_size > max_usage :
852
+ device_index += 1
853
+ device_usage = 0
854
+ layer_index_device = 0
855
+ max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
856
+ if device_index >= len (self .config .auto_map ): raise ValueError ("Model too large for device allocation scheme." )
804
857
805
- target = f"cuda:{ device_index } "
806
- if layer == self .config .num_hidden_layers + 0 : self .config .device_map .norm = target
807
- elif layer == self .config .num_hidden_layers + 1 : self .config .device_map .lm_head = target
808
- else : self .config .device_map .layers [layer ] = f"cuda:{ device_index } "
858
+ target = f"cuda:{ device_index } "
859
+ if layer == self .config .num_hidden_layers + 0 : self .config .device_map .norm = target
860
+ elif layer == self .config .num_hidden_layers + 1 : self .config .device_map .lm_head = target
861
+ else : self .config .device_map .layers [layer ] = f"cuda:{ device_index } "
809
862
810
- device_usage += this_layer_size
863
+ device_usage += this_layer_size
864
+ layer_index_device += 1
811
865
812
- # Load tensors to
866
+ # Load tensors, move to device(s)
813
867
814
868
for key in f .keys ():
815
869
816
- if key .endswith ("_proj.bias" ): continue # Skip loading unused, empty bias tensors
817
- if key .endswith (".rotary_emb.inv_freq" ): continue # This is always precomputed during init anyway
870
+ if _skip_key (key ): continue
818
871
819
872
device = self .config .device_map .map (key , loading = True )
820
873
tensor = f .get_tensor (key )
@@ -845,8 +898,10 @@ def __init__(self, config):
845
898
846
899
# Prepare position embeddings for max seq length
847
900
901
+ devs = self .config .device_map .get_layers_devs ()
902
+
848
903
self .sincos = {}
849
- for device in self . config . device_map . get_layers_devs () :
904
+ for device in devs :
850
905
851
906
inv_freq = 1.0 / (self .config .rotary_embedding_base ** (torch .arange (0 , self .config .head_dim , 2 , device = device ).float () / self .config .head_dim ))
852
907
t = torch .arange (self .config .max_seq_len , device = device , dtype = torch .float32 )
@@ -863,12 +918,21 @@ def __init__(self, config):
863
918
layer_streaming = self .config .stream_layer_interval > 0
864
919
865
920
modules = []
921
+ device_layer_index = [0 ] * len (devs )
922
+
866
923
for i in range (self .config .num_hidden_layers ):
867
924
868
925
device = self .config .device_map .layers [i ]
869
926
sin , cos = self .sincos [device ]
870
927
871
- layer = ExLlamaDecoderLayer (self .config , tensors , f"model.layers.{ i } " , i , sin , cos )
928
+ dequant = False
929
+ if self .config .dequant is not None :
930
+ device_idx = _device_to_int (device )
931
+ device_layer = device_layer_index [device_idx ]
932
+ device_layer_index [device_idx ] += 1
933
+ if device_layer < self .config .dequant [device_idx ]: dequant = True
934
+
935
+ layer = ExLlamaDecoderLayer (self .config , tensors , f"model.layers.{ i } " , i , sin , cos , dequant = dequant )
872
936
873
937
if layer_streaming and i > 0 and (i + 1 ) % self .config .stream_layer_interval == 0 :
874
938
if self .stream_buffer is None : self .stream_buffer = ExLlamaStreamer (self .config , layer ) # Use first layer as prototype
0 commit comments