diff --git a/OperatorList.md b/OperatorList.md index 1f6702a5c..5471e7aaf 100644 --- a/OperatorList.md +++ b/OperatorList.md @@ -86,7 +86,7 @@ - nll_loss - nll_loss_forward - nll_loss_nd -- scaled_dot_product_attention +- _flash_attention_forward - upsample_nearest2d - _fft_c2r - _fft_r2c diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 8d6341283..fd62d3999 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -41,7 +41,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable), ("bitwise_or_.Scalar", bitwise_or_scalar_, Autograd.disable), ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable), - ("bmm", bmm, Autograd.disable), + # ("bmm", bmm, Autograd.disable), ("clamp", clamp, Autograd.disable), ("clamp_", clamp_, Autograd.disable), ("clamp.Tensor", clamp_tensor, Autograd.disable), @@ -200,8 +200,8 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("sum", sum, Autograd.disable), ("sum.dim_IntList", sum_dim, Autograd.disable), ( - "scaled_dot_product_attention", - scaled_dot_product_attention, + "_flash_attention_forward", + flash_attention_forward, Autograd.disable, ), ("all", all, Autograd.disable), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 76c0bf78b..fe2c581a6 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -7,7 +7,7 @@ from .arange import arange, arange_start from .argmax import argmax from .argmin import argmin -from .attention import scaled_dot_product_attention +from .attention import flash_attention_forward, scaled_dot_product_attention from .batch_norm import batch_norm from .bitwise_and import ( bitwise_and_scalar, @@ -348,6 +348,7 @@ "repeat_interleave_self_int", "vstack", "repeat_interleave_tensor", + "flash_attention_forward", "scaled_dot_product_attention", "conv2d", "conv1d", diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index b7c0f6ffe..27065f8cf 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -1,10 +1,12 @@ import logging +import math import torch import triton import triton.language as tl from flag_gems.runtime import torch_device_fn +from flag_gems.utils.random_utils import update_philox_state from .. import runtime @@ -398,3 +400,1354 @@ def scaled_dot_product_attention( HAS_ATTN_MASK=HAS_ATTN_MASK, # ) return o + + +# Following implementation is largely a porting of TriDao's Flash Attention to Triton. +# Major difference can be found in dropout where the input to RNG is determined only +# by the element index in the attention score matrix. In contrast, the CUDA flash-attn +# employs a dropout that assumes an implementation specific threadblock data layout. + + +@triton.jit +def u64_to_lohi(x): + return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) + + +@triton.jit +def u64_from_lohi(lo, hi): + return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) + + +@triton.jit +def philox_(seed, subsequence, offset): + kPhilox10A: tl.constexpr = 0x9E3779B9 + kPhilox10B: tl.constexpr = 0xBB67AE85 + k0, k1 = u64_to_lohi(seed.to(tl.uint64)) + c0, c1 = u64_to_lohi(offset.to(tl.uint64)) + c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) + + # pragma unroll + kPhiloxSA: tl.constexpr = 0xD2511F53 + kPhiloxSB: tl.constexpr = 0xCD9E8D57 + for _ in tl.static_range(6): + res0 = kPhiloxSA * c0.to(tl.uint64) + res1 = kPhiloxSB * c2.to(tl.uint64) + res0_x, res0_y = u64_to_lohi(res0) + res1_x, res1_y = u64_to_lohi(res1) + c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x + k0 += kPhilox10A + k1 += kPhilox10B + + res0 = kPhiloxSA * c0.to(tl.uint64) + res1 = kPhiloxSB * c2.to(tl.uint64) + res0_x, res0_y = u64_to_lohi(res0) + res1_x, res1_y = u64_to_lohi(res1) + c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x + + return c0, c1, c2, c3 + + +@triton.jit +def apply_dropout_mask( + P, + mask, + encode_dropout_in_sign_bit: tl.constexpr, +): + if encode_dropout_in_sign_bit: + P = tl.where(mask, -P, P) + else: + P = tl.where(mask, P * 0, P) + return P + + +@triton.jit +def apply_dropout( + P, + row_start, + col_start, + n_cols, + bid, + hid, + philox_seed, + philox_offset, + p_dropout_uint8: tl.constexpr, + encode_dropout_in_sign_bit: tl.constexpr, + NUM_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row_start = tl.multiple_of(row_start, BLOCK_M) + col_start = tl.multiple_of(col_start, BLOCK_N) + row = row_start + tl.arange(0, BLOCK_M)[:, None] + # Down scale col_idx by 4 + col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :] + + subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64) + + offset = philox_offset + bid * NUM_HEADS + hid + offset += subsequence * 0 + r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) + + r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N) + + mask = (r & 0xFF) >= p_dropout_uint8 + + P = apply_dropout_mask( + P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit + ) + return P + + +@triton.jit +def apply_mask( + S, + col_idx, + row_idx, + max_seqlen_q, + max_seqlen_k, + ws_left, + ws_right, + is_even_mn: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + alibi_slope: tl.constexpr = None, +): + need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn) + if need_mask: + # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! + col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) + col_rb = min(max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + ws_right) + + if has_alibi: + S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) + + if is_causal: + S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S) + + if is_local: + S = tl.where( + (col_idx[None, :] > col_rb[:, None]) + | (col_idx[None, :] < col_lb[:, None]), + float("-inf"), + S, + ) + + if (not is_local) & (not is_causal) & (not is_even_mn): + S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S) + + return S + + +@triton.jit +def softmax_rescale( + O_acc, + S, + row_max, + row_sum, + softmax_scale_log2e: tl.constexpr, + is_border: tl.constexpr, + # is_init: tl.constexpr +): + prev_max = row_max + row_max = tl.maximum(row_max, tl.max(S, 1)) + + if is_border: + cur_max = tl.where(row_max == float("-inf"), 0, row_max) + else: + cur_max = row_max + + p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) + row_sum *= p_scale + O_acc *= p_scale[:, None] + + max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e) + P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) + row_sum = row_sum + tl.sum(P, 1) + return O_acc, P, row_max, row_sum + + +def block_m_heuristic(headdim, is_dropout): + block_m = 128 if headdim <= 128 else 64 + return block_m + + +def block_n_heuristic(headdim, is_dropout): + block_n = 64 if headdim <= 64 else 32 + return block_n + + +def block_m_splitkv_heuristic(headdim): + return 128 if headdim <= 128 else 64 + + +def block_n_splitkv_heuristic(headdim): + return 64 if headdim <= 64 else 32 + + +def is_even_mn(M, N, BM, BN, WL, WR): + if M % BM == 0 and N % BN == 0: + if M % N == 0 or N % M == 0: + if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0): + return True + return False + + +@triton.heuristics( + values={ + "BLOCK_M": lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]), + "BLOCK_N": lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]), + "num_warps": lambda args: 4, + "num_stages": lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, + "PRE_LOAD_V": lambda args: False, + "IS_EVEN_MN": lambda args: is_even_mn( + args["seqlen_q"], + args["seqlen_k"], + args["BLOCK_M"], + args["BLOCK_N"], + args["ws_left"], + args["ws_right"], + ), + } +) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) +def flash_fwd_kernel( + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, + h: tl.constexpr, + hk: tl.constexpr, + pSlopes, + philox_seed, + philox_offset, + pdrop_u8, + rpdrop, + slopes_batch_stride, + HEAD_DIM: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + softmax_scale: tl.constexpr, + softmax_scale_log2e: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, + return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + IS_EVEN_MN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + blocks_per_split: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, +): + m_block = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) + + # We draw a minimum covering frame on the attention map that this CTA is assigned to process. + # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. + + col_min = 0 + if is_local: + col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) + if not IS_EVEN_MN: + # round left + col_min = (col_min // BLOCK_N) * BLOCK_N + + col_max = seqlen_k + if is_causal or is_local: + col_max += (m_block - num_m_blocks + 1) * BLOCK_M + if is_local: + col_max += ws_right + col_max = min(seqlen_k, col_max) + + if not IS_EVEN_MN: + # round right + col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N + + if (not is_causal) and (not is_local): + if IS_EVEN_MN: + masking_cols: tl.constexpr = 0 + else: + masking_cols: tl.constexpr = BLOCK_N + elif (is_causal | is_local) and IS_EVEN_MN: # causal implies ws_right is zero + masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N + else: + # local + masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N + + if is_dropout: + philox_seed = tl.load(philox_seed).to(tl.uint64) + philox_offset = tl.load(philox_offset).to(tl.uint64) + + if has_alibi: + alibi_offset = bid * slopes_batch_stride + hid + alibi_slope = tl.load(pSlopes + alibi_offset) + alibi_slope /= softmax_scale + else: + alibi_slope = 0.0 + + q_b_stride = tl.multiple_of(q_b_stride, HEAD_DIM * h) + Q_ptr += bid * q_b_stride + Q_ptr += hid * q_h_stride + row_start = m_block * BLOCK_M + row_idx = row_start + tl.arange(0, BLOCK_M) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + qmask = row_idx[:, None] < seqlen_q + if IS_EVEN_MN: + Q = tl.load(Q_ptr + Q_off, cache_modifier=".cg") + else: + Q = tl.load(Q_ptr + Q_off, mask=qmask, cache_modifier=".cg") + + if return_P: + P_ptr += ( + (bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M + ) * seqlen_k_rounded + P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange( + 0, BLOCK_N + ) + p_bp0 = P_ptr + P_offset + + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + + k_b_stride = tl.multiple_of(k_b_stride, HEAD_DIM * hk) + h_hk_ratio = h // hk + K_ptr += bid * k_b_stride + K_ptr += (hid // h_hk_ratio) * k_h_stride + V_ptr += bid * k_b_stride + V_ptr += (hid // h_hk_ratio) * k_h_stride + + K_offset = ( + tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + ) + V_offset = ( + tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + ) + + p_bk0 = K_ptr + K_offset + p_bv0 = V_ptr + V_offset + + if is_causal | is_local | (not IS_EVEN_MN): + # Cut short masking cols if there's not enough cols out there + masking_cols = min(col_max - col_min, masking_cols) + for col_shift in tl.range(0, masking_cols, step=BLOCK_N): + col_start = col_max - col_shift - BLOCK_N + col_start = tl.multiple_of(col_start, BLOCK_N) + off = col_start * k_s_stride + if IS_EVEN_MN: + K = tl.load(p_bk0 + off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_bv0 + off, cache_modifier=".cg") + else: + col_idx = col_start + tl.arange(0, BLOCK_N) + kvmask = col_idx < seqlen_k + K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") + S = tl.dot(Q, K, allow_tf32=False) + col_idx = col_start + tl.arange(0, BLOCK_N) + row_idx = row_start + tl.arange(0, BLOCK_M) + + # tl.store(p_bp0 + col_start, S) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi, + alibi_slope=alibi_slope, + ) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=(is_causal or is_local), + ) + P = P.to(V_ptr.type.element_ty) + + if is_dropout: + if return_P: + P_drop = P + + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + seqlen_k, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + if IS_EVEN_MN: + tl.store(p_bp0 + col_start, P_drop) + else: + kvmask = col_idx < seqlen_k + tl.store( + p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :] + ) + + P = apply_dropout( + P, + row_start, + col_start, + seqlen_k, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=False, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + if not PRE_LOAD_V: + off = col_start * k_s_stride + if IS_EVEN_MN: + V = tl.load(p_bv0 + off, cache_modifier=".cg") + else: + kvmask = col_idx < seqlen_k + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") + O_ = tl.dot(P, V, O_, allow_tf32=False) + + for col_start in tl.range( + col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages + ): + col_start = tl.multiple_of(col_start, BLOCK_N) + off = col_start * k_s_stride + K = tl.load(p_bk0 + off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_bv0 + off, cache_modifier=".cg") + S = tl.dot(Q, K) + + col_idx = col_start + tl.arange(0, BLOCK_N) + row_idx = row_start + tl.arange(0, BLOCK_M) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + is_even_mn=True, + is_causal=False, + is_local=is_local, + has_alibi=has_alibi, + alibi_slope=alibi_slope, + ) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=is_local, + ) + P = P.to(V_ptr.type.element_ty) + + if is_dropout: + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + seqlen_k, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + if IS_EVEN_MN: + tl.store(p_bp0 + col_start, P_drop) + else: + kvmask = col_idx < seqlen_k + tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) + + P = apply_dropout( + P, + row_start, + col_start, + seqlen_k, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=False, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + if not PRE_LOAD_V: + off = col_start * k_s_stride + V = tl.load(p_bv0 + off, cache_modifier=".cg") + + O_ = tl.dot(P, V, O_) + + # LSE + # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels + # the effect of rowmax and outputs lse only. + lse = tl.where( + rowsum_ == 0 | (rowsum_ != rowsum_), + float("inf"), + rowmax_ * softmax_scale + tl.log(rowsum_), + ) + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) + + if is_dropout: + O_ *= inv_sum[:, None] * rpdrop + else: + O_ *= inv_sum[:, None] + + O = O_.to(O_ptr.type.element_ty) # noqa + + # Write back output + o_b_stride = tl.multiple_of(o_b_stride, HEAD_DIM * h) + O_ptr += bid * o_b_stride + O_ptr += hid * o_h_stride + O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM) + + if IS_EVEN_MN: + tl.store(O_ptr + O_offset, O) + else: + tl.store(O_ptr + O_offset, O, mask=qmask) + + # Write back lse + p_lse = lse_ptr + (bid * h + hid) * seqlen_q + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + + if IS_EVEN_MN: + tl.store(p_lse + row_idx, lse) + else: + tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) + + +@triton.heuristics( + values={ + "BLOCK_M": lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + "BLOCK_N": lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]), + "num_warps": lambda args: 4, + "num_stages": lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, + "PRE_LOAD_V": lambda args: True, + "IS_EVEN_MN": lambda args: is_even_mn( + args["seqlen_q"], + args["seqlen_k"], + args["BLOCK_M"], + args["BLOCK_N"], + args["ws_left"], + args["ws_right"], + ), + } +) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) +def flash_fwd_bh_parallel_kernel( + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, + h, + hk, + pSlopes, + philox_seed, + philox_offset, + pdrop_u8, + rpdrop, + slopes_batch_stride, + HEAD_DIM: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + softmax_scale: tl.constexpr, + softmax_scale_log2e: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, + return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + IS_EVEN_MN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + blocks_per_split: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, +): + # (TODO) + pass + + +@triton.heuristics( + values={ + "BLOCK_M": lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + "BLOCK_N": lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]), + "num_warps": lambda args: 4, + "num_stages": lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, + "PRE_LOAD_V": lambda args: True, + "IS_EVEN_MN": lambda args: is_even_mn( + args["seqlen_q"], + args["seqlen_k"], + args["BLOCK_M"], + args["BLOCK_N"], + args["ws_left"], + args["ws_right"], + ), + } +) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) +def flash_fwd_splitkv_kernel( + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, + h, + hk, + pSlopes, + philox_seed, + philox_offset, + pdrop_u8, + rpdrop, + slopes_batch_stride, + HEAD_DIM: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + softmax_scale: tl.constexpr, + softmax_scale_log2e: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, + return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + IS_EVEN_MN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + blocks_per_split: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, +): + m_block = tl.program_id(0) + split_id = tl.program_id(1) + bid = tl.program_id(2) // NUM_HEADS + hid = tl.program_id(2) % NUM_HEADS + + split_block_min = split_id * blocks_per_split + split_block_max = split_block_min + blocks_per_split + + n_block_max = tl.cdiv(seqlen_k, BLOCK_N) + if is_causal: + n_block_max = min( + n_block_max, + tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N), + ) + + if has_alibi: + alibi_offset = bid * slopes_batch_stride + hid + alibi_slope = tl.load(pSlopes + alibi_offset) + alibi_slope /= softmax_scale + else: + alibi_slope = 0 + + if not is_causal: + if IS_EVEN_MN: + masking_block_min = n_block_max + else: + masking_block_min = n_block_max - 1 + elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) + else: + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 + + Q_ptr += bid * q_b_stride + Q_ptr += hid * q_h_stride + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + p_qm = Q_ptr + Q_off + qmask = row_idx[:, None] < seqlen_q + if IS_EVEN_MN: + Q = tl.load(p_qm, cache_modifier=".cg") + else: + Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") + + h_hk_ratio = h // hk + K_ptr += bid * k_b_stride + K_ptr += (hid // h_hk_ratio) * k_h_stride + V_ptr += bid * k_b_stride + V_ptr += (hid // h_hk_ratio) * k_h_stride + + K_offset = ( + tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + ) + p_k0 = K_ptr + K_offset + + V_offset = ( + tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + ) + p_v0 = V_ptr + V_offset + + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + + if split_block_max <= masking_block_min: + # no masking needed + for n_block in tl.range( + split_block_min, split_block_max, num_stages=num_stages + ): + kv_off = n_block * BLOCK_N * k_s_stride + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + S = tl.dot(Q, K) + + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + + if has_alibi: + S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=False, + ) + + if not PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + P = P.to(Q_ptr.type.element_ty) + O_ = tl.dot(P, V, O_) + else: + for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): + kv_off = n_block * BLOCK_N * k_s_stride + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + if IS_EVEN_MN: + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + else: + kvmask = col_idx < seqlen_k + K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load( + p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" + ) + + S = tl.dot(Q, K) + + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=False, + has_alibi=has_alibi, + alibi_slope=alibi_slope, + ) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=(is_causal or is_local), + ) + + if not PRE_LOAD_V: + if IS_EVEN_MN: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + else: + V = tl.load( + p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" + ) + P = P.to(Q_ptr.type.element_ty) + O_ = tl.dot(P, V, O_) + + # LSE + lse = tl.where( + rowsum_ == 0 | (rowsum_ != rowsum_), + float("-inf"), + rowmax_ * softmax_scale + tl.log(rowsum_), + ) + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) + + # Rescale output + O_ *= inv_sum[:, None] + + # Write back output + # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) + # grid = (seq_block, split, batch * head) + O_split_ptr = O_ptr + # + split, batch, head offsets, seq_block offsets are already added in row_idx + O_split_ptr += ( + (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM + ) + O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) + O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM) + p_om = O_split_ptr + O_split_offset + + if IS_EVEN_MN: + tl.store(p_om, O_, cache_modifier=".cg") + else: + tl.store(p_om, O_, mask=qmask, cache_modifier=".cg") + + # Write back lse + # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) + lse_split_ptr = lse_ptr + # + split, batch, head, seq_block offsets + lse_split_ptr += ( + split_id * tl.num_programs(2) + tl.program_id(2) + ) * seqlen_q + m_block * BLOCK_M + + if IS_EVEN_MN: + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") + else: + tl.store( + lse_split_ptr + tl.arange(0, BLOCK_M), + lse, + mask=row_idx < seqlen_q, + cache_modifier=".cg", + ) + + +@triton.jit +def flash_fwd_splitkv_combine_kernel( + out_ptr, + lse_ptr, + out_splits_ptr, + lse_splits_ptr, + head_size: tl.constexpr, + out_b_stride, + out_s_stride, + out_h_stride, + n_splits, + BLOCK_M: tl.constexpr, + q_total, + MAX_N_SPLITS: tl.constexpr, +): + pid = tl.program_id(0) + lse_splits_ptr += pid * BLOCK_M + lse_ptr += pid * BLOCK_M + out_splits_ptr += pid * BLOCK_M * head_size + out_ptr += pid * BLOCK_M * head_size + lse_split_stride = tl.num_programs(0) * BLOCK_M + out_split_stride = tl.num_programs(0) * BLOCK_M * head_size + + # Subtracting maximum from each of the split lse's for better numerical stability + lse_split_offset = ( + tl.arange(0, BLOCK_M)[:, None] + + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride + ) + lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & ( + tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits + ) + lse_splits = tl.load( + lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf") + ) + max_lse = tl.max(lse_splits, 1) + + # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) + Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) + Z_scaled = tl.sum(Zi_scaled, 1) + Zi_Z = Zi_scaled / Z_scaled[:, None] + + # Write back LSE + lse = tl.log(Z_scaled) + max_lse + out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total + tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) + + out_split_offset = ( + tl.arange(0, BLOCK_M)[:, None, None] * head_size + + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride + + tl.arange(0, head_size)[None, None, :] + ) + out_split_mask = ( + pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total + ) & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) + out_splits = tl.load( + out_splits_ptr + out_split_offset, mask=out_split_mask, other=0 + ) + out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) + out = out.to(out_ptr.type.element_ty) + + # Write back output + out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size) + tl.store(out_ptr + out_offset, out, mask=out_mask[:, None]) + + +_debug = False + + +def mha_fwd( + q, + k, + v, + out, + alibi_slopes, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + disable_splitkv=False, +): + q_dtype = q.dtype + q_device = q.device + assert q_dtype in ( + torch.float16, + torch.bfloat16, + ), "FlashAttention only support fp16 and bf16 data type" + assert q_dtype == k.dtype + assert q_dtype == v.dtype + assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" + assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" + assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" + batch_size, seqlen_q, num_heads, head_size = q.size() + _, seqlen_k, num_heads_k, _ = k.size() + assert ( + head_size % 8 == 0 + ), "head_size must be a multiple of 8, this is ensured by padding!" + assert ( + num_heads % num_heads_k == 0 + ), "Number of heads in key/value must divide number of heads in query" + if window_size_left >= seqlen_k: + window_size_left = -1 + if window_size_right >= seqlen_k: + window_size_right = -1 + if seqlen_q == 1 and alibi_slopes is None: + is_causal = False + if is_causal: + window_size_right = 0 + + if ( + seqlen_q == 1 + and num_heads > num_heads_k + and window_size_left < 0 + and window_size_right < 0 + and p_dropout == 0 + and not alibi_slopes + ): + swap_seq_and_group = True + else: + swap_seq_and_group = False + + ngroups = num_heads // num_heads_k + if swap_seq_and_group: + q = q.reshape((batch_size, num_heads_k, ngroups, head_size)).transpose(1, 2) + seqlen_q = ngroups + num_heads = num_heads_k + + if out: + assert out.stride(-1) == 1 + assert out.dtype == q.dtype + assert out.size() == (batch_size, seqlen_q, num_heads, head_size) + else: + out = torch.empty_like(q, dtype=v.dtype) + + round_multiple = lambda x, m: (x + m - 1) // m * m + head_size_rounded = round_multiple(head_size, 32) + seqlen_q_rounded = round_multiple(seqlen_q, 128) + seqlen_k_rounded = round_multiple(seqlen_k, 32) + + def splits_heuristics(num_tasks, num_sms, n_blocks): + # splits when wave efficiency is low + n_waves = triton.cdiv(num_tasks, num_sms) + eff = (num_tasks / num_sms) / n_waves + if eff > 0.8 or n_waves > 1: + return 1 + + min_blocks_per_split = 2 + best_splits = min( + triton.cdiv(n_blocks, min_blocks_per_split), + int(math.floor(1.0 / eff)), + num_sms, + ) + + # best_splits = 1 + # best_eff = eff + # min_blocks_per_split = 1 + # max_blocks_per_split = triton.cdiv(n_blocks, 2) + # for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]: + # n_splits = triton.cdiv(n_blocks, blocks_per_split) + # n_waves = triton.cdiv(n_splits * num_tasks, num_sms) + # eff = (n_splits * num_tasks / num_sms) / n_waves + # if eff > 0.85: + # best_splits = n_splits + # break + return best_splits + + with torch_device_fn.device(q_device): + # Set softmax params + lse = torch.empty( + (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device + ) + if return_softmax: + assert ( + p_dropout > 0 + ), "return_softmax is only supported when p_dropout > 0.0" + p = torch.zeros( + (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), + dtype=q_dtype, + device=q_device, + ) + else: + p = torch.empty((), device=q_device) + + # Set dropout params + if p_dropout > 0: + increment = batch_size * num_heads * 32 + philox_seed, philox_offset = update_philox_state(increment) + philox_seed = torch.tensor(philox_seed, dtype=torch.int64, device=q_device) + philox_offset = torch.tensor( + philox_offset, dtype=torch.int64, device=q_device + ) + is_dropout = True + else: + philox_seed, philox_offset = None, None + is_dropout = False + + p_dropout = 1 - p_dropout + pdrop_u8 = math.floor(p_dropout * 255.0) + rpdrop = 1.0 / p_dropout + + M_LOG2E = 1.4426950408889634074 + softmax_scale_log2e = softmax_scale * M_LOG2E + + # Set alibi params + if alibi_slopes is not None: + assert alibi_slopes.device == q_device + assert alibi_slopes.dtype in (torch.float,) + assert alibi_slopes.stride(-1) == 1 + assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( + batch_size, + num_heads, + ) + alibi_slopes_batch_stride = ( + alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 + ) + has_alibi = True + else: + alibi_slopes_batch_stride = 0 + has_alibi = False + + # Set SWA params + is_causal = window_size_left < 0 and window_size_right == 0 + is_local = window_size_left >= 0 and window_size_right >= 0 + + # ONLY EVEN_K IS SUPPORTED + assert head_size == head_size_rounded + + # Do kernel dispatching + def dispatch(B, H, Q, K, D): + num_sms = torch_device_fn.get_device_properties( + "cuda" + ).multi_processor_count + + default_args = {} + + # Try bh parallel + # if B * H > 0.8 * num_sms: + # kernel = flash_fwd_bh_parallel_kernel[(H, B)] + # # Yield kernel and prefilled args + # return kernel, default_args, None, None + + # Try splitkv + if not is_dropout and not is_local and not disable_splitkv: + BM = block_m_splitkv_heuristic(D) + n_tasks = B * H * triton.cdiv(seqlen_q, BM) + BN = block_n_splitkv_heuristic(D) + n_blocks = triton.cdiv(seqlen_k, BN) + n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) + + if _debug: + n_splits = 32 + n_blocks = triton.cdiv(K, BN) + blocks_per_split = triton.cdiv(n_blocks, n_splits) + print("block_n:", BN) + print("n_splits:", n_splits) + print("blocks_per_split", blocks_per_split) + + if n_splits > 1: + lse_splits = torch.empty( + (n_splits, B, H, Q), dtype=torch.float, device=q_device + ) + out_splits = torch.empty( + (n_splits, B, H, Q, D), dtype=torch.float, device=q_device + ) + grid = lambda args: ( + triton.cdiv(Q, args["BLOCK_M"]), + n_splits, + B * H, + ) + splitkv_kernel = flash_fwd_splitkv_kernel[grid] + blocks_per_split = triton.cdiv(n_blocks, n_splits) + splitkv_args = default_args.copy() + splitkv_args["blocks_per_split"] = blocks_per_split + splitkv_args["O_ptr"] = out_splits + splitkv_args["lse_ptr"] = lse_splits + # kernel = yield kernel, args + + if D % 128 == 0: + BLOCK_M = 4 + elif D % 64 == 0: + BLOCK_M = 8 + else: + BLOCK_M = 16 + grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) + combine_kernel = flash_fwd_splitkv_combine_kernel[grid] + combine_args = { + "out_splits_ptr": out_splits, + "lse_splits_ptr": lse_splits, + "n_splits": n_splits, + "BLOCK_M": BLOCK_M, + "q_total": B * H * Q, + "MAX_N_SPLITS": triton.next_power_of_2(n_splits), + } + return splitkv_kernel, splitkv_args, combine_kernel, combine_args + + # Last option: flash_fwd + grid = lambda args: ( + triton.cdiv(Q, args["BLOCK_M"]), + B, + H, + ) + kernel = flash_fwd_kernel[grid] + return kernel, default_args, None, None + + kernel1, kernel1_args, kernel2, kernel2_args = dispatch( + batch_size, num_heads, seqlen_q, seqlen_k, head_size + ) + + if _debug: + p = torch.empty( + (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), + dtype=torch.float32, + device=q_device, + ) + return_softmax = True + + prefilled_args = { + "Q_ptr": q, + "K_ptr": k, + "V_ptr": v, + "P_ptr": p, + "O_ptr": out, + "lse_ptr": lse, + "seqlen_q": seqlen_q, + "seqlen_k": seqlen_k, + "seqlen_q_rounded": seqlen_q_rounded, + "seqlen_k_rounded": seqlen_k_rounded, + "q_b_stride": q.stride(0), + "q_s_stride": q.stride(-3), + "q_h_stride": q.stride(-2), + "k_b_stride": k.stride(0), + "k_s_stride": k.stride(-3), + "k_h_stride": k.stride(-2), + "o_b_stride": out.stride(0), + "o_s_stride": out.stride(-3), + "o_h_stride": out.stride(-2), + "h": num_heads, + "hk": num_heads_k, + "pSlopes": alibi_slopes, + "philox_seed": philox_seed, + "philox_offset": philox_offset, + "pdrop_u8": pdrop_u8, + "rpdrop": rpdrop, + "slopes_batch_stride": alibi_slopes_batch_stride, + "HEAD_DIM": head_size, + "is_dropout": is_dropout, + "is_causal": is_causal, + "is_local": is_local, + "has_alibi": has_alibi, + "softmax_scale": softmax_scale, + "softmax_scale_log2e": softmax_scale_log2e, + "ws_left": window_size_left, + "ws_right": window_size_right, + "return_P": return_softmax, + "BATCH_SIZE": batch_size, + "blocks_per_split": None, + "NUM_HEADS": num_heads, + "NUM_HEADS_K": num_heads_k, + } + + args_copy = prefilled_args.copy() + args_copy.update(kernel1_args) + + kernel = kernel1(**args_copy) + if _debug: + print(f"{kernel.name} shared memory:", kernel.metadata.shared) + print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) + print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) + # print(kernel.asm['ttgir']) + print("p:", p) + + # Combine + if kernel2 is not None: + prefilled_args = { + "out_ptr": out, + "lse_ptr": lse, + "head_size": head_size, + "out_b_stride": out.stride(0), + "out_s_stride": out.stride(-3), + "out_h_stride": out.stride(-1), + } + args_copy = prefilled_args.copy() + args_copy.update(kernel2_args) + kernel2(**args_copy) + + if swap_seq_and_group: + out = out.transpose(1, 2).reshape( + (batch_size, 1, num_heads_k * seqlen_q, head_size) + ) + q = q.transpose(1, 2).reshape( + (batch_size, 1, num_heads_k * seqlen_q, head_size) + ) + lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) + + return out, q, k, v, lse, philox_seed, philox_offset, p + + +def flash_attention_forward( + query, + key, + value, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + return_debug_mask, + *, + scale=None, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None, + disable_splitkv=False, +): + logging.debug("GEMS FLASH_ATTENTION_FORWARD") + assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." + + HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] + HEAD_DIM_V = value.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) + if window_size_left is not None: + non_null_window_left = window_size_left + else: + non_null_window_left = -1 + if window_size_right is not None: + non_null_window_right = window_size_right + else: + non_null_window_right = -1 + + out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( + query, + key, + value, + None, + alibi_slopes, + dropout_p, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_debug_mask, + disable_splitkv=disable_splitkv, + ) + + return (out, lse, philox_seed, philox_offset, p) diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 2bcb31a81..7625c76ba 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -4,10 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from .. import runtime from ..runtime import torch_device_fn @@ -133,7 +130,7 @@ def forward(ctx, x, p, train): # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) with torch_device_fn.device(device): - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset) ctx.p = p ctx.philox_seed = philox_seed diff --git a/src/flag_gems/ops/exponential_.py b/src/flag_gems/ops/exponential_.py index 8d36ce7eb..b94482d8f 100644 --- a/src/flag_gems/ops/exponential_.py +++ b/src/flag_gems/ops/exponential_.py @@ -4,10 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from .. import runtime from ..runtime import torch_device_fn @@ -94,7 +91,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) eps = torch.finfo(dtype).eps x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) with torch_device_fn.device(device): diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py index d1f061cda..3d46676b5 100644 --- a/src/flag_gems/ops/multinomial.py +++ b/src/flag_gems/ops/multinomial.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils import libentry -from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform +from flag_gems.utils.random_utils import uniform, update_philox_state @libentry() @@ -84,7 +84,7 @@ def multinomial(prob, n_samples, with_replacement=False, *, gen=None): # The CTA level parallelism is framed in a 2d grid of blocks with grid.y # indexing into distributions and grid.x output sample batches increment = n_dist * n_samples - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist) multinomial_with_replacement[grid]( cum_prob, out, n_categories, n_samples, philox_seed, philox_offset diff --git a/src/flag_gems/ops/normal.py b/src/flag_gems/ops/normal.py index b24b4398e..d35bbba8d 100644 --- a/src/flag_gems/ops/normal.py +++ b/src/flag_gems/ops/normal.py @@ -5,7 +5,7 @@ from ..runtime import torch_device_fn from ..utils import pointwise_dynamic -from ..utils.random_utils import philox_backend_seed_offset +from ..utils.random_utils import update_philox_state from ..utils.shape_utils import broadcast_shapes, volume from .randn import randn_kernel @@ -50,7 +50,7 @@ def normal_distribution(shape, device, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/rand.py b/src/flag_gems/ops/rand.py index 9cab927ac..9924f1bef 100644 --- a/src/flag_gems/ops/rand.py +++ b/src/flag_gems/ops/rand.py @@ -4,10 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from flag_gems.utils.shape_utils import volume from .. import runtime @@ -63,7 +60,7 @@ def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/rand_like.py b/src/flag_gems/ops/rand_like.py index 85338f595..6b5b9b706 100644 --- a/src/flag_gems/ops/rand_like.py +++ b/src/flag_gems/ops/rand_like.py @@ -4,7 +4,7 @@ import triton from flag_gems.ops.rand import rand_kernel -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from ..runtime import torch_device_fn @@ -25,7 +25,7 @@ def rand_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(x.device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn.py b/src/flag_gems/ops/randn.py index 3ee2932e5..b75c97394 100644 --- a/src/flag_gems/ops/randn.py +++ b/src/flag_gems/ops/randn.py @@ -4,10 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from flag_gems.utils.shape_utils import volume from .. import runtime @@ -77,7 +74,7 @@ def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn_like.py b/src/flag_gems/ops/randn_like.py index 0458328dc..b8c87b7df 100644 --- a/src/flag_gems/ops/randn_like.py +++ b/src/flag_gems/ops/randn_like.py @@ -4,7 +4,7 @@ import triton from flag_gems.ops.randn import randn_kernel -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from ..runtime import torch_device_fn @@ -25,7 +25,7 @@ def randn_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(x.device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py index 985f527b2..64d6ea30d 100644 --- a/src/flag_gems/ops/randperm.py +++ b/src/flag_gems/ops/randperm.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from .. import runtime from ..runtime import device, torch_device_fn @@ -384,7 +384,7 @@ def sort_by_key(key, value, valid_bits): # last step, shuffle inner-block data BLOCK_SIZE_SHUFFLE = 512 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) - philox_seed, philox_offset = philox_backend_seed_offset(n_elements) + philox_seed, philox_offset = update_philox_state(n_elements) with torch_device_fn.device(key.device): duplicate_keys_shuffle_kernel[grid_shuffle]( v_out, diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py index 114a552bf..2c0846b7b 100644 --- a/src/flag_gems/ops/uniform.py +++ b/src/flag_gems/ops/uniform.py @@ -3,10 +3,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from flag_gems.utils.shape_utils import volume from .. import runtime @@ -55,7 +52,7 @@ def uniform_(self, from_=0.0, to=1.0, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(self.device): uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to) return self diff --git a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py index 76be32499..83c1fa35e 100644 --- a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py +++ b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py @@ -5,10 +5,7 @@ import triton.language as tl from flag_gems.runtime import torch_device_fn -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state eps: tl.constexpr = [ 2.220446049250313e-16, @@ -238,7 +235,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) eps = torch.finfo(dtype).eps x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) type_index = lst.index(dtype) diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 230d6c0b9..f5a6a3a11 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -8,12 +8,13 @@ attention: num_warps: warps num_stages: stages block_m: + - 32 - 64 - 128 block_n: - 32 - 64 - - 128 + # - 128 pre_load_v: - true - false diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index 1d29b629d..876c81bed 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -36,7 +36,7 @@ def uint_to_uniform_float(x): # https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452 # It returns the current state of the default Philox RNG in seed and offset and # updates the next offset by adding `increment`. -def philox_backend_seed_offset(increment, device=None): +def update_philox_state(increment, device=None): device = device or torch_device_fn.current_device() gen = torch_device_fn.default_generators[device] state_copy = gen.get_state() @@ -54,6 +54,15 @@ def philox_backend_seed_offset(increment, device=None): return seed, offset +def set_philox_state(seed, offset, device=None): + device = device or torch_device_fn.current_device() + gen = torch_device_fn.default_generators[device] + assert offset % 4 == 0 + new_state = torch.tensor((seed, offset), dtype=torch.int64) + gen.set_state(new_state.view(torch.uint8)) + return + + def per_thread_offset(N, num_blocks, num_warps, warp_threads=32): block_threads = num_warps * warp_threads max_threads = num_blocks * block_threads diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py index 30861be56..f24c25f3c 100644 --- a/tests/test_attention_ops.py +++ b/tests/test_attention_ops.py @@ -3,27 +3,15 @@ import torch import flag_gems +from flag_gems.runtime import torch_device_fn from .accuracy_utils import gems_assert_close, to_reference +from .conftest import TO_CPU device = flag_gems.device -@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") -@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") -@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") -@pytest.mark.scaled_dot_product_attention -@pytest.mark.parametrize("batch", [8, 16]) -@pytest.mark.parametrize("num_head", [1, 8]) -@pytest.mark.parametrize("q_seq_len", [17, 64, 128]) -@pytest.mark.parametrize("kv_seq_len", [7, 87, 128, 577, 2048]) -@pytest.mark.parametrize("head_size", [64, 128]) -@pytest.mark.parametrize("add_bias", [True, False]) -@pytest.mark.parametrize("is_causal", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_scaled_dot_product_attention( - batch, num_head, q_seq_len, kv_seq_len, head_size, add_bias, is_causal, dtype -): +def make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype): np.random.seed(0) np_query = np.random.uniform( -0.05, 0.05, (batch, num_head, q_seq_len, head_size) @@ -34,50 +22,355 @@ def test_scaled_dot_product_attention( np_value = np.random.uniform( -0.05, 0.05, (batch, num_head, kv_seq_len, head_size) ).astype(np.float32) - np_attn_bias = np.random.uniform( - -0.05, 0.05, (batch, num_head, q_seq_len, kv_seq_len) - ).astype(np.float32) - query = torch.tensor(np_query, device=device, dtype=dtype) - key = torch.tensor(np_key, device=device, dtype=dtype) - value = torch.tensor(np_value, device=device, dtype=dtype) - if add_bias: - attn_bias = torch.tensor(np_attn_bias, device=device, dtype=dtype) - else: - attn_bias = None + query = torch.tensor(np_query, device="cuda", dtype=dtype) + key = torch.tensor(np_key, device="cuda", dtype=dtype) + value = torch.tensor(np_value, device="cuda", dtype=dtype) + + return query, key, value + + +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [(4, 8, 1024, 1024), (4, 8, 2048, 256), (4, 8, 17, 1030)], +) +@pytest.mark.parametrize("head_size", [64, 128]) +@pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_legacy( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + torch_result = torch.nn.functional.scaled_dot_product_attention( + ref_query, + ref_key, + ref_value, + attn_mask=None, + scale=scale, + is_causal=is_causal, + ) + + flaggem_result = flag_gems.ops.scaled_dot_product_attention( + query, key, value, attn_mask=None, scale=scale, is_causal=is_causal + ) + + gems_assert_close(flaggem_result, torch_result, dtype) + +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (4, 8, 1024, 1024), + ], +) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_square_qk_even_mn( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + torch_result = torch.nn.functional.scaled_dot_product_attention( + ref_query, + ref_key, + ref_value, + attn_mask=None, + scale=scale, + is_causal=is_causal, + ) + + with flag_gems.use_gems(): + flaggem_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, scale=scale, is_causal=is_causal + ) + + gems_assert_close(flaggem_result, torch_result, dtype) + + +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [(1, 1, 128, 2048), (4, 8, 1024, 128), (4, 8, 17, 1030)], +) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_nonsquare_qk( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) ref_query = to_reference(query, False) ref_key = to_reference(key, False) ref_value = to_reference(value, False) - ref_attn_bias = to_reference(attn_bias, False) if add_bias else None scale = float(1.0 / np.sqrt(head_size)) + torch_result = torch.nn.functional.scaled_dot_product_attention( + ref_query, + ref_key, + ref_value, + attn_mask=None, + scale=scale, + is_causal=is_causal, + ) - if is_causal: - torch_result = torch.nn.functional.scaled_dot_product_attention( - ref_query, - ref_key, - ref_value, + with flag_gems.use_gems(): + flaggem_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, scale=scale, is_causal=is_causal + ) + + gems_assert_close(flaggem_result, torch_result, dtype) + + +@pytest.mark.skipif(TO_CPU, reason="Unsupported in CPU mode") +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.flash_attention_forward +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [(1, 1, 128, 2048), (4, 8, 1024, 128), (4, 8, 17, 1030)], +) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_fwd_nonsquare_qk_causal( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + + q = ref_query.transpose(1, 2) + k = ref_key.transpose(1, 2) + v = ref_value.transpose(1, 2) + out, *_ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + 0, + is_causal, + False, + scale=scale, + ) + torch_result = out.transpose(1, 2) + + with flag_gems.use_gems(): + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + out, *_ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + 0, + is_causal, + False, + scale=scale, + ) + flaggem_result = out.transpose(1, 2) + + gems_assert_close(flaggem_result, torch_result, dtype) + + +@pytest.mark.skipif(TO_CPU, reason="Unsupported in CPU mode") +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.flash_attention_forward +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (1, 1, 1024, 1024), + ], +) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_fwd_dropout( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) + scale = float(1.0 / np.sqrt(head_size)) + + with flag_gems.use_gems(): + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + return_debug_mask = True + dropout_p = 0.2 + device = torch_device_fn.current_device() + gen = torch_device_fn.default_generators[device] + old_state = gen.get_state() + new_state = torch.tensor((12345, 0), dtype=torch.int64) + gen.set_state(new_state.view(torch.uint8)) + ( + out, + lse, + seed_0, + offset_0, + debug_softmax_0, + ) = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, scale=scale, - is_causal=is_causal, ) - else: - torch_result = torch.nn.functional.scaled_dot_product_attention( - ref_query, - ref_key, - ref_value, - attn_mask=ref_attn_bias, + gen.set_state(new_state.view(torch.uint8)) + ( + out, + lse, + seed_1, + offset_1, + debug_softmax_1, + ) = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, scale=scale, - is_causal=is_causal, ) + gen.set_state(old_state) + gems_assert_close(debug_softmax_0, debug_softmax_1, dtype) + dropout_ratio = torch.sum(debug_softmax_0 < 0) / torch.sum(debug_softmax_0 != 0) + np.testing.assert_allclose(dropout_ratio.to("cpu"), dropout_p, rtol=5e-2) + + +@pytest.mark.skipif(TO_CPU, reason="Unsupported in CPU mode") +@pytest.mark.skipif(torch.__version__ < "2.4", reason="Low Pytorch Version") +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.device == "musa", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.flash_attention_forward +@pytest.mark.parametrize( + ["batch", "num_head", "q_seq_len", "kv_seq_len"], + [(1, 1, 128, 2048), (8, 32, 1024, 1024), (8, 32, 1024, 128), (8, 32, 17, 1030)], +) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + ["window_size_left", "window_size_right"], [(256, 0), (128, 128)] +) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_fwd_swa( + batch, + num_head, + q_seq_len, + kv_seq_len, + head_size, + is_causal, + window_size_left, + window_size_right, + dtype, +): + query, key, value = make_input( + batch, num_head, q_seq_len, kv_seq_len, head_size, dtype + ) + + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + dropout_p = 0 + return_debug_mask = False + + q = ref_query.transpose(1, 2) + k = ref_key.transpose(1, 2) + v = ref_value.transpose(1, 2) + out, lse, _, _, _ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + torch_result = out.transpose(1, 2) + torch_lse = lse.transpose(1, 2) with flag_gems.use_gems(): - if is_causal: - flaggem_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, scale=scale, is_causal=is_causal - ) - else: - flaggem_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attn_bias, scale=scale, is_causal=is_causal - ) - gems_assert_close(flaggem_result, torch_result, dtype, reduce_dim=head_size) + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + out, lse, _, _, _ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + flaggem_result = out.transpose(1, 2) + flaggem_lse = lse.transpose(1, 2) + + gems_assert_close(flaggem_result, torch_result, dtype) + gems_assert_close(flaggem_lse, torch_lse, torch.float)