diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a76aadb73..2b4a1e246 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1067,7 +1067,7 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1077,7 +1077,7 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1087,8 +1087,8 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", + blocksize: Optional[int] = None, + quant_type: Optional[str] = "fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1106,9 +1106,9 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4". Raises: ValueError: Raised when the input data type or blocksize is not supported. @@ -1118,9 +1118,9 @@ def dequantize_4bit( """ ensure_backend_is_available(A.device.type) if quant_state is not None: - absmax = absmax or quant_state.absmax - quant_type = quant_type or quant_state.quant_type - blocksize = blocksize or quant_state.blocksize + absmax = quant_state.absmax + quant_type = quant_state.quant_type + blocksize = quant_state.blocksize if blocksize is None: # Some AMD GPUs have warpsize 64 # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0ea82575a..961f746ba 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -487,6 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False + self.ipex_linear_is_set = False super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias @@ -496,15 +497,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def set_ipex_linear(self, x: torch.Tensor): if ( - (x.device.type in ("cpu", "xpu")) - and not getattr(self.weight.quant_state, "ipex", False) + not getattr(self.weight.quant_state, "ipex", False) and self.weight.data.dtype == torch.uint8 and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" - and not self.training - and x.requires_grad == False ): - enable_ipex_fusion(self, x) + if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): + enable_ipex_fusion(self, x) def forward(self, x: torch.Tensor): # Check if ipex fusion can be used @@ -695,7 +694,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device in ("cuda", "xpu", "cpu"): + if device is not None and device.type in ("cuda", "xpu", "cpu"): if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) elif device.type == "cpu":