From 2cfc932e11cdd33a3ec5fcb5a542b1ab47e01d27 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 1 Jul 2025 12:46:43 +0000 Subject: [PATCH 1/6] implemented 8bit optimizers --- .../{triton_kernels.py => kernels_4bit.py} | 165 +--- .../backends/triton/kernels_8bit_quant.py | 238 ++++++ bitsandbytes/backends/triton/kernels_optim.py | 722 ++++++++++++++++++ bitsandbytes/backends/triton/ops.py | 67 +- bitsandbytes/functional.py | 113 ++- bitsandbytes/optim/optimizer.py | 5 +- bitsandbytes/utils.py | 7 + tests/test_optim.py | 48 +- 8 files changed, 1115 insertions(+), 250 deletions(-) rename bitsandbytes/backends/triton/{triton_kernels.py => kernels_4bit.py} (78%) create mode 100644 bitsandbytes/backends/triton/kernels_8bit_quant.py create mode 100644 bitsandbytes/backends/triton/kernels_optim.py diff --git a/bitsandbytes/backends/triton/triton_kernels.py b/bitsandbytes/backends/triton/kernels_4bit.py similarity index 78% rename from bitsandbytes/backends/triton/triton_kernels.py rename to bitsandbytes/backends/triton/kernels_4bit.py index 03ffa187d..0e94f49e8 100644 --- a/bitsandbytes/backends/triton/triton_kernels.py +++ b/bitsandbytes/backends/triton/kernels_4bit.py @@ -4,167 +4,6 @@ import triton.language as tl -# @triton.autotune( -# configs=[ -# # triton.Config({'SPLIT_SIZE': 64}), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128}), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_SIZE": 256}), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# triton.Config({"SPLIT_SIZE": 512}), -# # triton.Config({'SPLIT_SIZE': 1024}), -# ], -# key=["num_paired_elements", "QUANT_BLOCK"], -# ) -@triton.jit -def dequant_8bit_kernel( - a_ptr, - c_ptr, - quant_ptr, - absmax_ptr, - num_paired_elements, - QUANT_BLOCK: tl.constexpr, - SPLIT_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * SPLIT_SIZE - offsets = block_start + tl.arange(0, SPLIT_SIZE) - mask = offsets < num_paired_elements - - a = tl.load(a_ptr + offsets, mask) - a = a.to(tl.uint8) - - # apply conversion - scaled_int8 = tl.load(quant_ptr + a, mask) - - abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK - abs_offsets = offsets // QUANT_BLOCK - mask_blocked = offsets < abs_blocks_lim - - absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) - # apply scales - out_dq = scaled_int8 * absmax - - offs = block_start + tl.arange(0, SPLIT_SIZE) - mask = offs < num_paired_elements - tl.store(c_ptr + offs, out_dq, mask) - - -def dequant_int8_blockwise( - A_nf4: torch.Tensor, - quant_state_code: torch.Tensor, - absmax: torch.Tensor, - out: torch.Tensor, - quant_blocksize: int = 64, -): - number_of_paired_elements = A_nf4.numel() - - SPLIT_SIZE = 256 - # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) - grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) - dequant_8bit_kernel[grid]( - A_nf4, - out, - quant_state_code, - absmax, - number_of_paired_elements, - quant_blocksize, - SPLIT_SIZE, - ) - return out - - -# @triton.autotune( -# configs=[ -# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 1}), -# triton.Config({"SPLIT_NUM_BLOCKS": 2}), -# ], -# key=["n_elements"], -# ) -@triton.jit -def quantize_blockwise_kernel( - A_ptr, - code_ptr, - absmax_ptr, - out_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - CODE_SIZE: tl.constexpr, - SPLIT_NUM_BLOCKS: tl.constexpr, -): - block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS - thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - - offsets = block_start_idx * BLOCK_SIZE + thread_idx - mask = offsets < n_elements - - A = tl.load(A_ptr + offsets, mask=mask, other=0.0) - - # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) - A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) - - # Calculating absamax for each block - absmax = tl.max(tl.abs(A_reshaped), axis=1) - tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) - - A_normalized = A_reshaped / absmax[:, None] - A_normalized = tl.clamp(A_normalized, -1.0, 1.0) - - lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) - upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) - - for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter - pivot = (lower_pivot + upper_pivot) // 2 - val = tl.load(code_ptr + pivot) - is_higher = A_normalized > val # code[pivot] - lower_pivot = tl.where(is_higher, pivot, lower_pivot) - upper_pivot = tl.where(is_higher, upper_pivot, pivot) - - # Choose closest level - lower_val = tl.load(code_ptr + lower_pivot) - upper_val = tl.load(code_ptr + upper_pivot) - lower_dist = tl.abs(A_normalized - lower_val) - upper_dist = tl.abs(A_normalized - upper_val) - quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) - - # too slow approach - # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) - # quantized = tl.argmin(diff, axis=2).to(tl.uint8) - - quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) - tl.store(out_ptr + offsets, quantized_flat, mask=mask) - - -def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out): - n = A.numel() - - split_num_blocks = 1 - grid = (triton.cdiv(blocks, split_num_blocks),) - # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) - quantize_blockwise_kernel[grid]( - A_ptr=A, - code_ptr=code, - absmax_ptr=absmax, - out_ptr=quantized_out, - n_elements=n, - BLOCK_SIZE=blocksize, - CODE_SIZE=code.numel(), - SPLIT_NUM_BLOCKS=split_num_blocks, - ) - - return quantized_out, absmax - - # Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 # @triton.autotune( # configs=[ @@ -587,7 +426,7 @@ def dequant_nf4_kernel( tl.store(c_ptr + offs, out_dq, mask) -def _dequantize_4bit_impl( +def dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, @@ -611,7 +450,7 @@ def _dequantize_4bit_impl( dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) -def _dequantize_4bit_impl_passing_code( +def dequantize_4bit_impl_passing_code( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, diff --git a/bitsandbytes/backends/triton/kernels_8bit_quant.py b/bitsandbytes/backends/triton/kernels_8bit_quant.py new file mode 100644 index 000000000..42f97b83c --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_8bit_quant.py @@ -0,0 +1,238 @@ +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# # triton.Config({'SPLIT_SIZE': 64}), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128}), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_SIZE": 256}), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# triton.Config({"SPLIT_SIZE": 512}), +# # triton.Config({'SPLIT_SIZE': 1024}), +# ], +# key=["num_paired_elements", "QUANT_BLOCK"], +# ) +@triton.jit +def dequant_8bit_kernel( + a_ptr, + c_ptr, + quant_ptr, + absmax_ptr, + num_paired_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask) + a = a.to(tl.uint8) + + # apply conversion + scaled_int8 = tl.load(quant_ptr + a, mask) + + abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK + abs_offsets = offsets // QUANT_BLOCK + mask_blocked = offsets < abs_blocks_lim + + absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) + # apply scales + out_dq = scaled_int8 * absmax + + offs = block_start + tl.arange(0, SPLIT_SIZE) + mask = offs < num_paired_elements + tl.store(c_ptr + offs, out_dq, mask) + + +def dequant_8bit_blockwise( + a: torch.Tensor, + absmax: torch.Tensor, + quant_state_code: torch.Tensor, + quant_blocksize: int = 64, + dtype: torch.dtype = None, + out: torch.Tensor = None, +): + number_of_paired_elements = a.numel() + if out is None: + if dtype is None: + raise ValueError("If out is None, dtype must be specified") + out = torch.empty_like(a, dtype=dtype, device=a.device) + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) + grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + dequant_8bit_kernel[grid]( + a, + out, + quant_state_code, + absmax, + number_of_paired_elements, + quant_blocksize, + SPLIT_SIZE, + ) + return out + + +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_8bit_blockwise_kernel( + A_ptr, + code_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + CODE_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(code_ptr + pivot) + is_higher = A_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) + lower_dist = tl.abs(A_normalized - lower_val) + upper_dist = tl.abs(A_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + # too slow approach + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + tl.store(out_ptr + offsets, quantized_flat, mask=mask) + + +def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): + n = A.numel() + blocks = -(n // -blocksize) + + if absmax is None: + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + if out is None: + out = torch.empty_like(A.flatten(), dtype=torch.uint8) + + split_num_blocks = 1 + grid = (triton.cdiv(blocks, split_num_blocks),) + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) + quantize_8bit_blockwise_kernel[grid]( + A_ptr=A, + code_ptr=code, + absmax_ptr=absmax, + out_ptr=out, + n_elements=n, + BLOCK_SIZE=blocksize, + CODE_SIZE=code.numel(), + SPLIT_NUM_BLOCKS=split_num_blocks, + # num_warps=1, + # num_stages=2, + ) + out = out.reshape(A.shape) + + return out, absmax + + +@triton.jit +def quantize_8bit_blockwise_core( + a, + qmap_ptr, + CODE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(a_reshaped), axis=1) + + a_normalized = a_reshaped / absmax[:, None] + a_normalized = tl.clamp(a_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + for _ in range(8): + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(qmap_ptr + pivot) + is_higher = a_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(qmap_ptr + lower_pivot) + upper_val = tl.load(qmap_ptr + upper_pivot) + lower_dist = tl.abs(a_normalized - lower_val) + upper_dist = tl.abs(a_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) + return quantized_flat, absmax + + +@triton.jit +def dequant_8bit_kernel_util( + codes_ptr, + offsets, + qmap_ptr, + absmax_ptr, + mask, + BLOCK_SIZE: tl.constexpr, +): + codes = tl.load(codes_ptr + offsets, mask, other=0).to(tl.uint8) + abs_offsets = offsets // BLOCK_SIZE + absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=0.0, eviction_policy="evict_last") + + # apply conversion + scaled_int8 = tl.load(qmap_ptr + codes, mask) + # apply scales + out_dq = scaled_int8 * absmax + return out_dq diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py new file mode 100644 index 000000000..530ef472d --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -0,0 +1,722 @@ +import math +from typing import Optional + +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + +from .kernels_8bit_quant import ( + dequant_8bit_blockwise, + dequant_8bit_kernel_util, + quantize_8bit_blockwise_core, + quantize_blockwise_triton, +) + +########################################### +# Pure torch implementation for reference # +########################################### + + +@torch.compile +def _dequantize_blockwise_pytorch( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Pure PyTorch reference implementation for block-wise dequantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=dtype) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype) + + num_blocks = math.ceil(num_elements / blocksize) + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len)) + + dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize) + + rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype) + + rescaled_flat = rescaled_blocks.flatten() + if pad_len > 0: + rescaled_flat = rescaled_flat[:-pad_len] + + return rescaled_flat.reshape(A.shape) + + +@torch.compile +def _quantize_blockwise_pytorch( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pure PyTorch reference implementation for block-wise quantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + num_blocks = math.ceil(num_elements / blocksize) + + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + A_flat = torch.nn.functional.pad(A_flat, (0, pad_len)) + + A_blocks = A_flat.reshape(num_blocks, blocksize) + + absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0] + absmax[absmax == 0] = 1.0 + + scaled_blocks = A_blocks / absmax + + # Inefficient but straightforward quantization, takes a lot of memory + diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device)) + quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8) + + quantized_flat = quantized_indices.flatten() + if pad_len > 0: + quantized_flat = quantized_flat[:-pad_len] + + return quantized_flat.reshape(A.shape), absmax.flatten() + + +# Main updated function +def optimizer_update_8bit_blockwise_pytorch( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + n: int, + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros: + raise ValueError("skip_zeros is not supported on XPU yet.") + + blocksize = 256 + + with torch.no_grad(): + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32) + s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32) + + grad = g.float() * gnorm_scale + + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + if optimizer_name == "adam": + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +####################################### +# Mixed torch + triton implementation # +####################################### + + +# Much more memory efficient due to using triton for quantization/dequantization +def optimizer_update_8bit_blockwise_triton_quant( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + n: int, + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros and not torch.any(g): + return + + blocksize = 256 + grad = g.float() * gnorm_scale + + with torch.no_grad(): + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32) + s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32) + + # Apply optimizer-specific update logic + if optimizer_name == "adam": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +######################### +# Triton implementation # +######################### + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + + +@triton.jit +def _optimizer_update_1state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay, + gnorm_scale, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use one momentum state. + Supports: Momentum, RMSprop, Adagrad, Lion. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + + # 3. Optimizer-specific updates + # LION + if weight_decay > 0.0 and OPTIMIZER_ID == 2: + p *= 1.0 - lr * weight_decay + # Apply weight decay for momentum, rmsprop, adagrad + elif weight_decay > 0.0: + g += p * weight_decay + + # Momentum update + if OPTIMIZER_ID == 0: # MOMENTUM + if step == 1: + s1 = g + else: + s1 = s1 * beta1 + g + p -= lr * s1 + + # RMSprop update + elif OPTIMIZER_ID == 1: # RMSPROP + s1 = s1 * beta1 + (1.0 - beta1) * g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Adagrad update + elif OPTIMIZER_ID == 2: # ADAGRAD + s1 += g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Lion update + elif OPTIMIZER_ID == 4: # LION + val = s1 * beta1 + (1.0 - beta1) * g + update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0)) + p -= lr * update + s1 = s1 * beta2 + (1.0 - beta2) * g + + # 4. Store updated parameter and requantized state + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + +@triton.jit +def _optimizer_update_2state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + # ademamix changes alpha and beta3 + beta3, + # ademamix changes alpha and beta3 + alpha, + eps: tl.constexpr, + step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay: tl.constexpr, + gnorm_scale: tl.constexpr, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use two momentum states. + Supports: Adam, AdEMAMix. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # 3. Optimizer-specific updates + if OPTIMIZER_ID == 3: # ADAM + s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + s2 = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + s1 = s1 * beta1 + (1.0 - beta1) * g + s2 = s2 * beta2 + (1.0 - beta2) * g * g + + bias_correction1 = 1.0 - libdevice.pow(beta1, step) + bias_correction2 = 1.0 - libdevice.pow(beta2, step) + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps + p -= (lr / bias_correction1) * (s1 / denom) + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store states + s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + s2_codes, new_absmax2 = quantize_8bit_blockwise_core(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, s2_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2) + + elif OPTIMIZER_ID == 5: # ADEMAMIX + # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu) + m1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + m2 = dequant_8bit_kernel_util( + state1_ptr + n_elements, offsets, qmap1_ptr, absmax1_ptr + n_elements // BLOCK_SIZE_N, mask, BLOCK_SIZE_N + ) + nu = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + m1 = m1 * beta1 + (1.0 - beta1) * g + m2 = m2 * beta3 + (1.0 - beta3) * g + nu = nu * beta2 + (1.0 - beta2) * g * g + + bias_correction1 = 1.0 - libdevice.pow(beta1, step) + bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step)) + + update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps) + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + p -= lr * update + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store all three states + m1_codes, new_absmax_m1 = quantize_8bit_blockwise_core(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, m1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1) + + m2_codes, new_absmax_m2 = quantize_8bit_blockwise_core(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, new_absmax_m2) + + nu_codes, new_absmax_nu = quantize_8bit_blockwise_core(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, nu_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu) + + +def optimizer_update_1state_8bit_blockwise( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: torch.Tensor, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: torch.Tensor, + absmax1: torch.Tensor, + absmax2: torch.Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. + *, + optimizer_id: int, +): + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + BLOCK_SIZE = 256 + if n % BLOCK_SIZE != 0: + raise ValueError(f"Matrix size ({n}) must be a multiple of BLOCK_SIZE ({BLOCK_SIZE}) for block-wise updates.") + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + + _optimizer_update_1state_8bit_blockwise_triton_kernel[grid]( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + p.numel(), + BLOCK_SIZE_N=BLOCK_SIZE, + N_PER_TH=N_PER_TH, + OPTIMIZER_ID=optimizer_id, + num_warps=2, + ) + + +def optimizer_update_2state_8bit_blockwise( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: torch.Tensor, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: torch.Tensor, + absmax1: torch.Tensor, + absmax2: torch.Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. + *, + optimizer_id: int, +): + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + if optimizer_id == ADEMAMIX: + # Handle AdEMAMix's stacked state tensors + if state1.dim() < 2 or state1.shape[0] != 2: + raise ValueError( + f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}" + ) + if absmax1.dim() < 2 or absmax1.shape[0] != 2: + raise ValueError( + f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}" + ) + + BLOCK_SIZE = 256 + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + + _optimizer_update_2state_8bit_blockwise_triton_kernel[grid]( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + p.numel(), + BLOCK_SIZE_N=BLOCK_SIZE, + N_PER_TH=N_PER_TH, + OPTIMIZER_ID=optimizer_id, + num_warps=2, + ) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 1e2802ab5..fe6694761 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -1,8 +1,9 @@ from collections.abc import Sequence +from functools import partial import torch -from . import triton_kernels +from . import kernels_4bit, kernels_8bit_quant, kernels_optim # currently codes unused, kept for reference # Should be the same for quant/dequant @@ -14,16 +15,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) - out = torch.empty_like(A.flatten(), dtype=torch.uint8) - - triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) - out = out.reshape(A.shape) - + out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize) return out, absmax.float() @@ -34,20 +26,24 @@ def dequantize_blockwise( torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") - out = torch.empty_like(A, dtype=dtype, device=A.device) - triton_kernels.dequant_int8_blockwise( + out = kernels_8bit_quant.dequant_8bit_blockwise( A, - code, absmax, - out, + code, blocksize, + dtype=dtype, ) return out def dequantize_blockwise_inplace( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") @@ -55,12 +51,13 @@ def dequantize_blockwise_inplace( torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels.dequant_int8_blockwise( + kernels_8bit_quant.dequant_8bit_blockwise( A, - code, absmax, - out, + code, blocksize, + dtype=dtype, + out=out, ) @@ -84,7 +81,7 @@ def quantize_4bit( absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) - triton_kernels.quantize_4bit_blockwise_triton( + kernels_4bit.quantize_4bit_blockwise_triton( A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out ) packed = out @@ -119,7 +116,7 @@ def dequantize_4bit( out = torch.empty(shape, dtype=dtype, device=A.device) - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out @@ -134,7 +131,7 @@ def dequantize_4bit_inplace( ) -> None: torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) def gemv_4bit( @@ -150,7 +147,7 @@ def gemv_4bit( B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) - triton_kernels._dequantize_4bit_impl_passing_code( + kernels_4bit.dequantize_4bit_impl_passing_code( B, absmax, blocksize, @@ -164,3 +161,27 @@ def gemv_4bit( B_dq_triton, bias=None, ) + + +# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms +# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms +# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms + +# adam_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adam") +# momentum_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="momentum") +# rmsprop_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="rmsprop") +# lion_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="lion") +# adagrad_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adagrad") +# ademamix_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="ademamix") + +# ~0.95ms for adam +update_1state = kernels_optim.optimizer_update_1state_8bit_blockwise +update_2state = kernels_optim.optimizer_update_2state_8bit_blockwise +momentum_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["momentum"]) +rmsprop_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["rmsprop"]) +lion_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["lion"]) +adagrad_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["adagrad"]) + +ademamix_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["ademamix"]) +adam_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["adam"]) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9b446a2de..503d32002 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,6 +13,14 @@ from torch import Tensor from typing_extensions import deprecated +from bitsandbytes.backends.triton.ops import ( + adagrad_8bit_blockwise_grad, + adam_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, + lion_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, +) from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib @@ -84,34 +92,34 @@ str2optimizer8bit_blockwise = { "adam": ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp32, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, ), "momentum": ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - lib.cmomentum_8bit_blockwise_grad_bf16, + momentum_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, ), "rmsprop": ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - lib.crmsprop_8bit_blockwise_grad_bf16, + rmsprop_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, ), "lion": ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, + lion_8bit_blockwise_grad, + lion_8bit_blockwise_grad, + lion_8bit_blockwise_grad, ), "adagrad": ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - lib.cadagrad_8bit_blockwise_grad_bf16, + adagrad_8bit_blockwise_grad, + adagrad_8bit_blockwise_grad, + adagrad_8bit_blockwise_grad, ), "ademamix": ( - lib.cademamix_8bit_blockwise_grad_fp32, - lib.cademamix_8bit_blockwise_grad_fp16, - lib.cademamix_8bit_blockwise_grad_bf16, + ademamix_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, ), } @@ -422,8 +430,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): for t in tensors: # NULL pointers and paged tensors are OK. if t is not None and not getattr(t, "is_paged", False): - on_gpu &= t.is_cuda - gpu_ids.add(t.device.index) + on_gpu &= t.device.type != "cpu" + gpu_ids.add((t.device.type, t.device.index)) if not on_gpu: raise RuntimeError( @@ -1466,28 +1474,53 @@ def optimizer_update_8bit_blockwise( is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - with _cuda_device_of(g): + # print("p device: ", p.device, " g device: ", g.device) + # print("p device type: ", p.device, " g device type: ", g.device) + if p.device.type == "xpu": optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), + p, + g, + state1, + state2, + float(beta1), + float(beta2), + float(beta3), + float(alpha), + float(eps), + int(step), + float(lr), + qmap1, + qmap2, + absmax1, + absmax2, + float(weight_decay), + float(gnorm_scale), + bool(skip_zeros), + int(g.numel()), ) + else: + with _cuda_device_of(g): + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index ee1781a8b..36537be04 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -10,6 +10,7 @@ import torch import bitsandbytes.functional as F +from bitsandbytes.utils import sync_gpu class MockArgs: @@ -289,11 +290,11 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() + sync_gpu(p) if self.is_paged: # all paged operations are asynchronous, we need # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() + sync_gpu(loss) return loss diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 7920e2188..a3b043ba0 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data): LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} + + +def sync_gpu(t: torch.Tensor): + if t.device.type == "cuda": + torch.cuda.synchronize() + elif t.device.type == "xpu": + torch.xpu.synchronize() diff --git a/tests/test_optim.py b/tests/test_optim.py index 75e5a1714..0a998ba3e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb import bitsandbytes.functional as F -from tests.helpers import describe_dtype, id_formatter +from bitsandbytes.utils import sync_gpu +from tests.helpers import describe_dtype, get_available_devices, id_formatter # import apex @@ -168,7 +169,8 @@ def rm_path(path): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) -def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices(), ids=id_formatter("device")) +def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") @@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): pytest.skip() if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() @@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): atol, rtol = 1e-4, 1e-3 for i in range(k): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -201,7 +203,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2].cuda(), + bnb_optimizer.state[p2][name2].to(device), atol=atol, rtol=rtol, ) @@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) -def test_global_config(requires_cuda, dim1, dim2, gtype): +@pytest.mark.parametrize("device", get_available_devices()) +def test_global_config(dim1, dim2, gtype, device): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 @@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) - p1 = p1.cuda() - p2 = p2.cuda() - p3 = p3.cuda() + p1 = p1.to(device) + p2 = p2.to(device) + p3 = p3.to(device) adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps) @@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): atol, rtol = 1e-4, 1e-3 for i in range(50): - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 p1.grad = g1 p2.grad = g2 p3.grad = g3 @@ -302,13 +305,14 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices()) +def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): torch.set_printoptions(precision=6) if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() blocksize = 256 @@ -330,12 +334,12 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): relerrors = [] for i in range(50): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() - bnb_optimizer.step() torch_optimizer.step() + bnb_optimizer.step() # since Lion can have pretty noisy updates where things lie at the boundary assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0) @@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ) num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 - # assert num_not_close.sum().item() < 20 + assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) @@ -549,25 +553,25 @@ def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits): @pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) @pytest.mark.benchmark -def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): +def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device): if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 bnb_optimizer = str2optimizers[optim_name][1]([p1]) - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g total_steps = 500 for i in range(total_steps): if i == total_steps // 5: # 100 iterations for burn-in - torch.cuda.synchronize() + sync_gpu(p1) t0 = time.time() bnb_optimizer.step() - torch.cuda.synchronize() + sync_gpu(p1) s = time.time() - t0 print("") params = (total_steps - total_steps // 5) * dim1 * dim2 From c74cd39eca6474cf36a8ff822b2c24ba8eb6b940 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 11 Jul 2025 13:44:10 +0000 Subject: [PATCH 2/6] Add interface --- bitsandbytes/_ops.py | 61 ++++++++ bitsandbytes/backends/cuda/ops.py | 130 ++++++++++++++++++ bitsandbytes/backends/triton/kernels_optim.py | 129 ++++++++--------- bitsandbytes/backends/triton/ops.py | 106 ++++++++++---- bitsandbytes/backends/xpu/ops.py | 1 + bitsandbytes/functional.py | 124 +++-------------- 6 files changed, 350 insertions(+), 201 deletions(-) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..9d5882525 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -348,3 +348,64 @@ def _( ) -> torch.Tensor: torch._check_is_size(blocksize) return torch.empty(shape, dtype=dtype, device=A.device) + + +torch.library.define( + "bitsandbytes::optimizer_update_8bit_blockwise", + "(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()", +) + + +@register_fake("bitsandbytes::optimizer_update_8bit_blockwise") +def _( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + torch._check( + g.numel() == p.numel(), + lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + ) + compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + torch._check( + g.dtype in compute_dtypes, + lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + ) + torch._check( + g.dtype == p.dtype, + lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + ) + torch._check( + state1.dtype == torch.uint8, + lambda: f"state1 must be uint8, got {state1.dtype}", + ) + torch._check( + qmap1.dtype == absmax1.dtype == torch.float32, + lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + ) + if state2 is not None: + torch._check( + state2.dtype == torch.uint8, + lambda: f"state2 must be uint8, got {state2.dtype}", + ) + torch._check( + qmap2.dtype == absmax2.dtype == torch.float32, + lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + ) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 13359bbd8..8e6c6fedf 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -538,3 +538,133 @@ def _gemv_4bit_impl( ct.c_int32(blocksize), stream, ) + + +str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + lib.cmomentum_8bit_blockwise_grad_bf16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + lib.crmsprop_8bit_blockwise_grad_bf16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + lib.cadagrad_8bit_blockwise_grad_bf16, + ), + "ademamix": ( + lib.cademamix_8bit_blockwise_grad_fp32, + lib.cademamix_8bit_blockwise_grad_fp16, + lib.cademamix_8bit_blockwise_grad_bf16, + ), +} + + +def _optimizer_update_8bit_blockwise_impl( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.nsor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + # torch._check( + # g.numel() == p.numel(), + # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + # ) + # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + # torch._check( + # g.dtype in compute_dtypes, + # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + # ) + # torch._check( + # g.dtype == p.dtype, + # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + # ) + # torch._check( + # state1.dtype == torch.uint8, + # lambda: f"state1 must be uint8, got {state1.dtype}", + # ) + # torch._check( + # qmap1.dtype == absmax1.dtype == torch.float32, + # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + # ) + # if state2 is not None: + # torch._check( + # state2.dtype == torch.uint8, + # lambda: f"state2 must be uint8, got {state2.dtype}", + # ) + # torch._check( + # qmap2.dtype == absmax2.dtype == torch.float32, + # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + # ) + optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name) + if optimizer_fns is None: + raise ValueError( + f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" + ) + + if g.dtype == torch.float32: + optimizer_fn = optimizer_fns[0] + elif g.dtype == torch.float16: + optimizer_fn = optimizer_fns[1] + elif g.dtype == torch.bfloat16: + optimizer_fn = optimizer_fns[2] + else: + raise ValueError( + f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16" + ) + + with _cuda_device_of(g): + optimizer_fn( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + + +register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl) diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index 530ef472d..3b91115a0 100644 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -5,8 +5,8 @@ import triton import triton.language as tl -from triton.language.extra import libdevice +# from triton.language.extra import libdevice from .kernels_8bit_quant import ( dequant_8bit_blockwise, dequant_8bit_kernel_util, @@ -419,6 +419,8 @@ def _optimizer_update_1state_8bit_blockwise_triton_kernel( alpha, eps: tl.constexpr, step, + beta1_step, + beta2_step, lr, qmap1_ptr, qmap2_ptr, @@ -502,6 +504,8 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( alpha, eps: tl.constexpr, step, + beta1_step, + beta2_step, lr, qmap1_ptr, qmap2_ptr, @@ -537,8 +541,12 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( s1 = s1 * beta1 + (1.0 - beta1) * g s2 = s2 * beta2 + (1.0 - beta2) * g * g - bias_correction1 = 1.0 - libdevice.pow(beta1, step) - bias_correction2 = 1.0 - libdevice.pow(beta2, step) + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + # bias_correction1 = 1.0 - libdevice.pow(beta1, step) + # bias_correction2 = 1.0 - libdevice.pow(beta2, step) + bias_correction1 = 1.0 - beta1_step + bias_correction2 = 1.0 - beta2_step if weight_decay > 0.0: p *= 1.0 - lr * weight_decay @@ -562,7 +570,12 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu) m1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) m2 = dequant_8bit_kernel_util( - state1_ptr + n_elements, offsets, qmap1_ptr, absmax1_ptr + n_elements // BLOCK_SIZE_N, mask, BLOCK_SIZE_N + state1_ptr + n_elements, + offsets, + qmap1_ptr, + absmax1_ptr + n_elements // BLOCK_SIZE_N, + mask, + BLOCK_SIZE_N, ) nu = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) @@ -570,8 +583,12 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( m2 = m2 * beta3 + (1.0 - beta3) * g nu = nu * beta2 + (1.0 - beta2) * g * g - bias_correction1 = 1.0 - libdevice.pow(beta1, step) - bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step)) + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + # bias_correction1 = 1.0 - libdevice.pow(beta1, step) + # bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step)) + bias_correction1 = 1.0 - beta1_step + bias_correction2 = tl.sqrt(1.0 - beta2_step) update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps) @@ -590,76 +607,32 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( m2_codes, new_absmax_m2 = quantize_8bit_blockwise_core(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask) - tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, new_absmax_m2) + tl.store( + absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, + new_absmax_m2, + ) nu_codes, new_absmax_nu = quantize_8bit_blockwise_core(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state2_ptr + offsets, nu_codes, mask=mask) tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu) -def optimizer_update_1state_8bit_blockwise( - p: torch.Tensor, - g: torch.Tensor, - state1: torch.Tensor, - state2: torch.Tensor, - beta1: float, - beta2: float, - beta3: float, - alpha: float, - eps: float, - step: int, - lr: float, - qmap1: torch.Tensor, - qmap2: torch.Tensor, - absmax1: torch.Tensor, - absmax2: torch.Tensor, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - skip_zeros=False, - n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. - *, - optimizer_id: int, -): - if skip_zeros: - raise NotImplementedError("skip_zeros is not supported on XPU yet") - - BLOCK_SIZE = 256 - if n % BLOCK_SIZE != 0: - raise ValueError(f"Matrix size ({n}) must be a multiple of BLOCK_SIZE ({BLOCK_SIZE}) for block-wise updates.") - N_PER_TH = 1 # Number of blocks processed per thread. - grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) - - _optimizer_update_1state_8bit_blockwise_triton_kernel[grid]( - p, - g, - state1, - state2, - beta1, - beta2, - beta3, - alpha, - eps, - step, - lr, - qmap1, - qmap2, - absmax1, - absmax2, - weight_decay, - gnorm_scale, - p.numel(), - BLOCK_SIZE_N=BLOCK_SIZE, - N_PER_TH=N_PER_TH, - OPTIMIZER_ID=optimizer_id, - num_warps=2, - ) +name2optimizer_fn = { + "momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "adam": _optimizer_update_2state_8bit_blockwise_triton_kernel, + "lion": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel, +} -def optimizer_update_2state_8bit_blockwise( - p: torch.Tensor, +def optimizer_update_8bit_blockwise_impl( + optimizer_name: str, g: torch.Tensor, + p: torch.Tensor, state1: torch.Tensor, - state2: torch.Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, beta3: float, @@ -668,21 +641,18 @@ def optimizer_update_2state_8bit_blockwise( step: int, lr: float, qmap1: torch.Tensor, - qmap2: torch.Tensor, + qmap2: Optional[torch.Tensor], absmax1: torch.Tensor, - absmax2: torch.Tensor, + absmax2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, - n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. - *, - optimizer_id: int, -): +) -> None: if skip_zeros: raise NotImplementedError("skip_zeros is not supported on XPU yet") - if optimizer_id == ADEMAMIX: - # Handle AdEMAMix's stacked state tensors + if optimizer_name == "ademamix": + # Handle AdEMAMIX's stacked state tensors if state1.dim() < 2 or state1.shape[0] != 2: raise ValueError( f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}" @@ -695,8 +665,15 @@ def optimizer_update_2state_8bit_blockwise( BLOCK_SIZE = 256 N_PER_TH = 1 # Number of blocks processed per thread. grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + fn = name2optimizer_fn[optimizer_name] + optimizer_id = name2optimizer_id[optimizer_name] + + # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error. + # For backwards compatibility we precompute the bias correction factors. + beta1_step = beta1**step + beta2_step = beta2**step - _optimizer_update_2state_8bit_blockwise_triton_kernel[grid]( + fn[grid]( p, g, state1, @@ -707,6 +684,8 @@ def optimizer_update_2state_8bit_blockwise( alpha, eps, step, + beta1_step, + beta2_step, lr, qmap1, qmap2, diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index d3cd9136a..6cb34fdbc 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from functools import partial +from typing import Optional import torch @@ -170,25 +170,85 @@ def gemv_4bit( ) -# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_pytorch -# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms -# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms -# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms - -# adam_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adam") -# momentum_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="momentum") -# rmsprop_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="rmsprop") -# lion_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="lion") -# adagrad_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adagrad") -# ademamix_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="ademamix") - -# ~0.95ms for adam -update_1state = kernels_optim.optimizer_update_1state_8bit_blockwise -update_2state = kernels_optim.optimizer_update_2state_8bit_blockwise -momentum_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["momentum"]) -rmsprop_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["rmsprop"]) -lion_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["lion"]) -adagrad_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["adagrad"]) - -ademamix_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["ademamix"]) -adam_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["adam"]) +# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms +# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms +# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms +optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + torch._check( + g.numel() == p.numel(), + lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + ) + compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + torch._check( + g.dtype in compute_dtypes, + lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + ) + torch._check( + g.dtype == p.dtype, + lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + ) + torch._check( + state1.dtype == torch.uint8, + lambda: f"state1 must be uint8, got {state1.dtype}", + ) + torch._check( + qmap1.dtype == absmax1.dtype == torch.float32, + lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + ) + if state2 is not None: + torch._check( + state2.dtype == torch.uint8, + lambda: f"state2 must be uint8, got {state2.dtype}", + ) + torch._check( + qmap2.dtype == absmax2.dtype == torch.float32, + lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + ) + + with torch_accelerator_module.device(state1.device): + optimizer_update_8bit_blockwise_impl( + optimizer_name=optimizer_name, + g=g, + p=p, + state1=state1, + state2=state2, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + eps=eps, + step=step, + lr=lr, + qmap1=qmap1, + qmap2=qmap2, + absmax1=absmax1, + absmax2=absmax2, + weight_decay=weight_decay, + gnorm_scale=gnorm_scale, + skip_zeros=skip_zeros, + ) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 999116c97..59a69db90 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -63,5 +63,6 @@ def _( register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) + register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise) else: warnings.warn("XPU available but no ipex or triton packages found.") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 503d32002..243fda781 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,14 +13,6 @@ from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.backends.triton.ops import ( - adagrad_8bit_blockwise_grad, - adam_8bit_blockwise_grad, - ademamix_8bit_blockwise_grad, - lion_8bit_blockwise_grad, - momentum_8bit_blockwise_grad, - rmsprop_8bit_blockwise_grad, -) from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib @@ -90,39 +82,6 @@ ), } -str2optimizer8bit_blockwise = { - "adam": ( - adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp32, - adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, - adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, - ), - "momentum": ( - momentum_8bit_blockwise_grad, - momentum_8bit_blockwise_grad, - momentum_8bit_blockwise_grad, - ), - "rmsprop": ( - rmsprop_8bit_blockwise_grad, - rmsprop_8bit_blockwise_grad, - rmsprop_8bit_blockwise_grad, - ), - "lion": ( - lion_8bit_blockwise_grad, - lion_8bit_blockwise_grad, - lion_8bit_blockwise_grad, - ), - "adagrad": ( - adagrad_8bit_blockwise_grad, - adagrad_8bit_blockwise_grad, - adagrad_8bit_blockwise_grad, - ), - "ademamix": ( - ademamix_8bit_blockwise_grad, - ademamix_8bit_blockwise_grad, - ademamix_8bit_blockwise_grad, - ), -} - class GlobalPageManager: _instance = None @@ -1457,70 +1416,29 @@ def optimizer_update_8bit_blockwise( ) -> None: optim_func = None - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - # print("p device: ", p.device, " g device: ", g.device) - # print("p device type: ", p.device, " g device type: ", g.device) - if p.device.type == "xpu": - optim_func( - p, - g, - state1, - state2, - float(beta1), - float(beta2), - float(beta3), - float(alpha), - float(eps), - int(step), - float(lr), - qmap1, - qmap2, - absmax1, - absmax2, - float(weight_decay), - float(gnorm_scale), - bool(skip_zeros), - int(g.numel()), - ) - else: - with _cuda_device_of(g): - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + torch.ops.bitsandbytes.optimizer_update_8bit_blockwise( + optimizer_name, + g, + p, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) From aef6678a5c67ee579fa3b10b5259fa01b4ac5d47 Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 11 Jul 2025 14:16:51 +0000 Subject: [PATCH 3/6] Commented out torch checks --- bitsandbytes/backends/triton/ops.py | 60 ++++++++++++++--------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 6cb34fdbc..365600aa3 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -198,37 +198,37 @@ def optimizer_update_8bit_blockwise( gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - torch._check( - g.numel() == p.numel(), - lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", - ) - compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # torch._check( + # g.numel() == p.numel(), + # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", + # ) + # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] - torch._check( - g.dtype in compute_dtypes, - lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", - ) - torch._check( - g.dtype == p.dtype, - lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", - ) - torch._check( - state1.dtype == torch.uint8, - lambda: f"state1 must be uint8, got {state1.dtype}", - ) - torch._check( - qmap1.dtype == absmax1.dtype == torch.float32, - lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", - ) - if state2 is not None: - torch._check( - state2.dtype == torch.uint8, - lambda: f"state2 must be uint8, got {state2.dtype}", - ) - torch._check( - qmap2.dtype == absmax2.dtype == torch.float32, - lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", - ) + # torch._check( + # g.dtype in compute_dtypes, + # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", + # ) + # torch._check( + # g.dtype == p.dtype, + # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", + # ) + # torch._check( + # state1.dtype == torch.uint8, + # lambda: f"state1 must be uint8, got {state1.dtype}", + # ) + # torch._check( + # qmap1.dtype == absmax1.dtype == torch.float32, + # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", + # ) + # if state2 is not None: + # torch._check( + # state2.dtype == torch.uint8, + # lambda: f"state2 must be uint8, got {state2.dtype}", + # ) + # torch._check( + # qmap2.dtype == absmax2.dtype == torch.float32, + # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", + # ) with torch_accelerator_module.device(state1.device): optimizer_update_8bit_blockwise_impl( From a4a49bbf841402f725a6c5d6654e1ba549af241f Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Fri, 11 Jul 2025 15:14:58 +0000 Subject: [PATCH 4/6] Merged --- bitsandbytes/backends/cuda/ops.py | 2 +- bitsandbytes/optim/optimizer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 8e6c6fedf..268123f13 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -579,7 +579,7 @@ def _optimizer_update_8bit_blockwise_impl( g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, - state2: Optional[torch.nsor], + state2: Optional[torch.Tensor], beta1: float, beta2: float, beta3: float, diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 36537be04..7a40f1b75 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -280,6 +280,7 @@ def step(self, closure=None): self.initialized = True # if self.is_paged: self.page_mng.prefetch_all() + p = None for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -291,10 +292,10 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) sync_gpu(p) - if self.is_paged: + if self.is_paged and p is not None: # all paged operations are asynchronous, we need # to sync to make sure all tensors are in the right state - sync_gpu(loss) + sync_gpu(p) return loss From cc68b22f34656248b7adc6549719397b2ab3c77c Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Mon, 14 Jul 2025 15:33:04 +0000 Subject: [PATCH 5/6] Updated kernels --- bitsandbytes/backends/triton/kernels_optim.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index 3b91115a0..1ee7a90ec 100644 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -114,7 +114,6 @@ def optimizer_update_8bit_blockwise_pytorch( gnorm_scale: float, skip_zeros: bool, # ADEMIX - n: int, *, optimizer_name: str, ) -> None: @@ -262,7 +261,6 @@ def optimizer_update_8bit_blockwise_triton_quant( gnorm_scale: float, skip_zeros: bool, # ADEMIX - n: int, *, optimizer_name: str, ) -> None: @@ -627,7 +625,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( } -def optimizer_update_8bit_blockwise_impl( +def optimizer_update_8bit_blockwise_triton_impl( optimizer_name: str, g: torch.Tensor, p: torch.Tensor, @@ -699,3 +697,10 @@ def optimizer_update_8bit_blockwise_impl( OPTIMIZER_ID=optimizer_id, num_warps=2, ) + + +# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl) +# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant +# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant) +optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_impl From a13736acf95897196af0e92f97de2198b58d911c Mon Sep 17 00:00:00 2001 From: Egor Krivov Date: Tue, 15 Jul 2025 12:06:11 +0000 Subject: [PATCH 6/6] Reused code for quant/dequant --- .../backends/triton/kernels_8bit_quant.py | 103 +++++------------- bitsandbytes/backends/triton/kernels_optim.py | 32 +++--- 2 files changed, 46 insertions(+), 89 deletions(-) diff --git a/bitsandbytes/backends/triton/kernels_8bit_quant.py b/bitsandbytes/backends/triton/kernels_8bit_quant.py index 42f97b83c..c0a5a21ef 100644 --- a/bitsandbytes/backends/triton/kernels_8bit_quant.py +++ b/bitsandbytes/backends/triton/kernels_8bit_quant.py @@ -27,35 +27,19 @@ @triton.jit def dequant_8bit_kernel( a_ptr, - c_ptr, - quant_ptr, + out_ptr, + code_ptr, absmax_ptr, - num_paired_elements, + n, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * SPLIT_SIZE offsets = block_start + tl.arange(0, SPLIT_SIZE) - mask = offsets < num_paired_elements - - a = tl.load(a_ptr + offsets, mask) - a = a.to(tl.uint8) - - # apply conversion - scaled_int8 = tl.load(quant_ptr + a, mask) - - abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK - abs_offsets = offsets // QUANT_BLOCK - mask_blocked = offsets < abs_blocks_lim - - absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) - # apply scales - out_dq = scaled_int8 * absmax - - offs = block_start + tl.arange(0, SPLIT_SIZE) - mask = offs < num_paired_elements - tl.store(c_ptr + offs, out_dq, mask) + mask = offsets < n + out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK) + tl.store(out_ptr + offsets, out_dq, mask) def dequant_8bit_blockwise( @@ -66,7 +50,7 @@ def dequant_8bit_blockwise( dtype: torch.dtype = None, out: torch.Tensor = None, ): - number_of_paired_elements = a.numel() + n = a.numel() if out is None: if dtype is None: raise ValueError("If out is None, dtype must be specified") @@ -74,13 +58,13 @@ def dequant_8bit_blockwise( SPLIT_SIZE = 256 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) - grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + grid = (triton.cdiv(n, SPLIT_SIZE),) dequant_8bit_kernel[grid]( a, out, quant_state_code, absmax, - number_of_paired_elements, + n, quant_blocksize, SPLIT_SIZE, ) @@ -115,39 +99,9 @@ def quantize_8bit_blockwise_kernel( A = tl.load(A_ptr + offsets, mask=mask, other=0.0) - # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) - A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) - - # Calculating absamax for each block - absmax = tl.max(tl.abs(A_reshaped), axis=1) + quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS) tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) - - A_normalized = A_reshaped / absmax[:, None] - A_normalized = tl.clamp(A_normalized, -1.0, 1.0) - - lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) - upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) - - for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter - pivot = (lower_pivot + upper_pivot) // 2 - val = tl.load(code_ptr + pivot) - is_higher = A_normalized > val # code[pivot] - lower_pivot = tl.where(is_higher, pivot, lower_pivot) - upper_pivot = tl.where(is_higher, upper_pivot, pivot) - - # Choose closest level - lower_val = tl.load(code_ptr + lower_pivot) - upper_val = tl.load(code_ptr + upper_pivot) - lower_dist = tl.abs(A_normalized - lower_val) - upper_dist = tl.abs(A_normalized - upper_val) - quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) - - # too slow approach - # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) - # quantized = tl.argmin(diff, axis=2).to(tl.uint8) - - quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) - tl.store(out_ptr + offsets, quantized_flat, mask=mask) + tl.store(out_ptr + offsets, quantized, mask=mask) def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): @@ -180,9 +134,9 @@ def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): @triton.jit -def quantize_8bit_blockwise_core( +def quantize_8bit_blockwise_kernel_util( a, - qmap_ptr, + code_ptr, CODE_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, N_PER_TH: tl.constexpr, @@ -190,7 +144,7 @@ def quantize_8bit_blockwise_core( # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) - # Calculating absamax for each block + # Calculating absmax for each block absmax = tl.max(tl.abs(a_reshaped), axis=1) a_normalized = a_reshaped / absmax[:, None] @@ -202,37 +156,40 @@ def quantize_8bit_blockwise_core( # ceil(log2(code_size)) = 8, actually, in general case should be input parameter for _ in range(8): pivot = (lower_pivot + upper_pivot) // 2 - val = tl.load(qmap_ptr + pivot) + val = tl.load(code_ptr + pivot) is_higher = a_normalized > val # code[pivot] lower_pivot = tl.where(is_higher, pivot, lower_pivot) upper_pivot = tl.where(is_higher, upper_pivot, pivot) # Choose closest level - lower_val = tl.load(qmap_ptr + lower_pivot) - upper_val = tl.load(qmap_ptr + upper_pivot) + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) lower_dist = tl.abs(a_normalized - lower_val) upper_dist = tl.abs(a_normalized - upper_val) quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + # too slow approach + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) return quantized_flat, absmax @triton.jit -def dequant_8bit_kernel_util( - codes_ptr, +def dequant_8bit_blockwise_kernel_util( + a_ptr, offsets, - qmap_ptr, + code_ptr, absmax_ptr, mask, BLOCK_SIZE: tl.constexpr, ): - codes = tl.load(codes_ptr + offsets, mask, other=0).to(tl.uint8) - abs_offsets = offsets // BLOCK_SIZE - absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=0.0, eviction_policy="evict_last") - - # apply conversion - scaled_int8 = tl.load(qmap_ptr + codes, mask) - # apply scales + a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8) + scaled_int8 = tl.load(code_ptr + a, mask) + # Load scales + absmax_offsets = offsets // BLOCK_SIZE + absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last") + # Apply scales out_dq = scaled_int8 * absmax return out_dq diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index 1ee7a90ec..ef6c965f8 100644 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -9,8 +9,8 @@ # from triton.language.extra import libdevice from .kernels_8bit_quant import ( dequant_8bit_blockwise, - dequant_8bit_kernel_util, - quantize_8bit_blockwise_core, + dequant_8bit_blockwise_kernel_util, + quantize_8bit_blockwise_kernel_util, quantize_blockwise_triton, ) @@ -445,7 +445,7 @@ def _optimizer_update_1state_8bit_blockwise_triton_kernel( # 2. Load and dequantize tensors g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) # 3. Optimizer-specific updates # LION @@ -482,7 +482,7 @@ def _optimizer_update_1state_8bit_blockwise_triton_kernel( # 4. Store updated parameter and requantized state tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) - s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state1_ptr + offsets, s1_codes, mask=mask) tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) @@ -533,8 +533,8 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( # 3. Optimizer-specific updates if OPTIMIZER_ID == 3: # ADAM - s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) - s2 = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) s1 = s1 * beta1 + (1.0 - beta1) * g s2 = s2 * beta2 + (1.0 - beta2) * g * g @@ -556,18 +556,18 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) # Requantize and store states - s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state1_ptr + offsets, s1_codes, mask=mask) tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) - s2_codes, new_absmax2 = quantize_8bit_blockwise_core(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state2_ptr + offsets, s2_codes, mask=mask) tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2) elif OPTIMIZER_ID == 5: # ADEMAMIX # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu) - m1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) - m2 = dequant_8bit_kernel_util( + m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + m2 = dequant_8bit_blockwise_kernel_util( state1_ptr + n_elements, offsets, qmap1_ptr, @@ -575,7 +575,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( mask, BLOCK_SIZE_N, ) - nu = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) m1 = m1 * beta1 + (1.0 - beta1) * g m2 = m2 * beta3 + (1.0 - beta3) * g @@ -599,18 +599,18 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) # Requantize and store all three states - m1_codes, new_absmax_m1 = quantize_8bit_blockwise_core(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state1_ptr + offsets, m1_codes, mask=mask) tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1) - m2_codes, new_absmax_m2 = quantize_8bit_blockwise_core(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask) tl.store( absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, new_absmax_m2, ) - nu_codes, new_absmax_nu = quantize_8bit_blockwise_core(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) tl.store(state2_ptr + offsets, nu_codes, mask=mask) tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu) @@ -625,7 +625,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( } -def optimizer_update_8bit_blockwise_triton_impl( +def optimizer_update_8bit_blockwise_impl( optimizer_name: str, g: torch.Tensor, p: torch.Tensor, @@ -703,4 +703,4 @@ def optimizer_update_8bit_blockwise_triton_impl( # optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl) # optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant # optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant) -optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_impl +optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl