diff --git a/src/flag_gems/runtime/backend/_amd/__init__.py b/src/flag_gems/runtime/backend/_amd/__init__.py new file mode 100644 index 000000000..4a3494291 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/__init__.py @@ -0,0 +1,10 @@ +from backend_utils import VendorInfoBase # noqa: E402 + +vendor_info = VendorInfoBase( + vendor_name="amd", device_name="cuda", device_query_cmd="amd-smi" +) + +CUSTOMIZED_UNUSED_OPS = () + + +__all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/_amd/fused/__init__.py b/src/flag_gems/runtime/backend/_amd/fused/__init__.py new file mode 100644 index 000000000..cda76669f --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/fused/__init__.py @@ -0,0 +1,5 @@ +from .concat_and_cache_mla import concat_and_cache_mla + +__all__ = [ + "concat_and_cache_mla", +] diff --git a/src/flag_gems/runtime/backend/_amd/fused/concat_and_cache_mla.py b/src/flag_gems/runtime/backend/_amd/fused/concat_and_cache_mla.py new file mode 100644 index 000000000..2551d97c2 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/fused/concat_and_cache_mla.py @@ -0,0 +1,175 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + +# enum Fp8KVCacheDataType +FP8_KV_CACHE_DATA_TYPE_AUTO = tl.constexpr(0) +FP8_KV_CACHE_DATA_TYPE_FP8E4M3 = tl.constexpr(1) +FP8_KV_CACHE_DATA_TYPE_FP8E5M2 = tl.constexpr(2) + + +@libentry() +@triton.jit +def concat_and_cache_mla_kernel( + # pointers + kv_c_ptr, # in, [num_tokens, kv_lora_rank] + k_pe_ptr, # in, [num_tokens, pe_dim] + kv_cache_ptr, # out, [num_blocks, block_size, kv_lora_rank + pe_dim] + slot_mapping_ptr, # in, [num_tokens] + # strides + block_stride, + entry_stride, + kv_c_stride, + k_pe_stride, + # dims + kv_lora_rank, + pe_dim, + block_size, # kv cache block size + scale_ptr, + # data type + kv_dtype: tl.constexpr, # one of Fp8KVCacheDataType + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + slot_idx = tl.load(slot_mapping_ptr + token_idx) + + # Skip padded tokens + if slot_idx < 0: + return + + # Calculate cache position + block_id = slot_idx // block_size + block_offset = slot_idx % block_size + cache_base = block_id * block_stride + block_offset * entry_stride + + # Preload scale if needed + if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: + scale_val = tl.load(scale_ptr) + + # Process kv_c section + for i in range(0, kv_lora_rank, BLOCK_SIZE): + idx = i + tl.arange(0, BLOCK_SIZE) + mask = idx < kv_lora_rank + + src_ptr = kv_c_ptr + token_idx * kv_c_stride + idx + dst_ptr = kv_cache_ptr + cache_base + idx + + val = tl.load(src_ptr, mask=mask, other=0) + + if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: + if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: + val = (val / scale_val).to(tl.float8e4b8) + elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: + val = (val / scale_val).to(tl.float8e5b16) + val = val.to(tl.uint8, bitcast=True) + tl.store(dst_ptr, val, mask=mask) + + # Process k_pe section + for j in range(0, pe_dim, BLOCK_SIZE): + idx = j + tl.arange(0, BLOCK_SIZE) + mask = idx < pe_dim + + src_ptr = k_pe_ptr + token_idx * k_pe_stride + idx + dst_ptr = kv_cache_ptr + cache_base + kv_lora_rank + idx + + val = tl.load(src_ptr, mask=mask, other=0) + + if kv_dtype != FP8_KV_CACHE_DATA_TYPE_AUTO: + if kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E4M3: + val = (val / scale_val).to(tl.float8e4b8) + elif kv_dtype == FP8_KV_CACHE_DATA_TYPE_FP8E5M2: + val = (val / scale_val).to(tl.float8e5b16) + val = val.to(tl.uint8, bitcast=True) + tl.store(dst_ptr, val, mask=mask) + + +class ConcatAndCacheMla(torch.autograd.Function): + @staticmethod + def forward( + ctx, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, + ): + if kv_cache_dtype != "auto" and kv_cache.dtype != torch.uint8: + raise ValueError("For FP8 kv_cache must be uint8 dtype") + if kv_cache_dtype == "auto" and kv_cache.dtype != kv_c.dtype: + raise ValueError("For auto mode kv_cache must match input dtype") + + # Map string dtype to internal constants + kv_dtype_map = { + "auto": FP8_KV_CACHE_DATA_TYPE_AUTO, + "fp8": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, + "fp8e4m3": FP8_KV_CACHE_DATA_TYPE_FP8E4M3, + "fp8e5m2": FP8_KV_CACHE_DATA_TYPE_FP8E5M2, + } + kv_dtype = kv_dtype_map.get(kv_cache_dtype) + if kv_dtype is None: + raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}") + kv_dtype = int(kv_dtype) # tl.constexpr->int + + kv_lora_rank = kv_c.size(1) + pe_dim = k_pe.size(1) + num_tokens = slot_mapping.size(0) + + # make sure `scale` is a scalar tensor + if scale.numel() != 1: + scale = scale.view(1) + + # make sure all tensors are on the same device + device = kv_c.device + k_pe = k_pe.to(device) + kv_cache = kv_cache.to(device) + slot_mapping = slot_mapping.to(device) + scale = scale.to(device) + + # configure kernel launch + grid = (num_tokens,) + BLOCK_SIZE = min(kv_lora_rank, 512) + + assert kv_cache.dim() == 3, "kv_cache must be a 3D tensor" + assert ( + kv_cache.size(2) == kv_lora_rank + pe_dim + ), "kv_cache's last dimension must match kv_lora_rank + pe_dim" + with torch_device_fn.device(device): + concat_and_cache_mla_kernel[grid]( + kv_c, + k_pe, + kv_cache, + slot_mapping, + kv_cache.stride(0), # block_stride + kv_cache.stride(1), # entry_stride + kv_c.stride(0), # kv_c_stride + k_pe.stride(0), # k_pe_stride + kv_lora_rank, + pe_dim, + kv_cache.size(1), # kv cache block_size + scale, + kv_dtype=kv_dtype, + BLOCK_SIZE=BLOCK_SIZE, + ) + return kv_cache + + +def concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + logger.debug("GEMS CONCAT_AND_CACHE_MLA") + return ConcatAndCacheMla.apply( + kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale + ) diff --git a/src/flag_gems/runtime/backend/_amd/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_amd/heuristics_config_utils.py new file mode 100644 index 000000000..326c5d3d2 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/heuristics_config_utils.py @@ -0,0 +1,330 @@ +import torch +import triton + + +def simple_elementwise_blocksize_heur(args): + return 1024 + + +def argmax_heur_block_m(args): + return 4 if args["M"] < 4096 else 8 + + +def argmax_heur_block_n(args): + return min(4096, triton.next_power_of_2(args["N"])) + + +def argmin_heur_block_m(args): + return 4 if args["M"] < 4096 else 8 + + +def argmin_heur_block_n(args): + return min(4096, triton.next_power_of_2(args["N"])) + + +def bmm_heur_divisible_m(args): + return args["M"] % args["TILE_M"] == 0 + + +def bmm_heur_divisible_n(args): + return args["N"] % args["TILE_N"] == 0 + + +def bmm_heur_divisible_k(args): + return args["K"] % args["TILE_K"] == 0 + + +def dropout_heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def dropout_heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +def exponential_heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def exponential_heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +def gather_heur_block_m(args): + return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) + + +def gather_heur_block_n(args): + return min(2048, triton.next_power_of_2(args["N"])) + + +def index_select_heur_block_m(args): + return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) + + +def index_select_heur_block_n(args): + m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) + return max(m, 16) + + +def mm_heur_even_k(args): + return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 + + +def rand_heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def rand_heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +def randn_heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def randn_heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +def softmax_heur_tile_k(args): + MAX_TILE_K = 8192 + NUM_SMS = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + tile_k = 1 + upper_bound = min(args["K"], MAX_TILE_K) + while tile_k <= upper_bound: + num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) + num_waves = num_blocks / NUM_SMS + if (num_waves > 1) and (tile_k * 2 <= upper_bound): + tile_k *= 2 + else: + break + return tile_k + + +def softmax_heur_tile_n_non_inner(args): + return triton.cdiv(8192, args["TILE_K"]) + + +def softmax_heur_one_tile_per_cta(args): + return args["TILE_N"] >= args["N"] + + +def softmax_heur_num_warps_non_inner(args): + tile_size = args["TILE_N"] * args["TILE_K"] + if tile_size < 2048: + return 4 + elif tile_size < 4096: + return 8 + else: + return 16 + + +def softmax_heur_tile_n_inner(args): + if args["N"] <= (32 * 1024): + return triton.next_power_of_2(args["N"]) + else: + return 4096 + + +def softmax_heur_num_warps_inner(args): + tile_size = args["TILE_N"] + if tile_size < 2048: + return 4 + elif tile_size < 4096: + return 8 + else: + return 16 + + +def softmax_heur_tile_n_bwd_non_inner(args): + return max(1, 1024 // args["TILE_K"]) + + +def softmax_heru_tile_m(args): + return max(1, 1024 // args["TILE_N"]) + + +def uniform_heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def uniform_heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +def var_mean_heur_block_n(args): + return triton.next_power_of_2(args["BLOCK_NUM"]) + + +def upsample_nearest2d_SAME_H(args): + return args["OH"] == args["IH"] + + +def upsample_nearest2d_SAME_W(args): + return args["OW"] == args["IW"] + + +def upsample_nearest2d_USE_INT32_IDX(args): + return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX + + +def batch_norm_heur_block_m(args): + return min(2048, triton.next_power_of_2(args["batch_dim"])) + + +def batch_norm_heur_block_n(args): + # A maximum of 16384 elements are loaded at once. + BLOCK_M = batch_norm_heur_block_m(args) + BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) + return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) + + +def vdot_heur_block_size(args): + n = args["n_elements"] + if n < 1024: + return 32 + elif n < 8192: + return 256 + else: + return 1024 + + +HEURISTICS_CONFIGS = { + "argmax": { + "BLOCK_M": argmax_heur_block_m, + "BLOCK_N": argmax_heur_block_n, + }, + "argmin": { + "BLOCK_M": argmin_heur_block_m, + "BLOCK_N": argmin_heur_block_n, + }, + "bmm": { + "DIVISIBLE_M": bmm_heur_divisible_m, + "DIVISIBLE_N": bmm_heur_divisible_n, + "DIVISIBLE_K": bmm_heur_divisible_k, + }, + "dropout": { + "BLOCK": dropout_heur_block, + "num_warps": dropout_heur_num_warps, + }, + "exponential_": { + "BLOCK": exponential_heur_block, + "num_warps": exponential_heur_num_warps, + }, + "gather": { + "BLOCK_M": gather_heur_block_m, + "BLOCK_N": gather_heur_block_n, + }, + "index_select": { + "BLOCK_M": index_select_heur_block_m, + "BLOCK_N": index_select_heur_block_n, + }, + "mm": { + "EVEN_K": mm_heur_even_k, + }, + "rand": { + "BLOCK": rand_heur_block, + "num_warps": rand_heur_num_warps, + }, + "randn": { + "BLOCK": randn_heur_block, + "num_warps": randn_heur_num_warps, + }, + "softmax_non_inner": { + "TILE_K": softmax_heur_tile_k, + "TILE_N": softmax_heur_tile_n_non_inner, + "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, + "num_warps": softmax_heur_num_warps_non_inner, + }, + "softmax_inner": { + "TILE_N": softmax_heur_tile_n_inner, + "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, + "num_warps": softmax_heur_num_warps_inner, + }, + "softmax_backward_non_inner": { + "TILE_N": softmax_heur_tile_n_bwd_non_inner, + "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, + }, + "softmax_backward_inner": { + "TILE_M": softmax_heru_tile_m, + "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, + }, + "uniform": { + "BLOCK": uniform_heur_block, + "num_warps": uniform_heur_num_warps, + }, + "upsample_nearest2d": { + "SAME_H": upsample_nearest2d_SAME_H, + "SAME_W": upsample_nearest2d_SAME_W, + "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, + }, + "var_mean": { + "BLOCK_N": var_mean_heur_block_n, + }, + "batch_norm": { + "BLOCK_M": batch_norm_heur_block_m, + "BLOCK_N": batch_norm_heur_block_n, + }, + "vdot": { + "BLOCK_SIZE": vdot_heur_block_size, + }, + "mha_varlen_prefill": { + "BLOCK_M": lambda args: 128, + "BLOCK_N": lambda args: 32, + "num_warps": lambda args: 4, + "num_stages": lambda args: 3, + }, + "mha_varlen_decode": { + "BLOCK_M": lambda args: 16, + "BLOCK_N": lambda args: 64, + "num_warps": lambda args: 4, + "num_stages": lambda args: 3, + }, + "elementwise_generic": { + "BLOCK_SIZE": simple_elementwise_blocksize_heur, + "num_warps": lambda args: 8, + }, +} diff --git a/src/flag_gems/runtime/backend/_amd/ops/__init__.py b/src/flag_gems/runtime/backend/_amd/ops/__init__.py new file mode 100644 index 000000000..5a49919a9 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/ops/__init__.py @@ -0,0 +1,4 @@ +from .add import add +from .gelu import gelu + +__all__ = ["add", "gelu"] diff --git a/src/flag_gems/runtime/backend/_amd/ops/add.py b/src/flag_gems/runtime/backend/_amd/ops/add.py new file mode 100644 index 000000000..0eaa93864 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/ops/add.py @@ -0,0 +1,51 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + print("\n.......test for mutibackend specific add........\n") + output = torch.empty_like(x) + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output diff --git a/src/flag_gems/runtime/backend/_amd/ops/gelu.py b/src/flag_gems/runtime/backend/_amd/ops/gelu.py new file mode 100644 index 000000000..904581a93 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/ops/gelu.py @@ -0,0 +1,87 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import pointwise_dynamic, tl_extra_shim + +logger = logging.getLogger(__name__) +erf = tl_extra_shim.erf +exp = tl_extra_shim.exp +pow = tl_extra_shim.pow +tanh = tl_extra_shim.tanh + + +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) +@triton.jit +def gelu_none(x): + scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) + output = 0.5 * x * (1 + erf(x * scale)) + return output + + +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) +@triton.jit +def gelu_tanh(x): + output = ( + 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2)))) + ) + return output + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def gelu_backward_none(x, dy): + scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) + scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) + x_fp32 = x.to(tl.float32) + dydx = ( + scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2)) + + 0.5 * erf(scale1 * x_fp32) + + 0.5 + ) + dx = dydx * dy + return dx + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def gelu_backward_tanh(x, dy): + x_fp32 = x.to(tl.float32) + # 0.79788456 = math.sqrt(2 / math.pi) + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2))) + dydx = 0.5 * x * ( + (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2)) + ) + 0.5 * (1 + tanh_out) + dx = dydx * dy + return dx + + +class Gelu(torch.autograd.Function): + @staticmethod + def forward(ctx, A, approximate): + logger.debug("GEMS GELU FORWARD") + if approximate == "tanh": + out = gelu_tanh(A) + else: + out = gelu_none(A) + ctx.save_for_backward(A) + ctx.approximate = approximate + return out + + @staticmethod + def backward(ctx, out_grad): + logger.debug("GEMS GELU BACKWARD") + (inp,) = ctx.saved_tensors + approximate = ctx.approximate + if approximate == "tanh": + in_grad = gelu_backward_tanh(inp, out_grad) + else: + in_grad = gelu_backward_none(inp, out_grad) + return in_grad, None + + +def gelu(A, *, approximate="none"): + print("\n.......test for mutibackend specific gelu........\n") + return Gelu.apply(A, approximate) diff --git a/src/flag_gems/runtime/backend/_amd/tune_configs.yaml b/src/flag_gems/runtime/backend/_amd/tune_configs.yaml new file mode 100644 index 000000000..1ce27e102 --- /dev/null +++ b/src/flag_gems/runtime/backend/_amd/tune_configs.yaml @@ -0,0 +1,906 @@ +attention: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + PRE_LOAD_V: pre_load_v + num_warps: warps + num_stages: stages + block_m: + - 32 + - 64 + - 128 + block_n: + - 32 + - 64 + - 128 + pre_load_v: + - true + - false + warps: + - 4 + - 8 + stages: + - 2 + - 3 + - 4 +bmm: + - META: + TILE_M: 32 + TILE_N: 32 + TILE_K: 32 + GROUP_M: 1 + num_warps: 4 + num_stages: 2 + - META: + TILE_M: 64 + TILE_N: 32 + TILE_K: 32 + GROUP_M: 2 + num_warps: 4 + num_stages: 2 + - META: + TILE_M: 64 + TILE_N: 64 + TILE_K: 32 + GROUP_M: 2 + num_warps: 4 + num_stages: 2 + - META: + TILE_M: 128 + TILE_N: 32 + TILE_K: 32 + GROUP_M: 2 + num_warps: 4 + num_stages: 2 + - META: + TILE_M: 128 + TILE_N: 64 + TILE_K: 32 + GROUP_M: 2 + num_warps: 4 + num_stages: 2 + - META: + TILE_M: 128 + TILE_N: 128 + TILE_K: 32 + GROUP_M: 2 + num_warps: 4 + num_stages: 2 +log_softmax: + - META: + BLOCK_M: 8 + BLOCK_N: 256 + num_warps: 8 + - META: + BLOCK_M: 16 + BLOCK_N: 512 + num_warps: 8 + - META: + BLOCK_M: 32 + BLOCK_N: 512 + num_warps: 8 +mm: + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 2 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 128 + num_stages: 2 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 2 + num_warps: 4 +softmax_non_inner: + - META: + TILE_K: 32 + - META: + TILE_K: 64 + - META: + TILE_K: 128 + - META: + TILE_K: 256 + - META: + TILE_K: 1024 +softmax_inner: + - META: + TILE_N: 32 + - META: + TILE_N: 64 + - META: + TILE_N: 128 + - META: + TILE_N: 256 + - META: + TILE_N: 1024 +addmm: + - META: + BLOCK_SIZE_M: 128 + BLOCK_SIZE_N: 256 + BLOCK_SIZE_K: 64 + num_stages: 2 + num_warps: 8 + - META: + BLOCK_SIZE_M: 64 + BLOCK_SIZE_N: 256 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_SIZE_M: 128 + BLOCK_SIZE_N: 128 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_SIZE_M: 128 + BLOCK_SIZE_N: 64 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_SIZE_M: 64 + BLOCK_SIZE_N: 128 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_SIZE_M: 128 + BLOCK_SIZE_N: 32 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_SIZE_M: 64 + BLOCK_SIZE_N: 32 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 2 + - META: + BLOCK_SIZE_M: 32 + BLOCK_SIZE_N: 64 + BLOCK_SIZE_K: 32 + num_stages: 2 + num_warps: 2 +cross_entropy_loss: + - META: + BLOCK_C: 256 + BLOCK_D: 1 + num_warps: 4 + - META: + BLOCK_C: 256 + BLOCK_D: 4 + num_warps: 4 + - META: + BLOCK_C: 256 + BLOCK_D: 16 + num_warps: 4 + - META: + BLOCK_C: 512 + BLOCK_D: 1 + num_warps: 4 + - META: + BLOCK_C: 512 + BLOCK_D: 4 + num_warps: 4 + - META: + BLOCK_C: 512 + BLOCK_D: 16 + num_warps: 4 + - META: + BLOCK_C: 1024 + BLOCK_D: 1 + num_warps: 4 + - META: + BLOCK_C: 1024 + BLOCK_D: 4 + num_warps: 4 + - META: + BLOCK_C: 1024 + BLOCK_D: 16 + num_warps: 4 +masked_fill: + - gen: true + param_map: + META: + BLOCK_SIZE: block_n + num_warps: warps + block_n: + - 1024 + - 2048 + - 4096 + warps: + - 4 + - 8 + - 16 +naive_reduction: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + num_warps: 4 + block_m: + - 1 + - 2 + - 4 + - 8 + block_n: + - 1024 +layer_norm_persistent: + - gen: true + param_map: + META: {} + num_warps: warps + warps: + - 4 + - 8 + - 16 +layer_norm_loop: + - gen: true + param_map: + META: + TILE_N: tile_n + num_warps: warps + warps: + - 4 + - 8 + - 16 + tile_n: + - 1024 + - 2048 + - 4096 + - 8192 +layer_norm_backward: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: block_r + BLOCK_COL_SIZE: 2048 + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_r: + - 1 + - 2 + - 4 + - 8 +weight_bias_backward: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: 128 + BLOCK_COL_SIZE: block_c + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_c: + - 1 + - 2 + - 4 + - 8 +masked_select: + - gen: true + param_map: + META: + BLOCK_SIZE: blocks + num_warps: warps + warps: + - 4 + - 8 + - 16 + - 32 + blocks: + - 256 + - 512 + - 1024 + - 2048 + - 4096 + +instancenorm: + - gen: true + param_map: + META: {} + num_warps: warps + warps: + - 4 + - 8 + - 16 + +instance_norm_loop: + - gen: true + param_map: + META: + TILE_N: tile_n + num_warps: warps + warps: + - 4 + - 8 + - 16 + tile_n: + - 1024 + - 2048 + - 4096 + - 8192 +instance_norm_backward: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: block_m + BLOCK_COL_SIZE: 2048 + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + +instance_norm_weight_bias_backward: + - gen: true + param_map: + META: + BLOCK_BATCH_SIZE: block_m + BLOCK_COL_SIZE: 2048 + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + +count_nonzero: + - gen: true + param_map: + META: + BLOCK_SIZE: block_m + num_warps: 4 + block_m: + - 1024 + - 2048 + - 4096 +cross_entropy_loss_sum_and_scale: + - gen: true + param_map: + META: + BLOCK_N: block_n + block_n: + - 64 + - 256 + - 1024 + +upsample_nearest2d: + - gen: true + param_map: + META: + BLOCK_SIZE: block_n + num_warps: warps + block_n: [1024, 2048] + warps: [4, 8] + +upsample_bicubic2d_aa: + - gen: true + param_map: + META: + BLOCK_X: block_x + BLOCK_Y: block_y + num_warps: warps + block_x: [512, 256, 128, 64] + block_y: [2, 1] + warps: [4, 8] + +mv: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + num_warps: warps + num_stages: stages + warps: + - 4 + - 8 + block_m: + - 32 + - 64 + - 128 + block_n: + - 1 + - 2 + - 4 + - 8 + stages: + - 3 + - 4 +nonzero: + - gen: true + param_map: + META: + BLOCK_SIZE: blocks + num_warps: warps + num_stages: 2 + warps: + - 4 + - 8 + - 16 + - 32 + blocks: + - 256 + - 512 + - 1024 + - 2048 + - 4096 + - 8192 +randperm: + - gen: true + param_map: + META: {} + num_warps: warps + warps: + - 4 + - 8 + - 16 +vstack: + - gen: true + param_map: + META: + BLOCK_SIZE: blocks + num_warps: warps + warps: + - 4 + - 8 + - 16 + - 32 + blocks: + - 512 + - 1024 + - 2048 + - 4096 +triu: + - gen: true + param_map: + META: + M_BLOCK_SIZE: 1 + N_BLOCK_SIZE: 2048 + num_warps: warps + warps: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 +triu_batch: + - gen: true + param_map: + META: + BATCH_BLOCK_SIZE: 1 + MN_BLOCK_SIZE: 512 + num_warps: warps + warps: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 +weight_norm_kernel_last: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: block_m + BLOCK_COL_SIZE: block_n + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 512 + - 1024 + - 2048 + block_n: + - 1 + - 2 + - 4 + - 8 + - 32 +weight_norm_kernel_first: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: block_m + BLOCK_COL_SIZE: block_n + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_n: + - 512 + - 1024 + - 2048 +weight_norm_kernel: + - gen: true + param_map: + META: + BLOCK_ROW_SIZE: block_m + BLOCK_COL_SIZE: block_n + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_n: + - 256 + - 512 + - 1024 + - 2048 +vector_norm: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: 1024 + num_warps: 4 + block_m: + - 1 + - 2 + - 4 + - 8 +var_mean: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + num_warps: warps + block_m: + - 1 + - 2 + - 4 + - 8 + block_n: + - 1024 + - 2048 + warps: + - 4 + - 8 + - 16 + +conv2d_forward: + - META: + BLOCK_NI_HO_WO: 32 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 64 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 64 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 128 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 32 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 64 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 64 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 128 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 256 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 8 + - META: + BLOCK_NI_HO_WO: 256 + BLOCK_CO: 128 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 8 + - META: + BLOCK_NI_HO_WO: 256 + BLOCK_CO: 64 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 256 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 128 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 128 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_NI_HO_WO: 128 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 4 + - META: + BLOCK_NI_HO_WO: 64 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_stages: 2 + num_warps: 2 + +conv2d_backward_weight: + - META: + BLOCK_CI_HK_WK: 32 + BLOCK_CO: 32 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 64 + BLOCK_CO: 32 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 64 + BLOCK_CO: 64 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 128 + BLOCK_CO: 32 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 128 + BLOCK_CO: 64 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 128 + BLOCK_CO: 128 + BLOCK_NO: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_CI_HK_WK: 128 + BLOCK_CO: 256 + BLOCK_NO: 32 + num_stages: 2 + num_warps: 4 +batch_norm: + - gen: true + param_map: + META: {} + num_warps: warps + warps: + - 4 + - 8 + - 16 + +conv3d_forward: + - META: + BLOCK_NI_DO_HO_WO: 512 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 256 + BLOCK_CO: 32 + BLOCK_CI: 16 + num_warps: 2 + num_stages: 1 + - META: + BLOCK_NI_DO_HO_WO: 128 + BLOCK_CO: 32 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 128 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 2 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 128 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 128 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 256 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 256 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 1 + - META: + BLOCK_NI_DO_HO_WO: 256 + BLOCK_CO: 32 + BLOCK_CI: 16 + num_warps: 4 + num_stages: 1 + - META: + BLOCK_NI_DO_HO_WO: 128 + BLOCK_CO: 32 + BLOCK_CI: 32 + num_warps: 4 + num_stages: 2 + - META: + BLOCK_NI_DO_HO_WO: 256 + BLOCK_CO: 16 + BLOCK_CI: 16 + num_warps: 2 + num_stages: 1 + +kron: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + num_warps: warps + block_m: + - 1 + - 2 + - 4 + - 8 + block_n: + - 1024 + - 2048 + warps: + - 4 + - 8 + - 16 + +index_put: + - gen: true + param_map: + META: + BLOCK_SIZE0: block_size0 + BLOCK_SIZE1: block_size1 + block_size0: + - 1 + - 2 + - 4 + block_size1: + - 1024 + - 2048 + - 4096 + +index: + - gen: true + param_map: + META: + BLOCK_SIZE0: block_size0 + BLOCK_SIZE1: block_size1 + block_size0: + - 1 + - 2 + - 4 + block_size1: + - 1024 + - 2048 + - 4096 diff --git a/tests/test_quant.py b/tests/test_quant.py index 566d60cd8..c623c0322 100644 --- a/tests/test_quant.py +++ b/tests/test_quant.py @@ -60,7 +60,10 @@ def convert_fp8( dst: torch.Tensor, src: torch.Tensor, scale: float, kv_dtype: str ) -> None: if kv_dtype == "fp8": - dst_ = (src / scale).to(torch.float8_e4m3fn).view(dst.dtype) + if flag_gems.vendor_name == "amd": + dst_ = (src / scale).to(torch.float8_e4m3fnuz).view(dst.dtype) + else: + dst_ = (src / scale).to(torch.float8_e4m3fn).view(dst.dtype) dst.copy_(dst_) else: dst.copy_(src) @@ -146,7 +149,11 @@ def test_concat_and_cache_mla( convert_fp8( expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype ) - dtype = torch.float8_e4m3fn + dtype = ( + torch.float8_e4m3fn + if flag_gems.vendor_name != "amd" + else torch.float8_e4m3fnuz + ) if flag_gems.vendor_name == "mthreads": result_temp = to_reference(result_temp) # TODO: RuntimeError: Comparing