@@ -64,6 +64,18 @@ def is_fp4(quantization_args: QuantizationArgs):
6464        and  quantization_args .type  ==  QuantizationType .FLOAT 
6565    )
6666
67+ def  get_power_of_two (x ):
68+     powers  =  torch .tensor ([0 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ], dtype = torch .uint8 ).to (x .device )
69+     
70+     # Expand and compute distances 
71+     diff  =  (x .unsqueeze (- 1 ).to (torch .int16 ) -  powers .to (torch .int16 )).abs ()
72+     
73+     # Find nearest index 
74+     nearest_idx  =  diff .argmin (dim = - 1 )
75+     
76+     return  powers [nearest_idx ]
77+ 
78+ 
6779
6880def  calculate_qparams (
6981    min_vals : Tensor ,
@@ -94,33 +106,50 @@ def calculate_qparams(
94106    bit_range  =  bit_max  -  bit_min 
95107
96108    if  is_fp4 (quantization_args = quantization_args ):
97-         zp_dtype  =  FP8_E4M3_DATA .dtype 
109+         if  quantization_args .group_size  ==  16 :
110+             zp_dtype  =  FP8_E4M3_DATA .dtype 
111+         else :
112+             # group_size 32 
113+             zp_dtype  =  torch .uint8 
98114    else :
99115        zp_dtype  =  quantization_args .pytorch_dtype ()
100116
101117    if  quantization_args .symmetric :
102118        max_val_pos  =  torch .max (torch .abs (min_vals ), torch .abs (max_vals ))
103119
104-         if  is_fp4 (quantization_args = quantization_args ) and  global_scale  is  not None :
105-             # Conditionally scale the generated local scale by a global_scale 
106-             scales  =  global_scale  *  (max_val_pos  /  FP4_E2M1_DATA .max )
107-             scales  =  torch .clamp (scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min )
108-             scales  =  scales .to (FP8_E4M3_DATA .dtype )
120+         if  is_fp4 (quantization_args = quantization_args ):
121+             if  global_scale  is  not None :
122+                 # Conditionally scale the generated local scale by a global_scale 
123+                 scales  =  global_scale  *  (max_val_pos  /  FP4_E2M1_DATA .max )
124+                 scales  =  torch .clamp (
125+                     scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min 
126+                 )
127+                 scales  =  scales .to (FP8_E4M3_DATA .dtype )
128+             else :
129+                 
130+                 scales  =  torch .iinfo (torch .uint8 ).max  *  (max_val_pos ) # / FP4_E2M1_DATA.max) 
131+                 scales  =  torch .clamp (
132+                     scales ,
133+                     max = torch .iinfo (torch .uint8 ).max ,
134+                     min = torch .iinfo (torch .uint8 ).min ,
135+                 )
136+                 scales  =  scales .to (torch .uint8 )
137+                 scales  =  get_power_of_two (scales )
109138
110139        else :
111140            scales  =  max_val_pos  /  (float (bit_range ) /  2 )
112141
113142        # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped 
114-         if  scales .dtype  ==  FP8_E4M3_DATA .dtype :
115-              # torch.clamp not supported for FP8 
116-              # use the next largest fp8 value from 0 
117-             scales  =  torch .where (
118-                 scales  ==  0 ,
119-                 torch .tensor (0.125 , dtype = FP8_E4M3_DATA .dtype , device = device ),
120-                 scales ,
121-             )
122-         else :
123-             scales  =  torch .clamp (scales , min = torch .finfo (torch .float32 ).eps )
143+         #  if scales.dtype == FP8_E4M3_DATA.dtype:
144+         # torch.clamp not supported for FP8 
145+         # use the next largest fp8 value from 0 
146+         #     scales = torch.where(
147+         #         scales == 0,
148+         #         torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
149+         #         scales,
150+         #     )
151+         #  else:
152+         #     scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
124153
125154        zero_points  =  torch .zeros (scales .shape , device = device , dtype = min_vals .dtype )
126155    else :
0 commit comments