diff --git a/benchmark/attri_util.py b/benchmark/attri_util.py index dad081575..aa04fd015 100644 --- a/benchmark/attri_util.py +++ b/benchmark/attri_util.py @@ -28,26 +28,30 @@ ] -# This function is adapted from: https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/utils/triton_op.py -def llama_shapes(): +def model_shapes(): # batch sizes * seq lengths - BS = [2**i for i in range(0, 17)] + BS = [2**i for i in range(0, 9, 2)] # attn: wqkv, wo; ffn: w13, w2 - KN = [ - (4096, 12288), + NK = [ + # extract from llama3-8b + (1024, 4096), + (128256, 4096), + (14336, 4096), + (4096, 14336), (4096, 4096), - (4096, 22016), - (11008, 4096), - (8192, 1280), - (1024, 8192), - (8192, 7168), - (3584, 8192), - (16384, 2304), - (2048, 16384), - (16384, 13312), - (6656, 16384), + (6144, 4096), + (28672, 4096), + # extract from qwen2.5-7b + (3584, 3584), + (18944, 3584), + (3584, 18944), + (152064, 3584), + (37888, 3584), + (512, 3584), + (4608, 3584), ] - return [(bs, n, k, None) for bs, (k, n) in itertools.product(BS, KN)] + + return [(4, bs, n, k) for bs, (n, k) in itertools.product(BS, NK)] @dataclass diff --git a/benchmark/core_shapes.yaml b/benchmark/core_shapes.yaml index 616eed0d5..5ea08dce9 100644 --- a/benchmark/core_shapes.yaml +++ b/benchmark/core_shapes.yaml @@ -1,11 +1,3 @@ -outer: - shapes: - - [384, 384] - - [1024, 1024] - - [4096, 4096] - - [8192, 8192] - - [10240, 10240] #from perf - randperm: shapes: - [64] @@ -41,13 +33,21 @@ diag: BlasBenchmark: shapes: + - [2, 384, 384, 384] - [2, 4096, 4096, 4096] - - [16, 384, 384, 384] - [16, 1024, 1024, 1024] - [16, 2048, 2048, 2048] - [16, 4096, 4096, 4096] shape_desc: "B, M, N, K" # shapes are defined as (B, M, N, K) +MvAndOuterBenchmark: + shapes: + - [384, 384] + - [1024, 1024] + - [4096, 4096] + - [8192, 8192] + - [10240, 10240] #from perf + # NORM shapes can be either 3D or 4D: # - 3D shapes are represented as [batch_size, channels, hidden_size] # - 4D shapes are represented as [batch_size, channels, height, width] diff --git a/benchmark/test_blas_perf.py b/benchmark/test_blas_perf.py index e57894f96..dd547b406 100644 --- a/benchmark/test_blas_perf.py +++ b/benchmark/test_blas_perf.py @@ -1,12 +1,11 @@ -import itertools from typing import Generator import pytest import torch -from .attri_util import DEFAULT_METRICS, FLOAT_DTYPES, BenchLevel, llama_shapes +from .attri_util import DEFAULT_METRICS, FLOAT_DTYPES, BenchLevel, model_shapes from .conftest import Config -from .performance_utils import Benchmark +from .performance_utils import Benchmark, GenericBenchmark2DOnly class BlasBenchmark(Benchmark): @@ -22,31 +21,23 @@ def __init__(self, *args, input_fn, **kwargs): def get_input_iter(self, cur_dtype) -> Generator: for b, m, n, k in self.shapes: - yield from self.input_fn(b, m, n, k, cur_dtype, self.device) - # llama shapes + yield from self.input_fn(b, m, n, k, cur_dtype, self.device, False) + if Config.bench_level == BenchLevel.COMPREHENSIVE: - for m, n, k, _ in llama_shapes(): - yield from self.input_fn(1, m, n, k, cur_dtype, self.device) + for b, m, n, k in self.shapes: + yield from self.input_fn(b, m, n, k, cur_dtype, self.device, True) def set_more_shapes(self): - split_k_shapes = [ - (1, m, m, k) - for m in [16 * i for i in range(1, 5)] - for k in [4096 * i for i in range(1, 9)] + large_k_shapes = [ + (8, 1848, 1536, 151936), + (8, 1848, 1536, 128256), + (8, 1848, 1536, 152064), ] - # 'mv' operations only involve M and N dimensions. - # Shapes with large K values are not suitable for these two operations. - if self.op_name not in ["mv"]: - # B=1 or 4, M= 13, N= 2 , K=2^6..2^15 - large_k_shapes = list( - itertools.product([1, 4], [13], [2], [2**i for i in range(6, 15)]) - ) - return large_k_shapes + split_k_shapes - return split_k_shapes + + model_shaps = model_shapes() + return large_k_shapes + model_shaps def get_tflops(self, op, *args, **kwargs): - """This method is currently not really implemented and serves as a placeholder. - A proper implementation will be developed in the future.""" total_flops = 0 # shape(m,k)(k,n) # total_flops mxnx2k @@ -54,13 +45,12 @@ def get_tflops(self, op, *args, **kwargs): total_flops = args[0].shape[0] * args[0].shape[1] * args[1].shape[1] * 2 # shape(m,n)(n,p) # total_flops mxpx(2n+1) - if self.op_name == "addmm": + elif self.op_name == "addmm": total_flops = ( args[0].shape[0] * args[1].shape[1] * (args[1].shape[0] * 2 + 1) ) - # shape(b,n,m), (b,m,p) # total_flops bxnxpx2m - if self.op_name == "bmm": + elif self.op_name == "bmm": total_flops = ( args[0].shape[0] * args[0].shape[1] @@ -68,37 +58,38 @@ def get_tflops(self, op, *args, **kwargs): * 2 * args[0].shape[2] ) - # shape(n,m)(m,) - # total_flops n*2m - if self.op_name == "mv": - total_flops = args[0].shape[0] * 2 * args[0].shape[1] - return total_flops -def addmm_input_fn(b, m, n, k, cur_dtype, device): +def addmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major): inp1 = torch.randn([m, k], dtype=cur_dtype, device=device) - inp2 = torch.randn([k, n], dtype=cur_dtype, device=device) bias = torch.randn([m, n], dtype=cur_dtype, device=device) - yield bias, inp1, inp2, + if b_column_major: + inp2 = torch.randn([n, k], dtype=cur_dtype, device=device) + yield bias, inp1, inp2.t(), + else: + inp2 = torch.randn([k, n], dtype=cur_dtype, device=device) + yield bias, inp1, inp2, -def bmm_input_fn(b, m, n, k, cur_dtype, device): +def bmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major): inp1 = torch.randn([b, m, k], dtype=cur_dtype, device=device) - inp2 = torch.randn([b, k, n], dtype=cur_dtype, device=device) - yield inp1, inp2 + if b_column_major: + inp2 = torch.randn([b, n, k], dtype=cur_dtype, device=device) + yield inp1, inp2.transpose(1, 2) + else: + inp2 = torch.randn([b, k, n], dtype=cur_dtype, device=device) + yield inp1, inp2 -def mm_input_fn(b, m, n, k, cur_dtype, device): +def mm_input_fn(b, m, n, k, cur_dtype, device, b_column_major): inp1 = torch.randn([m, k], dtype=cur_dtype, device=device) - inp2 = torch.randn([k, n], dtype=cur_dtype, device=device) - yield inp1, inp2 - - -def mv_input_fn(b, m, n, k, cur_dtype, device): - inp1 = torch.randn([m, n], dtype=cur_dtype, device=device) - inp2 = torch.randn([n], dtype=cur_dtype, device=device) - yield inp1, inp2 + if b_column_major: + inp2 = torch.randn([n, k], dtype=cur_dtype, device=device) + yield inp1, inp2.t() + else: + inp2 = torch.randn([k, n], dtype=cur_dtype, device=device) + yield inp1, inp2 @pytest.mark.parametrize( @@ -122,12 +113,6 @@ def mv_input_fn(b, m, n, k, cur_dtype, device): mm_input_fn, marks=pytest.mark.mm, ), - pytest.param( - "mv", - torch.Tensor.mv, - mv_input_fn, - marks=pytest.mark.mv, - ), ], ) def test_blas_benchmark(op_name, torch_op, input_fn): @@ -137,9 +122,9 @@ def test_blas_benchmark(op_name, torch_op, input_fn): bench.run() -class OuterBenchmark(BlasBenchmark): +class MvAndOuterBenchmark(GenericBenchmark2DOnly): """ - benchmark for outer + Benchmark for MV and Outer operations """ def set_more_shapes(self): @@ -150,17 +135,40 @@ def get_input_iter(self, cur_dtype) -> Generator: yield from self.input_fn(m, n, cur_dtype, self.device) -@pytest.mark.outer -def test_outer_benchmark(): - def outer_input_fn(m, n, cur_dtype, device): - inp1 = torch.randn([m], dtype=cur_dtype, device=device) - inp2 = torch.randn([n], dtype=cur_dtype, device=device) - yield inp1, inp2 +def mv_input_fn(m, n, cur_dtype, device): + inp1 = torch.randn([m, n], dtype=cur_dtype, device=device) + inp2 = torch.randn([n], dtype=cur_dtype, device=device) + yield inp1, inp2 + + +def outer_input_fn(m, n, cur_dtype, device): + inp1 = torch.randn([m], dtype=cur_dtype, device=device) + inp2 = torch.randn([n], dtype=cur_dtype, device=device) + yield inp1, inp2 + - bench = OuterBenchmark( - input_fn=outer_input_fn, - op_name="outer", - torch_op=torch.Tensor.outer, +@pytest.mark.parametrize( + "op_name, torch_op, input_fn", + [ + pytest.param( + "mv", + torch.Tensor.mv, + mv_input_fn, + marks=pytest.mark.mv, + ), + pytest.param( + "outer", + torch.Tensor.outer, + outer_input_fn, + marks=pytest.mark.outer, + ), + ], +) +def test_mv_and_outer_benchmark(op_name, torch_op, input_fn): + bench = MvAndOuterBenchmark( + input_fn=input_fn, + op_name=op_name, + torch_op=torch_op, dtypes=FLOAT_DTYPES, ) bench.run() diff --git a/src/flag_gems/ops/mm.py b/src/flag_gems/ops/mm.py index 40a3a66a4..3251bebc6 100644 --- a/src/flag_gems/ops/mm.py +++ b/src/flag_gems/ops/mm.py @@ -9,71 +9,408 @@ from ..utils import libentry, libtuner from ..utils import triton_lang_extension as tle +try: + device_id = torch_device_fn.current_device() +except AttributeError: + device_id = 0 + +try: + L2_CACHE_SIZE = torch_device_fn.get_device_properties(device_id).L2_cache_size + SM_COUNT = torch_device_fn.get_device_properties(device_id).multi_processor_count +except AttributeError: + L2_CACHE_SIZE = 40 * 1024 * 1024 # 40MB in bytes + SM_COUNT = 82 # nvidia 3090 +CACHE_USAGE_THRESHOLD = 0.7 + +# TODO: make ALLOW_TF32 as an input params. + + +@triton.jit() +def swizzle_tile( + tile_id, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + @libentry() -@libtuner( - configs=runtime.get_tuned_config("mm"), - key=["M", "N", "K"], +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"]) == 0, + } ) -@triton.heuristics(runtime.get_heuristic_config("mm")) @triton.jit -def mm_kernel( +def first_wave( A, B, C, M, N, K, + locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, + total_full_tiles_streamk, + total_partial_tiles_streamk, + iters_per_tile, + ACC_TYPE: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, ): + pid = tl.program_id(0) # pid range from 0 to sm_count + start_iter = pid * total_full_tiles_streamk + tl.minimum( + pid, total_partial_tiles_streamk + ) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum( + pid + 1, total_partial_tiles_streamk + ) + while start_iter < last_iter: + remain_iters = start_iter % iters_per_tile + # Iterate over the K axis. Recalculate end_iter as M/N may change during the iteration. + end_iter = tl.minimum(start_iter + (iters_per_tile - remain_iters), last_iter) + + tile_id = start_iter // iters_per_tile + + pid_m, pid_n = swizzle_tile( + tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M + ) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + # pointers + A_ptr = ( + A + + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + + BLOCK_K * stride_ak * remain_iters + ) + B_ptr = ( + B + + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + BLOCK_K * stride_bk * remain_iters + ) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + if EVEN_K: + a = tl.load(A_ptr) + b = tl.load(B_ptr) + else: + k_mask = (current_iter % iters_per_tile) * BLOCK_K + rk < K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A_ptr, mask=k_mask[None, :], other=_0) + b = tl.load(B_ptr, mask=k_mask[:, None], other=_0) + acc += tl.dot(a, b, out_dtype=ACC_TYPE, allow_tf32=False) + A_ptr += BLOCK_K * stride_ak + B_ptr += BLOCK_K * stride_bk + # last iteration of the tile always happens before its start on another SM + if end_iter % iters_per_tile == 0: + C_ptr = C + ( + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) # compute inside the if/else to avoid spilling! + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_ptr, acc, mask=mask) + if remain_iters != 0: # only if tile has been partially processed + tl.atomic_xchg(locks + tile_id, 1) + else: + while tl.atomic_cas(locks + tile_id, 1, 1) != 1: + pass + C_ptr = C + ( + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) # compute inside the if/else to avoid spilling! + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.atomic_add(C_ptr, acc, mask=mask, sem="relaxed") + start_iter = end_iter + + +@libentry() +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"]) == 0, + } +) +@triton.jit +def classic_tiles_mm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + acc += tl.dot(a, b, out_dtype=ACC_TYPE, allow_tf32=False) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("mm_iobound") + runtime.get_tuned_config("mm"), + key=["M", "N", "K", "stride_am", "stride_bk"], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"]) == 0, + } +) +@triton.jit +def mm_kernel_with_grouped_k( + A, + B, + C, # [Split_K, M, N] + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cb, + stride_cm, + stride_cn, + acc_type: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, # Number of split-K groups + GROUP_K_LENGTH: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + assert GROUP_K_LENGTH % BLOCK_K == 0, "GROUP_K_LENGTH must be divisible by BLOCK_K" + + num_blocks_m = tl.cdiv(M, BLOCK_M) + total_num_m = num_blocks_m * SPLIT_K + + pid_n = pid // total_num_m + odd_column = pid_n % 2 + pid_m_normal = pid % total_num_m + # this is a line-one implementation for the following code: + # if odd_column: + # pid_m_for_c = (total_num_m - 1) - pid_m_normal + # else: + # pid_m_for_c = pid_m_normal + pid_m_for_c = (1 - odd_column) * pid_m_normal + odd_column * ( + total_num_m - 1 - pid_m_normal + ) + + pid_m = pid_m_for_c % num_blocks_m + pid_k = pid_m_for_c // num_blocks_m + + # Calculate K_LENGTH based on pid_k + group_k_length = min(K - pid_k * GROUP_K_LENGTH, GROUP_K_LENGTH) + # matrix multiplication + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + k_start = pid_k * GROUP_K_LENGTH + offs_k = k_start + tl.arange(0, BLOCK_K) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_n % N, BLOCK_N), BLOCK_N) + + # pointers + A_ptr = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + B_ptr = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_type) + + for k in range(0, tl.cdiv(group_k_length, BLOCK_K)): + if EVEN_K: + a = tl.load(A_ptr) + b = tl.load(B_ptr) + else: + k_remaining = k_start + group_k_length - k * BLOCK_K + a = tl.load(A_ptr, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(B_ptr, mask=offs_k[:, None] < k_remaining, other=0.0) + if a.dtype != b.dtype: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=acc_type, allow_tf32=False) + A_ptr += BLOCK_K * stride_ak + B_ptr += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + + # Store results + offs_cb = pid_k * stride_cb + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C_ptr = C + offs_cb + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] + + tl.store(C_ptr, acc, mask=mask) + + +@libentry() +@triton.autotune(configs=runtime.get_tuned_config("sum"), key=["M", "N"]) +@triton.jit +def group_merge_kernel( + SRC, # [SPLIT_K, M, N] 3D Tensor + DST, # [M, N] + SPLIT_K, + M, + N, + stride_k, + stride_m, + stride_n, + acc_type: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + offs_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_m = offs_m < M + mask_n = offs_n < N + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_type) + + for k in range(SPLIT_K): + src_ptr = ( + SRC + k * stride_k + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + ) + sub_matrix = tl.load(src_ptr, mask=mask_m[:, None] & mask_n[None, :], other=0.0) + + acc += sub_matrix + acc = acc.to(DST.dtype.element_ty) + dst_ptr = DST + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + tl.store(dst_ptr, acc, mask=mask_m[:, None] & mask_n[None, :]) + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("mm_iobound"), + # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. + key=["M", "N", "K", "stride_am", "stride_bk"], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"]) == 0, + } +) +@triton.jit +def mm_kernel_iobound( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + acc_type: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, +): + # column major tile pid = tle.program_id(0) - pid_z = tle.program_id(1) grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + pid_m = pid % grid_m + pid_n = pid // grid_m + # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_type) + for k in range(0, tl.cdiv(K, BLOCK_K)): if EVEN_K: a = tl.load(A) b = tl.load(B) else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) + k_remaining = K - k * BLOCK_K _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) if a.dtype != b.dtype: a = a.to(C.dtype.element_ty) b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk + acc += tl.dot(a, b, out_dtype=acc_type, allow_tf32=False) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -81,19 +418,114 @@ def mm_kernel( C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) + tl.store(C, acc, mask=mask) -_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] +@libentry() +@libtuner( + configs=runtime.get_tuned_config("mm"), + # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. + key=["M", "N", "K", "stride_am", "stride_bk"], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"]) == 0, + } +) +@triton.jit +def mm_kernel_general( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + acc_type: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + # swizzle pid to make better use of the L2 cache + pid_m_, pid_n_ = tle.program_id(0), tle.program_id(1) + num_pid_m, num_pid_n = tl.num_programs(0), tl.num_programs(1) + pid_m, pid_n = tl.swizzle2d(pid_m_, pid_n_, num_pid_m, num_pid_n, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_type) + for k in range(0, tl.cdiv(K, BLOCK_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if a.dtype != b.dtype: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=acc_type, allow_tf32=False) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + tl.store(C, acc, mask=mask) + + +def mini_mm_scenario(a, b, l2_cache_size=40 * 1024 * 1024, cache_usage_threshold=0.8): + return ( + a.shape[0] <= 256 + and (a.numel() * a.element_size() + b.shape[0] * b.element_size()) + < l2_cache_size * cache_usage_threshold + ) + + +def streamk_scenario(a, b, M, N, K): + # TODO: this my change sometime according to the realbenchmark result + # Currently, the best configuration for streamk has only been tested on A100(capability[0] > 7). + # The optimal settings for other devices need to be determined through real testing. + capability = torch_device_fn.get_device_capability(device_id) + return ( + capability[0] > 7 + and a.dtype in [torch.float16] + and b.dtype in [torch.float16] + and M > 1024 + and N > 1024 + and K > M * 10 + ) + + +def two_stages_splitk_mm_scenario(M, N, K): + return (M < 32 or N < 32) and (K > M * 10 or K > N * 10) def get_higher_dtype(a, b): if a is b: return a - + _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] assert a in _ordered_datatypes assert b in _ordered_datatypes @@ -104,29 +536,169 @@ def get_higher_dtype(a, b): return a -def mm(a, b): - logging.debug("GEMS MM") - device = a.device - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c_dtype = get_higher_dtype(a.dtype, b.dtype) - c = torch.empty((M, N), device=device, dtype=c_dtype) - dot_out_dtype = tl.float32 +def streamk_mm(a, b, c, M, N, K, c_dtype, acc_type, sm_count=108): + # TODO: profile to different settings for different chip + if b.stride(0) == 1: + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + num_stages = 3 + num_warps = 8 + else: + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 + num_stages = 3 + num_warps = 16 + + GROUP_M = 8 + number_blocks_m = triton.cdiv(M, BLOCK_M) + number_blocks_n = triton.cdiv(N, BLOCK_N) + + total_tiles = number_blocks_m * number_blocks_n + iters_per_tile = triton.cdiv(K, BLOCK_K) + tiles_per_wave = sm_count + + # tiles that would executed in the last wave in general situation. + # and this is the tiles that we are going to adopt streamk) + total_tiles_streamk = total_tiles % tiles_per_wave + # mini wave + total_iters_streamk = total_tiles_streamk * iters_per_tile + total_full_tiles_streamk = total_iters_streamk // tiles_per_wave + total_partial_tiles_streamk = total_iters_streamk % tiles_per_wave + + locks = torch.zeros((total_tiles_streamk,), device=a.device, dtype=torch.int32) + + with torch_device_fn.device(a.device): + first_wave[(tiles_per_wave,)]( + a, + b, + c, + M, + N, + K, + locks, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=acc_type, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + + classic_tiles_mm[(total_tiles - total_tiles_streamk,)]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=acc_type, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + return c + + +def splitk_mm(a, b, c, M, N, K, c_dtype, acc_type): + logging.debug("GEMS MM (SPLITK)") + GROUP_K_LENGTH = 1024 + SPLIT_K = triton.cdiv(K, GROUP_K_LENGTH) + # TODO: float32 or c_dtype + multi_c = torch.empty((SPLIT_K, M, N), device=a.device, dtype=c_dtype) + # 1st kernel: compute partial results + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]) * SPLIT_K, + ) + grid2 = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + with torch_device_fn.device(a.device): + mm_kernel_with_grouped_k[grid]( + a, + b, + multi_c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + multi_c.stride(0), + multi_c.stride(1), + multi_c.stride(2), + acc_type=acc_type, + SPLIT_K=SPLIT_K, + GROUP_K_LENGTH=GROUP_K_LENGTH, + ) + # return torch.sum(multi_c, dim=0) + # 2nd kernel: merge partial results + group_merge_kernel[grid2]( + multi_c, + c, + SPLIT_K, + M, + N, + multi_c.stride(0), + multi_c.stride(1), + multi_c.stride(2), + acc_type=acc_type, + ) + return c + + +def iobound_mm(a, b, c, M, N, K, acc_type): + logging.debug("GEMS MM (iobound)") # launch kernel grid = lambda META: ( triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - META["SPLIT_K"], ) with torch_device_fn.device(a.device): - mm_kernel[grid]( + mm_kernel_iobound[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + acc_type=acc_type, + ) + return c + + +def general_mm(a, b, c, M, N, K, acc_type): + logging.debug("GEMS MM (general)") + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + with torch_device_fn.device(a.device): + mm_kernel_general[grid]( a, b, c, @@ -139,7 +711,33 @@ def mm(a, b): b.stride(1), c.stride(0), c.stride(1), - dot_out_dtype=dot_out_dtype, + acc_type=acc_type, GROUP_M=8, ) return c + + +def mm(a, b): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c_dtype = get_higher_dtype(a.dtype, b.dtype) + c = torch.empty((M, N), device=device, dtype=c_dtype) + acc_type = tl.float32 + + if mini_mm_scenario(a, b, L2_CACHE_SIZE, CACHE_USAGE_THRESHOLD): + return iobound_mm(a, b, c, M, N, K, acc_type) + elif streamk_scenario(a, b, M, N, K): + return streamk_mm(a, b, c, M, N, K, c_dtype, acc_type, sm_count=SM_COUNT) + elif two_stages_splitk_mm_scenario(M, N, K): + return splitk_mm(a, b, c, M, N, K, c_dtype, acc_type) + else: + return general_mm(a, b, c, M, N, K, acc_type) diff --git a/src/flag_gems/ops/mm_experimental/mm_streamk_experimental.py b/src/flag_gems/ops/mm_experimental/mm_streamk_experimental.py new file mode 100644 index 000000000..6a5b09a57 --- /dev/null +++ b/src/flag_gems/ops/mm_experimental/mm_streamk_experimental.py @@ -0,0 +1,330 @@ +import logging + +import torch +import triton +import triton.language as tl + + +@triton.jit() +def swizzle_tile( + tile_id, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit( + do_not_specialize=["full", "remaining", "iters_per_tile", "start_iter", "end_iter"] +) +def mac_loop( + A, + B, + C, + P, + M, + N, + K, + locks, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + full, + remaining, + iters_per_tile, + start_iter, + end_iter, + ACC_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + # where are we in the grid + pid = tl.program_id(0) + tile_id = start_iter // iters_per_tile + + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + if stride_am == 1: + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if stride_bk == 1: + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + # pointers + A_ptr = ( + A + + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + + BLOCK_K * stride_ak * (start_iter % iters_per_tile) + ) + B_ptr = ( + B + + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + BLOCK_K * stride_bk * (start_iter % iters_per_tile) + ) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + # TODO: when k is not even, we need to load the data from A and B with mask. + a = tl.load(A_ptr) + b = tl.load(B_ptr) + # acc += tl.dot(a, b) + acc += tl.dot(a, b, out_dtype=ACC_TYPE, allow_tf32=False) + A_ptr += BLOCK_K * stride_ak + B_ptr += BLOCK_K * stride_bk + + rm1 = tl.arange(0, BLOCK_M) + rn1 = tl.arange(0, BLOCK_N) + + # the first situation: not the starting parts. only need to store the data on P + if start_iter % iters_per_tile != 0: + P_ptr = P + pid * BLOCK_M * BLOCK_N + (rm1[:, None] * BLOCK_N + rn1[None, :]) + tl.store(P_ptr, acc, cache_modifier=".cg") + # tl.debug_barrier() + tl.atomic_xchg(locks + pid, 1) + else: # the first part of certain grids. shoud read datas and merge datas + next_pid = pid + 1 + stop_loading_iter = start_iter + iters_per_tile + end = end_iter + while end < stop_loading_iter: + while tl.atomic_cas(locks + next_pid, 1, 1) != 1: + pass + P_ptr = ( + P + + next_pid * BLOCK_M * BLOCK_N + + (rm1[:, None] * BLOCK_N + rn1[None, :]) + ) + acc += tl.load(P_ptr, cache_modifier=".cg") + end += full + (next_pid < remaining) + next_pid += 1 + + # acc = acc.to(C.dtype.element_ty) # + C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, acc, mask=mask) + + +@triton.jit() +def first_wave( + A, + B, + C, + P, + M, + N, + K, + locks, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + full, + remaining, + iters_per_tile, + ACC_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) # FROM 0 TO SM_COUNT + start_iter = pid * full + tl.minimum(pid, remaining) + last_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining) + + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + # iterate over K axis, M/N may change during iteration, so we need to re-calculate the end_iter + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + mac_loop( + A, + B, + C, + P, + M, + N, + K, + locks, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + full, + remaining, + iters_per_tile, + start_iter, + end_iter, + ACC_TYPE, + BLOCK_M, + BLOCK_N, + BLOCK_K, + GROUP_M, + ) + start_iter = end_iter + + +@triton.jit() +def classic_mm( + A, + B, + C, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + if stride_am == 1: + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if stride_bk == 1: + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + # acc += tl.dot(a, b) + acc += tl.dot(a, b, out_dtype=ACC_TYPE, allow_tf32=False) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + tl.store(C, acc) + + +def streamk_mm(a, b, c, M, N, K, c_dtype, acc_type, sm_count=108): + # TODO: change the hard code to tuning config + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 + num_stages = 3 + num_warps = 8 + + GROUP_M = 8 + number_blocks_m = triton.cdiv(M, BLOCK_M) + number_blocks_n = triton.cdiv(N, BLOCK_N) + + total_tiles = number_blocks_m * number_blocks_n + iters_per_tile = triton.cdiv(K, BLOCK_K) + tiles_per_wave = sm_count + + number_cooperative_tiles = total_tiles % tiles_per_wave + # mini wave + total_iters_streamk = number_cooperative_tiles * iters_per_tile + tiles_per_pid = total_iters_streamk // tiles_per_wave + tile_remaining = total_iters_streamk % tiles_per_wave + + locks = torch.zeros((tiles_per_wave,), device=a.device, dtype=torch.int32) + P = torch.zeros( + (tiles_per_wave, BLOCK_M, BLOCK_N), device=a.device, dtype=torch.float32 + ) + # with torch_device_fn.device(a.device): + k1 = first_wave[(tiles_per_wave,)]( + a, + b, + c, + P, + M, + N, + K, + locks, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + full=tiles_per_pid, + remaining=tile_remaining, + iters_per_tile=iters_per_tile, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=acc_type, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + logging.DEBUG(f"{k1.n_regs} registers used, {k1.n_spills} spills") + logging.DEBUG(f"shared memory: {k1.metadata.shared} bytes") + + k2 = classic_mm[(total_tiles - number_cooperative_tiles,)]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_tiles_streamk=number_cooperative_tiles, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=acc_type, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + logging.DEBUG(f"{k2.n_regs} registers used, {k2.n_spills} spills") + logging.DEBUG(f"shared memory: {k2.metadata.shared} bytes") + return c diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 4c3145c28..c8094d4d1 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -79,10 +79,6 @@ def index_select_heur_block_n(args): return max(m, 16) -def mm_heur_even_k(args): - return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 - - def rand_heur_block(args): if args["N"] <= 512: return 512 @@ -244,9 +240,6 @@ def batch_norm_heur_block_n(args): "BLOCK_M": index_select_heur_block_m, "BLOCK_N": index_select_heur_block_n, }, - "mm": { - "EVEN_K": mm_heur_even_k, - }, "rand": { "BLOCK": rand_heur_block, "num_warps": rand_heur_num_warps, diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 879d9474a..beb8814e6 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -142,124 +142,354 @@ log_softmax: BLOCK_M: 32 BLOCK_N: 512 num_warps: 8 +mm_iobound: +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 16, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 32} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 32, BLOCK_K: 64} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 32} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 2 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 3 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 4 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 5 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 64, BLOCK_K: 64} + num_stages: 6 + num_warps: 2 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 32} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 128, BLOCK_K: 64} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 32} + num_stages: 6 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 2 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 3 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 4 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 5 + num_warps: 4 +- META: {BLOCK_M: 32, BLOCK_N: 256, BLOCK_K: 64} + num_stages: 6 + num_warps: 4 mm: +- META: # for largek + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 128 + num_stages: 3 + num_warps: 8 - META: BLOCK_M: 128 BLOCK_N: 256 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 3 num_warps: 8 - META: BLOCK_M: 256 BLOCK_N: 128 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 3 num_warps: 8 - META: BLOCK_M: 256 BLOCK_N: 64 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 256 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 128 BLOCK_N: 128 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 128 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 128 BLOCK_N: 32 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 32 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 5 num_warps: 2 - META: BLOCK_M: 128 BLOCK_N: 256 BLOCK_K: 32 - SPLIT_K: 1 num_stages: 3 num_warps: 8 - META: BLOCK_M: 256 BLOCK_N: 128 BLOCK_K: 128 - SPLIT_K: 1 num_stages: 3 num_warps: 8 - META: BLOCK_M: 256 BLOCK_N: 64 BLOCK_K: 128 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 256 BLOCK_K: 128 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 128 BLOCK_N: 128 BLOCK_K: 128 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 128 BLOCK_N: 64 BLOCK_K: 64 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 128 BLOCK_K: 64 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 128 BLOCK_N: 32 BLOCK_K: 64 - SPLIT_K: 1 num_stages: 4 num_warps: 4 - META: BLOCK_M: 64 BLOCK_N: 32 BLOCK_K: 64 - SPLIT_K: 1 num_stages: 5 num_warps: 2 softmax_non_inner: diff --git a/tests/test_blas_ops.py b/tests/test_blas_ops.py index a7773f011..3eea0d584 100644 --- a/tests/test_blas_ops.py +++ b/tests/test_blas_ops.py @@ -57,9 +57,13 @@ def test_accuracy_bmm(M, N, K, dtype): @pytest.mark.mm @pytest.mark.parametrize("M, N, K", MNK_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_mm(M, N, K, dtype): +@pytest.mark.parametrize("b_column_major", [True, False]) +def test_accuracy_mm(M, N, K, dtype, b_column_major): mat1 = torch.randn((M, K), dtype=dtype, device=flag_gems.device) - mat2 = torch.randn((K, N), dtype=dtype, device=flag_gems.device) + if b_column_major: + mat2 = torch.randn((N, K), dtype=dtype, device=flag_gems.device).t() + else: + mat2 = torch.randn((K, N), dtype=dtype, device=flag_gems.device) ref_mat1 = to_reference(mat1, True) ref_mat2 = to_reference(mat2, True)