|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from functools import lru_cache
|
15 |
| - |
| 15 | +from loguru import logger as rich_logger |
16 | 16 | import torch
|
17 | 17 | from auto_round.utils import logger
|
18 | 18 | from auto_round.config import global_config
|
@@ -67,7 +67,6 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor):
|
67 | 67 | return fp8
|
68 | 68 |
|
69 | 69 |
|
70 |
| - |
71 | 70 | @register_dtype("fp8_dynamic_per_token_sym")
|
72 | 71 | def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs):
|
73 | 72 | """Dynamic per-token symmetric quantization using float8.
|
@@ -200,7 +199,6 @@ def progressive_quant_fp8_int4_bas(tensor, bits=4, group_size=-1, v=0, min_scale
|
200 | 199 | return qdq_tensor, scale_fp8_to_int4 * scale_bf16_to_fp8, None
|
201 | 200 |
|
202 | 201 |
|
203 |
| - |
204 | 202 | ##ugly code, need to refine later
|
205 | 203 |
|
206 | 204 | @register_dtype("fp8_gaudi2_sym")
|
@@ -293,3 +291,60 @@ def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0
|
293 | 291 | qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8
|
294 | 292 |
|
295 | 293 | return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4
|
| 294 | + |
| 295 | +@register_dtype("fp8_gaudi2_to_int_sym_v2") |
| 296 | +def progressive_quant_fp8_int4_v2(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5, |
| 297 | + weight_fp8_max_scale=1.0,**kwargs): |
| 298 | + """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 |
| 299 | +
|
| 300 | + This method first quantizes the input tensor into float8 format and then performs |
| 301 | + a secondary quantization to int4 with grouping. |
| 302 | +
|
| 303 | + Args: |
| 304 | + tensor (torch.Tensor): Input tensor to quantize. |
| 305 | + bits (int, optional): Bit precision for secondary quantization. Defaults to 4. |
| 306 | + group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). |
| 307 | + v (float, optional): Optional parameter for variance tuning. Defaults to 0. |
| 308 | + min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. |
| 309 | + max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. |
| 310 | + q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. |
| 311 | + weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0. |
| 312 | + **kwargs: Additional arguments for compatibility. |
| 313 | +
|
| 314 | + Returns: |
| 315 | + tuple: |
| 316 | + - Quantized and dequantized tensor (torch.Tensor). |
| 317 | + - Combined scaling factor (torch.Tensor). |
| 318 | + - Placeholder for zp (None). |
| 319 | + """ |
| 320 | + # convert to int4 |
| 321 | + from auto_round.data_type.int import quant_tensor_sym |
| 322 | + qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym( |
| 323 | + tensor, |
| 324 | + bits=bits, |
| 325 | + group_size=group_size, |
| 326 | + v=v, |
| 327 | + min_scale=min_scale, |
| 328 | + max_scale=max_scale, |
| 329 | + scale_dtype=torch.bfloat16, |
| 330 | + q_scale_thresh=q_scale_thresh, |
| 331 | + ) |
| 332 | + # FIXME(Yi): some fuse error here |
| 333 | + torch._dynamo.graph_break() |
| 334 | + fp8_max = STANDARD_FP8E4M3FN_MAX * global_config.FP8_WEIGHT_BACKOFF |
| 335 | + tensor_max = torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale ## better train a ratio |
| 336 | + scale = tensor_max.to(torch.float32) / fp8_max |
| 337 | + min_scaling_factor = 1.0 / (fp8_max* 512.0) ##copy from vllm |
| 338 | + scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) |
| 339 | + fp8_res = qdq_int4_tensor / scale_bf16_to_fp8 |
| 340 | + fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) |
| 341 | + float8_e4m3fn_ste_gaudi2 = get_gaudi2_fp8_ste_func() |
| 342 | + fp8_res = float8_e4m3fn_ste_gaudi2(fp8_res) |
| 343 | + |
| 344 | + ##convert to bf16 |
| 345 | + fp8_res_using_16bit = fp8_res.to(tensor.dtype) |
| 346 | + |
| 347 | + qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8 |
| 348 | + |
| 349 | + # return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4 |
| 350 | + return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4 |
0 commit comments