From e1b1ceffb394ffe591beb3f6831cd7334aeb9aa5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Mar 2025 13:07:36 +0000 Subject: [PATCH 1/5] fix 4bit XPU dequant 4bit Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 6 +++--- bitsandbytes/nn/modules.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a76aadb73..78a088629 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1118,9 +1118,9 @@ def dequantize_4bit( """ ensure_backend_is_available(A.device.type) if quant_state is not None: - absmax = absmax or quant_state.absmax - quant_type = quant_type or quant_state.quant_type - blocksize = blocksize or quant_state.blocksize + absmax = quant_state.absmax or absmax + quant_type = quant_state.quant_type or quant_type + blocksize = quant_state.blocksize or blocksize if blocksize is None: # Some AMD GPUs have warpsize 64 # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0ea82575a..78214ff74 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -496,15 +496,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def set_ipex_linear(self, x: torch.Tensor): if ( - (x.device.type in ("cpu", "xpu")) - and not getattr(self.weight.quant_state, "ipex", False) + not getattr(self.weight.quant_state, "ipex", False) and self.weight.data.dtype == torch.uint8 and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" - and not self.training - and x.requires_grad == False ): - enable_ipex_fusion(self, x) + if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): + enable_ipex_fusion(self, x) def forward(self, x: torch.Tensor): # Check if ipex fusion can be used From f54ac9e43ff3f5a3d8917067fd8d7f86fc02f830 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Mar 2025 13:12:52 +0000 Subject: [PATCH 2/5] fix default value Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 78a088629..2b4a1e246 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1067,7 +1067,7 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1077,7 +1077,7 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1087,8 +1087,8 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", + blocksize: Optional[int] = None, + quant_type: Optional[str] = "fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1106,9 +1106,9 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4". Raises: ValueError: Raised when the input data type or blocksize is not supported. @@ -1118,9 +1118,9 @@ def dequantize_4bit( """ ensure_backend_is_available(A.device.type) if quant_state is not None: - absmax = quant_state.absmax or absmax - quant_type = quant_state.quant_type or quant_type - blocksize = quant_state.blocksize or blocksize + absmax = quant_state.absmax + quant_type = quant_state.quant_type + blocksize = quant_state.blocksize if blocksize is None: # Some AMD GPUs have warpsize 64 # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP From 66723959aa3b08d1498c44be85de0679caea537f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Mar 2025 15:15:05 +0000 Subject: [PATCH 3/5] fix ipex linear set Signed-off-by: jiqing-feng --- bitsandbytes/nn/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 78214ff74..6faffe057 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -487,6 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False + self.ipex_linear_is_set = True super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias From 9ae91af188e0a80ccf2b21e8d24320557faa7464 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Mar 2025 15:26:21 +0000 Subject: [PATCH 4/5] fix ipex linear set to false when calling state dict Signed-off-by: jiqing-feng --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6faffe057..ab80769b9 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -487,7 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False - self.ipex_linear_is_set = True + self.ipex_linear_is_set = False super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias From 30c0b4ea6d7ef90cca8a66d78ebf5ea5b3203f5f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 18 Mar 2025 16:37:10 +0000 Subject: [PATCH 5/5] fix Int8Param device patch Signed-off-by: jiqing-feng --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ab80769b9..961f746ba 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -694,7 +694,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device in ("cuda", "xpu", "cpu"): + if device is not None and device.type in ("cuda", "xpu", "cpu"): if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) elif device.type == "cpu":