Skip to content

[WIP]Enable HQQ on Intel GPU. #2593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def from_hp_to_intx(
device = input_float.device
from torchao.dtypes import Int4CPULayout
from torchao.dtypes.uintx import TensorCoreTiledLayout
from torchao.dtypes.uintx import Int4XPULayout

data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
input_float,
Expand All @@ -279,7 +280,7 @@ def from_hp_to_intx(
device=device,
verbose=False,
raw_output=not isinstance(
_layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout)
_layout, (TensorCoreTiledLayout, PlainLayout, Int4CPULayout, Int4XPULayout)
),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def from_plain(
if zero_point.dtype == scale.dtype:
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
return cls(packed_weight, scale_and_zero, False, _layout, None, None)
else:
return cls(
Expand Down
8 changes: 6 additions & 2 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,12 @@ def _awq_uintx_transform(
preserve_zero = False
_layout = config.layout
if isinstance(_layout, Int4XPULayout):
zero_point_dtype = torch.int8
zero_point_domain = ZeroPointDomain.INT
if use_hqq:
zero_point_dtype = module.weight.dtype
zero_point_domain = ZeroPointDomain.FLOAT
else:
zero_point_dtype = torch.int8
zero_point_domain = ZeroPointDomain.INT
else:
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,10 @@ def _choose_qparams_and_quantize_affine_hqq(

# cleanup
del W, _min, _max
torch.cuda.empty_cache()
if (hasattr(device, "type") and "cuda" in device.type) or (isinstance(device, str) and "cuda" in device):
torch.cuda.empty_cache()
if (hasattr(device, "type") and "xpu" in device.type) or (isinstance(device, str) and "xpu" in device):
torch.xpu.empty_cache()

return W_q, scale, zero, shape

Expand Down