diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 39f9131a9e..39738528cb 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -269,6 +269,7 @@ def from_hp_to_intx( device = input_float.device from torchao.dtypes import Int4CPULayout from torchao.dtypes.uintx import TensorCoreTiledLayout + from torchao.dtypes.uintx import Int4XPULayout data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( input_float, @@ -279,7 +280,7 @@ def from_hp_to_intx( device=device, verbose=False, raw_output=not isinstance( - _layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout) + _layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout, Int4XPULayout) ), # raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint) # note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 955a7a8610..1a84e0c002 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -266,7 +266,7 @@ def from_plain( if zero_point.dtype == scale.dtype: from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return cls(packed_weight, scale_and_zero, False, _layout, None, None) else: return cls( diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..d8dd365b30 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -149,8 +149,12 @@ def _awq_uintx_transform( preserve_zero = False _layout = config.layout if isinstance(_layout, Int4XPULayout): - zero_point_dtype = torch.int8 - zero_point_domain = ZeroPointDomain.INT + if use_hqq: + zero_point_dtype = module.weight.dtype + zero_point_domain = ZeroPointDomain.FLOAT + else: + zero_point_dtype = torch.int8 + zero_point_domain = ZeroPointDomain.INT else: zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ec0dc6d236..01779d0f12 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2114,7 +2114,10 @@ def _choose_qparams_and_quantize_affine_hqq( # cleanup del W, _min, _max - torch.cuda.empty_cache() + if (hasattr(device, "type") and "cuda" in device.type) or (isinstance(device, str) and "cuda" in device): + torch.cuda.empty_cache() + if (hasattr(device, "type") and "xpu" in device.type) or (isinstance(device, str) and "xpu" in device): + torch.xpu.empty_cache() return W_q, scale, zero, shape