Skip to content

Commit 868655f

Browse files
committed
add mxfp4 calibration support
1 parent d1b315f commit 868655f

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def _quantize(
468468
if global_scale is not None:
469469
scale = scale.to(global_scale.dtype) / global_scale
470470

471+
scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max
471472
scaled = x / scale
472473

473474
if zero_point is not None:
@@ -501,6 +502,8 @@ def _dequantize(
501502
if global_scale is not None:
502503
scale = scale.to(global_scale.dtype) / global_scale
503504

505+
scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max
506+
504507
dequant_value = x_q.to(scale.dtype)
505508

506509
if zero_point is not None:
@@ -510,5 +513,4 @@ def _dequantize(
510513

511514
if dtype is not None:
512515
dequant_value = dequant_value.to(dtype)
513-
514516
return dequant_value

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def initialize_qparams(
248248
scale_dtype = observed_dtype
249249

250250
if is_fp4(quantization_args=quantization_args):
251-
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
251+
if quantization_args.group_size == 16:
252+
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
253+
else:
254+
# group_size 32
255+
scale_dtype = zp_dtype = torch.uint8
252256
else:
253257
# TODO: consider erroring out in the future as if the dtype if not one of these,
254258
# there is likely bug

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ def is_fp4(quantization_args: QuantizationArgs):
6464
and quantization_args.type == QuantizationType.FLOAT
6565
)
6666

67+
def get_power_of_two(x):
68+
powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device)
69+
70+
# Expand and compute distances
71+
diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs()
72+
73+
# Find nearest index
74+
nearest_idx = diff.argmin(dim=-1)
75+
76+
return powers[nearest_idx]
77+
78+
6779

6880
def calculate_qparams(
6981
min_vals: Tensor,
@@ -94,33 +106,50 @@ def calculate_qparams(
94106
bit_range = bit_max - bit_min
95107

96108
if is_fp4(quantization_args=quantization_args):
97-
zp_dtype = FP8_E4M3_DATA.dtype
109+
if quantization_args.group_size == 16:
110+
zp_dtype = FP8_E4M3_DATA.dtype
111+
else:
112+
# group_size 32
113+
zp_dtype = torch.uint8
98114
else:
99115
zp_dtype = quantization_args.pytorch_dtype()
100116

101117
if quantization_args.symmetric:
102118
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
103119

104-
if is_fp4(quantization_args=quantization_args) and global_scale is not None:
105-
# Conditionally scale the generated local scale by a global_scale
106-
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
107-
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
108-
scales = scales.to(FP8_E4M3_DATA.dtype)
120+
if is_fp4(quantization_args=quantization_args):
121+
if global_scale is not None:
122+
# Conditionally scale the generated local scale by a global_scale
123+
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
124+
scales = torch.clamp(
125+
scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min
126+
)
127+
scales = scales.to(FP8_E4M3_DATA.dtype)
128+
else:
129+
130+
scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max)
131+
scales = torch.clamp(
132+
scales,
133+
max=torch.iinfo(torch.uint8).max,
134+
min=torch.iinfo(torch.uint8).min,
135+
)
136+
scales = scales.to(torch.uint8)
137+
scales = get_power_of_two(scales)
109138

110139
else:
111140
scales = max_val_pos / (float(bit_range) / 2)
112141

113142
# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
114-
if scales.dtype == FP8_E4M3_DATA.dtype:
115-
# torch.clamp not supported for FP8
116-
# use the next largest fp8 value from 0
117-
scales = torch.where(
118-
scales == 0,
119-
torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
120-
scales,
121-
)
122-
else:
123-
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
143+
# if scales.dtype == FP8_E4M3_DATA.dtype:
144+
# torch.clamp not supported for FP8
145+
# use the next largest fp8 value from 0
146+
# scales = torch.where(
147+
# scales == 0,
148+
# torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
149+
# scales,
150+
# )
151+
# else:
152+
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
124153

125154
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
126155
else:

0 commit comments

Comments
 (0)