@@ -277,6 +277,7 @@ def __init__(
277277 self .params_dtype = params_dtype
278278 self .quant_config = quant_config
279279 self .prefix = prefix
280+ self .allow_fp8_block_shape_mismatch = False
280281 if quant_config is None :
281282 self .quant_method : QuantizeMethodBase | None = UnquantizedLinearMethod ()
282283 else :
@@ -475,6 +476,7 @@ def __init__(
475476 disable_tp = disable_tp ,
476477 )
477478
479+ self ._maybe_allow_fp8_block_shape_mismatch ()
478480 self .gather_output = gather_output
479481
480482 if output_sizes is None :
@@ -509,6 +511,33 @@ def __init__(
509511 self .register_parameter ("bias" , None )
510512 self .update_param_tp_status ()
511513
514+ def _maybe_allow_fp8_block_shape_mismatch (self ) -> None :
515+ quant_config = getattr (self , "quant_config" , None )
516+ weight_block = getattr (quant_config , "weight_block_size" , None )
517+ if (
518+ weight_block is None
519+ or len (weight_block ) < 1
520+ or len (self .output_partition_sizes ) <= 1
521+ ):
522+ return
523+
524+ try :
525+ block_n = int (weight_block [0 ])
526+ except (ValueError , TypeError ):
527+ return
528+
529+ if block_n <= 0 :
530+ return
531+
532+ if any (size % block_n != 0 for size in self .output_partition_sizes ):
533+ self .allow_fp8_block_shape_mismatch = True
534+ logger .debug (
535+ "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)" ,
536+ getattr (self , "prefix" , "<unknown>" ),
537+ block_n ,
538+ self .output_partition_sizes ,
539+ )
540+
512541 def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
513542 output_dim = getattr (param , "output_dim" , None )
514543
@@ -906,9 +935,11 @@ def __init__(
906935 * ,
907936 return_bias : bool = True ,
908937 disable_tp : bool = False ,
938+ v_head_size : int | None = None ,
909939 ):
910940 self .hidden_size = hidden_size
911941 self .head_size = head_size
942+ self .v_head_size = v_head_size if v_head_size is not None else head_size
912943 self .total_num_heads = total_num_heads
913944 if total_num_kv_heads is None :
914945 total_num_kv_heads = total_num_heads
@@ -924,12 +955,14 @@ def __init__(
924955 self .num_kv_head_replicas = 1
925956 input_size = self .hidden_size
926957 output_size = (
927- (self .num_heads + 2 * self .num_kv_heads ) * tp_size * self .head_size
928- )
958+ self .num_heads * self .head_size
959+ + self .num_kv_heads * self .head_size
960+ + self .num_kv_heads * self .v_head_size
961+ ) * tp_size
929962 self .output_sizes = [
930963 self .num_heads * self .head_size * tp_size , # q_proj
931964 self .num_kv_heads * self .head_size * tp_size , # k_proj
932- self .num_kv_heads * self .head_size * tp_size , # v_proj
965+ self .num_kv_heads * self .v_head_size * tp_size , # v_proj
933966 ]
934967
935968 super ().__init__ (
@@ -950,15 +983,16 @@ def _get_shard_offset_mapping(self, loaded_shard_id: str):
950983 "q" : 0 ,
951984 "k" : self .num_heads * self .head_size ,
952985 "v" : (self .num_heads + self .num_kv_heads ) * self .head_size ,
953- "total" : (self .num_heads + 2 * self .num_kv_heads ) * self .head_size ,
986+ "total" : (self .num_heads + self .num_kv_heads ) * self .head_size
987+ + self .num_kv_heads * self .v_head_size ,
954988 }
955989 return shard_offset_mapping .get (loaded_shard_id )
956990
957991 def _get_shard_size_mapping (self , loaded_shard_id : str ):
958992 shard_size_mapping = {
959993 "q" : self .num_heads * self .head_size ,
960994 "k" : self .num_kv_heads * self .head_size ,
961- "v" : self .num_kv_heads * self .head_size ,
995+ "v" : self .num_kv_heads * self .v_head_size ,
962996 }
963997 return shard_size_mapping .get (loaded_shard_id )
964998
@@ -985,7 +1019,7 @@ def _load_fused_module_from_checkpoint(
9851019 (
9861020 "v" ,
9871021 (self .total_num_heads + self .total_num_kv_heads ) * self .head_size ,
988- self .total_num_kv_heads * self .head_size ,
1022+ self .total_num_kv_heads * self .v_head_size ,
9891023 ),
9901024 ]
9911025
@@ -1110,7 +1144,7 @@ def weight_loader(
11101144 (
11111145 "v" ,
11121146 (self .total_num_heads + self .total_num_kv_heads ) * self .head_size ,
1113- self .total_num_kv_heads * self .head_size ,
1147+ self .total_num_kv_heads * self .v_head_size ,
11141148 ),
11151149 ]
11161150 use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
@@ -1139,11 +1173,12 @@ def weight_loader(
11391173 "v" : (
11401174 (self .total_num_heads + self .total_num_kv_heads )
11411175 * self .head_size ,
1142- self .total_num_kv_heads * self .head_size ,
1176+ self .total_num_kv_heads * self .v_head_size ,
11431177 ),
11441178 "total" : (
1145- (self .total_num_heads + 2 * self .total_num_kv_heads )
1146- * self .head_size ,
1179+ (self .total_num_heads + self .total_num_kv_heads )
1180+ * self .head_size
1181+ + self .total_num_kv_heads * self .v_head_size ,
11471182 0 ,
11481183 ),
11491184 }
@@ -1170,7 +1205,7 @@ def weight_loader(
11701205 shard_size = self .num_kv_heads * self .head_size
11711206 elif loaded_shard_id == "v" :
11721207 shard_offset = (self .num_heads + self .num_kv_heads ) * self .head_size
1173- shard_size = self .num_kv_heads * self .head_size
1208+ shard_size = self .num_kv_heads * self .v_head_size
11741209 # Special case for Quantized Weights.
11751210 # If quantized, we need to adjust the offset and size to account
11761211 # for the packing.
@@ -1199,10 +1234,11 @@ def weight_loader(
11991234 ),
12001235 "v" : (
12011236 (self .num_heads + self .num_kv_heads ) * self .head_size ,
1202- self .num_kv_heads * self .head_size ,
1237+ self .num_kv_heads * self .v_head_size ,
12031238 ),
12041239 "total" : (
1205- (self .num_heads + 2 * self .num_kv_heads ) * self .head_size ,
1240+ (self .num_heads + self .num_kv_heads ) * self .head_size
1241+ + self .num_kv_heads * self .v_head_size ,
12061242 0 ,
12071243 ),
12081244 }
0 commit comments