Skip to content

Commit 14a59a2

Browse files
committed
add v2
Signed-off-by: Yi Liu <[email protected]>
1 parent 2e27ca3 commit 14a59a2

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

auto_round/autoround.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
10751075
total_loss = 0
10761076

10771077
for i in range(self.iters):
1078+
logger.info(f"iter {i} / {self.iters}")
10781079
total_loss = 0
10791080
if self.sampler == "rand":
10801081
whole_indices = torch.randperm(nsamples)[:pick_samples]

auto_round/data_type/fp8.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import lru_cache
15-
15+
from loguru import logger as rich_logger
1616
import torch
1717
from auto_round.utils import logger
1818
from auto_round.config import global_config
@@ -67,7 +67,6 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor):
6767
return fp8
6868

6969

70-
7170
@register_dtype("fp8_dynamic_per_token_sym")
7271
def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs):
7372
"""Dynamic per-token symmetric quantization using float8.
@@ -200,7 +199,6 @@ def progressive_quant_fp8_int4_bas(tensor, bits=4, group_size=-1, v=0, min_scale
200199
return qdq_tensor, scale_fp8_to_int4 * scale_bf16_to_fp8, None
201200

202201

203-
204202
##ugly code, need to refine later
205203

206204
@register_dtype("fp8_gaudi2_sym")
@@ -293,3 +291,60 @@ def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0
293291
qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8
294292

295293
return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4
294+
295+
@register_dtype("fp8_gaudi2_to_int_sym_v2")
296+
def progressive_quant_fp8_int4_v2(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5,
297+
weight_fp8_max_scale=1.0,**kwargs):
298+
"""Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128
299+
300+
This method first quantizes the input tensor into float8 format and then performs
301+
a secondary quantization to int4 with grouping.
302+
303+
Args:
304+
tensor (torch.Tensor): Input tensor to quantize.
305+
bits (int, optional): Bit precision for secondary quantization. Defaults to 4.
306+
group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping).
307+
v (float, optional): Optional parameter for variance tuning. Defaults to 0.
308+
min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0.
309+
max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0.
310+
q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5.
311+
weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0.
312+
**kwargs: Additional arguments for compatibility.
313+
314+
Returns:
315+
tuple:
316+
- Quantized and dequantized tensor (torch.Tensor).
317+
- Combined scaling factor (torch.Tensor).
318+
- Placeholder for zp (None).
319+
"""
320+
# convert to int4
321+
from auto_round.data_type.int import quant_tensor_sym
322+
qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym(
323+
tensor,
324+
bits=bits,
325+
group_size=group_size,
326+
v=v,
327+
min_scale=min_scale,
328+
max_scale=max_scale,
329+
scale_dtype=torch.bfloat16,
330+
q_scale_thresh=q_scale_thresh,
331+
)
332+
# FIXME(Yi): some fuse error here
333+
torch._dynamo.graph_break()
334+
fp8_max = STANDARD_FP8E4M3FN_MAX * global_config.FP8_WEIGHT_BACKOFF
335+
tensor_max = torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale ## better train a ratio
336+
scale = tensor_max.to(torch.float32) / fp8_max
337+
min_scaling_factor = 1.0 / (fp8_max* 512.0) ##copy from vllm
338+
scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
339+
fp8_res = qdq_int4_tensor / scale_bf16_to_fp8
340+
fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
341+
float8_e4m3fn_ste_gaudi2 = get_gaudi2_fp8_ste_func()
342+
fp8_res = float8_e4m3fn_ste_gaudi2(fp8_res)
343+
344+
##convert to bf16
345+
fp8_res_using_16bit = fp8_res.to(tensor.dtype)
346+
347+
qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8
348+
349+
# return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4
350+
return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4

0 commit comments

Comments
 (0)