|
| 1 | +# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024 |
| 2 | +#******************************************************** |
| 3 | +import torch, math |
| 4 | +from torch import Tensor |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +def init_to_zero(name): |
| 9 | + return lambda nargs: nargs[name].zero_() |
| 10 | + |
| 11 | +def is_divisible(dividend, divisor): |
| 12 | + return dividend % divisor == 0 |
| 13 | + |
| 14 | +def kernel_config_pruner(configs, nargs, **kwargs): |
| 15 | + m = nargs['M'] |
| 16 | + n = nargs['N'] |
| 17 | + k = nargs['K'] |
| 18 | + g = nargs['group_size'] |
| 19 | + |
| 20 | + used = set() |
| 21 | + for config in configs: |
| 22 | + group_size_m = config.kwargs['GROUP_SIZE_M'] |
| 23 | + block_size_m = config.kwargs['BLOCK_SIZE_M'] #min(m, config.kwargs['BLOCK_SIZE_M']) |
| 24 | + block_size_n = config.kwargs['BLOCK_SIZE_N'] #min(n, config.kwargs['BLOCK_SIZE_N']) |
| 25 | + block_size_k = config.kwargs['BLOCK_SIZE_K'] #min(k, config.kwargs['BLOCK_SIZE_K']) |
| 26 | + split_k = config.kwargs['SPLIT_K'] |
| 27 | + |
| 28 | + #Constraints |
| 29 | + #BLOCK_SIZE_K >= group_size |
| 30 | + block_size_k = min(block_size_k, g) |
| 31 | + #K needs to be devisible by BLOCK_SIZE_K * SPLIT_K |
| 32 | + if(not is_divisible(k, block_size_k * split_k)): |
| 33 | + continue |
| 34 | + |
| 35 | + A_load_order = config.kwargs['A_load_order'] |
| 36 | + meta_evict_policy = config.kwargs['meta_evict_policy'] |
| 37 | + atomic_mode = config.kwargs['atomic_mode'] |
| 38 | + |
| 39 | + _key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k, |
| 40 | + A_load_order, meta_evict_policy, atomic_mode, |
| 41 | + config.num_stages, config.num_warps, |
| 42 | + ) |
| 43 | + |
| 44 | + if _key in used: |
| 45 | + continue |
| 46 | + |
| 47 | + used.add(_key) |
| 48 | + yield triton.Config( |
| 49 | + { |
| 50 | + 'BLOCK_SIZE_M': block_size_m, |
| 51 | + 'BLOCK_SIZE_N': block_size_n, |
| 52 | + 'BLOCK_SIZE_K': block_size_k, |
| 53 | + 'GROUP_SIZE_M': group_size_m, |
| 54 | + 'SPLIT_K' : split_k, |
| 55 | + |
| 56 | + 'A_load_order' : A_load_order, |
| 57 | + 'meta_evict_policy' : meta_evict_policy, |
| 58 | + 'atomic_mode' : atomic_mode, |
| 59 | + }, |
| 60 | + num_stages=config.num_stages, |
| 61 | + num_warps=config.num_warps, |
| 62 | + pre_hook=config.pre_hook, |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def get_gemm_config(): |
| 67 | + #Tuned on 4090 RTX |
| 68 | + _configs = [] |
| 69 | + for _M in [16]: #This is fixed to 16 for skinny matrices |
| 70 | + for _N in [32, 64]: |
| 71 | + for _K in [32, 64, 128]: #[128], group_size >= 128 |
| 72 | + for _w in [4]: #[4] |
| 73 | + for _s in [2, 3]: #[2, 3] # |
| 74 | + for _sK in [2, 4, 8]: #[2, 4, 8] |
| 75 | + for _a_load_order in [1, 2, 3]: #[1, 2, 3] - [1]: default 4090 |
| 76 | + for _meta_evict_policy in ['']: #[', 'evict_last'] - ['']: default 4090 |
| 77 | + for _atomic_mode in ['release', 'relaxed']: #['release', 'relaxed']: |
| 78 | + _configs.append( |
| 79 | + triton.Config( |
| 80 | + {'BLOCK_SIZE_M': _M, 'BLOCK_SIZE_N': _N, 'BLOCK_SIZE_K': _K, |
| 81 | + 'GROUP_SIZE_M': 8, 'SPLIT_K': _sK, |
| 82 | + 'A_load_order': _a_load_order, 'meta_evict_policy': _meta_evict_policy, 'atomic_mode': _atomic_mode, |
| 83 | + }, |
| 84 | + num_stages=_s, num_warps=_w, |
| 85 | + pre_hook=init_to_zero("c_ptr"), |
| 86 | + ) |
| 87 | + ) |
| 88 | + return _configs |
| 89 | + |
| 90 | + |
| 91 | + |
| 92 | +#@triton.heuristics(values={'CLOSEST_M': lambda args: 2 ** int(math.ceil(math.log2(args['M'])))}) |
| 93 | +@triton.autotune( |
| 94 | + configs = get_gemm_config(), |
| 95 | + key=['M', 'N', 'K', 'group_size', 'elements_per_sample'], |
| 96 | + prune_configs_by={ |
| 97 | + 'early_config_prune': kernel_config_pruner, |
| 98 | + }, |
| 99 | + warmup=200, |
| 100 | + rep=50, #20 for faster tuning |
| 101 | +) |
| 102 | + |
| 103 | +@triton.jit |
| 104 | +def gemm_splitK_A16fWnO16f_int32packing_kernel( |
| 105 | + a_ptr, b_ptr, c_ptr, |
| 106 | + scales_ptr, zeros_ptr, |
| 107 | + M, N, K, |
| 108 | + W_nbits: tl.constexpr, group_size: tl.constexpr, unpack_mask: tl.constexpr, elements_per_sample: tl.constexpr, |
| 109 | + stride_am, stride_ak, |
| 110 | + stride_bk, stride_bn, |
| 111 | + stride_cm, stride_cn, |
| 112 | + stride_meta_g, stride_meta_n, |
| 113 | + acc_dtype: tl.constexpr, |
| 114 | + ######### tuning params ######### |
| 115 | + #CLOSEST_M: tl.constexpr, |
| 116 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, |
| 117 | + GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, |
| 118 | + A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr, atomic_mode: tl.constexpr, |
| 119 | +): |
| 120 | + """ |
| 121 | + Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py |
| 122 | + GEMM for C = matmul(A, dequantize(B, scales, zeros)) |
| 123 | + A is of shape (M, K): float16 or bfloat16 |
| 124 | + B is of shape (K//elements_per_sample, N): int32 as a packed matrix |
| 125 | + C is of shape (M, N): float16 or bfloat16 depending on the input A |
| 126 | + scales and zeros is of shape (group_size, N): float16 or bfloat16 |
| 127 | +
|
| 128 | + BLOCK_SIZE_M >=16 |
| 129 | + BLOCK_SIZE_K * SPLIT_K <= group_size for imp1 |
| 130 | + BLOCK_SIZE_K == SPLIT_K for imp2 (similar to original) |
| 131 | + """ |
| 132 | + |
| 133 | + pid = tl.program_id(axis=0) |
| 134 | + pid_k = tl.program_id(axis=1) |
| 135 | + |
| 136 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 137 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 138 | + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) |
| 139 | + |
| 140 | + #Swizzle |
| 141 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 142 | + group_id = pid // num_pid_in_group |
| 143 | + first_pid_m = group_id * GROUP_SIZE_M |
| 144 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 145 | + pid_m = first_pid_m + (pid % group_size_m) |
| 146 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 147 | + |
| 148 | + #Offsets |
| 149 | + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 150 | + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 151 | + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) |
| 152 | + |
| 153 | + #Vectorized coalesced load |
| 154 | + offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M) |
| 155 | + offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N) |
| 156 | + |
| 157 | + #Inputs |
| 158 | + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 159 | + a_mask = offs_am[:, None] < M |
| 160 | + b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) |
| 161 | + |
| 162 | + #Meta data stuff |
| 163 | + q_shifts = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] |
| 164 | + |
| 165 | + scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n |
| 166 | + zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n |
| 167 | + |
| 168 | + stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size |
| 169 | + BLOCK_SIZE_K_P: tl.constexpr = BLOCK_SIZE_K // elements_per_sample |
| 170 | + #################################################################################### |
| 171 | + |
| 172 | + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) |
| 173 | + |
| 174 | + for k in tl.range(0, num_pid_k, 1, num_stages=1): |
| 175 | + |
| 176 | + b = tl.load(b_ptrs, eviction_policy='evict_first') |
| 177 | + |
| 178 | + if(A_load_order == 1): #Early load |
| 179 | + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') |
| 180 | + |
| 181 | + #Meta-data loading policy |
| 182 | + k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32) |
| 183 | + scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) |
| 184 | + zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy) |
| 185 | + |
| 186 | + if(A_load_order == 2): #Mid load |
| 187 | + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') |
| 188 | + |
| 189 | + # Unpack and dequantize |
| 190 | + b = (b >> q_shifts) & unpack_mask |
| 191 | + b = (b.to(scales.dtype) - zeros) * scales |
| 192 | + |
| 193 | + if(A_load_order == 3): #Late load |
| 194 | + a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last') |
| 195 | + |
| 196 | + #Dot |
| 197 | + acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=acc_dtype, input_precision="ieee") |
| 198 | + |
| 199 | + #Advance |
| 200 | + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak |
| 201 | + b_ptrs += BLOCK_SIZE_K_P * SPLIT_K * stride_bk |
| 202 | + |
| 203 | + #Output |
| 204 | + #acc = acc.to(tl.float16) |
| 205 | + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 206 | + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 207 | + c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) |
| 208 | + tl.atomic_add(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), sem=atomic_mode) #release / relaxed |
| 209 | + |
| 210 | + |
| 211 | +@torch.library.custom_op("gemlite::gemm_splitK_A16fWnO16f_int32packing_forward", mutates_args=()) |
| 212 | +def gemm_splitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, |
| 213 | + W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, |
| 214 | + acc_dtype: int, |
| 215 | + ) -> Tensor: |
| 216 | + |
| 217 | + M, K, N = x.shape[0], x.shape[1], W_q.shape[1] |
| 218 | + |
| 219 | + #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" |
| 220 | + #assert group_size >= 128, "Only group_size >= 128 is currently supported" |
| 221 | + |
| 222 | + output = torch.empty((M, N), device=W_q.device, dtype=scales.dtype) |
| 223 | + |
| 224 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) |
| 225 | + |
| 226 | + gemm_splitK_A16fWnO16f_int32packing_kernel[grid]( |
| 227 | + x, W_q, output, |
| 228 | + scales, zeros, |
| 229 | + M, N, K, |
| 230 | + W_nbits, group_size, unpack_mask, elements_per_sample, |
| 231 | + x.stride(0), x.stride(1), |
| 232 | + W_q.stride(0), W_q.stride(1), |
| 233 | + output.stride(0), output.stride(1), |
| 234 | + scales.stride(0), scales.stride(1), |
| 235 | + tl.float16 if (acc_dtype == 1) else tl.float32, |
| 236 | + ) |
| 237 | + |
| 238 | + return output |
| 239 | + |
| 240 | +@torch.library.register_fake("gemlite::gemm_splitK_A16fWnO16f_int32packing_forward") |
| 241 | +def gemm_splitK_A16fWnO16f_int32packing_forward_fake(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor, |
| 242 | + W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int, |
| 243 | + acc_dtype: int, |
| 244 | + ) -> Tensor: |
| 245 | + |
| 246 | + M, K, N = x.shape[0], x.shape[1], W_q.shape[1] |
| 247 | + return torch.empty((M, N), device=W_q.device, dtype=scales.dtype) |
| 248 | + |
| 249 | + |
| 250 | +class gemm_splitK_A16fWnO16f_int32packing: |
| 251 | + kernel = gemm_splitK_A16fWnO16f_int32packing_kernel |
| 252 | + forward = gemm_splitK_A16fWnO16f_int32packing_forward |
| 253 | + matmul_type = "GEMM_SPLITK" |
| 254 | + |
| 255 | +__all__ = ["gemm_splitK_A16fWnO16f_int32packing"] |
0 commit comments