|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | +# from flag_gems import runtime |
| 8 | +from flag_gems.runtime import torch_device_fn |
| 9 | +from flag_gems.utils import triton_lang_extension as tle |
| 10 | + |
| 11 | + |
| 12 | +def prepare_tensor_for_kron(tensor_a, tensor_b): |
| 13 | + a_shape = list(tensor_a.shape) |
| 14 | + b_shape = list(tensor_b.shape) |
| 15 | + |
| 16 | + if tensor_a.numel() == 0 or tensor_b.numel() == 0: |
| 17 | + if not a_shape: |
| 18 | + a_shape = [0] |
| 19 | + if not b_shape: |
| 20 | + b_shape = [0] |
| 21 | + |
| 22 | + if len(a_shape) > len(b_shape): |
| 23 | + b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape |
| 24 | + elif len(b_shape) > len(a_shape): |
| 25 | + a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape |
| 26 | + |
| 27 | + out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) |
| 28 | + return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape |
| 29 | + |
| 30 | + if len(a_shape) < 2: |
| 31 | + a_shape = [1] * (2 - len(a_shape)) + a_shape |
| 32 | + if len(b_shape) < 2: |
| 33 | + b_shape = [1] * (2 - len(b_shape)) + b_shape |
| 34 | + |
| 35 | + if len(a_shape) > len(b_shape): |
| 36 | + b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape |
| 37 | + elif len(b_shape) > len(a_shape): |
| 38 | + a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape |
| 39 | + |
| 40 | + out_shape = tuple(a * b for a, b in zip(a_shape, b_shape)) |
| 41 | + return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape |
| 42 | + |
| 43 | + |
| 44 | +def calculate_indices(batch_idx, shape_a, shape_b): |
| 45 | + a_batch_dims = shape_a[:-2] or (1,) |
| 46 | + b_batch_dims = shape_b[:-2] or (1,) |
| 47 | + out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims)) |
| 48 | + |
| 49 | + out_indices = [] |
| 50 | + remaining = batch_idx |
| 51 | + for dim_size in out_batch_dims[::-1]: |
| 52 | + out_indices.insert(0, remaining % dim_size) |
| 53 | + remaining //= dim_size |
| 54 | + |
| 55 | + a_idx = b_idx = 0 |
| 56 | + for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)): |
| 57 | + a_idx = a_idx * a_dim + (out_idx // b_dim) |
| 58 | + b_idx = b_idx * b_dim + (out_idx % b_dim) |
| 59 | + |
| 60 | + return a_idx, b_idx |
| 61 | + |
| 62 | + |
| 63 | +def heur_block_n(args): |
| 64 | + import builtins |
| 65 | + |
| 66 | + return builtins.min(args["N"], 8192) |
| 67 | + |
| 68 | + |
| 69 | +def heur_block_m(args): |
| 70 | + return triton.next_power_of_2(triton.cdiv(args["M"], 12)) |
| 71 | + |
| 72 | + |
| 73 | +# @triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"]) |
| 74 | +@triton.heuristics( |
| 75 | + { |
| 76 | + "BLOCK_M": heur_block_m, |
| 77 | + "BLOCK_N": heur_block_n, |
| 78 | + } |
| 79 | +) |
| 80 | +@triton.jit |
| 81 | +def kron_kernel( |
| 82 | + a_ptr, |
| 83 | + b_ptr, |
| 84 | + c_ptr, |
| 85 | + map_ptr, |
| 86 | + batch_size: tl.int64, |
| 87 | + M: tl.int64, |
| 88 | + N: tl.int64, |
| 89 | + M1: tl.int64, |
| 90 | + M2: tl.int64, |
| 91 | + N1: tl.int64, |
| 92 | + N2: tl.int64, |
| 93 | + a_stride_0: tl.int64, |
| 94 | + a_stride_1: tl.int64, |
| 95 | + b_stride_0: tl.int64, |
| 96 | + b_stride_1: tl.int64, |
| 97 | + c_stride_0: tl.int64, |
| 98 | + c_stride_1: tl.int64, |
| 99 | + a_batch_stride: tl.int64, |
| 100 | + b_batch_stride: tl.int64, |
| 101 | + c_batch_stride: tl.int64, |
| 102 | + BLOCK_M: tl.constexpr, |
| 103 | + BLOCK_N: tl.constexpr, |
| 104 | +): |
| 105 | + pid = tle.program_id(0) |
| 106 | + num_blocks_n = tl.cdiv(N, BLOCK_N) |
| 107 | + num_blocks_m = tl.cdiv(M, BLOCK_M) |
| 108 | + num_blocks_per_batch = num_blocks_m * num_blocks_n |
| 109 | + |
| 110 | + batch_id = pid // num_blocks_per_batch |
| 111 | + local_pid = pid % num_blocks_per_batch |
| 112 | + block_m = local_pid // num_blocks_n |
| 113 | + block_n = local_pid % num_blocks_n |
| 114 | + |
| 115 | + offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 116 | + offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 117 | + |
| 118 | + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size) |
| 119 | + |
| 120 | + offset = batch_id * 2 |
| 121 | + is_valid = batch_id < batch_size |
| 122 | + a_batch_idx = tl.load(map_ptr + offset, mask=is_valid) |
| 123 | + b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid) |
| 124 | + |
| 125 | + a_row = offs_m[:, None] // M2 |
| 126 | + a_col = offs_n[None, :] // N2 |
| 127 | + b_row = offs_m[:, None] % M2 |
| 128 | + b_col = offs_n[None, :] % N2 |
| 129 | + |
| 130 | + a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1 |
| 131 | + b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1 |
| 132 | + |
| 133 | + a = tl.load(a_ptr + a_idx, mask=mask) |
| 134 | + b = tl.load(b_ptr + b_idx, mask=mask) |
| 135 | + c = a * b |
| 136 | + |
| 137 | + c_idx = ( |
| 138 | + batch_id * c_batch_stride |
| 139 | + + offs_m[:, None] * c_stride_0 |
| 140 | + + offs_n[None, :] * c_stride_1 |
| 141 | + ) |
| 142 | + tl.store(c_ptr + c_idx, c, mask=mask) |
| 143 | + |
| 144 | + |
| 145 | +def kron(A, B): |
| 146 | + if A.dim() == 0 and B.dim() == 0: |
| 147 | + return A * B |
| 148 | + |
| 149 | + if A.numel() == 0 or B.numel() == 0: |
| 150 | + A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) |
| 151 | + output_dtype = torch.promote_types(A.dtype, B.dtype) |
| 152 | + return torch.empty(out_shape, device=A.device, dtype=output_dtype) |
| 153 | + |
| 154 | + if A.dim() == 0: |
| 155 | + return A.unsqueeze(0) * B |
| 156 | + if B.dim() == 0: |
| 157 | + return A * B.unsqueeze(0) |
| 158 | + |
| 159 | + A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B) |
| 160 | + M1, N1 = A_prepared.shape[-2:] |
| 161 | + M2, N2 = B_prepared.shape[-2:] |
| 162 | + M, N = M1 * M2, N1 * N2 |
| 163 | + |
| 164 | + batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1 |
| 165 | + |
| 166 | + output_dtype = torch.promote_types(A.dtype, B.dtype) |
| 167 | + C = torch.empty(out_shape, device=A.device, dtype=output_dtype) |
| 168 | + |
| 169 | + C_reshaped = C.view(-1, M, N) |
| 170 | + A_view = A_prepared.reshape(-1, M1, N1) |
| 171 | + B_view = B_prepared.reshape(-1, M2, N2) |
| 172 | + |
| 173 | + if not A_view.is_contiguous(): |
| 174 | + A_view = A_view.contiguous() |
| 175 | + if not B_view.is_contiguous(): |
| 176 | + B_view = B_view.contiguous() |
| 177 | + |
| 178 | + batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64) |
| 179 | + for i in range(batch_size): |
| 180 | + a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape) |
| 181 | + batch_indices[i * 2] = a_idx |
| 182 | + batch_indices[i * 2 + 1] = b_idx |
| 183 | + |
| 184 | + a_batch_stride = M1 * N1 |
| 185 | + b_batch_stride = M2 * N2 |
| 186 | + c_batch_stride = M * N |
| 187 | + with torch_device_fn.device(A.device): |
| 188 | + grid = lambda meta: ( |
| 189 | + batch_size |
| 190 | + * triton.cdiv(M, meta["BLOCK_M"]) |
| 191 | + * triton.cdiv(N, meta["BLOCK_N"]), |
| 192 | + ) |
| 193 | + |
| 194 | + kron_kernel[grid]( |
| 195 | + A_view, |
| 196 | + B_view, |
| 197 | + C_reshaped, |
| 198 | + batch_indices, |
| 199 | + batch_size, |
| 200 | + M, |
| 201 | + N, |
| 202 | + M1, |
| 203 | + M2, |
| 204 | + N1, |
| 205 | + N2, |
| 206 | + A_view.stride(1), |
| 207 | + A_view.stride(2), |
| 208 | + B_view.stride(1), |
| 209 | + B_view.stride(2), |
| 210 | + C_reshaped.stride(1), |
| 211 | + C_reshaped.stride(2), |
| 212 | + a_batch_stride, |
| 213 | + b_batch_stride, |
| 214 | + c_batch_stride, |
| 215 | + ) |
| 216 | + |
| 217 | + if A.dim() <= 1 and B.dim() <= 1: |
| 218 | + return C.reshape(-1) |
| 219 | + |
| 220 | + return C |
0 commit comments