From 4a63a7e7498b73bd557f4bfd4da9934c40606aaa Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Sat, 1 Mar 2025 09:55:37 +0000 Subject: [PATCH 01/25] Make flash attn compatible with flash_attn v2 api. WIP. --- OperatorList.md | 2 +- src/flag_gems/__init__.py | 394 ++++++++++++++-------------- src/flag_gems/ops/__init__.py | 4 +- src/flag_gems/ops/attention.py | 280 ++++++++++++++++++++ src/flag_gems/ops/dropout.py | 7 +- src/flag_gems/utils/random_utils.py | 2 +- 6 files changed, 483 insertions(+), 206 deletions(-) diff --git a/OperatorList.md b/OperatorList.md index 1f6702a5c..5471e7aaf 100644 --- a/OperatorList.md +++ b/OperatorList.md @@ -86,7 +86,7 @@ - nll_loss - nll_loss_forward - nll_loss_nd -- scaled_dot_product_attention +- _flash_attention_forward - upsample_nearest2d - _fft_c2r - _fft_r2c diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c50d66bea..8cf202e3e 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -19,206 +19,206 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): global current_work_registrar current_work_registrar = registrar( ( - ("abs", abs, Autograd.disable), - ("add.Tensor", add, Autograd.disable), - ("addmm", addmm, Autograd.disable), - ("arange.start_step", arange_start, Autograd.disable), - ("arange.start", arange_start, Autograd.disable), - ("arange", arange, Autograd.disable), - ("batch_norm", batch_norm, Autograd.enable), - ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable), - ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable), - ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable), - ("bitwise_not", bitwise_not, Autograd.disable), - ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable), - ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable), - ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable), - ("bmm", bmm, Autograd.disable), - ("clamp", clamp, Autograd.disable), - ("clamp.Tensor", clamp_tensor, Autograd.disable), - ("cos", cos, Autograd.disable), - ("pad", pad, Autograd.disable), - ("constant_pad_nd", constant_pad_nd, Autograd.disable), - ("cumsum", cumsum, Autograd.disable), - ("cummin", cummin, Autograd.disable), - ("div.Tensor", true_divide, Autograd.disable), - ("div.Scalar", true_divide, Autograd.disable), - ("div.Tensor_mode", div_mode, Autograd.disable), - ("div.Scalar_mode", div_mode, Autograd.disable), + # ("abs", abs, Autograd.disable), + # ("add.Tensor", add, Autograd.disable), + # ("addmm", addmm, Autograd.disable), + # ("arange.start_step", arange_start, Autograd.disable), + # ("arange.start", arange_start, Autograd.disable), + # ("arange", arange, Autograd.disable), + # ("batch_norm", batch_norm, Autograd.enable), + # ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable), + # ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable), + # ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable), + # ("bitwise_not", bitwise_not, Autograd.disable), + # ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable), + # ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable), + # ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable), + # # ("bmm", bmm, Autograd.disable), + # ("clamp", clamp, Autograd.disable), + # ("clamp.Tensor", clamp_tensor, Autograd.disable), + # ("cos", cos, Autograd.disable), + # ("pad", pad, Autograd.disable), + # ("constant_pad_nd", constant_pad_nd, Autograd.disable), + # ("cumsum", cumsum, Autograd.disable), + # ("cummin", cummin, Autograd.disable), + # ("div.Tensor", true_divide, Autograd.disable), + # ("div.Scalar", true_divide, Autograd.disable), + # ("div.Tensor_mode", div_mode, Autograd.disable), + # ("div.Scalar_mode", div_mode, Autograd.disable), + # ( + # "divide.Tensor", + # true_divide, + # Autograd.disable, + # ), # divide, an alias for div + # ("divide.Scalar", true_divide, Autograd.disable), + # ("divide.Tensor_mode", div_mode, Autograd.disable), + # ("divide.Scalar_mode", div_mode, Autograd.disable), + # ( + # "true_divide.Tensor", + # true_divide, + # Autograd.disable, + # ), # true_divide, an alias for div + # ("true_divide.Scalar", true_divide, Autograd.disable), + # ("floor_divide", floor_divide, Autograd.disable), + # ("floor_divide.Scalar", floor_divide, Autograd.disable), + # ("remainder.Tensor", remainder, Autograd.disable), + # ("native_dropout", native_dropout, Autograd.enable), + # ("erf", erf, Autograd.disable), + # ("embedding", embedding, Autograd.enable), + # ("eq.Tensor", eq, Autograd.disable), + # ("eq.Scalar", eq_scalar, Autograd.disable), + # ("exp", exp, Autograd.disable), + # ("exponential_", exponential_, Autograd.disable), + # ("ge.Tensor", ge, Autograd.disable), + # ("ge.Scalar", ge_scalar, Autograd.disable), + # ("gelu", gelu, Autograd.enable), + # ("native_group_norm", group_norm, Autograd.enable), + # ("_weight_norm_interface", weight_norm_interface, Autograd.enable), + # ("_weight_norm", weight_norm, Autograd.enable), + # ("gt.Tensor", gt, Autograd.disable), + # ("gt.Scalar", gt_scalar, Autograd.disable), + # ("instance_norm", instance_norm, Autograd.enable), + # ("isfinite", isfinite, Autograd.disable), + # ("isin.Tensor_Tensor", isin, Autograd.disable), + # ("isin.Scalar_Tensor", isin, Autograd.disable), + # ("isin.Tensor_Scalar", isin, Autograd.disable), + # ("isinf", isinf, Autograd.disable), + # ("isnan", isnan, Autograd.disable), + # ("minimum", minimum, Autograd.disable), + # ("maximum", maximum, Autograd.disable), + # ("native_layer_norm", layer_norm, Autograd.enable), + # ("le.Tensor", le, Autograd.disable), + # ("le.Scalar", le_scalar, Autograd.disable), + # ("lt.Tensor", lt, Autograd.disable), + # ("lt.Scalar", lt_scalar, Autograd.disable), + # ("rms_norm", rms_norm, Autograd.disable), + # ("rand", rand, Autograd.disable), + # ("randn", randn, Autograd.disable), + # ("rand_like", rand_like, Autograd.disable), + # ("randn_like", randn_like, Autograd.disable), + # ("zeros", zeros, Autograd.disable), + # ("ones", ones, Autograd.disable), + # ("full", full, Autograd.disable), + # ("zeros_like", zeros_like, Autograd.disable), + # ("ones_like", ones_like, Autograd.disable), + # ("full_like", full_like, Autograd.disable), + # ("resolve_neg", resolve_neg, Autograd.disable), + # ("resolve_conj", resolve_conj, Autograd.disable), + # ("normal.Tensor_float", normal_tensor_float, Autograd.disable), + # ("normal.float_Tensor", normal_float_tensor, Autograd.disable), + # ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable), + # ("uniform_", uniform_, Autograd.disable), + # ("mean", mean, Autograd.disable), + # ("mean.dim", mean_dim, Autograd.disable), + # ("mm", mm, Autograd.disable), + # ("mul.Tensor", mul, Autograd.disable), + # ("multinomial", multinomial, Autograd.disable), + # ("mv", mv, Autograd.disable), + # ("ne.Tensor", ne, Autograd.disable), + # ("ne.Scalar", ne_scalar, Autograd.disable), + # ("neg", neg, Autograd.disable), + # ("pow.Scalar", pow_scalar, Autograd.disable), + # ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable), + # ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable), + # ("reciprocal", reciprocal, Autograd.disable), + # ("relu", relu, Autograd.enable), + # ("rsqrt", rsqrt, Autograd.disable), + # ("sigmoid", sigmoid, Autograd.enable), + # ("silu", silu, Autograd.enable), + # ("sin", sin, Autograd.disable), + # ("softmax.int", softmax, Autograd.enable), + # ("sort", sort, Autograd.disable), + # ("sub.Tensor", sub, Autograd.disable), + # ("tanh", tanh, Autograd.enable), + # ("triu", triu, Autograd.disable), + # ("topk", topk, Autograd.disable), + # ("var_mean.correction", var_mean, Autograd.disable), + # ("linalg_vector_norm", vector_norm, Autograd.disable), + # ("where.self_out", where_self_out, Autograd.disable), + # ("where.self", where_self, Autograd.disable), + # ("where.ScalarSelf", where_scalar_self, Autograd.disable), + # ("where.ScalarOther", where_scalar_other, Autograd.disable), + # ("max", max, Autograd.disable), + # ("max.dim", max_dim, Autograd.disable), + # ("min", min, Autograd.disable), + # ("min.dim", min_dim, Autograd.disable), + # ("amax", amax, Autograd.disable), + # ("argmax", argmax, Autograd.disable), + # ("argmin", argmin, Autograd.disable), + # ("prod", prod, Autograd.disable), + # ("prod.dim_int", prod_dim, Autograd.disable), + # ("sum", sum, Autograd.disable), + # ("sum.dim_IntList", sum_dim, Autograd.disable), ( - "divide.Tensor", - true_divide, - Autograd.disable, - ), # divide, an alias for div - ("divide.Scalar", true_divide, Autograd.disable), - ("divide.Tensor_mode", div_mode, Autograd.disable), - ("divide.Scalar_mode", div_mode, Autograd.disable), - ( - "true_divide.Tensor", - true_divide, - Autograd.disable, - ), # true_divide, an alias for div - ("true_divide.Scalar", true_divide, Autograd.disable), - ("floor_divide", floor_divide, Autograd.disable), - ("floor_divide.Scalar", floor_divide, Autograd.disable), - ("remainder.Tensor", remainder, Autograd.disable), - ("native_dropout", native_dropout, Autograd.enable), - ("erf", erf, Autograd.disable), - ("embedding", embedding, Autograd.enable), - ("eq.Tensor", eq, Autograd.disable), - ("eq.Scalar", eq_scalar, Autograd.disable), - ("exp", exp, Autograd.disable), - ("exponential_", exponential_, Autograd.disable), - ("ge.Tensor", ge, Autograd.disable), - ("ge.Scalar", ge_scalar, Autograd.disable), - ("gelu", gelu, Autograd.enable), - ("native_group_norm", group_norm, Autograd.enable), - ("_weight_norm_interface", weight_norm_interface, Autograd.enable), - ("_weight_norm", weight_norm, Autograd.enable), - ("gt.Tensor", gt, Autograd.disable), - ("gt.Scalar", gt_scalar, Autograd.disable), - ("instance_norm", instance_norm, Autograd.enable), - ("isfinite", isfinite, Autograd.disable), - ("isin.Tensor_Tensor", isin, Autograd.disable), - ("isin.Scalar_Tensor", isin, Autograd.disable), - ("isin.Tensor_Scalar", isin, Autograd.disable), - ("isinf", isinf, Autograd.disable), - ("isnan", isnan, Autograd.disable), - ("minimum", minimum, Autograd.disable), - ("maximum", maximum, Autograd.disable), - ("native_layer_norm", layer_norm, Autograd.enable), - ("le.Tensor", le, Autograd.disable), - ("le.Scalar", le_scalar, Autograd.disable), - ("lt.Tensor", lt, Autograd.disable), - ("lt.Scalar", lt_scalar, Autograd.disable), - ("rms_norm", rms_norm, Autograd.disable), - ("rand", rand, Autograd.disable), - ("randn", randn, Autograd.disable), - ("rand_like", rand_like, Autograd.disable), - ("randn_like", randn_like, Autograd.disable), - ("zeros", zeros, Autograd.disable), - ("ones", ones, Autograd.disable), - ("full", full, Autograd.disable), - ("zeros_like", zeros_like, Autograd.disable), - ("ones_like", ones_like, Autograd.disable), - ("full_like", full_like, Autograd.disable), - ("resolve_neg", resolve_neg, Autograd.disable), - ("resolve_conj", resolve_conj, Autograd.disable), - ("normal.Tensor_float", normal_tensor_float, Autograd.disable), - ("normal.float_Tensor", normal_float_tensor, Autograd.disable), - ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable), - ("uniform_", uniform_, Autograd.disable), - ("mean", mean, Autograd.disable), - ("mean.dim", mean_dim, Autograd.disable), - ("mm", mm, Autograd.disable), - ("mul.Tensor", mul, Autograd.disable), - ("multinomial", multinomial, Autograd.disable), - ("mv", mv, Autograd.disable), - ("ne.Tensor", ne, Autograd.disable), - ("ne.Scalar", ne_scalar, Autograd.disable), - ("neg", neg, Autograd.disable), - ("pow.Scalar", pow_scalar, Autograd.disable), - ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable), - ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable), - ("reciprocal", reciprocal, Autograd.disable), - ("relu", relu, Autograd.enable), - ("rsqrt", rsqrt, Autograd.disable), - ("sigmoid", sigmoid, Autograd.enable), - ("silu", silu, Autograd.enable), - ("sin", sin, Autograd.disable), - ("softmax.int", softmax, Autograd.enable), - ("sort", sort, Autograd.disable), - ("sub.Tensor", sub, Autograd.disable), - ("tanh", tanh, Autograd.enable), - ("triu", triu, Autograd.disable), - ("topk", topk, Autograd.disable), - ("var_mean.correction", var_mean, Autograd.disable), - ("linalg_vector_norm", vector_norm, Autograd.disable), - ("where.self_out", where_self_out, Autograd.disable), - ("where.self", where_self, Autograd.disable), - ("where.ScalarSelf", where_scalar_self, Autograd.disable), - ("where.ScalarOther", where_scalar_other, Autograd.disable), - ("max", max, Autograd.disable), - ("max.dim", max_dim, Autograd.disable), - ("min", min, Autograd.disable), - ("min.dim", min_dim, Autograd.disable), - ("amax", amax, Autograd.disable), - ("argmax", argmax, Autograd.disable), - ("argmin", argmin, Autograd.disable), - ("prod", prod, Autograd.disable), - ("prod.dim_int", prod_dim, Autograd.disable), - ("sum", sum, Autograd.disable), - ("sum.dim_IntList", sum_dim, Autograd.disable), - ( - "scaled_dot_product_attention", - scaled_dot_product_attention, - Autograd.disable, - ), - ("all", all, Autograd.disable), - ("all.dim", all_dim, Autograd.disable), - ("all.dims", all_dims, Autograd.disable), - ("any", any, Autograd.disable), - ("any.dim", any_dim, Autograd.disable), - ("any.dims", any_dims, Autograd.disable), - ("quantile", quantile, Autograd.disable), - ("log_softmax.int", log_softmax, Autograd.enable), - ("outer", outer, Autograd.enable), - ("cross_entropy_loss", cross_entropy_loss, Autograd.enable), - ("nll_loss_forward", nll_loss_forward, Autograd.disable), - ("nll_loss_backward", nll_loss_backward, Autograd.disable), - ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable), - ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable), - ("scatter.src", scatter, Autograd.disable), - ("scatter.reduce", scatter, Autograd.disable), - ("gather", gather, Autograd.disable), - ("gather_backward", gather_backward, Autograd.disable), - ("isclose", isclose, Autograd.disable), - ("allclose", allclose, Autograd.disable), - ("fill.Scalar", fill_scalar, Autograd.disable), - ("fill.Tensor", fill_tensor, Autograd.disable), - ("flip", flip, Autograd.disable), - ("slice_scatter", slice_scatter, Autograd.disable), - ("select_scatter", select_scatter, Autograd.disable), - ("index_select", index_select, Autograd.disable), - ("tile", tile, Autograd.disable), - ("masked_fill.Tensor", masked_fill, Autograd.disable), - ("masked_fill.Scalar", masked_fill, Autograd.disable), - ("masked_fill_.Tensor", masked_fill_, Autograd.disable), - ("masked_fill_.Scalar", masked_fill_, Autograd.disable), - ("_unique2", _unique2, Autograd.disable), - ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable), - ("upsample_nearest2d", upsample_nearest2d, Autograd.disable), - ("nonzero", nonzero, Autograd.disable), - ("repeat", repeat, Autograd.disable), - ("masked_select", masked_select, Autograd.disable), - ("stack", stack, Autograd.disable), - ("hstack", hstack, Autograd.disable), - ("cat", cat, Autograd.disable), - ( - "repeat_interleave.self_int", - repeat_interleave_self_int, - Autograd.disable, - ), - ("vstack", vstack, Autograd.disable), - ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable), - ( - "repeat_interleave.self_Tensor", - repeat_interleave_self_tensor, + "_flash_attention_forward", + flash_attention_forward, Autograd.disable, ), - ("randperm", randperm, Autograd.disable), - ("diag", diag, Autograd.disable), - ("diag_embed", diag_embed, Autograd.disable), - ("diagonal_backward", diagonal_backward, Autograd.disable), - ("index_add", index_add, Autograd.disable), - ("count_nonzero", count_nonzero, Autograd.disable), - ("logical_or", logical_or, Autograd.disable), - ("logical_and", logical_and, Autograd.disable), - ("logical_xor", logical_xor, Autograd.disable), - ("logical_not", logical_not, Autograd.disable), - ("log_sigmoid", log_sigmoid, Autograd.disable), - ("vdot", vdot, Autograd.disable), - ("mse_loss", mse_loss, Autograd.disable), + # ("all", all, Autograd.disable), + # ("all.dim", all_dim, Autograd.disable), + # ("all.dims", all_dims, Autograd.disable), + # ("any", any, Autograd.disable), + # ("any.dim", any_dim, Autograd.disable), + # ("any.dims", any_dims, Autograd.disable), + # ("quantile", quantile, Autograd.disable), + # ("log_softmax.int", log_softmax, Autograd.enable), + # ("outer", outer, Autograd.enable), + # ("cross_entropy_loss", cross_entropy_loss, Autograd.enable), + # ("nll_loss_forward", nll_loss_forward, Autograd.disable), + # ("nll_loss_backward", nll_loss_backward, Autograd.disable), + # ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable), + # ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable), + # ("scatter.src", scatter, Autograd.disable), + # ("scatter.reduce", scatter, Autograd.disable), + # ("gather", gather, Autograd.disable), + # ("gather_backward", gather_backward, Autograd.disable), + # ("isclose", isclose, Autograd.disable), + # ("allclose", allclose, Autograd.disable), + # ("fill.Scalar", fill_scalar, Autograd.disable), + # ("fill.Tensor", fill_tensor, Autograd.disable), + # ("flip", flip, Autograd.disable), + # ("slice_scatter", slice_scatter, Autograd.disable), + # ("select_scatter", select_scatter, Autograd.disable), + # ("index_select", index_select, Autograd.disable), + # ("tile", tile, Autograd.disable), + # ("masked_fill.Tensor", masked_fill, Autograd.disable), + # ("masked_fill.Scalar", masked_fill, Autograd.disable), + # ("masked_fill_.Tensor", masked_fill_, Autograd.disable), + # ("masked_fill_.Scalar", masked_fill_, Autograd.disable), + # ("_unique2", _unique2, Autograd.disable), + # ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable), + # ("upsample_nearest2d", upsample_nearest2d, Autograd.disable), + # ("nonzero", nonzero, Autograd.disable), + # ("repeat", repeat, Autograd.disable), + # ("masked_select", masked_select, Autograd.disable), + # ("stack", stack, Autograd.disable), + # ("hstack", hstack, Autograd.disable), + # ("cat", cat, Autograd.disable), + # ( + # "repeat_interleave.self_int", + # repeat_interleave_self_int, + # Autograd.disable, + # ), + # ("vstack", vstack, Autograd.disable), + # ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable), + # ( + # "repeat_interleave.self_Tensor", + # repeat_interleave_self_tensor, + # Autograd.disable, + # ), + # ("randperm", randperm, Autograd.disable), + # ("diag", diag, Autograd.disable), + # ("diag_embed", diag_embed, Autograd.disable), + # ("diagonal_backward", diagonal_backward, Autograd.disable), + # ("index_add", index_add, Autograd.disable), + # ("count_nonzero", count_nonzero, Autograd.disable), + # ("logical_or", logical_or, Autograd.disable), + # ("logical_and", logical_and, Autograd.disable), + # ("logical_xor", logical_xor, Autograd.disable), + # ("logical_not", logical_not, Autograd.disable), + # ("log_sigmoid", log_sigmoid, Autograd.disable), + # ("vdot", vdot, Autograd.disable), + # ("mse_loss", mse_loss, Autograd.disable), ), user_unused_ops_list=[] if unused is None else unused, lib=lib, diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index f0a1cc9fe..c4051d6d8 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -7,7 +7,7 @@ from .arange import arange, arange_start from .argmax import argmax from .argmin import argmin -from .attention import scaled_dot_product_attention +from .attention import flash_attention_forward from .batch_norm import batch_norm from .bitwise_and import ( bitwise_and_scalar, @@ -292,7 +292,7 @@ "repeat_interleave_self_int", "vstack", "repeat_interleave_tensor", - "scaled_dot_product_attention", + "flash_attention_forward", "conv2d", "conv1d", "_conv_depthwise2d", diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 7e16189e8..3c114fcf1 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -5,6 +5,7 @@ import triton.language as tl from flag_gems.runtime import torch_device_fn +from flag_gems.utils.random_utils import update_philox_state from .. import runtime @@ -305,6 +306,285 @@ def _attn_fwd( tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) +@triton.jit +def philox_offset_one_warp(b, h, nh: tl.constexpr): + # To align with TriDao's implementation, philox_offset linearly determined by + # a 3d dense tensor (batch_id, head_id, thread_id) with shape (batch_size, num_heads, 32) + # and stride ( num_heads * 32, 32, 1 ) + return (b * nh + h) * 32 + tl.arange(0, 32) + + +@triton.jit +def to_lohi(x): + return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) + + +@triton.jit +def from_lohi(lo, hi): + return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) + + +@triton.jit +def philox_(seed, subsequence, offset): + kPhilox10A: tl.constexpr = 0x9E3779B9 + kPhilox10B: tl.constexpr = 0xBB67AE85 + k0, k1 = to_lohi(seed.to(tl.uint64)) + c0, c1 = to_lohi(offset.to(tl.uint64)) + c2, c3 = to_lohi(subsequence(tl.uint64)) + + # pragma unroll + kPhiloxSA: tl.constexpr = 0xD2511F53 + kPhiloxSB: tl.constexpr = 0xCD9E8D57 + for _ in range(6): + res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) + res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) + res0_x, res0_y = to_lohi(res0) + res1_x, res1_y = to_lohi(res1) + c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x + k0 += kPhilox10A + k1 += kPhilox10B + + res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) + res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) + res0_x.res0_y = to_lohi(res0) + res1_x.res1_y = to_lohi(res1) + c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x + + return c0, c1, c2, c3 + + +@triton.jit +def apply_mask( + P, + mask, + encode_dropout_in_sign_bit: tl.constexpr, +): + if encode_dropout_in_sign_bit: + P = tl.where(mask, -P, P) + else: + P = tl.where(mask, 0, P) + return P + + +@triton.jit +def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr): + r = uint32_r + p = uint8_p + m0 = tl.where(r & 0xFF < p, 0, 1) + r >>= 8 + m1 = tl.where(r & 0xFF < p, 0, 1) + m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) + + r >>= 8 + m0 = tl.where(r & 0xFF < p, 0, 1) + r >>= 8 + m1 = tl.where(r & 0xFF < p, 0, 1) + m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) + + m = tl.join(m0, m1).trans(2, 0, 1).reshape(4 * M, N) + return m + + +@triton.jit( + do_not_specialize=[ + "b", + "h", + "row_start", + "col_start", + "philox_seed", + "philox_offset", + ] +) +def apply_dropout( + P, + row_start, + col_start, + philox_seed, + philox_offset, + p_dropout_uint8: tl.constexpr, + encode_dropout_in_sign_bit: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # P is of size (BLOCK_M, BLOCK_N) and its scalar bitsize is 32 + # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32 + M: tl.constexpr = BLOCK_M // 16 + N: tl.constexpr = BLOCK_N // 32 + row = row_start + tl.arange(0, M)[:, None] + col = col_start + tl.arange(0, BLOCK_N)[None, :] // 32 + + philox_offset = philox_offset + tl.arange(0, BLOCK_N)[None, :] % 32 + + subsequence = from_lohi(row * 32, col) + r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset) + + # Fully unrolled due to triton's inability to concat 2d tensor + m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N) + m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N) + m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N) + + m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N) + m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N) + m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N) + + m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N) + P = apply_mask(P, m) + return P + + +@triton.jit +def flash_fwd_kernel( + pQ, + pK, + pV, + pP, + pO, + pSlopes, + philox_seed, + philox_offset, + pdrop_int8, + drop: tl.constexpr, + causal: tl.constexpr, + scale: tl.constexp, + wl: tl.constexpr, + wr: tl.constexpr, + return_P: tl.constexpr, +): + pass + + +def flash_attention_forward( + query, + key, + value, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + return_debug_mask, + *, + scale=None, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None +): + logging.debug("GEMS FLASH_ATTENTION") + assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." + + HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] + HEAD_DIM_V = value.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) + non_null_window_left = window_size_left or -1 + non_null_window_right = window_size_right or -1 + out = torch.empty_like(query, dtype=value.dtype) + + mha_out = mha_fwd( + query, + key, + value, + out, + alibi_slopes, + dropout_p, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_debug_mask, + ) + ( + output, + q_padded, + k_padded, + v_padded, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask, + ) = mha_out + + +def mha_fwd( + q, + k, + v, + out, + alibi_slopes, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, +): + q_dtype = q.dtype + q_device = q.device + assert q_dtype in ( + torch.float16, + torch.bfloat16, + ), "FlashAttention only support fp16 and bf16 data type" + assert q_dtype == k.dtype + assert q_dtype == v.dtype + assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" + assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" + assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" + batch_size, seqlen_q, num_heads, head_size = q.size() + _, seqlen_k, num_heads_k, _ = k.size() + assert ( + head_size % 8 == 0 + ), "head_size must be a multiple of 8, this is ensured by padding!" + assert ( + num_heads % num_heads_k == 0 + ), "Number of heads in key/value must divide number of heads in query" + if window_size_left >= seqlen_k: + window_size_left = -1 + if window_size_right >= seqlen_k: + window_size_right = -1 + if seqlen_q == 1 and alibi_slopes is None: + is_causal = False + if is_causal: + window_size_right = 0 + + if out: + assert out.stride(-1) == 1 + assert out.dtype == q.dtype + assert out.size() == (batch_size, seqlen_q, num_heads, head_size) + else: + out = torch.empty_like(q) + + round_multiple = lambda x, m: (x + m - 1) // m * m + head_size_rounded = round_multiple(head_size, 32) + seqlen_q_rounded = round_multiple(seqlen_q, 128) + seqlen_k_rounded = round_multiple(seqlen_k, 128) + + lse = torch.empty( + (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device + ) + if return_softmax: + assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0" + p = torch.empty( + (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), + dtype=q_dtype, + device=q_device, + ) + + if p_dropout > 0: + increment = triton.cdiv(batch_size * num_heads * 32) + philox_seed, philox_offset = update_philox_state(increment) + + with torch_device_fn.device(q.device): + grid = lambda args: ( + triton.cdiv(seqlen_q, args["BLOCK_M"]), + batch_size * num_heads, + 1, + ) + + def scaled_dot_product_attention( query, key, diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 2bcb31a81..7625c76ba 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -4,10 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, - uint_to_uniform_float, -) +from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state from .. import runtime from ..runtime import torch_device_fn @@ -133,7 +130,7 @@ def forward(ctx, x, p, train): # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) with torch_device_fn.device(device): - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset) ctx.p = p ctx.philox_seed = philox_seed diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index 22f7155ce..5ca61a8f4 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -34,7 +34,7 @@ def uint_to_uniform_float(x): # https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452 # It returns the current state of the default Philox RNG in seed and offset and # updates the next offset by adding `increment`. -def philox_backend_seed_offset(increment, device=None): +def update_philox_state(increment, device=None): device = device or torch_device_fn.current_device() gen = torch_device_fn.default_generators[device] state_copy = gen.get_state() From f43a2e04ea7a5ba5f3a391a8176ac7b1b7855107 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Mon, 3 Mar 2025 04:00:38 +0000 Subject: [PATCH 02/25] update kernel wrapper. --- src/flag_gems/ops/attention.py | 395 +++++++++++++++++++-------------- 1 file changed, 224 insertions(+), 171 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 3c114fcf1..1804ddcf2 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -306,6 +306,96 @@ def _attn_fwd( tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +): + logging.debug("GEMS SCALED DOT PRODUCT ATTENTION") + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = value.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" + + o = torch.empty_like(query, dtype=value.dtype) + + stage = 3 if is_causal else 1 + + if scale is None: + sm_scale = 1.0 / (HEAD_DIM_K**0.5) + else: + sm_scale = scale + + kv_head_num = key.shape[1] + + grid = lambda args: ( + triton.cdiv(query.shape[2], args["BLOCK_M"]), + query.shape[0] * query.shape[1], + 1, + ) + + if attn_mask is not None: + HAS_ATTN_MASK = True + stride_attn_mask_batch = attn_mask.stride(0) + stride_attn_mask_head = attn_mask.stride(1) + stride_attn_mask_q_seqlen = attn_mask.stride(2) + stride_attn_mask_kv_seqlen = attn_mask.stride(3) + else: + HAS_ATTN_MASK = False + stride_attn_mask_batch = 1 + stride_attn_mask_head = 1 + stride_attn_mask_q_seqlen = 1 + stride_attn_mask_kv_seqlen = 1 + + with torch_device_fn.device(query.device): + _attn_fwd[grid]( + query, + key, + value, + attn_mask, + sm_scale, + o, # + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), # + key.stride(0), + key.stride(1), + key.stride(2), + key.stride(3), # + value.stride(0), + value.stride(1), + value.stride(2), + value.stride(3), # + stride_attn_mask_batch, + stride_attn_mask_head, + stride_attn_mask_q_seqlen, + stride_attn_mask_kv_seqlen, # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + query.shape[0], + query.shape[1], + kv_head_num, # + query.shape[2], # + key.shape[2], # + HEAD_DIM_K, # + STAGE=stage, # + HAS_ATTN_MASK=HAS_ATTN_MASK, # + ) + return o + + + @triton.jit def philox_offset_one_warp(b, h, nh: tl.constexpr): # To align with TriDao's implementation, philox_offset linearly determined by @@ -315,12 +405,12 @@ def philox_offset_one_warp(b, h, nh: tl.constexpr): @triton.jit -def to_lohi(x): +def u64_to_lohi(x): return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) @triton.jit -def from_lohi(lo, hi): +def u64_from_lohi(lo, hi): return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) @@ -328,9 +418,9 @@ def from_lohi(lo, hi): def philox_(seed, subsequence, offset): kPhilox10A: tl.constexpr = 0x9E3779B9 kPhilox10B: tl.constexpr = 0xBB67AE85 - k0, k1 = to_lohi(seed.to(tl.uint64)) - c0, c1 = to_lohi(offset.to(tl.uint64)) - c2, c3 = to_lohi(subsequence(tl.uint64)) + k0, k1 = u64_to_lohi(seed.to(tl.uint64)) + c0, c1 = u64_to_lohi(offset.to(tl.uint64)) + c2, c3 = u64_to_lohi(subsequence(tl.uint64)) # pragma unroll kPhiloxSA: tl.constexpr = 0xD2511F53 @@ -338,16 +428,16 @@ def philox_(seed, subsequence, offset): for _ in range(6): res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) - res0_x, res0_y = to_lohi(res0) - res1_x, res1_y = to_lohi(res1) + res0_x, res0_y = u64_to_lohi(res0) + res1_x, res1_y = u64_to_lohi(res1) c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x k0 += kPhilox10A k1 += kPhilox10B res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) - res0_x.res0_y = to_lohi(res0) - res1_x.res1_y = to_lohi(res1) + res0_x.res0_y = u64_to_lohi(res0) + res1_x.res1_y = u64_to_lohi(res1) c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x return c0, c1, c2, c3 @@ -367,9 +457,9 @@ def apply_mask( @triton.jit -def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr): - r = uint32_r - p = uint8_p +def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr): + r = r_u32 + p = p_u8 m0 = tl.where(r & 0xFF < p, 0, 1) r >>= 8 m1 = tl.where(r & 0xFF < p, 0, 1) @@ -397,12 +487,15 @@ def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr): ) def apply_dropout( P, - row_start, - col_start, + sor, + soc, philox_seed, philox_offset, p_dropout_uint8: tl.constexpr, encode_dropout_in_sign_bit: tl.constexpr, + bid: tl.constexpr, + hid: tl.constexpr, + nheads: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -410,12 +503,13 @@ def apply_dropout( # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32 M: tl.constexpr = BLOCK_M // 16 N: tl.constexpr = BLOCK_N // 32 - row = row_start + tl.arange(0, M)[:, None] - col = col_start + tl.arange(0, BLOCK_N)[None, :] // 32 + row = sor + tl.arange(0, M)[:, None] + col = soc + tl.arange(0, BLOCK_N)[None, :] // 32 - philox_offset = philox_offset + tl.arange(0, BLOCK_N)[None, :] % 32 + tid = tl.arange(0, BLOCK_N)[None, :] % 32 + philox_offset += (bid * nheads + hid) * 32 + tid - subsequence = from_lohi(row * 32, col) + subsequence = u64_from_lohi(row * 32, col) r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset) # Fully unrolled due to triton's inability to concat 2d tensor @@ -432,7 +526,7 @@ def apply_dropout( return P -@triton.jit +@triton.jit(do_not_specialize=[]) def flash_fwd_kernel( pQ, pK, @@ -443,70 +537,17 @@ def flash_fwd_kernel( philox_seed, philox_offset, pdrop_int8, - drop: tl.constexpr, - causal: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, scale: tl.constexp, - wl: tl.constexpr, - wr: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, return_P: tl.constexpr, + is_local: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + num_warps: tl.constexpr, ): - pass - - -def flash_attention_forward( - query, - key, - value, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p, - is_causal, - return_debug_mask, - *, - scale=None, - window_size_left=None, - window_size_right=None, - seqused_k=None, - alibi_slopes=None -): - logging.debug("GEMS FLASH_ATTENTION") - assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." - - HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] - HEAD_DIM_V = value.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - - softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) - non_null_window_left = window_size_left or -1 - non_null_window_right = window_size_right or -1 - out = torch.empty_like(query, dtype=value.dtype) - - mha_out = mha_fwd( - query, - key, - value, - out, - alibi_slopes, - dropout_p, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_debug_mask, - ) - ( - output, - q_padded, - k_padded, - v_padded, - logsumexp, - philox_seed, - philox_offset, - debug_attn_mask, - ) = mha_out def mha_fwd( @@ -562,113 +603,125 @@ def mha_fwd( seqlen_q_rounded = round_multiple(seqlen_q, 128) seqlen_k_rounded = round_multiple(seqlen_k, 128) - lse = torch.empty( - (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device - ) - if return_softmax: - assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0" - p = torch.empty( - (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), - dtype=q_dtype, - device=q_device, - ) - - if p_dropout > 0: - increment = triton.cdiv(batch_size * num_heads * 32) - philox_seed, philox_offset = update_philox_state(increment) + with torch_device_fn.device(q_device): + # Set softmax params + lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float) + if return_softmax: + assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0" + p = torch.empty( + (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), + dtype=q_dtype, + ) - with torch_device_fn.device(q.device): + # Set dropout params + if p_dropout > 0: + increment = triton.cdiv(batch_size * num_heads * 32) + philox_seed, philox_offset = update_philox_state(increment) + is_dropout = True + else: + is_dropout = False + + p_dropout = 1 - p_dropout + pdrop_u8 = math.floor(p_dropout * 255.0) + + # Set alibi params + if alibi_slopes is not None: + assert alibi_slopes.device == q_device + assert alibi_slopes.dtype in (torch.float, ) + assert alibi_slopes.stride(-1) == 1 + assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads) + alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 + + # Set SWA params + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + # ONLY EVEN_K IS SUPPORTED + assert head_size == head_size_rounded + grid = lambda args: ( - triton.cdiv(seqlen_q, args["BLOCK_M"]), - batch_size * num_heads, - 1, + triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks + batch_size, + num_heads, ) + flash_fwd_kernel[grid]( + q, + k, + v, + p, + out, + alibi_slopes, + philox_seed, + philox_offset, + pdrop_u8, + is_dropout, + is_causal, + softmax_scale, + window_size_left, + window_size_right, + return_softmax, + is_local, + BLOCK_M, + BLOCK_N, + num_warps + ) + -def scaled_dot_product_attention( + +def flash_attention_forward( query, key, value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + return_debug_mask, + *, scale=None, - enable_gqa=False, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None ): - logging.debug("GEMS SCALED DOT PRODUCT ATTENTION") - # shape constraints + logging.debug("GEMS FLASH_ATTENTION") + assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." + HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] - # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = value.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} - assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" - - o = torch.empty_like(query, dtype=value.dtype) - - stage = 3 if is_causal else 1 - - if scale is None: - sm_scale = 1.0 / (HEAD_DIM_K**0.5) - else: - sm_scale = scale - kv_head_num = key.shape[1] + softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) + non_null_window_left = window_size_left or -1 + non_null_window_right = window_size_right or -1 + out = torch.empty_like(query, dtype=value.dtype) - grid = lambda args: ( - triton.cdiv(query.shape[2], args["BLOCK_M"]), - query.shape[0] * query.shape[1], - 1, + mha_out = mha_fwd( + query, + key, + value, + out, + alibi_slopes, + dropout_p, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_debug_mask, ) + ( + output, + q_padded, + k_padded, + v_padded, + logsumexp, + philox_seed, + philox_offset, + debug_attn_mask, + ) = mha_out + - if attn_mask is not None: - HAS_ATTN_MASK = True - stride_attn_mask_batch = attn_mask.stride(0) - stride_attn_mask_head = attn_mask.stride(1) - stride_attn_mask_q_seqlen = attn_mask.stride(2) - stride_attn_mask_kv_seqlen = attn_mask.stride(3) - else: - HAS_ATTN_MASK = False - stride_attn_mask_batch = 1 - stride_attn_mask_head = 1 - stride_attn_mask_q_seqlen = 1 - stride_attn_mask_kv_seqlen = 1 - with torch_device_fn.device(query.device): - _attn_fwd[grid]( - query, - key, - value, - attn_mask, - sm_scale, - o, # - query.stride(0), - query.stride(1), - query.stride(2), - query.stride(3), # - key.stride(0), - key.stride(1), - key.stride(2), - key.stride(3), # - value.stride(0), - value.stride(1), - value.stride(2), - value.stride(3), # - stride_attn_mask_batch, - stride_attn_mask_head, - stride_attn_mask_q_seqlen, - stride_attn_mask_kv_seqlen, # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - query.shape[0], - query.shape[1], - kv_head_num, # - query.shape[2], # - key.shape[2], # - HEAD_DIM_K, # - STAGE=stage, # - HAS_ATTN_MASK=HAS_ATTN_MASK, # - ) - return o From b041f6b0edb0264f8a230abcca8ad6b880cbed6f Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Mon, 3 Mar 2025 09:08:53 +0000 Subject: [PATCH 03/25] update masking. --- src/flag_gems/ops/attention.py | 88 +++++++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 11 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 1804ddcf2..364caa832 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -444,7 +444,7 @@ def philox_(seed, subsequence, offset): @triton.jit -def apply_mask( +def apply_dropout_mask( P, mask, encode_dropout_in_sign_bit: tl.constexpr, @@ -522,10 +522,43 @@ def apply_dropout( m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N) m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N) - P = apply_mask(P, m) + P = apply_dropout_mask(P, m) return P +@triton.jit +def apply_mask( + P, + col_idx, + row_idx, + warp_row_stride, + max_seqlen_q, + max_seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, +): + if has_alibi or is_causal or is_local: + col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) + col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right) + + if not has_alibi: + alibi_slope = .0 + + P -= alibi_slope * tl.abs(col_idx - row_idx) + + if is_causal: + P = tl.where(col_idx >= col_rb, float('-inf'), P) + + if is_local: + P = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), P) + + return P + + @triton.jit(do_not_specialize=[]) def flash_fwd_kernel( pQ, @@ -533,21 +566,46 @@ def flash_fwd_kernel( pV, pP, pO, + seqlen_q, + seqlen_k, pSlopes, philox_seed, philox_offset, pdrop_int8, + slopes_batch_stride, is_dropout: tl.constexpr, is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, scale: tl.constexp, ws_left: tl.constexpr, ws_right: tl.constexpr, return_P: tl.constexpr, - is_local: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, num_warps: tl.constexpr, ): + m_block = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + + if is_local: + n_block_min: tl.constexpr = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) + else: + n_block_min: tl.constexpr = 0 + + n_block_max = tl.cdiv(seqlen_k, BLOCK_N) + + if is_causal or is_local: + n_block_max = min(n_block_max, + tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, BLOCK_N)) + + if has_alibi: + alibi_offset = bid * slopes_batch_stride + hid + alibi_slope = tl.load(pSlopes + alibi_offset) + alibi_slope /= scale + + def mha_fwd( @@ -631,7 +689,11 @@ def mha_fwd( assert alibi_slopes.stride(-1) == 1 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads) alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 - + has_alibi = True + else: + alibi_slopes_batch_stride = 0 + has_alibi = False + # Set SWA params is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal @@ -650,17 +712,21 @@ def mha_fwd( v, p, out, + seqlen_q, + seqlen_k, alibi_slopes, philox_seed, philox_offset, pdrop_u8, - is_dropout, - is_causal, - softmax_scale, - window_size_left, - window_size_right, - return_softmax, - is_local, + alibi_slopes_batch_stride, + is_dropout=is_dropout, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi, + scale=softmax_scale, + ws_left=window_size_left, + ws_right=window_size_right, + return_P=return_softmax, BLOCK_M, BLOCK_N, num_warps From 7ab8dab042203660393ce41ca1c542ce269bebff Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Wed, 5 Mar 2025 01:31:14 +0000 Subject: [PATCH 04/25] done all masking steps. --- src/flag_gems/ops/attention.py | 192 +++++++++++++++++++++++++++++---- 1 file changed, 173 insertions(+), 19 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 364caa832..28b8bad26 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -395,6 +395,7 @@ def scaled_dot_product_attention( return o +# The following implementation is a fundamentally a triton rewrite of TriDao's Flash Attention in Cuda. @triton.jit def philox_offset_one_warp(b, h, nh: tl.constexpr): @@ -489,13 +490,13 @@ def apply_dropout( P, sor, soc, + bid, + hid, philox_seed, philox_offset, p_dropout_uint8: tl.constexpr, encode_dropout_in_sign_bit: tl.constexpr, - bid: tl.constexpr, - hid: tl.constexpr, - nheads: tl.constexpr, + NUM_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -507,7 +508,7 @@ def apply_dropout( col = soc + tl.arange(0, BLOCK_N)[None, :] // 32 tid = tl.arange(0, BLOCK_N)[None, :] % 32 - philox_offset += (bid * nheads + hid) * 32 + tid + philox_offset += (bid * NUM_HEADS + hid) * 32 + tid subsequence = u64_from_lohi(row * 32, col) r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset) @@ -526,12 +527,11 @@ def apply_dropout( return P -@triton.jit +@triton.jit(do_not_specialize=['max_seqlen_q', 'max_seqlen_k']) def apply_mask( - P, + S, col_idx, row_idx, - warp_row_stride, max_seqlen_q, max_seqlen_k, ws_left, @@ -547,17 +547,36 @@ def apply_mask( if not has_alibi: alibi_slope = .0 - - P -= alibi_slope * tl.abs(col_idx - row_idx) - + + S -= alibi_slope * tl.abs(col_idx - row_idx) + if is_causal: - P = tl.where(col_idx >= col_rb, float('-inf'), P) - + S = tl.where(col_idx >= col_rb, float('-inf'), S) + if is_local: - P = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), P) + S = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), S) + + return S + + +@triton.jit +def softmax_rescale( + O_acc, + S, + row_max, + row_sum, + softmax_scale_log2: tl.constexpr +): + prev_row_max = row_max + row_max = tl.maximum(row_max, tl.max(S, 1)) + row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2 + row_sum *= row_sum_scale + O_acc *= row_sum_scale[:, None] + max_scaled = tl.where(rowmax == float('-inf'), 0, rowmax * softmax_scale_log2) + P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None]) + row_sum = row_sum + tl.sum(exp_S, 1) + return O_acc, P, row_max, row_sum - return P - @triton.jit(do_not_specialize=[]) def flash_fwd_kernel( @@ -568,19 +587,35 @@ def flash_fwd_kernel( pO, seqlen_q, seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, pSlopes, philox_seed, philox_offset, - pdrop_int8, + pdrop_u8, slopes_batch_stride, is_dropout: tl.constexpr, is_causal: tl.constexpr, is_local: tl.constexpr, has_alibi: tl.constexpr, - scale: tl.constexp, + softmax_scale: tl.constexp, + softmax_scale_log2: tl.constexpr, ws_left: tl.constexpr, ws_right: tl.constexpr, return_P: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + HEADSIZE: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, num_warps: tl.constexpr, @@ -605,7 +640,107 @@ def flash_fwd_kernel( alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale - + if (not is_causal) and (not is_local): + n_masking_steps = 1 + elif is_causal: + n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + else: + # local and not causal, + n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + + + Q_ptr += bid * q_b_stride + Q_ptr += hid * q_h_stride + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEADSIZE)[None, :] + qmask = row_idx < seqlen_q + Q = tl.load(pQ + Q_off, mask=qmask) + + # Start from the right most block + n_block = n_block_max - 1 + + h_hk_ratio = h // hk + K_ptr += bid * k_b_stride + K_ptr += (hid // h_hk_ratio) * k_h_stride + V_ptr += bid * k_b_stride + V_ptr += (hid // h_hk_ratio) * k_h_stride + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] + + P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded + P_ptr += n_block * BLOCK_N + P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) + + O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + for _ in range(n_masking_steps): + kvmask = col_idx < seqlen_k + K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg") + V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") + S = tl.dot(Q_block, K_block) + KV_offset += BLOCK_N * k_s_stride + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_causal, + is_local, + has_alibi + ) + + O_acc, P, rowmax_, rowsum_ = softmax_rescale(O_, S, rowmax_, rowsum_, softmax_scale_log2) + P = P.to(pO.type.element_ty) + + row_start = m_block * (BLOCK_M // 16) + col_start = n_block * (BLOCK_N // 32) + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :]) + P_offset += BLOCK_N + + if is_dropout: + P = apply_dropout( + P, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=False, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + tl.dot(P, V, acc=O_) + + if n_masking_steps > 1 and n_block <= n_block_min: + n_block -= 1 + break + + def mha_fwd( @@ -700,6 +835,9 @@ def mha_fwd( # ONLY EVEN_K IS SUPPORTED assert head_size == head_size_rounded + M_LOG2E = 1.4426950408889634074 + scale_softmax_log2 = softmax_scale * M_LOG2E + grid = lambda args: ( triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks batch_size, @@ -714,6 +852,17 @@ def mha_fwd( out, seqlen_q, seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q.stride(0), + q.stride(-3), + q.stride(-2), + k.stride(0), + k.stride(-3), + k.stride(-2), + out.stride(0), + out.stride(-3), + out.stride(-2), alibi_slopes, philox_seed, philox_offset, @@ -723,10 +872,15 @@ def mha_fwd( is_causal=is_causal, is_local=is_local, has_alibi=has_alibi, - scale=softmax_scale, + softmax_scale=softmax_scale, + softmax_scale_log2=scale_softmax_log2 ws_left=window_size_left, ws_right=window_size_right, return_P=return_softmax, + BATCH_SIZE=batch_size, + NUM_HEADS=num_heads, + NUM_HEADS_K=num_heads_k, + HEADSIZE=head_size_rounded, BLOCK_M, BLOCK_N, num_warps From 4909d10e058657a639d25c6c98f5c1bbb60d2a77 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Wed, 5 Mar 2025 10:12:25 +0000 Subject: [PATCH 05/25] fwd kernel almost done. --- src/flag_gems/ops/attention.py | 135 ++++++++++++++++++++++++++++----- 1 file changed, 117 insertions(+), 18 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 28b8bad26..780ff12de 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -537,11 +537,12 @@ def apply_mask( ws_left, ws_right, alibi_slope, + is_even_mn: tl.constexpr, is_causal: tl.constexpr, is_local: tl.constexpr, has_alibi: tl.constexpr, ): - if has_alibi or is_causal or is_local: + if has_alibi or is_causal or is_local or not is_even_mn: col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right) @@ -551,10 +552,13 @@ def apply_mask( S -= alibi_slope * tl.abs(col_idx - row_idx) if is_causal: - S = tl.where(col_idx >= col_rb, float('-inf'), S) + S = tl.where(col_idx[None, :] >= col_rb[None, :], float('-inf'), S) if is_local: - S = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), S) + S = tl.where(col_idx[None, :] >= col_rb[None, :] | col_idx[None, :] < col_lb[None, :], float('-inf'), S) + + if (not local) and (not is_causal) and (not is_even_mn): + S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S) return S @@ -565,7 +569,8 @@ def softmax_rescale( S, row_max, row_sum, - softmax_scale_log2: tl.constexpr + softmax_scale_log2: tl.constexpr, + is_border: tl.constexpr ): prev_row_max = row_max row_max = tl.maximum(row_max, tl.max(S, 1)) @@ -578,6 +583,11 @@ def softmax_rescale( return O_acc, P, row_max, row_sum +@triton.heuristics( + values={ + 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) + } +) @triton.jit(do_not_specialize=[]) def flash_fwd_kernel( pQ, @@ -616,9 +626,11 @@ def flash_fwd_kernel( NUM_HEADS: tl.constexpr, NUM_HEADS_K: tl.constexpr, HEADSIZE: tl.constexpr, + IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, num_warps: tl.constexpr, + num_stages: tl.constexpr ): m_block = tl.program_id(0) bid = tl.program_id(1) @@ -641,8 +653,11 @@ def flash_fwd_kernel( alibi_slope /= scale if (not is_causal) and (not is_local): - n_masking_steps = 1 - elif is_causal: + if IS_EVEN_MN: + n_masking_steps = 0 + else: + n_masking_steps = 1 + elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) else: # local and not causal, @@ -665,7 +680,6 @@ def flash_fwd_kernel( V_ptr += bid * k_b_stride V_ptr += (hid // h_hk_ratio) * k_h_stride col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded P_ptr += n_block * BLOCK_N @@ -674,12 +688,17 @@ def flash_fwd_kernel( O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + for _ in range(n_masking_steps): - kvmask = col_idx < seqlen_k - K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg") - V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") + KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] + if IS_EVEN_MN: + K = tl.load(pK + KV_offset, cache_modifier="cg") + V = tl.load(pV + KV_offset, cache_modifier="cg") + else: + kvmask = col_idx < seqlen_k + K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg") + V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") S = tl.dot(Q_block, K_block) - KV_offset += BLOCK_N * k_s_stride S = apply_mask( S, col_idx, @@ -689,12 +708,21 @@ def flash_fwd_kernel( ws_left, ws_right, alibi_slope, - is_causal, - is_local, - has_alibi + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi ) + col_idx -= BLOCK_N - O_acc, P, rowmax_, rowsum_ = softmax_rescale(O_, S, rowmax_, rowsum_, softmax_scale_log2) + O_acc, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2=softmax_scale_log2, + is_border=(is_causal or is_local) + ) P = P.to(pO.type.element_ty) row_start = m_block * (BLOCK_M // 16) @@ -736,11 +764,81 @@ def flash_fwd_kernel( tl.dot(P, V, acc=O_) + n_block -= 1 if n_masking_steps > 1 and n_block <= n_block_min: - n_block -= 1 break + for _ in range(n_block_min, n_block + 1): + KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] + K = tl.load(pK + KV_offset, cache_modifier="cg") + V = tl.load(pV + KV_offset, cache_modifier="cg") + S = tl.dot(Q_block, K_block) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_even_mn=True, + is_causal=False, + is_local=is_local, + has_alibi=has_alibi + ) + col_idx -= BLOCK_N + O_acc, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2=softmax_scale_log2, + is_border=is_local + ) + P = P.to(pO.type.element_ty) + + row_start = m_block * (BLOCK_M // 16) + col_start = n_block * (BLOCK_N // 32) + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :]) + P_offset += BLOCK_N + + if is_dropout: + P = apply_dropout( + P, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=False, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + tl.dot(P, V, acc=O_) + + n_block -=1 def mha_fwd( @@ -881,8 +979,9 @@ def mha_fwd( NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, HEADSIZE=head_size_rounded, - BLOCK_M, - BLOCK_N, + IS_EVEN_MN=is_even_mn, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, num_warps ) From 14e5d4c0eab25ed126d2ca44b5613bb0aa92b57d Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 6 Mar 2025 04:46:51 +0000 Subject: [PATCH 06/25] fwd kernel done. --- src/flag_gems/ops/attention.py | 134 +++++++++++++++++++++++---------- 1 file changed, 96 insertions(+), 38 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 780ff12de..edd3cc31b 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -583,6 +583,19 @@ def softmax_rescale( return O_acc, P, row_max, row_sum +@triton.jit +def + + +@triton.autotune( + configs=runtime.get_tuned_config("attention"), + key=["HEAD_DIM"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": None, + "top_k": 1.0, + }, +) @triton.heuristics( values={ 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) @@ -590,11 +603,12 @@ def softmax_rescale( ) @triton.jit(do_not_specialize=[]) def flash_fwd_kernel( - pQ, - pK, - pV, - pP, - pO, + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -612,6 +626,7 @@ def flash_fwd_kernel( philox_seed, philox_offset, pdrop_u8, + rpdrop, slopes_batch_stride, is_dropout: tl.constexpr, is_causal: tl.constexpr, @@ -622,10 +637,11 @@ def flash_fwd_kernel( ws_left: tl.constexpr, ws_right: tl.constexpr, return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, BATCH_SIZE: tl.constexpr, NUM_HEADS: tl.constexpr, NUM_HEADS_K: tl.constexpr, - HEADSIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -690,14 +706,16 @@ def flash_fwd_kernel( rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) for _ in range(n_masking_steps): - KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] + KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] if IS_EVEN_MN: K = tl.load(pK + KV_offset, cache_modifier="cg") - V = tl.load(pV + KV_offset, cache_modifier="cg") + if PRE_LOAD_V: + V = tl.load(pV + KV_offset, cache_modifier="cg") else: kvmask = col_idx < seqlen_k K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg") - V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") + if PRE_LOAD_V: + V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") S = tl.dot(Q_block, K_block) S = apply_mask( S, @@ -762,14 +780,19 @@ def flash_fwd_kernel( BLOCK_N=BLOCK_N, ) + if not PRE_LOAD_V: + if IS_EVEN_MN: + V = tl.load(pV + KV_offset, cache_modifier="cg") + else: + V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") tl.dot(P, V, acc=O_) n_block -= 1 if n_masking_steps > 1 and n_block <= n_block_min: break - for _ in range(n_block_min, n_block + 1): - KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None] + for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages): + KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] K = tl.load(pK + KV_offset, cache_modifier="cg") V = tl.load(pV + KV_offset, cache_modifier="cg") S = tl.dot(Q_block, K_block) @@ -837,8 +860,34 @@ def flash_fwd_kernel( ) tl.dot(P, V, acc=O_) - - n_block -=1 + + # Final LSE + lse = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, float('inf'), rowmax_ * softmax_scale +tl.log(rowsum_)) + inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_) + + # Rescale output + if is_dropout + O_ *= inv_sum * rpdrop + else: + O_ *= inv_sum + + O = O_.to(pO.type.element_ty) + + # Write back output + O_ptr += bid * o_b_stride + O_ptr += hid * o_h_stride + O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM) + if IS_EVEN_MN: + tl.store(O_ptr + O_offset, O) + else: + tl.store(O_ptr + O_offset, O, mask=qmask) + + # Write back lse + lse_ptr += bid * hid * seqlen_q + if IS_EVEN_MN: + tl.store(lse_ptr + row_idx, lse) + else: + tl.store(lse_ptr + row_idx, lse, mask=qmask) def mha_fwd( @@ -882,6 +931,17 @@ def mha_fwd( if is_causal: window_size_right = 0 + if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes + swap_seq_and_group = True + else: + swap_seq_and_group = False + + ngroups = num_heads // num_heads_k + if swap_seq_and_group: + q = q.reshape((batch_size, num_heads_k, ngroups, head_size)).transpose(1, 2) + seqlen_q = ngroups + num_heads = num_heads_k + if out: assert out.stride(-1) == 1 assert out.dtype == q.dtype @@ -903,6 +963,8 @@ def mha_fwd( (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q_dtype, ) + else: + p = torch.empty(()) # Set dropout params if p_dropout > 0: @@ -914,6 +976,11 @@ def mha_fwd( p_dropout = 1 - p_dropout pdrop_u8 = math.floor(p_dropout * 255.0) + rpdrop = 1. / p_dropout + + M_LOG2E = 1.4426950408889634074 + scale_softmax_log2 = softmax_scale * M_LOG2E + scale_softmax_rp_dropout = rpdrop * softmax_scale # Set alibi params if alibi_slopes is not None: @@ -932,9 +999,7 @@ def mha_fwd( # ONLY EVEN_K IS SUPPORTED assert head_size == head_size_rounded - - M_LOG2E = 1.4426950408889634074 - scale_softmax_log2 = softmax_scale * M_LOG2E + grid = lambda args: ( triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks @@ -948,6 +1013,7 @@ def mha_fwd( v, p, out, + lse, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -965,26 +1031,30 @@ def mha_fwd( philox_seed, philox_offset, pdrop_u8, + rpdrop, alibi_slopes_batch_stride, is_dropout=is_dropout, is_causal=is_causal, is_local=is_local, has_alibi=has_alibi, softmax_scale=softmax_scale, - softmax_scale_log2=scale_softmax_log2 + softmax_scale_log2=scale_softmax_log2, ws_left=window_size_left, ws_right=window_size_right, return_P=return_softmax, BATCH_SIZE=batch_size, NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, - HEADSIZE=head_size_rounded, + HEAD_DIM=head_size_rounded, IS_EVEN_MN=is_even_mn, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps ) - + + if swap_seq_and_group: + out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) + q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) + lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) + + return out, q, k, v, lse, philox_seed, philox_offset, p def flash_attention_forward( @@ -1016,13 +1086,12 @@ def flash_attention_forward( softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) non_null_window_left = window_size_left or -1 non_null_window_right = window_size_right or -1 - out = torch.empty_like(query, dtype=value.dtype) - mha_out = mha_fwd( + out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( query, key, value, - out, + None, alibi_slopes, dropout_p, softmax_scale, @@ -1031,16 +1100,5 @@ def flash_attention_forward( non_null_window_right, return_debug_mask, ) - ( - output, - q_padded, - k_padded, - v_padded, - logsumexp, - philox_seed, - philox_offset, - debug_attn_mask, - ) = mha_out - - - + + return (out, lse, philox_seed, philox_offset, p) From 62c92f2faa3a4bb2c6f4ce156acd019b84400956 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 6 Mar 2025 05:05:45 +0000 Subject: [PATCH 07/25] fix syntax errors. --- src/flag_gems/ops/attention.py | 10 +++------- src/flag_gems/ops/exponential_.py | 4 ++-- src/flag_gems/ops/multinomial.py | 4 ++-- src/flag_gems/ops/normal.py | 4 ++-- src/flag_gems/ops/rand.py | 4 ++-- src/flag_gems/ops/rand_like.py | 4 ++-- src/flag_gems/ops/randn.py | 4 ++-- src/flag_gems/ops/randn_like.py | 4 ++-- src/flag_gems/ops/randperm.py | 4 ++-- src/flag_gems/ops/uniform.py | 4 ++-- .../runtime/backend/_metax/ops/exponential_.py | 4 ++-- 11 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index edd3cc31b..9c13485d6 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -583,10 +583,6 @@ def softmax_rescale( return O_acc, P, row_max, row_sum -@triton.jit -def - - @triton.autotune( configs=runtime.get_tuned_config("attention"), key=["HEAD_DIM"], @@ -632,7 +628,7 @@ def flash_fwd_kernel( is_causal: tl.constexpr, is_local: tl.constexpr, has_alibi: tl.constexpr, - softmax_scale: tl.constexp, + softmax_scale: tl.constexpr, softmax_scale_log2: tl.constexpr, ws_left: tl.constexpr, ws_right: tl.constexpr, @@ -866,7 +862,7 @@ def flash_fwd_kernel( inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_) # Rescale output - if is_dropout + if is_dropout: O_ *= inv_sum * rpdrop else: O_ *= inv_sum @@ -931,7 +927,7 @@ def mha_fwd( if is_causal: window_size_right = 0 - if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes + if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes: swap_seq_and_group = True else: swap_seq_and_group = False diff --git a/src/flag_gems/ops/exponential_.py b/src/flag_gems/ops/exponential_.py index 8d36ce7eb..0ab5796ee 100644 --- a/src/flag_gems/ops/exponential_.py +++ b/src/flag_gems/ops/exponential_.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, + update_philox_state, uint_to_uniform_float, ) @@ -94,7 +94,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) eps = torch.finfo(dtype).eps x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) with torch_device_fn.device(device): diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py index d1f061cda..8824426cd 100644 --- a/src/flag_gems/ops/multinomial.py +++ b/src/flag_gems/ops/multinomial.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils import libentry -from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform +from flag_gems.utils.random_utils import update_philox_state, uniform @libentry() @@ -84,7 +84,7 @@ def multinomial(prob, n_samples, with_replacement=False, *, gen=None): # The CTA level parallelism is framed in a 2d grid of blocks with grid.y # indexing into distributions and grid.x output sample batches increment = n_dist * n_samples - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist) multinomial_with_replacement[grid]( cum_prob, out, n_categories, n_samples, philox_seed, philox_offset diff --git a/src/flag_gems/ops/normal.py b/src/flag_gems/ops/normal.py index b24b4398e..d35bbba8d 100644 --- a/src/flag_gems/ops/normal.py +++ b/src/flag_gems/ops/normal.py @@ -5,7 +5,7 @@ from ..runtime import torch_device_fn from ..utils import pointwise_dynamic -from ..utils.random_utils import philox_backend_seed_offset +from ..utils.random_utils import update_philox_state from ..utils.shape_utils import broadcast_shapes, volume from .randn import randn_kernel @@ -50,7 +50,7 @@ def normal_distribution(shape, device, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/rand.py b/src/flag_gems/ops/rand.py index 9cab927ac..3dc2a51de 100644 --- a/src/flag_gems/ops/rand.py +++ b/src/flag_gems/ops/rand.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, + update_philox_state, uint_to_uniform_float, ) from flag_gems.utils.shape_utils import volume @@ -63,7 +63,7 @@ def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/rand_like.py b/src/flag_gems/ops/rand_like.py index 85338f595..6b5b9b706 100644 --- a/src/flag_gems/ops/rand_like.py +++ b/src/flag_gems/ops/rand_like.py @@ -4,7 +4,7 @@ import triton from flag_gems.ops.rand import rand_kernel -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from ..runtime import torch_device_fn @@ -25,7 +25,7 @@ def rand_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(x.device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn.py b/src/flag_gems/ops/randn.py index 3ee2932e5..581247a73 100644 --- a/src/flag_gems/ops/randn.py +++ b/src/flag_gems/ops/randn.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, + update_philox_state, uint_to_uniform_float, ) from flag_gems.utils.shape_utils import volume @@ -77,7 +77,7 @@ def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn_like.py b/src/flag_gems/ops/randn_like.py index 0458328dc..b8c87b7df 100644 --- a/src/flag_gems/ops/randn_like.py +++ b/src/flag_gems/ops/randn_like.py @@ -4,7 +4,7 @@ import triton from flag_gems.ops.randn import randn_kernel -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from ..runtime import torch_device_fn @@ -25,7 +25,7 @@ def randn_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(x.device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py index 9757a1adc..1e7a07bba 100644 --- a/src/flag_gems/ops/randperm.py +++ b/src/flag_gems/ops/randperm.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_backend_seed_offset +from flag_gems.utils.random_utils import update_philox_state from .. import runtime from ..runtime import device, torch_device_fn @@ -384,7 +384,7 @@ def sort_by_key(key, value, valid_bits): # last step, shuffle inner-block data BLOCK_SIZE_SHUFFLE = 512 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) - philox_seed, philox_offset = philox_backend_seed_offset(n_elements) + philox_seed, philox_offset = update_philox_state(n_elements) with torch_device_fn.device(key.device): duplicate_keys_shuffle_kernel[grid_shuffle]( v_out, diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py index 114a552bf..8b2d1f8f5 100644 --- a/src/flag_gems/ops/uniform.py +++ b/src/flag_gems/ops/uniform.py @@ -4,7 +4,7 @@ import triton.language as tl from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, + update_philox_state, uint_to_uniform_float, ) from flag_gems.utils.shape_utils import volume @@ -55,7 +55,7 @@ def uniform_(self, from_=0.0, to=1.0, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) with torch_device_fn.device(self.device): uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to) return self diff --git a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py index 76be32499..899b3e49e 100644 --- a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py +++ b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py @@ -6,7 +6,7 @@ from flag_gems.runtime import torch_device_fn from flag_gems.utils.random_utils import ( - philox_backend_seed_offset, + update_philox_state, uint_to_uniform_float, ) @@ -238,7 +238,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_backend_seed_offset(increment) + philox_seed, philox_offset = update_philox_state(increment) eps = torch.finfo(dtype).eps x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) type_index = lst.index(dtype) From af83c2f5e4e2550f40d2299955c0cff493648cb2 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 6 Mar 2025 11:55:59 +0000 Subject: [PATCH 08/25] rowmax inf needs to be handled. --- src/flag_gems/ops/attention.py | 268 ++++++++++++++++++--------------- 1 file changed, 148 insertions(+), 120 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 9c13485d6..78789162a 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -1,4 +1,5 @@ import logging +import math import torch import triton @@ -542,22 +543,27 @@ def apply_mask( is_local: tl.constexpr, has_alibi: tl.constexpr, ): - if has_alibi or is_causal or is_local or not is_even_mn: + # need_mask = has_alibi or is_causal + # need_mask |= is_local + # need_mask |= not is_even_mn + need_mask: tl.constexpr = has_alibi | is_local | (not is_even_mn) + if need_mask: col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right) if not has_alibi: alibi_slope = .0 - S -= alibi_slope * tl.abs(col_idx - row_idx) + S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) if is_causal: - S = tl.where(col_idx[None, :] >= col_rb[None, :], float('-inf'), S) + S = tl.where(col_idx[None, :] >= col_rb[:, None], float('-inf'), S) if is_local: - S = tl.where(col_idx[None, :] >= col_rb[None, :] | col_idx[None, :] < col_lb[None, :], float('-inf'), S) + S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S) - if (not local) and (not is_causal) and (not is_even_mn): + + if (not local) & (not is_causal) & (not is_even_mn): S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S) return S @@ -575,11 +581,12 @@ def softmax_rescale( prev_row_max = row_max row_max = tl.maximum(row_max, tl.max(S, 1)) row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2 + tl.device_print('row_sum_scale', row_max - prev_row_max) row_sum *= row_sum_scale O_acc *= row_sum_scale[:, None] - max_scaled = tl.where(rowmax == float('-inf'), 0, rowmax * softmax_scale_log2) + max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2) P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None]) - row_sum = row_sum + tl.sum(exp_S, 1) + row_sum = row_sum + tl.sum(S, 1) return O_acc, P, row_max, row_sum @@ -597,7 +604,7 @@ def softmax_rescale( 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) } ) -@triton.jit(do_not_specialize=[]) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) def flash_fwd_kernel( Q_ptr, K_ptr, @@ -618,12 +625,15 @@ def flash_fwd_kernel( o_b_stride, o_s_stride, o_h_stride, + h, + hk, pSlopes, philox_seed, philox_offset, pdrop_u8, rpdrop, slopes_batch_stride, + HEAD_DIM: tl.constexpr, is_dropout: tl.constexpr, is_causal: tl.constexpr, is_local: tl.constexpr, @@ -637,7 +647,6 @@ def flash_fwd_kernel( BATCH_SIZE: tl.constexpr, NUM_HEADS: tl.constexpr, NUM_HEADS_K: tl.constexpr, - HEAD_DIM: tl.constexpr, IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -663,6 +672,8 @@ def flash_fwd_kernel( alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale + else: + alibi_slope = 0.0 if (not is_causal) and (not is_local): if IS_EVEN_MN: @@ -679,9 +690,12 @@ def flash_fwd_kernel( Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) - Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEADSIZE)[None, :] - qmask = row_idx < seqlen_q - Q = tl.load(pQ + Q_off, mask=qmask) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + qmask = row_idx[:, None] < seqlen_q + if IS_EVEN_MN: + Q = tl.load(Q_ptr + Q_off) + else: + Q = tl.load(Q_ptr + Q_off, mask=qmask) # Start from the right most block n_block = n_block_max - 1 @@ -697,101 +711,106 @@ def flash_fwd_kernel( P_ptr += n_block * BLOCK_N P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) - O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - for _ in range(n_masking_steps): - KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - if IS_EVEN_MN: - K = tl.load(pK + KV_offset, cache_modifier="cg") - if PRE_LOAD_V: - V = tl.load(pV + KV_offset, cache_modifier="cg") - else: - kvmask = col_idx < seqlen_k - K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg") - if PRE_LOAD_V: - V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") - S = tl.dot(Q_block, K_block) - S = apply_mask( - S, - col_idx, - row_idx, - seqlen_q, - seqlen_k, - ws_left, - ws_right, - alibi_slope, - is_even_mn=IS_EVEN_MN, - is_causal=is_causal, - is_local=is_local, - has_alibi=has_alibi - ) - col_idx -= BLOCK_N - - O_acc, P, rowmax_, rowsum_ = softmax_rescale( - O_, - S, - rowmax_, - rowsum_, - softmax_scale_log2=softmax_scale_log2, - is_border=(is_causal or is_local) - ) - P = P.to(pO.type.element_ty) - - row_start = m_block * (BLOCK_M // 16) - col_start = n_block * (BLOCK_N // 32) - if return_P: - P_drop = P - P_drop = apply_dropout( - P_drop, - row_start, - col_start, - bid, - hid, - philox_seed, - philox_offset, - pdrop_u8, - encode_dropout_in_sign_bit=True, - NUM_HEADS=NUM_HEADS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :]) - P_offset += BLOCK_N - - if is_dropout: - P = apply_dropout( - P, - row_start, - col_start, - bid, - hid, - philox_seed, - philox_offset, - pdrop_u8, - encode_dropout_in_sign_bit=False, - NUM_HEADS=NUM_HEADS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - - if not PRE_LOAD_V: - if IS_EVEN_MN: - V = tl.load(pV + KV_offset, cache_modifier="cg") - else: - V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg") - tl.dot(P, V, acc=O_) - - n_block -= 1 - if n_masking_steps > 1 and n_block <= n_block_min: - break - - for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages): - KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - K = tl.load(pK + KV_offset, cache_modifier="cg") - V = tl.load(pV + KV_offset, cache_modifier="cg") - S = tl.dot(Q_block, K_block) + # for _ in range(n_masking_steps): + # K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + # if IS_EVEN_MN: + # K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + # if PRE_LOAD_V: + # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + # else: + # kvmask = col_idx[None, :] < seqlen_k + # K = tl.load(K_ptr + K_offset, mask=kvmask, cache_modifier=".cg") + # if PRE_LOAD_V: + # V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg") + # S = tl.dot(Q, K, allow_tf32=False) + # S = apply_mask( + # S, + # col_idx, + # row_idx, + # seqlen_q, + # seqlen_k, + # ws_left, + # ws_right, + # alibi_slope, + # is_even_mn=IS_EVEN_MN, + # is_causal=is_causal, + # is_local=is_local, + # has_alibi=has_alibi + # ) + # col_idx -= BLOCK_N + + # O_acc, P, rowmax_, rowsum_ = softmax_rescale( + # O_, + # S, + # rowmax_, + # rowsum_, + # softmax_scale_log2=softmax_scale_log2, + # is_border=(is_causal or is_local) + # ) + # P = P.to(O_ptr.type.element_ty) + + # row_start = m_block * (BLOCK_M // 16) + # col_start = n_block * (BLOCK_N // 32) + # if return_P: + # P_drop = P + # P_drop = apply_dropout( + # P_drop, + # row_start, + # col_start, + # bid, + # hid, + # philox_seed, + # philox_offset, + # pdrop_u8, + # encode_dropout_in_sign_bit=True, + # NUM_HEADS=NUM_HEADS, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # ) + # tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) + # P_offset += BLOCK_N + + # if is_dropout: + # P = apply_dropout( + # P, + # row_start, + # col_start, + # bid, + # hid, + # philox_seed, + # philox_offset, + # pdrop_u8, + # encode_dropout_in_sign_bit=False, + # NUM_HEADS=NUM_HEADS, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # ) + + # if not PRE_LOAD_V: + # if IS_EVEN_MN: + # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + # else: + # V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg") + # tl.dot(P, V, acc=O_, allow_tf32=False) + + # n_block -= 1 + # # if n_masking_steps > 1 and n_block <= n_block_min: + # # break + + + # for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages): + if True: + K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + if PRE_LOAD_V: + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + S = tl.dot(Q, K) S = apply_mask( S, col_idx, @@ -808,7 +827,7 @@ def flash_fwd_kernel( ) col_idx -= BLOCK_N - O_acc, P, rowmax_, rowsum_ = softmax_rescale( + O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, rowmax_, @@ -816,7 +835,8 @@ def flash_fwd_kernel( softmax_scale_log2=softmax_scale_log2, is_border=is_local ) - P = P.to(pO.type.element_ty) + + P = P.to(O_ptr.type.element_ty) row_start = m_block * (BLOCK_M // 16) col_start = n_block * (BLOCK_N // 32) @@ -836,7 +856,7 @@ def flash_fwd_kernel( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :]) + tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) P_offset += BLOCK_N if is_dropout: @@ -855,24 +875,29 @@ def flash_fwd_kernel( BLOCK_N=BLOCK_N, ) + if not PRE_LOAD_V: + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + tl.dot(P, V, acc=O_) # Final LSE - lse = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, float('inf'), rowmax_ * softmax_scale +tl.log(rowsum_)) - inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_) + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) - # Rescale output - if is_dropout: - O_ *= inv_sum * rpdrop - else: - O_ *= inv_sum + # # Rescale output + # if is_dropout: + # O_ *= inv_sum[:, None] * rpdrop + # else: + # O_ *= inv_sum[:, None] - O = O_.to(pO.type.element_ty) + O = O_.to(O_ptr.type.element_ty) # Write back output O_ptr += bid * o_b_stride O_ptr += hid * o_h_stride O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM) + if IS_EVEN_MN: tl.store(O_ptr + O_offset, O) else: @@ -943,7 +968,7 @@ def mha_fwd( assert out.dtype == q.dtype assert out.size() == (batch_size, seqlen_q, num_heads, head_size) else: - out = torch.empty_like(q) + out = torch.empty_like(q, dtype=v.dtype) round_multiple = lambda x, m: (x + m - 1) // m * m head_size_rounded = round_multiple(head_size, 32) @@ -952,15 +977,16 @@ def mha_fwd( with torch_device_fn.device(q_device): # Set softmax params - lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float) + lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device) if return_softmax: assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0" p = torch.empty( (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q_dtype, + device=q_device ) else: - p = torch.empty(()) + p = torch.empty((), device=q_device) # Set dropout params if p_dropout > 0: @@ -968,6 +994,7 @@ def mha_fwd( philox_seed, philox_offset = update_philox_state(increment) is_dropout = True else: + philox_seed, philox_offset = None, None is_dropout = False p_dropout = 1 - p_dropout @@ -1023,12 +1050,15 @@ def mha_fwd( out.stride(0), out.stride(-3), out.stride(-2), + num_heads, + num_heads_k, alibi_slopes, philox_seed, philox_offset, pdrop_u8, rpdrop, alibi_slopes_batch_stride, + head_size, is_dropout=is_dropout, is_causal=is_causal, is_local=is_local, @@ -1041,8 +1071,6 @@ def mha_fwd( BATCH_SIZE=batch_size, NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, - HEAD_DIM=head_size_rounded, - IS_EVEN_MN=is_even_mn, ) if swap_seq_and_group: From 1962a965372e161fb41c337cd6290a56521e3df2 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Fri, 7 Mar 2025 12:10:52 +0000 Subject: [PATCH 09/25] passed noncausal, nonlocal and no bias. --- src/flag_gems/ops/attention.py | 255 +++++++++++++++++---------------- 1 file changed, 134 insertions(+), 121 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 78789162a..26efbf69f 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -562,8 +562,7 @@ def apply_mask( if is_local: S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S) - - if (not local) & (not is_causal) & (not is_even_mn): + if (not is_local) & (not is_causal) & (not is_even_mn): S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S) return S @@ -576,32 +575,44 @@ def softmax_rescale( row_max, row_sum, softmax_scale_log2: tl.constexpr, - is_border: tl.constexpr + is_border: tl.constexpr, + is_init: tl.constexpr ): - prev_row_max = row_max + prev_max = row_max row_max = tl.maximum(row_max, tl.max(S, 1)) - row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2 - tl.device_print('row_sum_scale', row_max - prev_row_max) - row_sum *= row_sum_scale - O_acc *= row_sum_scale[:, None] + + if not is_init: + if is_border: + cur_max = tl.where(row_max == float('-inf'), 0, row_max) + else: + cur_max = row_max + p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2) + row_sum *= p_scale + O_acc *= p_scale[:, None] + max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2) P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None]) - row_sum = row_sum + tl.sum(S, 1) + row_sum = row_sum + tl.sum(P, 1) return O_acc, P, row_max, row_sum -@triton.autotune( - configs=runtime.get_tuned_config("attention"), - key=["HEAD_DIM"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": None, - "top_k": 1.0, - }, -) +# @triton.autotune( +# configs=runtime.get_tuned_config("attention"), +# key=["HEAD_DIM"], +# prune_configs_by={ +# "early_config_prune": early_config_prune, +# "perf_model": None, +# "top_k": 1.0, +# }, +# ) @triton.heuristics( values={ - 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) + 'BLOCK_M': lambda args: 64, + 'BLOCK_N': lambda args: 64, + 'num_warps': lambda args: 4, + 'num_stages': lambda args: 2, + 'PRE_LOAD_V': lambda args: False, + 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) @@ -686,7 +697,6 @@ def flash_fwd_kernel( # local and not causal, n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 - Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) @@ -705,8 +715,7 @@ def flash_fwd_kernel( K_ptr += (hid // h_hk_ratio) * k_h_stride V_ptr += bid * k_b_stride V_ptr += (hid // h_hk_ratio) * k_h_stride - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - + P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded P_ptr += n_block * BLOCK_N P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) @@ -715,96 +724,98 @@ def flash_fwd_kernel( rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - # for _ in range(n_masking_steps): - # K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - # if IS_EVEN_MN: - # K = tl.load(K_ptr + K_offset, cache_modifier=".cg") - # if PRE_LOAD_V: - # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - # else: - # kvmask = col_idx[None, :] < seqlen_k - # K = tl.load(K_ptr + K_offset, mask=kvmask, cache_modifier=".cg") - # if PRE_LOAD_V: - # V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg") - # S = tl.dot(Q, K, allow_tf32=False) - # S = apply_mask( - # S, - # col_idx, - # row_idx, - # seqlen_q, - # seqlen_k, - # ws_left, - # ws_right, - # alibi_slope, - # is_even_mn=IS_EVEN_MN, - # is_causal=is_causal, - # is_local=is_local, - # has_alibi=has_alibi - # ) - # col_idx -= BLOCK_N - - # O_acc, P, rowmax_, rowsum_ = softmax_rescale( - # O_, - # S, - # rowmax_, - # rowsum_, - # softmax_scale_log2=softmax_scale_log2, - # is_border=(is_causal or is_local) - # ) - # P = P.to(O_ptr.type.element_ty) - - # row_start = m_block * (BLOCK_M // 16) - # col_start = n_block * (BLOCK_N // 32) - # if return_P: - # P_drop = P - # P_drop = apply_dropout( - # P_drop, - # row_start, - # col_start, - # bid, - # hid, - # philox_seed, - # philox_offset, - # pdrop_u8, - # encode_dropout_in_sign_bit=True, - # NUM_HEADS=NUM_HEADS, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # ) - # tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) - # P_offset += BLOCK_N - - # if is_dropout: - # P = apply_dropout( - # P, - # row_start, - # col_start, - # bid, - # hid, - # philox_seed, - # philox_offset, - # pdrop_u8, - # encode_dropout_in_sign_bit=False, - # NUM_HEADS=NUM_HEADS, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # ) - - # if not PRE_LOAD_V: - # if IS_EVEN_MN: - # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - # else: - # V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg") - # tl.dot(P, V, acc=O_, allow_tf32=False) - - # n_block -= 1 - # # if n_masking_steps > 1 and n_block <= n_block_min: - # # break - - - # for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages): - if True: + for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1): + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + if IS_EVEN_MN: + K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + else: + kvmask = col_idx < seqlen_k + K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + S = tl.dot(Q, K, allow_tf32=False) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi + ) + # col_idx -= BLOCK_N + + is_init = (n_block == n_block_max - 1).to(tl.int1) + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2=softmax_scale_log2, + is_border=(is_causal or is_local), + is_init=is_init + ) + P = P.to(O_ptr.type.element_ty) + + row_start = m_block * (BLOCK_M // 16) + col_start = n_block * (BLOCK_N // 32) + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :]) + P_offset += BLOCK_N + + if is_dropout: + P = apply_dropout( + P, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=False, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + if not PRE_LOAD_V: + if IS_EVEN_MN: + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + else: + V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + O_ = tl.dot(P, V, O_, allow_tf32=False) + + # if n_masking_steps > 1 and n_block <= n_block_min: + # break + + + for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages): + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] K = tl.load(K_ptr + K_offset, cache_modifier=".cg") if PRE_LOAD_V: @@ -825,15 +836,17 @@ def flash_fwd_kernel( is_local=is_local, has_alibi=has_alibi ) - col_idx -= BLOCK_N + # col_idx -= BLOCK_N + is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, rowmax_, rowsum_, softmax_scale_log2=softmax_scale_log2, - is_border=is_local + is_border=is_local, + is_init=is_init ) P = P.to(O_ptr.type.element_ty) @@ -879,17 +892,17 @@ def flash_fwd_kernel( V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - tl.dot(P, V, acc=O_) + O_ = tl.dot(P, V, O_) # Final LSE lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) - # # Rescale output - # if is_dropout: - # O_ *= inv_sum[:, None] * rpdrop - # else: - # O_ *= inv_sum[:, None] + # Rescale output + if is_dropout: + O_ *= inv_sum[:, None] * rpdrop + else: + O_ *= inv_sum[:, None] O = O_.to(O_ptr.type.element_ty) @@ -908,7 +921,7 @@ def flash_fwd_kernel( if IS_EVEN_MN: tl.store(lse_ptr + row_idx, lse) else: - tl.store(lse_ptr + row_idx, lse, mask=qmask) + tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q) def mha_fwd( From afc6108b599ea8d39a830638e2c1328a6632ba97 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Tue, 11 Mar 2025 11:32:33 +0000 Subject: [PATCH 10/25] added splitkv, perf still lags. --- src/flag_gems/__init__.py | 5 + src/flag_gems/ops/__init__.py | 3 +- src/flag_gems/ops/attention.py | 468 ++++++++++++++++++++++++++++++--- 3 files changed, 444 insertions(+), 32 deletions(-) diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 8cf202e3e..85cc04941 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -149,6 +149,11 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): # ("prod.dim_int", prod_dim, Autograd.disable), # ("sum", sum, Autograd.disable), # ("sum.dim_IntList", sum_dim, Autograd.disable), + # ( + # "scaled_dot_product_attention", + # scaled_dot_product_attention, + # Autograd.disable, + # ), ( "_flash_attention_forward", flash_attention_forward, diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index c4051d6d8..2802e11f6 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -7,7 +7,7 @@ from .arange import arange, arange_start from .argmax import argmax from .argmin import argmin -from .attention import flash_attention_forward +from .attention import flash_attention_forward, scaled_dot_product_attention from .batch_norm import batch_norm from .bitwise_and import ( bitwise_and_scalar, @@ -293,6 +293,7 @@ "vstack", "repeat_interleave_tensor", "flash_attention_forward", + "scaled_dot_product_attention", "conv2d", "conv1d", "_conv_depthwise2d", diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 26efbf69f..6a917b104 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -543,10 +543,7 @@ def apply_mask( is_local: tl.constexpr, has_alibi: tl.constexpr, ): - # need_mask = has_alibi or is_causal - # need_mask |= is_local - # need_mask |= not is_even_mn - need_mask: tl.constexpr = has_alibi | is_local | (not is_even_mn) + need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn) if need_mask: col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right) @@ -574,7 +571,7 @@ def softmax_rescale( S, row_max, row_sum, - softmax_scale_log2: tl.constexpr, + softmax_scale_log2e: tl.constexpr, is_border: tl.constexpr, is_init: tl.constexpr ): @@ -586,16 +583,33 @@ def softmax_rescale( cur_max = tl.where(row_max == float('-inf'), 0, row_max) else: cur_max = row_max - p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2) + p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) row_sum *= p_scale O_acc *= p_scale[:, None] - max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2) - P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None]) + max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e) + P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) row_sum = row_sum + tl.sum(P, 1) return O_acc, P, row_max, row_sum +def block_m_heuristic(headdim, is_dropout): + return 64 if headdim <= 128 else 32 + +def block_n_heuristic(headdim, is_dropout): + return 64 if headdim <= 128 else 32 + +def block_m_splitkv_heuristic(headdim): + return 64 if headdim <= 128 else 32 + +def block_n_splitkv_heuristic(headdim): + if headdim <= 64: + return 128 + elif headdim <= 128: + return 64 + else: + return 32 + # @triton.autotune( # configs=runtime.get_tuned_config("attention"), # key=["HEAD_DIM"], @@ -607,11 +621,11 @@ def softmax_rescale( # ) @triton.heuristics( values={ - 'BLOCK_M': lambda args: 64, - 'BLOCK_N': lambda args: 64, + 'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]), + 'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]), 'num_warps': lambda args: 4, - 'num_stages': lambda args: 2, - 'PRE_LOAD_V': lambda args: False, + 'num_stages': lambda args: 3, + 'PRE_LOAD_V': lambda args: True, 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) @@ -650,7 +664,7 @@ def flash_fwd_kernel( is_local: tl.constexpr, has_alibi: tl.constexpr, softmax_scale: tl.constexpr, - softmax_scale_log2: tl.constexpr, + softmax_scale_log2e: tl.constexpr, ws_left: tl.constexpr, ws_right: tl.constexpr, return_P: tl.constexpr, @@ -661,6 +675,7 @@ def flash_fwd_kernel( IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + num_splits: tl.constexpr, num_warps: tl.constexpr, num_stages: tl.constexpr ): @@ -677,7 +692,7 @@ def flash_fwd_kernel( if is_causal or is_local: n_block_max = min(n_block_max, - tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, BLOCK_N)) + tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) if has_alibi: alibi_offset = bid * slopes_batch_stride + hid @@ -691,7 +706,7 @@ def flash_fwd_kernel( n_masking_steps = 0 else: n_masking_steps = 1 - elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero + elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) else: # local and not causal, @@ -760,7 +775,7 @@ def flash_fwd_kernel( S, rowmax_, rowsum_, - softmax_scale_log2=softmax_scale_log2, + softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), is_init=is_init ) @@ -818,10 +833,15 @@ def flash_fwd_kernel( col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + # if PRE_LOAD_V: + # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + S = tl.dot(Q, K) + if PRE_LOAD_V: V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - S = tl.dot(Q, K) + S = apply_mask( S, col_idx, @@ -844,7 +864,7 @@ def flash_fwd_kernel( S, rowmax_, rowsum_, - softmax_scale_log2=softmax_scale_log2, + softmax_scale_log2e=softmax_scale_log2e, is_border=is_local, is_init=is_init ) @@ -894,7 +914,8 @@ def flash_fwd_kernel( O_ = tl.dot(P, V, O_) - # Final LSE + # LSE + # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only. lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) @@ -924,6 +945,312 @@ def flash_fwd_kernel( tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q) +@triton.heuristics( + values={ + 'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'num_warps': lambda args: 4, + 'num_stages': lambda args: 3, + 'PRE_LOAD_V': lambda args: True, + 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), + } +) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) +def flash_fwd_splitkv_kernel( + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, + h, + hk, + pSlopes, + philox_seed, + philox_offset, + pdrop_u8, + rpdrop, + slopes_batch_stride, + HEAD_DIM: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + softmax_scale: tl.constexpr, + softmax_scale_log2e: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, + return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + IS_EVEN_MN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + num_splits: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr +): + m_block = tl.program_id(0) + split_id = tl.program_id(1) + bid = tl.program_id(2) // NUM_HEADS + hid = tl.program_id(2) % NUM_HEADS + + blocks_per_split = tl.cdiv(tl.cdiv(seqlen_k, BLOCK_N), num_splits) + + if is_local: + n_block_min = max(split_id * blocks_per_split, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) + else: + n_block_min = split_id * blocks_per_split + + n_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N)) + if is_causal or is_local: + n_block_max = min(n_block_max, + tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) + + if has_alibi: + alibi_offset = bid * slopes_batch_stride + hid + alibi_slope = tl.load(pSlopes + alibi_offset) + alibi_slope /= scale + else: + alibi_slope = 0.0 + + if (not is_causal) and (not is_local): + if IS_EVEN_MN: + n_masking_steps = 0 + else: + n_masking_steps = 1 + elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero + n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + else: + # local and not causal, + n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + + Q_ptr += bid * q_b_stride + Q_ptr += hid * q_h_stride + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + qmask = row_idx[:, None] < seqlen_q + if IS_EVEN_MN: + Q = tl.load(Q_ptr + Q_off) + else: + Q = tl.load(Q_ptr + Q_off, mask=qmask) + + # Start from the right most block + n_block = n_block_max - 1 + + h_hk_ratio = h // hk + K_ptr += bid * k_b_stride + K_ptr += (hid // h_hk_ratio) * k_h_stride + V_ptr += bid * k_b_stride + V_ptr += (hid // h_hk_ratio) * k_h_stride + + P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded + P_ptr += n_block * BLOCK_N + P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) + + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + + for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1): + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + if IS_EVEN_MN: + K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + else: + kvmask = col_idx < seqlen_k + K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + S = tl.dot(Q, K, allow_tf32=False) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi + ) + # col_idx -= BLOCK_N + + is_init = (n_block == n_block_max - 1).to(tl.int1) + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=(is_causal or is_local), + is_init=is_init + ) + P = P.to(Q_ptr.type.element_ty) + + if not PRE_LOAD_V: + if IS_EVEN_MN: + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + else: + V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + O_ = tl.dot(P, V, O_, allow_tf32=False) + # if n_masking_steps > 1 and n_block <= n_block_min: + # break + + for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages): + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + # if PRE_LOAD_V: + # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + S = tl.dot(Q, K) + + if PRE_LOAD_V: + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_even_mn=True, + is_causal=False, + is_local=is_local, + has_alibi=has_alibi + ) + # col_idx -= BLOCK_N + + is_init = (n_block == n_block_max - 1).to(tl.int1) + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=is_local, + is_init=is_init + ) + + P = P.to(Q_ptr.type.element_ty) + + if not PRE_LOAD_V: + V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + + O_ = tl.dot(P, V, O_) + + # LSE + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) + + # Rescale output + O_ *= inv_sum[:, None] + + # Write back output + O_split_ptr = O_ptr + # (n_splits, batch_size, num_heads, seqlen_q, head_size) + # grid = (seq_block, split, batch * head) + O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM + O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) + + if IS_EVEN_MN: + tl.store(O_split_ptr + O_split_offset, O_) + else: + tl.store(O_split_ptr + O_split_offset, O_, mask=qmask) + + # Write back lse + lse_split_ptr = lse_ptr + lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + lse_split_ptr += m_block * BLOCK_M + + if IS_EVEN_MN: + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse) + else: + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q) + + +@triton.jit +def flash_fwd_splitkv_combine_kernel( + out_ptr, + lse_ptr, + out_splits_ptr, + lse_splits_ptr, + head_size: tl.constexpr, + out_b_stride, + out_s_stride, + out_h_stride, + n_splits, + BLOCK_M: tl.constexpr, + q_total, + MAX_N_SPLITS: tl.constexpr, +): + pid = tl.program_id(0) + lse_splits_ptr += pid * BLOCK_M + lse_ptr += pid * BLOCK_M + out_splits_ptr += pid * BLOCK_M * head_size + out_ptr += pid * BLOCK_M * head_size + lse_split_stride = tl.num_programs(0) * BLOCK_M + out_split_stride = tl.num_programs(0) * BLOCK_M * head_size + + # Subtracting maximum from each of the split lse's for better numerical stability + lse_split_offset = tl.arange(0, BLOCK_M)[:, None] + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride + lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits) + lse_splits = tl.load(lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float('-inf')) + max_lse = tl.max(lse_splits, 1) + + # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) + Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) + Z_scaled = tl.sum(Zi_scaled, 1) + Zi_Z = Zi_scaled / Z_scaled[:, None] + + # Write back LSE + lse = tl.log(Z_scaled) + max_lse + out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total + tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) + + out_split_offset = ( + tl.arange(0, BLOCK_M)[:, None, None] * head_size + + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride + + tl.arange(0, head_size)[None, None, :] + ) + out_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total) & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) + out_splits = tl.load(out_splits_ptr + out_split_offset, mask=out_split_mask, other=0) + out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) + out = out.to(out_ptr.type.element_ty) + + # tl.device_print('O', out) + # Write back output + out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size) + tl.store(out_ptr + out_offset, out, mask=out_mask[:, None]) + + def mha_fwd( q, k, @@ -988,6 +1315,26 @@ def mha_fwd( seqlen_q_rounded = round_multiple(seqlen_q, 128) seqlen_k_rounded = round_multiple(seqlen_k, 128) + def splits_heuristics(num_tasks, num_sms, n_blocks): + # splits only number of waves and wave efficiency are both low + n_waves = triton.cdiv(num_tasks, num_sms) + eff = (num_tasks / num_sms) / n_waves + if eff > 0.85 or n_waves > 10: + return 1 + max_eff = eff + best_splits = 1 + for w in range(n_waves, 10): + n_splits = min(num_sms, n_blocks, w * num_sms // num_tasks) + blocks_per_split = triton.cdiv(n_blocks, n_splits) + if blocks_per_split < 4: + continue + n_splits = triton.cdiv(n_blocks, blocks_per_split) + eff = (n_splits * num_tasks / num_sms) / w + if eff > max_eff: + max_eff = eff + best_splits = n_splits + return best_splits + with torch_device_fn.device(q_device): # Set softmax params lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device) @@ -1014,9 +1361,31 @@ def mha_fwd( pdrop_u8 = math.floor(p_dropout * 255.0) rpdrop = 1. / p_dropout + # Check splitkv + if not is_dropout: + n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m_splitkv_heuristic(head_size)) + num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count + n_blocks = triton.cdiv(seqlen_k, block_n_splitkv_heuristic(head_size)) + n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) + print('n_blocks:', n_blocks) + print('n_splits:', n_splits) + else: + n_splits = 1 + + if n_splits > 1: + lse_splits = torch.empty( + (n_splits, batch_size, num_heads, seqlen_q), + dtype=torch.float, + device=q_device + ) + out_splits = torch.empty( + (n_splits, batch_size, num_heads, seqlen_q, head_size), + dtype=torch.float, + device=q_device + ) + M_LOG2E = 1.4426950408889634074 - scale_softmax_log2 = softmax_scale * M_LOG2E - scale_softmax_rp_dropout = rpdrop * softmax_scale + softmax_scale_log2e = softmax_scale * M_LOG2E # Set alibi params if alibi_slopes is not None: @@ -1036,20 +1405,33 @@ def mha_fwd( # ONLY EVEN_K IS SUPPORTED assert head_size == head_size_rounded + # Launch kernel + if n_splits > 1: + grid = lambda args: ( + triton.cdiv(seqlen_q, args["BLOCK_M"]), + n_splits, + batch_size * num_heads + ) + kernel = flash_fwd_splitkv_kernel[grid] + tmp_lse = lse_splits + tmp_out = out_splits + else: + grid = lambda args: ( + triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks + batch_size, + num_heads, + ) + kernel = flash_fwd_kernel[grid] + tmp_lse = lse + tmp_out = out - grid = lambda args: ( - triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks - batch_size, - num_heads, - ) - - flash_fwd_kernel[grid]( + kernel( q, k, v, p, - out, - lse, + tmp_out, + tmp_lse, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -1077,15 +1459,39 @@ def mha_fwd( is_local=is_local, has_alibi=has_alibi, softmax_scale=softmax_scale, - softmax_scale_log2=scale_softmax_log2, + softmax_scale_log2e=softmax_scale_log2e, ws_left=window_size_left, ws_right=window_size_right, return_P=return_softmax, BATCH_SIZE=batch_size, + num_splits=n_splits, NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, ) + if n_splits > 1: + if head_size % 128 == 0: + BLOCK_M = 4 + elif head_size % 64 == 0: + BLOCK_M = 8 + else: + BLOCK_M = 16 + grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), ) + flash_fwd_splitkv_combine_kernel[grid]( + out, + lse, + tmp_out, + tmp_lse, + head_size, + out.stride(0), + out.stride(-3), + out.stride(-1), + n_splits, + BLOCK_M, + q_total=batch_size * num_heads * seqlen_q, + MAX_N_SPLITS=triton.next_power_of_2(n_splits), + ) + if swap_seq_and_group: out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) From 88f33d4138a29d3ad01ce55da27a8e0303c4799f Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Tue, 18 Mar 2025 03:51:45 +0000 Subject: [PATCH 11/25] nuked dynamic cf in loop, but still failed pipelining. --- src/flag_gems/ops/__init__.py | 614 ++++++++++++++++----------------- src/flag_gems/ops/attention.py | 267 +++++++------- 2 files changed, 455 insertions(+), 426 deletions(-) diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 2802e11f6..01a31e633 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,312 +1,312 @@ -from .abs import abs -from .add import add -from .addmm import addmm -from .all import all, all_dim, all_dims -from .amax import amax -from .any import any, any_dim, any_dims -from .arange import arange, arange_start -from .argmax import argmax -from .argmin import argmin +# from .abs import abs +# from .add import add +# from .addmm import addmm +# from .all import all, all_dim, all_dims +# from .amax import amax +# from .any import any, any_dim, any_dims +# from .arange import arange, arange_start +# from .argmax import argmax +# from .argmin import argmin from .attention import flash_attention_forward, scaled_dot_product_attention -from .batch_norm import batch_norm -from .bitwise_and import ( - bitwise_and_scalar, - bitwise_and_scalar_tensor, - bitwise_and_tensor, -) -from .bitwise_not import bitwise_not -from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor -from .bmm import bmm -from .cat import cat -from .clamp import clamp, clamp_tensor -from .conv1d import conv1d -from .conv2d import conv2d -from .conv_depthwise2d import _conv_depthwise2d -from .cos import cos -from .count_nonzero import count_nonzero -from .cross_entropy_loss import cross_entropy_loss -from .cummin import cummin -from .cumsum import cumsum, normed_cumsum -from .diag import diag -from .diag_embed import diag_embed -from .diagonal import diagonal_backward -from .div import div_mode, floor_divide, remainder, true_divide -from .dropout import native_dropout -from .embedding import embedding -from .eq import eq, eq_scalar -from .erf import erf -from .exp import exp -from .exponential_ import exponential_ -from .fill import fill_scalar, fill_tensor -from .flip import flip -from .full import full -from .full_like import full_like -from .gather import gather, gather_backward -from .ge import ge, ge_scalar -from .gelu import gelu -from .groupnorm import group_norm -from .gt import gt, gt_scalar -from .hstack import hstack -from .index_add import index_add -from .index_select import index_select -from .instancenorm import instance_norm -from .isclose import allclose, isclose -from .isfinite import isfinite -from .isin import isin -from .isinf import isinf -from .isnan import isnan -from .layernorm import layer_norm -from .le import le, le_scalar -from .log_sigmoid import log_sigmoid -from .log_softmax import log_softmax -from .logical_and import logical_and -from .logical_not import logical_not -from .logical_or import logical_or -from .logical_xor import logical_xor -from .lt import lt, lt_scalar -from .masked_fill import masked_fill, masked_fill_ -from .masked_select import masked_select -from .max import max, max_dim -from .maximum import maximum -from .mean import mean, mean_dim -from .min import min, min_dim -from .minimum import minimum -from .mm import mm -from .mse_loss import mse_loss -from .mul import mul -from .multinomial import multinomial -from .mv import mv -from .ne import ne, ne_scalar -from .neg import neg -from .nllloss import ( - nll_loss2d_backward, - nll_loss2d_forward, - nll_loss_backward, - nll_loss_forward, -) -from .nonzero import nonzero -from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor -from .ones import ones -from .ones_like import ones_like -from .outer import outer -from .pad import constant_pad_nd, pad -from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor -from .prod import prod, prod_dim -from .quantile import quantile -from .rand import rand -from .rand_like import rand_like -from .randn import randn -from .randn_like import randn_like -from .randperm import randperm -from .reciprocal import reciprocal -from .relu import relu -from .repeat import repeat -from .repeat_interleave import ( - repeat_interleave_self_int, - repeat_interleave_self_tensor, - repeat_interleave_tensor, -) -from .resolve_conj import resolve_conj -from .resolve_neg import resolve_neg -from .rms_norm import rms_norm -from .rsqrt import rsqrt -from .scatter import scatter -from .select_scatter import select_scatter -from .sigmoid import sigmoid -from .silu import silu -from .sin import sin -from .slice_scatter import slice_scatter -from .softmax import softmax -from .sort import sort -from .stack import stack -from .sub import sub -from .sum import sum, sum_dim -from .tanh import tanh -from .tile import tile -from .topk import topk -from .triu import triu -from .uniform import uniform_ -from .unique import _unique2 -from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa -from .upsample_nearest2d import upsample_nearest2d -from .var_mean import var_mean -from .vdot import vdot -from .vector_norm import vector_norm -from .vstack import vstack -from .weightnorm import weight_norm, weight_norm_interface -from .where import where_scalar_other, where_scalar_self, where_self, where_self_out -from .zeros import zeros -from .zeros_like import zeros_like +# from .batch_norm import batch_norm +# from .bitwise_and import ( +# bitwise_and_scalar, +# bitwise_and_scalar_tensor, +# bitwise_and_tensor, +# ) +# from .bitwise_not import bitwise_not +# from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor +# from .bmm import bmm +# from .cat import cat +# from .clamp import clamp, clamp_tensor +# from .conv1d import conv1d +# from .conv2d import conv2d +# from .conv_depthwise2d import _conv_depthwise2d +# from .cos import cos +# from .count_nonzero import count_nonzero +# from .cross_entropy_loss import cross_entropy_loss +# from .cummin import cummin +# from .cumsum import cumsum, normed_cumsum +# from .diag import diag +# from .diag_embed import diag_embed +# from .diagonal import diagonal_backward +# from .div import div_mode, floor_divide, remainder, true_divide +# from .dropout import native_dropout +# from .embedding import embedding +# from .eq import eq, eq_scalar +# from .erf import erf +# from .exp import exp +# from .exponential_ import exponential_ +# from .fill import fill_scalar, fill_tensor +# from .flip import flip +# from .full import full +# from .full_like import full_like +# from .gather import gather, gather_backward +# from .ge import ge, ge_scalar +# from .gelu import gelu +# from .groupnorm import group_norm +# from .gt import gt, gt_scalar +# from .hstack import hstack +# from .index_add import index_add +# from .index_select import index_select +# from .instancenorm import instance_norm +# from .isclose import allclose, isclose +# from .isfinite import isfinite +# from .isin import isin +# from .isinf import isinf +# from .isnan import isnan +# from .layernorm import layer_norm +# from .le import le, le_scalar +# from .log_sigmoid import log_sigmoid +# from .log_softmax import log_softmax +# from .logical_and import logical_and +# from .logical_not import logical_not +# from .logical_or import logical_or +# from .logical_xor import logical_xor +# from .lt import lt, lt_scalar +# from .masked_fill import masked_fill, masked_fill_ +# from .masked_select import masked_select +# from .max import max, max_dim +# from .maximum import maximum +# from .mean import mean, mean_dim +# from .min import min, min_dim +# from .minimum import minimum +# from .mm import mm +# from .mse_loss import mse_loss +# from .mul import mul +# from .multinomial import multinomial +# from .mv import mv +# from .ne import ne, ne_scalar +# from .neg import neg +# from .nllloss import ( +# nll_loss2d_backward, +# nll_loss2d_forward, +# nll_loss_backward, +# nll_loss_forward, +# ) +# from .nonzero import nonzero +# from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor +# from .ones import ones +# from .ones_like import ones_like +# from .outer import outer +# from .pad import constant_pad_nd, pad +# from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor +# from .prod import prod, prod_dim +# from .quantile import quantile +# from .rand import rand +# from .rand_like import rand_like +# from .randn import randn +# from .randn_like import randn_like +# from .randperm import randperm +# from .reciprocal import reciprocal +# from .relu import relu +# from .repeat import repeat +# from .repeat_interleave import ( +# repeat_interleave_self_int, +# repeat_interleave_self_tensor, +# repeat_interleave_tensor, +# ) +# from .resolve_conj import resolve_conj +# from .resolve_neg import resolve_neg +# from .rms_norm import rms_norm +# from .rsqrt import rsqrt +# from .scatter import scatter +# from .select_scatter import select_scatter +# from .sigmoid import sigmoid +# from .silu import silu +# from .sin import sin +# from .slice_scatter import slice_scatter +# from .softmax import softmax +# from .sort import sort +# from .stack import stack +# from .sub import sub +# from .sum import sum, sum_dim +# from .tanh import tanh +# from .tile import tile +# from .topk import topk +# from .triu import triu +# from .uniform import uniform_ +# from .unique import _unique2 +# from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa +# from .upsample_nearest2d import upsample_nearest2d +# from .var_mean import var_mean +# from .vdot import vdot +# from .vector_norm import vector_norm +# from .vstack import vstack +# from .weightnorm import weight_norm, weight_norm_interface +# from .where import where_scalar_other, where_scalar_self, where_self, where_self_out +# from .zeros import zeros +# from .zeros_like import zeros_like __all__ = [ - "log_sigmoid", - "all", - "all_dim", - "all_dims", - "allclose", - "any", - "any_dim", - "any_dims", - "add", - "abs", - "addmm", - "arange", - "arange_start", - "batch_norm", - "bitwise_and_tensor", - "bitwise_and_scalar", - "bitwise_and_scalar_tensor", - "bitwise_not", - "bitwise_or_tensor", - "bitwise_or_scalar", - "bitwise_or_scalar_tensor", - "bmm", - "clamp", - "clamp_tensor", - "cos", - "count_nonzero", - "diag", - "diag_embed", - "diagonal_backward", - "pad", - "constant_pad_nd", - "cummin", - "cumsum", - "normed_cumsum", - "true_divide", - "div_mode", - "floor_divide", - "remainder", - "zeros", - "ones", - "full", - "native_dropout", - "erf", - "embedding", - "eq", - "eq_scalar", - "exp", - "fill_scalar", - "fill_tensor", - "exponential_", - "gather", - "gather_backward", - "flip", - "ones_like", - "full_like", - "zeros_like", - "ge", - "ge_scalar", - "gelu", - "group_norm", - "gt", - "gt_scalar", - "index_select", - "instance_norm", - "isclose", - "isfinite", - "isin", - "isinf", - "isnan", - "layer_norm", - "weight_norm_interface", - "weight_norm", - "le", - "le_scalar", - "lt", - "lt_scalar", - "rms_norm", - "mean", - "mean_dim", - "mm", - "mul", - "multinomial", - "maximum", - "minimum", - "rand", - "randn", - "randperm", - "rand_like", - "randn_like", - "resolve_neg", - "resolve_conj", - "normal_tensor_float", - "normal_float_tensor", - "normal_tensor_tensor", - "uniform_", - "mv", - "ne", - "ne_scalar", - "neg", - "pow_scalar", - "pow_tensor_scalar", - "pow_tensor_tensor", - "reciprocal", - "relu", - "rsqrt", - "scatter", - "sigmoid", - "silu", - "sin", - "softmax", - "sub", - "tanh", - "tile", - "triu", - "topk", - "max", - "max_dim", - "min", - "min_dim", - "sum", - "sum_dim", - "amax", - "argmax", - "argmin", - "prod", - "prod_dim", - "quantile", - "var_mean", - "vector_norm", - "log_softmax", - "outer", - "cross_entropy_loss", - "where_self_out", - "where_self", - "where_scalar_self", - "where_scalar_other", - "index_add", - "select_scatter", - "slice_scatter", - "masked_fill", - "masked_fill_", - "_unique2", - "_upsample_bicubic2d_aa", - "upsample_nearest2d", - "nonzero", - "repeat", - "masked_select", - "stack", - "hstack", - "cat", - "repeat_interleave_self_int", - "vstack", - "repeat_interleave_tensor", +# "log_sigmoid", +# "all", +# "all_dim", +# "all_dims", +# "allclose", +# "any", +# "any_dim", +# "any_dims", +# "add", +# "abs", +# "addmm", +# "arange", +# "arange_start", +# "batch_norm", +# "bitwise_and_tensor", +# "bitwise_and_scalar", +# "bitwise_and_scalar_tensor", +# "bitwise_not", +# "bitwise_or_tensor", +# "bitwise_or_scalar", +# "bitwise_or_scalar_tensor", +# "bmm", +# "clamp", +# "clamp_tensor", +# "cos", +# "count_nonzero", +# "diag", +# "diag_embed", +# "diagonal_backward", +# "pad", +# "constant_pad_nd", +# "cummin", +# "cumsum", +# "normed_cumsum", +# "true_divide", +# "div_mode", +# "floor_divide", +# "remainder", +# "zeros", +# "ones", +# "full", +# "native_dropout", +# "erf", +# "embedding", +# "eq", +# "eq_scalar", +# "exp", +# "fill_scalar", +# "fill_tensor", +# "exponential_", +# "gather", +# "gather_backward", +# "flip", +# "ones_like", +# "full_like", +# "zeros_like", +# "ge", +# "ge_scalar", +# "gelu", +# "group_norm", +# "gt", +# "gt_scalar", +# "index_select", +# "instance_norm", +# "isclose", +# "isfinite", +# "isin", +# "isinf", +# "isnan", +# "layer_norm", +# "weight_norm_interface", +# "weight_norm", +# "le", +# "le_scalar", +# "lt", +# "lt_scalar", +# "rms_norm", +# "mean", +# "mean_dim", +# "mm", +# "mul", +# "multinomial", +# "maximum", +# "minimum", +# "rand", +# "randn", +# "randperm", +# "rand_like", +# "randn_like", +# "resolve_neg", +# "resolve_conj", +# "normal_tensor_float", +# "normal_float_tensor", +# "normal_tensor_tensor", +# "uniform_", +# "mv", +# "ne", +# "ne_scalar", +# "neg", +# "pow_scalar", +# "pow_tensor_scalar", +# "pow_tensor_tensor", +# "reciprocal", +# "relu", +# "rsqrt", +# "scatter", +# "sigmoid", +# "silu", +# "sin", +# "softmax", +# "sub", +# "tanh", +# "tile", +# "triu", +# "topk", +# "max", +# "max_dim", +# "min", +# "min_dim", +# "sum", +# "sum_dim", +# "amax", +# "argmax", +# "argmin", +# "prod", +# "prod_dim", +# "quantile", +# "var_mean", +# "vector_norm", +# "log_softmax", +# "outer", +# "cross_entropy_loss", +# "where_self_out", +# "where_self", +# "where_scalar_self", +# "where_scalar_other", +# "index_add", +# "select_scatter", +# "slice_scatter", +# "masked_fill", +# "masked_fill_", +# "_unique2", +# "_upsample_bicubic2d_aa", +# "upsample_nearest2d", +# "nonzero", +# "repeat", +# "masked_select", +# "stack", +# "hstack", +# "cat", +# "repeat_interleave_self_int", +# "vstack", +# "repeat_interleave_tensor", "flash_attention_forward", - "scaled_dot_product_attention", - "conv2d", - "conv1d", - "_conv_depthwise2d", - "repeat_interleave_self_tensor", - "logical_or", - "logical_and", - "logical_xor", - "logical_not", - "sort", - "nll_loss_forward", - "nll_loss_backward", - "nll_loss2d_forward", - "nll_loss2d_backward", - "vdot", - "mse_loss", +# "scaled_dot_product_attention", +# "conv2d", +# "conv1d", +# "_conv_depthwise2d", +# "repeat_interleave_self_tensor", +# "logical_or", +# "logical_and", +# "logical_xor", +# "logical_not", +# "sort", +# "nll_loss_forward", +# "nll_loss_backward", +# "nll_loss2d_forward", +# "nll_loss2d_backward", +# "vdot", +# "mse_loss", ] diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 6a917b104..756ca2571 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -573,19 +573,27 @@ def softmax_rescale( row_sum, softmax_scale_log2e: tl.constexpr, is_border: tl.constexpr, - is_init: tl.constexpr + # is_init: tl.constexpr ): prev_max = row_max row_max = tl.maximum(row_max, tl.max(S, 1)) - if not is_init: - if is_border: - cur_max = tl.where(row_max == float('-inf'), 0, row_max) - else: - cur_max = row_max - p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) - row_sum *= p_scale - O_acc *= p_scale[:, None] + # if not is_init: + # if is_border: + # cur_max = tl.where(row_max == float('-inf'), 0, row_max) + # else: + # cur_max = row_max + # p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) + # row_sum *= p_scale + # O_acc *= p_scale[:, None] + + if is_border: + cur_max = tl.where(row_max == float('-inf'), 0, row_max) + else: + cur_max = row_max + p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) + row_sum *= p_scale + O_acc *= p_scale[:, None] max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e) P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) @@ -594,21 +602,23 @@ def softmax_rescale( def block_m_heuristic(headdim, is_dropout): - return 64 if headdim <= 128 else 32 + # return 128 if headdim <= 128 else 64 + return 64 def block_n_heuristic(headdim, is_dropout): - return 64 if headdim <= 128 else 32 + # return 128 if headdim <= 128 else 64 + return 64 def block_m_splitkv_heuristic(headdim): - return 64 if headdim <= 128 else 32 + return 128 if headdim <= 128 else 64 def block_n_splitkv_heuristic(headdim): if headdim <= 64: - return 128 + return 256 elif headdim <= 128: - return 64 + return 128 else: - return 32 + return 64 # @triton.autotune( # configs=runtime.get_tuned_config("attention"), @@ -629,7 +639,7 @@ def block_n_splitkv_heuristic(headdim): 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) -@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "k_b_stride", "o_b_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) def flash_fwd_kernel( Q_ptr, K_ptr, @@ -684,15 +694,25 @@ def flash_fwd_kernel( hid = tl.program_id(2) if is_local: - n_block_min: tl.constexpr = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) + col_min = m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left + col_min = max(col_min, 0) else: - n_block_min: tl.constexpr = 0 - - n_block_max = tl.cdiv(seqlen_k, BLOCK_N) + col_min = 0 + col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N if is_causal or is_local: - n_block_max = min(n_block_max, - tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) + col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right) + + # if is_local: + # n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) + # else: + # n_block_min = 0 + + # n_block_max = tl.cdiv(seqlen_k, BLOCK_N) + + # if is_causal or is_local: + # n_block_max = min(n_block_max, + # tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) if has_alibi: alibi_offset = bid * slopes_batch_stride + hid @@ -703,24 +723,36 @@ def flash_fwd_kernel( if (not is_causal) and (not is_local): if IS_EVEN_MN: - n_masking_steps = 0 + n_masking_blocks: tl.constexpr = 0 else: - n_masking_steps = 1 + n_masking_blocks: tl.constexpr = 1 elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) else: # local and not causal, - n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + + masking_cols = n_masking_blocks * BLOCK_N Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride - row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + row_start = m_block * BLOCK_M + row_idx = row_start + tl.arange(0, BLOCK_M) Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] qmask = row_idx[:, None] < seqlen_q if IS_EVEN_MN: - Q = tl.load(Q_ptr + Q_off) + Q = tl.load(Q_ptr + Q_off, cache_modifier='.cg') else: - Q = tl.load(Q_ptr + Q_off, mask=qmask) + Q = tl.load(Q_ptr + Q_off, mask=qmask, cache_modifier='.cg') + + if return_P: + P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded + P_ptr += (n_block_max - 1) * BLOCK_N + P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) + + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) # Start from the right most block n_block = n_block_max - 1 @@ -730,29 +762,28 @@ def flash_fwd_kernel( K_ptr += (hid // h_hk_ratio) * k_h_stride V_ptr += bid * k_b_stride V_ptr += (hid // h_hk_ratio) * k_h_stride - - P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded - P_ptr += n_block * BLOCK_N - P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) - O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) - rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - - for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1): - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + + p_bk0 = K_ptr + K_offset + p_bv0 = V_ptr + V_offset + + for col_start in tl.range(max(col_min, col_max - masking_cols), col_max, step=BLOCK_N): + # for r_blk_idx in tl.range(0, min(n_masking_blocks, n_blocks_max - n_blocks_min)): + off = col_start * k_s_stride if IS_EVEN_MN: - K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + K = tl.load(p_bk0 + off, cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") else: - kvmask = col_idx < seqlen_k - K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg") + kvmask = col < seqlen_k + K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") S = tl.dot(Q, K, allow_tf32=False) + col_idx = col_start + tl.arange(0, BLOCK_N) + row_idx = row_start + tl.arange(0, BLOCK_M) S = apply_mask( S, col_idx, @@ -767,9 +798,8 @@ def flash_fwd_kernel( is_local=is_local, has_alibi=has_alibi ) - # col_idx -= BLOCK_N - is_init = (n_block == n_block_max - 1).to(tl.int1) + # is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -777,32 +807,33 @@ def flash_fwd_kernel( rowsum_, softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), - is_init=is_init + # is_init=is_init ) P = P.to(O_ptr.type.element_ty) - row_start = m_block * (BLOCK_M // 16) - col_start = n_block * (BLOCK_N // 32) - if return_P: - P_drop = P - P_drop = apply_dropout( - P_drop, - row_start, - col_start, - bid, - hid, - philox_seed, - philox_offset, - pdrop_u8, - encode_dropout_in_sign_bit=True, - NUM_HEADS=NUM_HEADS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :]) - P_offset += BLOCK_N - if is_dropout: + row_start = m_block * (BLOCK_M // 16) + col_start = n_block * (BLOCK_N // 32) + + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :]) + P_offset += BLOCK_N + P = apply_dropout( P, row_start, @@ -819,29 +850,26 @@ def flash_fwd_kernel( ) if not PRE_LOAD_V: + off = col_start * k_s_stride if IS_EVEN_MN: - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") else: - V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - # if n_masking_steps > 1 and n_block <= n_block_min: - # break - - - for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages): - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + for col_start in tl.range(min_col, max_col - masking_cols, step=BLOCK_N, num_stages=num_stages): + # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): + off = col_start * k_s_stride + K = tl.load(p_bk0 + off, cache_modifier=".cg") # if PRE_LOAD_V: - # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") S = tl.dot(Q, K) if PRE_LOAD_V: - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") + col_idx = col_start + tl.arange(0, BLOCK_N) + row_idx = row_start + tl.arange(0, BLOCK_M) S = apply_mask( S, col_idx, @@ -856,9 +884,8 @@ def flash_fwd_kernel( is_local=is_local, has_alibi=has_alibi ) - # col_idx -= BLOCK_N - is_init = (n_block == n_block_max - 1).to(tl.int1) + # is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -866,33 +893,32 @@ def flash_fwd_kernel( rowsum_, softmax_scale_log2e=softmax_scale_log2e, is_border=is_local, - is_init=is_init ) - P = P.to(O_ptr.type.element_ty) - row_start = m_block * (BLOCK_M // 16) - col_start = n_block * (BLOCK_N // 32) - if return_P: - P_drop = P - P_drop = apply_dropout( - P_drop, - row_start, - col_start, - bid, - hid, - philox_seed, - philox_offset, - pdrop_u8, - encode_dropout_in_sign_bit=True, - NUM_HEADS=NUM_HEADS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) - P_offset += BLOCK_N - if is_dropout: + row_start = m_block * (BLOCK_M // 16) + col_start = n_block * (BLOCK_N // 32) + + if return_P: + P_drop = P + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) + P_offset += BLOCK_N + P = apply_dropout( P, row_start, @@ -909,8 +935,8 @@ def flash_fwd_kernel( ) if not PRE_LOAD_V: - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + off = col_start * k_s_stride + V = tl.load(p_bv0 + off, cache_modifier=".cg") O_ = tl.dot(P, V, O_) @@ -1031,14 +1057,14 @@ def flash_fwd_splitkv_kernel( if (not is_causal) and (not is_local): if IS_EVEN_MN: - n_masking_steps = 0 + n_masking_blocks = 0 else: - n_masking_steps = 1 + n_masking_blocks = 1 elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) else: # local and not causal, - n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) + 1 Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride @@ -1067,7 +1093,7 @@ def flash_fwd_splitkv_kernel( rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1): + for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_blocks - 1, step=-1): col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] @@ -1115,10 +1141,10 @@ def flash_fwd_splitkv_kernel( else: V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - # if n_masking_steps > 1 and n_block <= n_block_min: + # if n_masking_blocks > 1 and n_block <= n_block_min: # break - for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages): + for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] K = tl.load(K_ptr + K_offset, cache_modifier=".cg") @@ -1425,7 +1451,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): tmp_lse = lse tmp_out = out - kernel( + kernel = kernel( q, k, v, @@ -1468,6 +1494,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, ) + print(f'{kernel.name} shared memory:', kernel.metadata.shared) + # print(kernel.asm['ttgir']) if n_splits > 1: if head_size % 128 == 0: @@ -1477,7 +1505,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): else: BLOCK_M = 16 grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), ) - flash_fwd_splitkv_combine_kernel[grid]( + kernel = flash_fwd_splitkv_combine_kernel[grid]( out, lse, tmp_out, @@ -1491,6 +1519,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): q_total=batch_size * num_heads * seqlen_q, MAX_N_SPLITS=triton.next_power_of_2(n_splits), ) + print(f'{kernel.name} shared memory:', kernel.metadata.shared) if swap_seq_and_group: out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) From af24ad2baaee60f483d090bc0a3c2e79377fa81f Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Tue, 18 Mar 2025 04:08:19 +0000 Subject: [PATCH 12/25] bug fix. --- src/flag_gems/ops/attention.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 756ca2571..c8bfdc09a 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -744,7 +744,7 @@ def flash_fwd_kernel( Q = tl.load(Q_ptr + Q_off, cache_modifier='.cg') else: Q = tl.load(Q_ptr + Q_off, mask=qmask, cache_modifier='.cg') - + if return_P: P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded P_ptr += (n_block_max - 1) * BLOCK_N @@ -754,9 +754,6 @@ def flash_fwd_kernel( rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - # Start from the right most block - n_block = n_block_max - 1 - h_hk_ratio = h // hk K_ptr += bid * k_b_stride K_ptr += (hid // h_hk_ratio) * k_h_stride @@ -857,7 +854,7 @@ def flash_fwd_kernel( V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - for col_start in tl.range(min_col, max_col - masking_cols, step=BLOCK_N, num_stages=num_stages): + for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages): # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): off = col_start * k_s_stride K = tl.load(p_bk0 + off, cache_modifier=".cg") From 1025a3d768e5806504faa5a10de606e39d6a3adc Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Wed, 19 Mar 2025 12:34:35 +0000 Subject: [PATCH 13/25] Pipeline works, causal passes. --- src/flag_gems/ops/attention.py | 67 +++++++++---------- .../runtime/backend/_nvidia/tune_configs.yaml | 3 +- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index c8bfdc09a..07c48b87e 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -602,12 +602,10 @@ def softmax_rescale( def block_m_heuristic(headdim, is_dropout): - # return 128 if headdim <= 128 else 64 - return 64 + return 128 if headdim <= 128 else 64 def block_n_heuristic(headdim, is_dropout): - # return 128 if headdim <= 128 else 64 - return 64 + return 64 if headdim <= 64 else 32 def block_m_splitkv_heuristic(headdim): return 128 if headdim <= 128 else 64 @@ -634,12 +632,12 @@ def block_n_splitkv_heuristic(headdim): 'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]), 'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]), 'num_warps': lambda args: 4, - 'num_stages': lambda args: 3, + 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) -@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "k_b_stride", "o_b_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) def flash_fwd_kernel( Q_ptr, K_ptr, @@ -660,8 +658,8 @@ def flash_fwd_kernel( o_b_stride, o_s_stride, o_h_stride, - h, - hk, + h: tl.constexpr, + hk: tl.constexpr, pSlopes, philox_seed, philox_offset, @@ -703,37 +701,16 @@ def flash_fwd_kernel( if is_causal or is_local: col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right) - # if is_local: - # n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) - # else: - # n_block_min = 0 - - # n_block_max = tl.cdiv(seqlen_k, BLOCK_N) - - # if is_causal or is_local: - # n_block_max = min(n_block_max, - # tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) - if has_alibi: alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale else: alibi_slope = 0.0 - - if (not is_causal) and (not is_local): - if IS_EVEN_MN: - n_masking_blocks: tl.constexpr = 0 - else: - n_masking_blocks: tl.constexpr = 1 - elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) - else: - # local and not causal, - n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1 - masking_cols = n_masking_blocks * BLOCK_N + # masking_cols: tl.constexpr = n_masking_blocks * BLOCK_N + q_b_stride = tl.multiple_of(q_b_stride, HEAD_DIM * h) Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride row_start = m_block * BLOCK_M @@ -754,6 +731,7 @@ def flash_fwd_kernel( rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + k_b_stride = tl.multiple_of(k_b_stride, HEAD_DIM * hk) h_hk_ratio = h // hk K_ptr += bid * k_b_stride K_ptr += (hid // h_hk_ratio) * k_h_stride @@ -766,8 +744,24 @@ def flash_fwd_kernel( p_bk0 = K_ptr + K_offset p_bv0 = V_ptr + V_offset - for col_start in tl.range(max(col_min, col_max - masking_cols), col_max, step=BLOCK_N): - # for r_blk_idx in tl.range(0, min(n_masking_blocks, n_blocks_max - n_blocks_min)): + if (not is_causal) and (not is_local): + if IS_EVEN_MN: + # n_masking_blocks: tl.constexpr = 0 + masking_cols: tl.constexpr = 0 + else: + # n_masking_blocks: tl.constexpr = 1 + masking_cols: tl.constexpr = BLOCK_N + elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero + # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N + else: + # local and not causal, + # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N + + for col_shift in tl.range(0, masking_cols, step=BLOCK_N): + col_start = col_max - col_shift - BLOCK_N + col_start = tl.multiple_of(col_start, BLOCK_N) off = col_start * k_s_stride if IS_EVEN_MN: K = tl.load(p_bk0 + off, cache_modifier=".cg") @@ -796,7 +790,6 @@ def flash_fwd_kernel( has_alibi=has_alibi ) - # is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -804,7 +797,6 @@ def flash_fwd_kernel( rowsum_, softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), - # is_init=is_init ) P = P.to(O_ptr.type.element_ty) @@ -856,6 +848,7 @@ def flash_fwd_kernel( for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages): # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): + # col_start = tl.multiple_of(col_start, BLOCK_N) off = col_start * k_s_stride K = tl.load(p_bk0 + off, cache_modifier=".cg") # if PRE_LOAD_V: @@ -882,7 +875,6 @@ def flash_fwd_kernel( has_alibi=has_alibi ) - # is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -951,6 +943,7 @@ def flash_fwd_kernel( O = O_.to(O_ptr.type.element_ty) # Write back output + o_b_stride = tl.multiple_of(o_b_stride, HEAD_DIM * h) O_ptr += bid * o_b_stride O_ptr += hid * o_h_stride O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM) @@ -1492,6 +1485,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): NUM_HEADS_K=num_heads_k, ) print(f'{kernel.name} shared memory:', kernel.metadata.shared) + print(f'{kernel.name} num_warps:', kernel.metadata.num_warps) + print(f'{kernel.name} num_stages:', kernel.metadata.num_stages) # print(kernel.asm['ttgir']) if n_splits > 1: diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index b4865b3c1..3753f67d2 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -8,12 +8,13 @@ attention: num_warps: warps num_stages: stages block_m: + - 32 - 64 - 128 block_n: - 32 - 64 - - 128 + # - 128 pre_load_v: - true - false From 12f105dbd5e91e1c707da6b21527981276b879a2 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 20 Mar 2025 04:34:50 +0000 Subject: [PATCH 14/25] a couple of fixes, nonequal q k seqlens still broken. --- src/flag_gems/ops/attention.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 07c48b87e..b3363ddcf 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -707,8 +707,6 @@ def flash_fwd_kernel( alibi_slope /= scale else: alibi_slope = 0.0 - - # masking_cols: tl.constexpr = n_masking_blocks * BLOCK_N q_b_stride = tl.multiple_of(q_b_stride, HEAD_DIM * h) Q_ptr += bid * q_b_stride @@ -746,17 +744,13 @@ def flash_fwd_kernel( if (not is_causal) and (not is_local): if IS_EVEN_MN: - # n_masking_blocks: tl.constexpr = 0 masking_cols: tl.constexpr = 0 else: - # n_masking_blocks: tl.constexpr = 1 masking_cols: tl.constexpr = BLOCK_N elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N else: # local and not causal, - # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N for col_shift in tl.range(0, masking_cols, step=BLOCK_N): @@ -847,16 +841,12 @@ def flash_fwd_kernel( O_ = tl.dot(P, V, O_, allow_tf32=False) for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages): - # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): # col_start = tl.multiple_of(col_start, BLOCK_N) off = col_start * k_s_stride K = tl.load(p_bk0 + off, cache_modifier=".cg") - # if PRE_LOAD_V: - # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - S = tl.dot(Q, K) - if PRE_LOAD_V: - V = tl.load(p_bv0 + off, cache_modifier=".cg") + V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + S = tl.dot(Q, K) col_idx = col_start + tl.arange(0, BLOCK_N) row_idx = row_start + tl.arange(0, BLOCK_M) @@ -1113,7 +1103,6 @@ def flash_fwd_splitkv_kernel( ) # col_idx -= BLOCK_N - is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -1121,7 +1110,6 @@ def flash_fwd_splitkv_kernel( rowsum_, softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), - is_init=is_init ) P = P.to(Q_ptr.type.element_ty) @@ -1163,7 +1151,6 @@ def flash_fwd_splitkv_kernel( ) # col_idx -= BLOCK_N - is_init = (n_block == n_block_max - 1).to(tl.int1) O_, P, rowmax_, rowsum_ = softmax_rescale( O_, S, @@ -1171,7 +1158,6 @@ def flash_fwd_splitkv_kernel( rowsum_, softmax_scale_log2e=softmax_scale_log2e, is_border=is_local, - is_init=is_init ) P = P.to(Q_ptr.type.element_ty) From a33ff44275784588d446b3a81a7c715188c0ed2a Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 20 Mar 2025 15:57:47 +0000 Subject: [PATCH 15/25] Causal results are stable now. Consistent with aten._flash_attention_forward. --- src/flag_gems/ops/attention.py | 82 ++++++++++++++++------------------ 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index b3363ddcf..54248c4a1 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -546,18 +546,16 @@ def apply_mask( need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn) if need_mask: col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) - col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right) + col_rb = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + ws_right) - if not has_alibi: - alibi_slope = .0 - - S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) + if has_alibi: + S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) if is_causal: - S = tl.where(col_idx[None, :] >= col_rb[:, None], float('-inf'), S) + S = tl.where(col_idx[None, :] > col_rb[:, None], float('-inf'), S) if is_local: - S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S) + S = tl.where(col_idx[None, :] > col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S) if (not is_local) & (not is_causal) & (not is_even_mn): S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S) @@ -578,34 +576,36 @@ def softmax_rescale( prev_max = row_max row_max = tl.maximum(row_max, tl.max(S, 1)) - # if not is_init: - # if is_border: - # cur_max = tl.where(row_max == float('-inf'), 0, row_max) - # else: - # cur_max = row_max - # p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) - # row_sum *= p_scale - # O_acc *= p_scale[:, None] - if is_border: cur_max = tl.where(row_max == float('-inf'), 0, row_max) else: cur_max = row_max + p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) row_sum *= p_scale O_acc *= p_scale[:, None] max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e) + P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) row_sum = row_sum + tl.sum(P, 1) return O_acc, P, row_max, row_sum def block_m_heuristic(headdim, is_dropout): - return 128 if headdim <= 128 else 64 + block_m = 128 if headdim <= 128 else 64 + print('block_m:', block_m) + return block_m def block_n_heuristic(headdim, is_dropout): - return 64 if headdim <= 64 else 32 + block_n = 64 if headdim <= 64 else 32 + print('block_n:', block_n) + return block_n + +def is_even_mn(args): + even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) + print('is_even_mn:', even_mn) + return even_mn def block_m_splitkv_heuristic(headdim): return 128 if headdim <= 128 else 64 @@ -633,8 +633,8 @@ def block_n_splitkv_heuristic(headdim): 'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]), 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, - 'PRE_LOAD_V': lambda args: True, - 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), + 'PRE_LOAD_V': lambda args: False, + 'IS_EVEN_MN': lambda args: is_even_mn(args), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) @@ -762,7 +762,8 @@ def flash_fwd_kernel( if PRE_LOAD_V: V = tl.load(p_bv0 + off, cache_modifier=".cg") else: - kvmask = col < seqlen_k + col_idx = col_start + tl.arange(0, BLOCK_N) + kvmask = col_idx < seqlen_k K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") if PRE_LOAD_V: V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") @@ -845,7 +846,7 @@ def flash_fwd_kernel( off = col_start * k_s_stride K = tl.load(p_bk0 + off, cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") S = tl.dot(Q, K) col_idx = col_start + tl.arange(0, BLOCK_N) @@ -919,12 +920,13 @@ def flash_fwd_kernel( O_ = tl.dot(P, V, O_) + # LSE # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only. lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) - inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) - + # Rescale output + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) if is_dropout: O_ *= inv_sum[:, None] * rpdrop else: @@ -1101,7 +1103,6 @@ def flash_fwd_splitkv_kernel( is_local=is_local, has_alibi=has_alibi ) - # col_idx -= BLOCK_N O_, P, rowmax_, rowsum_ = softmax_rescale( O_, @@ -1119,21 +1120,16 @@ def flash_fwd_splitkv_kernel( else: V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - # if n_masking_blocks > 1 and n_block <= n_block_min: - # break + for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] K = tl.load(K_ptr + K_offset, cache_modifier=".cg") - # if PRE_LOAD_V: - # V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - # V = tl.load(V_ptr + V_offset, cache_modifier=".cg") - S = tl.dot(Q, K) - if PRE_LOAD_V: V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + S = tl.dot(Q, K) S = apply_mask( S, @@ -1149,7 +1145,6 @@ def flash_fwd_splitkv_kernel( is_local=is_local, has_alibi=has_alibi ) - # col_idx -= BLOCK_N O_, P, rowmax_, rowsum_ = softmax_rescale( O_, @@ -1246,8 +1241,7 @@ def flash_fwd_splitkv_combine_kernel( out_splits = tl.load(out_splits_ptr + out_split_offset, mask=out_split_mask, other=0) out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) out = out.to(out_ptr.type.element_ty) - - # tl.device_print('O', out) + # Write back output out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size) tl.store(out_ptr + out_offset, out, mask=out_mask[:, None]) @@ -1364,16 +1358,18 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): rpdrop = 1. / p_dropout # Check splitkv - if not is_dropout: - n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m_splitkv_heuristic(head_size)) + def try_split_kv(): + block_m = block_m_splitkv_heuristic(head_size) + n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m) num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count - n_blocks = triton.cdiv(seqlen_k, block_n_splitkv_heuristic(head_size)) + block_n = block_n_splitkv_heuristic(head_size) + n_blocks = triton.cdiv(seqlen_k, block_n) n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) - print('n_blocks:', n_blocks) - print('n_splits:', n_splits) - else: - n_splits = 1 - + return n_splits + + n_splits = try_split_kv() if is_dropout else 1 + print('n_splits:', n_splits) + if n_splits > 1: lse_splits = torch.empty( (n_splits, batch_size, num_heads, seqlen_q), From 04e49e00b458ed16fd1179187101e5ffb9a5e188 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Fri, 21 Mar 2025 16:41:07 +0000 Subject: [PATCH 16/25] Dropout passes. --- src/flag_gems/ops/attention.py | 93 +++++++++++++---------------- src/flag_gems/utils/random_utils.py | 7 +++ 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 54248c4a1..b0f866ca1 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -422,24 +422,24 @@ def philox_(seed, subsequence, offset): kPhilox10B: tl.constexpr = 0xBB67AE85 k0, k1 = u64_to_lohi(seed.to(tl.uint64)) c0, c1 = u64_to_lohi(offset.to(tl.uint64)) - c2, c3 = u64_to_lohi(subsequence(tl.uint64)) + c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) # pragma unroll kPhiloxSA: tl.constexpr = 0xD2511F53 kPhiloxSB: tl.constexpr = 0xCD9E8D57 for _ in range(6): - res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) - res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) + res0 = kPhiloxSA * c0.to(tl.uint64) + res1 = kPhiloxSB * c2.to(tl.uint64) res0_x, res0_y = u64_to_lohi(res0) res1_x, res1_y = u64_to_lohi(res1) c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x k0 += kPhilox10A k1 += kPhilox10B - res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64) - res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64) - res0_x.res0_y = u64_to_lohi(res0) - res1_x.res1_y = u64_to_lohi(res1) + res0 = kPhiloxSA * c0.to(tl.uint64) + res1 = kPhiloxSB * c2.to(tl.uint64) + res0_x, res0_y = u64_to_lohi(res0) + res1_x, res1_y = u64_to_lohi(res1) c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x return c0, c1, c2, c3 @@ -462,35 +462,30 @@ def apply_dropout_mask( def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr): r = r_u32 p = p_u8 - m0 = tl.where(r & 0xFF < p, 0, 1) + # m0 = tl.where(r & 0xFF < p, 0, 1) + m0 = ~(r & 0xFF < p) r >>= 8 - m1 = tl.where(r & 0xFF < p, 0, 1) - m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) + # m1 = tl.where(r & 0xFF < p, 0, 1) + m1 = ~(r & 0xFF < p) + m = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) r >>= 8 - m0 = tl.where(r & 0xFF < p, 0, 1) + # n0 = tl.where(r & 0xFF < p, 0, 1) + n0 = ~(r & 0xFF < p) r >>= 8 - m1 = tl.where(r & 0xFF < p, 0, 1) - m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) - - m = tl.join(m0, m1).trans(2, 0, 1).reshape(4 * M, N) - return m - - -@triton.jit( - do_not_specialize=[ - "b", - "h", - "row_start", - "col_start", - "philox_seed", - "philox_offset", - ] -) + # n1 = tl.where(r & 0xFF < p, 0, 1) + n1 = ~(r & 0xFF < p) + n = tl.join(n0, n1).trans(2, 0, 1).reshape(2 * M, N) + + mn = tl.join(m, n).trans(2, 0, 1).reshape(4 * M, N) + return mn + + +@triton.jit def apply_dropout( P, - sor, - soc, + row_start, + col_start, bid, hid, philox_seed, @@ -501,30 +496,31 @@ def apply_dropout( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - # P is of size (BLOCK_M, BLOCK_N) and its scalar bitsize is 32 - # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32 + # We only need one philox call for every 16 rows because a single philox call + # generates 4 random uints, which are casted for 16 random draws in uint8's. M: tl.constexpr = BLOCK_M // 16 N: tl.constexpr = BLOCK_N // 32 - row = sor + tl.arange(0, M)[:, None] - col = soc + tl.arange(0, BLOCK_N)[None, :] // 32 + row = row_start // 16 + tl.arange(0, M)[:, None] + col = col_start + tl.arange(0, BLOCK_N)[None, :] + + subsequence = u64_from_lohi(row, col // 32) tid = tl.arange(0, BLOCK_N)[None, :] % 32 philox_offset += (bid * NUM_HEADS + hid) * 32 + tid - - subsequence = u64_from_lohi(row * 32, col) + philox_offset += subsequence * 0 r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset) # Fully unrolled due to triton's inability to concat 2d tensor - m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N) - m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N) - m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N) + m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N) + m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N) + m = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, BLOCK_N) - m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N) - m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N) - m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N) + n0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N) + n1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N) + n = tl.join(n0, n1).trans(2, 0, 1).reshape(8 * M, BLOCK_N) - m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N) - P = apply_dropout_mask(P, m) + mn = tl.join(m, n).trans(2, 0, 1).reshape(16 * M, BLOCK_N) + P = apply_dropout_mask(P, mn, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit) return P @@ -796,11 +792,9 @@ def flash_fwd_kernel( P = P.to(O_ptr.type.element_ty) if is_dropout: - row_start = m_block * (BLOCK_M // 16) - col_start = n_block * (BLOCK_N // 32) - if return_P: P_drop = P + P_drop = apply_dropout( P_drop, row_start, @@ -877,9 +871,6 @@ def flash_fwd_kernel( P = P.to(O_ptr.type.element_ty) if is_dropout: - row_start = m_block * (BLOCK_M // 16) - col_start = n_block * (BLOCK_N // 32) - if return_P: P_drop = P P_drop = apply_dropout( @@ -1346,7 +1337,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): # Set dropout params if p_dropout > 0: - increment = triton.cdiv(batch_size * num_heads * 32) + increment = batch_size * num_heads * 32 philox_seed, philox_offset = update_philox_state(increment) is_dropout = True else: diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index 5ca61a8f4..ec0f771b5 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -46,6 +46,13 @@ def update_philox_state(increment, device=None): gen.set_state(state_copy) return seed, offset +def set_philox_state(seed, offset, device=None): + device = device or torch_device_fn.current_device() + gen = torch_device_fn.default_generators[device] + assert offset % 4 == 0 + new_state = torch.tensor((seed, offset), dtype=torch.int64) + gen = get.set_state(new_state.view(torch.uint8)) + return def per_thread_offset(N, num_blocks, num_warps, warp_threads=32): block_threads = num_warps * warp_threads From 2ee91ea76cf6bd3c6634b7cae76be1c0dbbe5e55 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Fri, 21 Mar 2025 16:45:36 +0000 Subject: [PATCH 17/25] dropout disables splitkv. --- src/flag_gems/ops/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index b0f866ca1..c686d6230 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -1358,7 +1358,7 @@ def try_split_kv(): n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) return n_splits - n_splits = try_split_kv() if is_dropout else 1 + n_splits = try_split_kv() if not is_dropout else 1 print('n_splits:', n_splits) if n_splits > 1: From 1b7e6c0e2577da13d43363775c509f3a6c0ef977 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Sun, 23 Mar 2025 05:34:31 +0000 Subject: [PATCH 18/25] Working on splitkv.. --- src/flag_gems/ops/attention.py | 217 +++++++++++++++++---------------- 1 file changed, 114 insertions(+), 103 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index c686d6230..8601e0bd7 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -427,7 +427,7 @@ def philox_(seed, subsequence, offset): # pragma unroll kPhiloxSA: tl.constexpr = 0xD2511F53 kPhiloxSB: tl.constexpr = 0xCD9E8D57 - for _ in range(6): + for _ in tl.static_range(6): res0 = kPhiloxSA * c0.to(tl.uint64) res1 = kPhiloxSB * c2.to(tl.uint64) res0_x, res0_y = u64_to_lohi(res0) @@ -462,22 +462,17 @@ def apply_dropout_mask( def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr): r = r_u32 p = p_u8 - # m0 = tl.where(r & 0xFF < p, 0, 1) m0 = ~(r & 0xFF < p) r >>= 8 - # m1 = tl.where(r & 0xFF < p, 0, 1) m1 = ~(r & 0xFF < p) - m = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N) + m = tl.join(m0, m1) r >>= 8 - # n0 = tl.where(r & 0xFF < p, 0, 1) n0 = ~(r & 0xFF < p) r >>= 8 - # n1 = tl.where(r & 0xFF < p, 0, 1) n1 = ~(r & 0xFF < p) - n = tl.join(n0, n1).trans(2, 0, 1).reshape(2 * M, N) - - mn = tl.join(m, n).trans(2, 0, 1).reshape(4 * M, N) + n = tl.join(n0, n1) + mn = tl.join(m, n) return mn @@ -500,26 +495,28 @@ def apply_dropout( # generates 4 random uints, which are casted for 16 random draws in uint8's. M: tl.constexpr = BLOCK_M // 16 N: tl.constexpr = BLOCK_N // 32 + row_start = tl.multiple_of(row_start, BLOCK_M) + col_start = tl.multiple_of(col_start, BLOCK_N) row = row_start // 16 + tl.arange(0, M)[:, None] col = col_start + tl.arange(0, BLOCK_N)[None, :] subsequence = u64_from_lohi(row, col // 32) tid = tl.arange(0, BLOCK_N)[None, :] % 32 - philox_offset += (bid * NUM_HEADS + hid) * 32 + tid - philox_offset += subsequence * 0 - r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset) + offset = philox_offset + (bid * NUM_HEADS + hid) * 32 + tid + offset += subsequence * 0 + r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) # Fully unrolled due to triton's inability to concat 2d tensor m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N) m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N) - m = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, BLOCK_N) + m = tl.join(m0, m1) n0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N) n1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N) - n = tl.join(n0, n1).trans(2, 0, 1).reshape(8 * M, BLOCK_N) + n = tl.join(n0, n1) - mn = tl.join(m, n).trans(2, 0, 1).reshape(16 * M, BLOCK_N) + mn = tl.join(m, n).reshape(16 * M, BLOCK_N) P = apply_dropout_mask(P, mn, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit) return P @@ -608,11 +605,9 @@ def block_m_splitkv_heuristic(headdim): def block_n_splitkv_heuristic(headdim): if headdim <= 64: - return 256 - elif headdim <= 128: - return 128 - else: return 64 + else: + return 32 # @triton.autotune( # configs=runtime.get_tuned_config("attention"), @@ -718,8 +713,8 @@ def flash_fwd_kernel( if return_P: P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded - P_ptr += (n_block_max - 1) * BLOCK_N P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) + p_bp0 = P_ptr + P_offset O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) @@ -809,8 +804,10 @@ def flash_fwd_kernel( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :]) - P_offset += BLOCK_N + if IS_EVEN_MN: + tl.store(p_bp0 + col_start, P_drop) + else: + tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) P = apply_dropout( P, @@ -887,8 +884,10 @@ def flash_fwd_kernel( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) - tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask) - P_offset += BLOCK_N + if IS_EVEN_MN: + tl.store(p_bp0 + col_start, P_drop) + else: + tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) P = apply_dropout( P, @@ -954,7 +953,7 @@ def flash_fwd_kernel( 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) -@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"]) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) def flash_fwd_splitkv_kernel( Q_ptr, K_ptr, @@ -1000,7 +999,7 @@ def flash_fwd_splitkv_kernel( IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - num_splits: tl.constexpr, + blocks_per_split: tl.constexpr, num_warps: tl.constexpr, num_stages: tl.constexpr ): @@ -1008,15 +1007,16 @@ def flash_fwd_splitkv_kernel( split_id = tl.program_id(1) bid = tl.program_id(2) // NUM_HEADS hid = tl.program_id(2) % NUM_HEADS - - blocks_per_split = tl.cdiv(tl.cdiv(seqlen_k, BLOCK_N), num_splits) + + split_block_min = split_id * blocks_per_split + split_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N)) if is_local: - n_block_min = max(split_id * blocks_per_split, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) + n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) else: - n_block_min = split_id * blocks_per_split + n_block_min = 0 - n_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N)) + n_block_max = tl.cdiv(seqlen_k, BLOCK_N) if is_causal or is_local: n_block_max = min(n_block_max, tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) @@ -1030,56 +1030,59 @@ def flash_fwd_splitkv_kernel( if (not is_causal) and (not is_local): if IS_EVEN_MN: - n_masking_blocks = 0 + masking_block_min = n_block_max else: - n_masking_blocks = 1 + masking_block_min = n_block_max - 1 elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) else: # local and not causal, - n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) + 1 + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + p_qm = Q_ptr + Q_off qmask = row_idx[:, None] < seqlen_q if IS_EVEN_MN: - Q = tl.load(Q_ptr + Q_off) + Q = tl.load(p_qm) else: - Q = tl.load(Q_ptr + Q_off, mask=qmask) - - # Start from the right most block - n_block = n_block_max - 1 + Q = tl.load(p_qm, mask=qmask) h_hk_ratio = h // hk K_ptr += bid * k_b_stride K_ptr += (hid // h_hk_ratio) * k_h_stride V_ptr += bid * k_b_stride V_ptr += (hid // h_hk_ratio) * k_h_stride + + K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + p_k0 = K_ptr + K_offset - P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded - P_ptr += n_block * BLOCK_N - P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N) + V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + p_v0 = V_ptr + V_offset O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - - for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_blocks - 1, step=-1): + + split_masking_block = max(masking_block_min, split_block_min) + for n_block in tl.range(split_block_max - 1, split_masking_block - 1, step=-1): + kv_off = n_block * BLOCK_N * k_s_stride col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) if IS_EVEN_MN: - K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") else: kvmask = col_idx < seqlen_k - K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg") + K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") + S = tl.dot(Q, K, allow_tf32=False) + S = apply_mask( S, col_idx, @@ -1112,16 +1115,15 @@ def flash_fwd_splitkv_kernel( V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - - for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages): - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - K = tl.load(K_ptr + K_offset, cache_modifier=".cg") + nomasking_max = min(split_block_max, masking_block_min) + for n_block in tl.range(split_block_min, nomasking_max, num_stages=num_stages): + kv_off = n_block * BLOCK_N * k_s_stride + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") if PRE_LOAD_V: - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") S = tl.dot(Q, K) + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) S = apply_mask( S, col_idx, @@ -1149,8 +1151,7 @@ def flash_fwd_splitkv_kernel( P = P.to(Q_ptr.type.element_ty) if not PRE_LOAD_V: - V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") O_ = tl.dot(P, V, O_) @@ -1162,21 +1163,24 @@ def flash_fwd_splitkv_kernel( O_ *= inv_sum[:, None] # Write back output - O_split_ptr = O_ptr - # (n_splits, batch_size, num_heads, seqlen_q, head_size) + # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) # grid = (seq_block, split, batch * head) + O_split_ptr = O_ptr + # + split, batch, head offsets, seq_block offsets are already added in row_idx O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) + p_om = O_split_ptr + O_split_offset if IS_EVEN_MN: - tl.store(O_split_ptr + O_split_offset, O_) + tl.store(p_om, O_) else: - tl.store(O_split_ptr + O_split_offset, O_, mask=qmask) + tl.store(p_om, O_, mask=qmask) # Write back lse + # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) lse_split_ptr = lse_ptr - lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q - lse_split_ptr += m_block * BLOCK_M + # + split, batch, head, seq_block offsets + lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M if IS_EVEN_MN: tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse) @@ -1303,22 +1307,25 @@ def mha_fwd( seqlen_k_rounded = round_multiple(seqlen_k, 128) def splits_heuristics(num_tasks, num_sms, n_blocks): - # splits only number of waves and wave efficiency are both low + # splits when wave efficiency is low n_waves = triton.cdiv(num_tasks, num_sms) eff = (num_tasks / num_sms) / n_waves - if eff > 0.85 or n_waves > 10: + if eff > 0.85: return 1 - max_eff = eff + best_splits = 1 - for w in range(n_waves, 10): - n_splits = min(num_sms, n_blocks, w * num_sms // num_tasks) - blocks_per_split = triton.cdiv(n_blocks, n_splits) - if blocks_per_split < 4: - continue + best_eff = eff + min_blocks_per_split = 1 + max_blocks_per_split = triton.cdiv(n_blocks, 2) + for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]: n_splits = triton.cdiv(n_blocks, blocks_per_split) - eff = (n_splits * num_tasks / num_sms) / w - if eff > max_eff: - max_eff = eff + n_waves = triton.cdiv(n_splits * num_tasks, num_sms) + eff = (n_splits * num_tasks / num_sms) / n_waves + if eff > 0.85: + best_splits = n_splits + break + if eff > best_eff: + best_eff = eff best_splits = n_splits return best_splits @@ -1348,6 +1355,27 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): pdrop_u8 = math.floor(p_dropout * 255.0) rpdrop = 1. / p_dropout + M_LOG2E = 1.4426950408889634074 + softmax_scale_log2e = softmax_scale * M_LOG2E + + # Set alibi params + if alibi_slopes is not None: + assert alibi_slopes.device == q_device + assert alibi_slopes.dtype in (torch.float, ) + assert alibi_slopes.stride(-1) == 1 + assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads) + alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 + has_alibi = True + else: + alibi_slopes_batch_stride = 0 + has_alibi = False + + # Set SWA params + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + # ONLY EVEN_K IS SUPPORTED + assert head_size == head_size_rounded + # Check splitkv def try_split_kv(): block_m = block_m_splitkv_heuristic(head_size) @@ -1356,10 +1384,15 @@ def try_split_kv(): block_n = block_n_splitkv_heuristic(head_size) n_blocks = triton.cdiv(seqlen_k, block_n) n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) - return n_splits - - n_splits = try_split_kv() if not is_dropout else 1 + blocks_per_split = triton.cdiv(n_blocks, n_splits) + return n_splits, blocks_per_split + + if not is_dropout: + n_splits, blocks_per_split = try_split_kv() + else: + n_splits, blocks_per_split = 1, None print('n_splits:', n_splits) + print('blocks_per_split', blocks_per_split) if n_splits > 1: lse_splits = torch.empty( @@ -1373,27 +1406,6 @@ def try_split_kv(): device=q_device ) - M_LOG2E = 1.4426950408889634074 - softmax_scale_log2e = softmax_scale * M_LOG2E - - # Set alibi params - if alibi_slopes is not None: - assert alibi_slopes.device == q_device - assert alibi_slopes.dtype in (torch.float, ) - assert alibi_slopes.stride(-1) == 1 - assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads) - alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 - has_alibi = True - else: - alibi_slopes_batch_stride = 0 - has_alibi = False - - # Set SWA params - is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal - - # ONLY EVEN_K IS SUPPORTED - assert head_size == head_size_rounded - # Launch kernel if n_splits > 1: grid = lambda args: ( @@ -1453,7 +1465,7 @@ def try_split_kv(): ws_right=window_size_right, return_P=return_softmax, BATCH_SIZE=batch_size, - num_splits=n_splits, + blocks_per_split=blocks_per_split, NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, ) @@ -1484,7 +1496,6 @@ def try_split_kv(): q_total=batch_size * num_heads * seqlen_q, MAX_N_SPLITS=triton.next_power_of_2(n_splits), ) - print(f'{kernel.name} shared memory:', kernel.metadata.shared) if swap_seq_and_group: out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) From a1c93be491cb403a823ba6190dbaf6fd3a83d3be Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Sun, 23 Mar 2025 06:05:42 +0000 Subject: [PATCH 19/25] Working on splitkv.. --- src/flag_gems/ops/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 8601e0bd7..2e34de23e 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -1156,7 +1156,9 @@ def flash_fwd_splitkv_kernel( O_ = tl.dot(P, V, O_) # LSE - lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + # if (split_block_max <= n_block_min) or (split_block_min >= n_block_max): + + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) # Rescale output From 24cb70e15ad75b940f221b2c4b42cc0ae58bfd84 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Sun, 23 Mar 2025 06:33:19 +0000 Subject: [PATCH 20/25] Splitkv passes but requires solid perf opt. --- src/flag_gems/ops/attention.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 2e34de23e..4d07179dd 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -674,7 +674,7 @@ def flash_fwd_kernel( IS_EVEN_MN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - num_splits: tl.constexpr, + blocks_per_split: tl.constexpr, num_warps: tl.constexpr, num_stages: tl.constexpr ): @@ -913,7 +913,7 @@ def flash_fwd_kernel( # LSE # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only. - lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) # Rescale output inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) @@ -1156,8 +1156,6 @@ def flash_fwd_splitkv_kernel( O_ = tl.dot(P, V, O_) # LSE - # if (split_block_max <= n_block_min) or (split_block_min >= n_block_max): - lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) @@ -1312,7 +1310,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): # splits when wave efficiency is low n_waves = triton.cdiv(num_tasks, num_sms) eff = (num_tasks / num_sms) / n_waves - if eff > 0.85: + if eff > 0.8 or n_waves > 1: return 1 best_splits = 1 @@ -1326,9 +1324,6 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): if eff > 0.85: best_splits = n_splits break - if eff > best_eff: - best_eff = eff - best_splits = n_splits return best_splits with torch_device_fn.device(q_device): From 901cfdc626734a9aa84217ea559b392c72cbd0dd Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Sun, 23 Mar 2025 15:08:58 +0000 Subject: [PATCH 21/25] Dirty hacking for debugging splitkv. --- src/flag_gems/ops/attention.py | 351 +++++++++++++++++++++++++-------- 1 file changed, 272 insertions(+), 79 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 4d07179dd..025270edd 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -784,7 +784,7 @@ def flash_fwd_kernel( softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), ) - P = P.to(O_ptr.type.element_ty) + P = P.to(V_ptr.type.element_ty) if is_dropout: if return_P: @@ -865,7 +865,7 @@ def flash_fwd_kernel( softmax_scale_log2e=softmax_scale_log2e, is_border=is_local, ) - P = P.to(O_ptr.type.element_ty) + P = P.to(V_ptr.type.element_ty) if is_dropout: if return_P: @@ -948,13 +948,13 @@ def flash_fwd_kernel( 'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), 'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), 'num_warps': lambda args: 4, - 'num_stages': lambda args: 3, + 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) -def flash_fwd_splitkv_kernel( +def flash_fwd_splitkv_kernel_v2( Q_ptr, K_ptr, V_ptr, @@ -1008,18 +1008,16 @@ def flash_fwd_splitkv_kernel( bid = tl.program_id(2) // NUM_HEADS hid = tl.program_id(2) % NUM_HEADS - split_block_min = split_id * blocks_per_split - split_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N)) + split_col_min = split_id * blocks_per_split * BLOCK_N + split_col_max = split_col_min + blocks_per_split * BLOCK_N - if is_local: - n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N) - else: - n_block_min = 0 + col_min = 0 + + col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N + if is_causal: + col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right) - n_block_max = tl.cdiv(seqlen_k, BLOCK_N) - if is_causal or is_local: - n_block_max = min(n_block_max, - tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) + split_col_max = min(split_col_max, col_max) if has_alibi: alibi_offset = bid * slopes_batch_stride + hid @@ -1028,27 +1026,28 @@ def flash_fwd_splitkv_kernel( else: alibi_slope = 0.0 - if (not is_causal) and (not is_local): + if not is_causal: if IS_EVEN_MN: - masking_block_min = n_block_max + masking_cols: tl.constexpr = 0 else: - masking_block_min = n_block_max - 1 + masking_cols: tl.constexpr = BLOCK_N elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) + masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N else: # local and not causal, - masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 + masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N Q_ptr += bid * q_b_stride Q_ptr += hid * q_h_stride - row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + row_start = m_block * BLOCK_M + row_idx = row_start + tl.arange(0, BLOCK_M) Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] p_qm = Q_ptr + Q_off qmask = row_idx[:, None] < seqlen_q if IS_EVEN_MN: - Q = tl.load(p_qm) + Q = tl.load(p_qm, cache_modifier=".cg") else: - Q = tl.load(p_qm, mask=qmask) + Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") h_hk_ratio = h // hk K_ptr += bid * k_b_stride @@ -1057,32 +1056,31 @@ def flash_fwd_splitkv_kernel( V_ptr += (hid // h_hk_ratio) * k_h_stride K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - p_k0 = K_ptr + K_offset + p_bk0 = K_ptr + K_offset V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - p_v0 = V_ptr + V_offset + p_bv0 = V_ptr + V_offset O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - split_masking_block = max(masking_block_min, split_block_min) - for n_block in tl.range(split_block_max - 1, split_masking_block - 1, step=-1): - kv_off = n_block * BLOCK_N * k_s_stride - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + for col_start in tl.range(split_col_min, split_col_max, step=BLOCK_N): + col_start = tl.multiple_of(col_start, BLOCK_N) + off = col_start * k_s_stride if IS_EVEN_MN: - K = tl.load(p_k0 + kv_off, cache_modifier=".cg") + K = tl.load(p_bk0 + off, cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") else: + col_idx = col_start + tl.arange(0, BLOCK_N) kvmask = col_idx < seqlen_k - K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") + K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") if PRE_LOAD_V: - V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") - + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") S = tl.dot(Q, K, allow_tf32=False) - + col_idx = col_start + tl.arange(0, BLOCK_N) + row_idx = row_start + tl.arange(0, BLOCK_M) S = apply_mask( S, col_idx, @@ -1094,7 +1092,7 @@ def flash_fwd_splitkv_kernel( alibi_slope, is_even_mn=IS_EVEN_MN, is_causal=is_causal, - is_local=is_local, + is_local=False, has_alibi=has_alibi ) @@ -1106,54 +1104,242 @@ def flash_fwd_splitkv_kernel( softmax_scale_log2e=softmax_scale_log2e, is_border=(is_causal or is_local), ) - P = P.to(Q_ptr.type.element_ty) + P = P.to(V_ptr.type.element_ty) if not PRE_LOAD_V: + off = col_start * k_s_stride if IS_EVEN_MN: - V = tl.load(V_ptr + V_offset, cache_modifier=".cg") + V = tl.load(p_bv0 + off, cache_modifier=".cg") else: - V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg") + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") O_ = tl.dot(P, V, O_, allow_tf32=False) - nomasking_max = min(split_block_max, masking_block_min) - for n_block in tl.range(split_block_min, nomasking_max, num_stages=num_stages): - kv_off = n_block * BLOCK_N * k_s_stride - K = tl.load(p_k0 + kv_off, cache_modifier=".cg") - if PRE_LOAD_V: - V = tl.load(p_v0 + kv_off, cache_modifier=".cg") - S = tl.dot(Q, K) + # LSE + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) + + # Rescale output + O_ *= inv_sum[:, None] - col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) - S = apply_mask( - S, - col_idx, - row_idx, - seqlen_q, - seqlen_k, - ws_left, - ws_right, - alibi_slope, - is_even_mn=True, - is_causal=False, - is_local=is_local, - has_alibi=has_alibi - ) + # Write back output + # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) + # grid = (seq_block, split, batch * head) + O_split_ptr = O_ptr + # + split, batch, head offsets, seq_block offsets are already added in row_idx + O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM + O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) + O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM) + p_om = O_split_ptr + O_split_offset - O_, P, rowmax_, rowsum_ = softmax_rescale( - O_, - S, - rowmax_, - rowsum_, - softmax_scale_log2e=softmax_scale_log2e, - is_border=is_local, - ) + if IS_EVEN_MN: + tl.store(p_om, O_, cache_modifier=".cg") + else: + tl.store(p_om, O_, mask=qmask, cache_modifier=".cg") + + # Write back lse + # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) + lse_split_ptr = lse_ptr + # + split, batch, head, seq_block offsets + lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M - P = P.to(Q_ptr.type.element_ty) + if IS_EVEN_MN: + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") + else: + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg") - if not PRE_LOAD_V: - V = tl.load(p_v0 + kv_off, cache_modifier=".cg") - O_ = tl.dot(P, V, O_) +@triton.heuristics( + values={ + 'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'num_warps': lambda args: 4, + 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, + 'PRE_LOAD_V': lambda args: True, + 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), + } +) +@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) +def flash_fwd_splitkv_kernel( + Q_ptr, + K_ptr, + V_ptr, + P_ptr, + O_ptr, + lse_ptr, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + seqlen_k_rounded, + q_b_stride, + q_s_stride, + q_h_stride, + k_b_stride, + k_s_stride, + k_h_stride, + o_b_stride, + o_s_stride, + o_h_stride, + h, + hk, + pSlopes, + philox_seed, + philox_offset, + pdrop_u8, + rpdrop, + slopes_batch_stride, + HEAD_DIM: tl.constexpr, + is_dropout: tl.constexpr, + is_causal: tl.constexpr, + is_local: tl.constexpr, + has_alibi: tl.constexpr, + softmax_scale: tl.constexpr, + softmax_scale_log2e: tl.constexpr, + ws_left: tl.constexpr, + ws_right: tl.constexpr, + return_P: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BATCH_SIZE: tl.constexpr, + NUM_HEADS: tl.constexpr, + NUM_HEADS_K: tl.constexpr, + IS_EVEN_MN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + blocks_per_split: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr +): + m_block = tl.program_id(0) + split_id = tl.program_id(1) + bid = tl.program_id(2) // NUM_HEADS + hid = tl.program_id(2) % NUM_HEADS + + split_block_min = split_id * blocks_per_split + split_block_max = split_block_min + blocks_per_split + + n_block_max = tl.cdiv(seqlen_k, BLOCK_N) + if is_causal: + n_block_max = min(n_block_max, + tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N)) + + if has_alibi: + alibi_offset = bid * slopes_batch_stride + hid + alibi_slope = tl.load(pSlopes + alibi_offset) + alibi_slope /= scale + else: + alibi_slope = 0.0 + + if not is_causal: + if IS_EVEN_MN: + masking_block_min = n_block_max + else: + masking_block_min = n_block_max - 1 + elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) + else: + masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 + + Q_ptr += bid * q_b_stride + Q_ptr += hid * q_h_stride + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] + p_qm = Q_ptr + Q_off + qmask = row_idx[:, None] < seqlen_q + if IS_EVEN_MN: + Q = tl.load(p_qm, cache_modifier=".cg") + else: + Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") + + h_hk_ratio = h // hk + K_ptr += bid * k_b_stride + K_ptr += (hid // h_hk_ratio) * k_h_stride + V_ptr += bid * k_b_stride + V_ptr += (hid // h_hk_ratio) * k_h_stride + + K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] + p_k0 = K_ptr + K_offset + + V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] + p_v0 = V_ptr + V_offset + + O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) + rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) + + if split_block_max <= masking_block_min: + # no masking needed + for n_block in tl.range(split_block_min, split_block_max, num_stages=num_stages): + kv_off = n_block * BLOCK_N * k_s_stride + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + S = tl.dot(Q, K) + + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + + if has_alibi: + S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=False, + ) + + if not PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + P = P.to(Q_ptr.type.element_ty) + O_ = tl.dot(P, V, O_) + else: + for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): + kv_off = n_block * BLOCK_N * k_s_stride + col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + if IS_EVEN_MN: + K = tl.load(p_k0 + kv_off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + else: + kvmask = col_idx < seqlen_k + K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") + + S = tl.dot(Q, K, allow_tf32=False) + + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + alibi_slope, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=False, + has_alibi=has_alibi + ) + + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=(is_causal or is_local), + ) + + if not PRE_LOAD_V: + if IS_EVEN_MN: + V = tl.load(p_v0 + kv_off, cache_modifier=".cg") + else: + V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") + P = P.to(Q_ptr.type.element_ty) + O_ = tl.dot(P, V, O_, allow_tf32=False) # LSE lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) @@ -1169,12 +1355,13 @@ def flash_fwd_splitkv_kernel( # + split, batch, head offsets, seq_block offsets are already added in row_idx O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) + O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM) p_om = O_split_ptr + O_split_offset if IS_EVEN_MN: - tl.store(p_om, O_) + tl.store(p_om, O_, cache_modifier=".cg") else: - tl.store(p_om, O_, mask=qmask) + tl.store(p_om, O_, mask=qmask, cache_modifier=".cg") # Write back lse # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) @@ -1183,9 +1370,9 @@ def flash_fwd_splitkv_kernel( lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M if IS_EVEN_MN: - tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse) + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") else: - tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q) + tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg") @triton.jit @@ -1384,10 +1571,16 @@ def try_split_kv(): blocks_per_split = triton.cdiv(n_blocks, n_splits) return n_splits, blocks_per_split - if not is_dropout: + if not is_dropout and not is_local: n_splits, blocks_per_split = try_split_kv() else: n_splits, blocks_per_split = 1, None + +# n_splits = 1 + block_n = block_n_splitkv_heuristic(head_size) + n_blocks = triton.cdiv(seqlen_k, block_n) + blocks_per_split = triton.cdiv(n_blocks, n_splits) + print('block_n:', block_n) print('n_splits:', n_splits) print('blocks_per_split', blocks_per_split) @@ -1410,7 +1603,7 @@ def try_split_kv(): n_splits, batch_size * num_heads ) - kernel = flash_fwd_splitkv_kernel[grid] + kernel = flash_fwd_splitkv_kernel_v2[grid] tmp_lse = lse_splits tmp_out = out_splits else: From 8e47b1d4b0f4a6d57d310f0ad9e7d45b3bbcf854 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Mon, 24 Mar 2025 04:27:19 +0000 Subject: [PATCH 22/25] fixed an error with splitkv block_n heuristics. --- src/flag_gems/__init__.py | 395 +++++++++++---------- src/flag_gems/ops/__init__.py | 614 ++++++++++++++++----------------- src/flag_gems/ops/attention.py | 97 +++--- 3 files changed, 545 insertions(+), 561 deletions(-) diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 85cc04941..0bb7dc486 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -19,211 +19,206 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): global current_work_registrar current_work_registrar = registrar( ( - # ("abs", abs, Autograd.disable), - # ("add.Tensor", add, Autograd.disable), - # ("addmm", addmm, Autograd.disable), - # ("arange.start_step", arange_start, Autograd.disable), - # ("arange.start", arange_start, Autograd.disable), - # ("arange", arange, Autograd.disable), - # ("batch_norm", batch_norm, Autograd.enable), - # ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable), - # ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable), - # ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable), - # ("bitwise_not", bitwise_not, Autograd.disable), - # ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable), - # ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable), - # ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable), - # # ("bmm", bmm, Autograd.disable), - # ("clamp", clamp, Autograd.disable), - # ("clamp.Tensor", clamp_tensor, Autograd.disable), - # ("cos", cos, Autograd.disable), - # ("pad", pad, Autograd.disable), - # ("constant_pad_nd", constant_pad_nd, Autograd.disable), - # ("cumsum", cumsum, Autograd.disable), - # ("cummin", cummin, Autograd.disable), - # ("div.Tensor", true_divide, Autograd.disable), - # ("div.Scalar", true_divide, Autograd.disable), - # ("div.Tensor_mode", div_mode, Autograd.disable), - # ("div.Scalar_mode", div_mode, Autograd.disable), - # ( - # "divide.Tensor", - # true_divide, - # Autograd.disable, - # ), # divide, an alias for div - # ("divide.Scalar", true_divide, Autograd.disable), - # ("divide.Tensor_mode", div_mode, Autograd.disable), - # ("divide.Scalar_mode", div_mode, Autograd.disable), - # ( - # "true_divide.Tensor", - # true_divide, - # Autograd.disable, - # ), # true_divide, an alias for div - # ("true_divide.Scalar", true_divide, Autograd.disable), - # ("floor_divide", floor_divide, Autograd.disable), - # ("floor_divide.Scalar", floor_divide, Autograd.disable), - # ("remainder.Tensor", remainder, Autograd.disable), - # ("native_dropout", native_dropout, Autograd.enable), - # ("erf", erf, Autograd.disable), - # ("embedding", embedding, Autograd.enable), - # ("eq.Tensor", eq, Autograd.disable), - # ("eq.Scalar", eq_scalar, Autograd.disable), - # ("exp", exp, Autograd.disable), - # ("exponential_", exponential_, Autograd.disable), - # ("ge.Tensor", ge, Autograd.disable), - # ("ge.Scalar", ge_scalar, Autograd.disable), - # ("gelu", gelu, Autograd.enable), - # ("native_group_norm", group_norm, Autograd.enable), - # ("_weight_norm_interface", weight_norm_interface, Autograd.enable), - # ("_weight_norm", weight_norm, Autograd.enable), - # ("gt.Tensor", gt, Autograd.disable), - # ("gt.Scalar", gt_scalar, Autograd.disable), - # ("instance_norm", instance_norm, Autograd.enable), - # ("isfinite", isfinite, Autograd.disable), - # ("isin.Tensor_Tensor", isin, Autograd.disable), - # ("isin.Scalar_Tensor", isin, Autograd.disable), - # ("isin.Tensor_Scalar", isin, Autograd.disable), - # ("isinf", isinf, Autograd.disable), - # ("isnan", isnan, Autograd.disable), - # ("minimum", minimum, Autograd.disable), - # ("maximum", maximum, Autograd.disable), - # ("native_layer_norm", layer_norm, Autograd.enable), - # ("le.Tensor", le, Autograd.disable), - # ("le.Scalar", le_scalar, Autograd.disable), - # ("lt.Tensor", lt, Autograd.disable), - # ("lt.Scalar", lt_scalar, Autograd.disable), - # ("rms_norm", rms_norm, Autograd.disable), - # ("rand", rand, Autograd.disable), - # ("randn", randn, Autograd.disable), - # ("rand_like", rand_like, Autograd.disable), - # ("randn_like", randn_like, Autograd.disable), - # ("zeros", zeros, Autograd.disable), - # ("ones", ones, Autograd.disable), - # ("full", full, Autograd.disable), - # ("zeros_like", zeros_like, Autograd.disable), - # ("ones_like", ones_like, Autograd.disable), - # ("full_like", full_like, Autograd.disable), - # ("resolve_neg", resolve_neg, Autograd.disable), - # ("resolve_conj", resolve_conj, Autograd.disable), - # ("normal.Tensor_float", normal_tensor_float, Autograd.disable), - # ("normal.float_Tensor", normal_float_tensor, Autograd.disable), - # ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable), - # ("uniform_", uniform_, Autograd.disable), - # ("mean", mean, Autograd.disable), - # ("mean.dim", mean_dim, Autograd.disable), - # ("mm", mm, Autograd.disable), - # ("mul.Tensor", mul, Autograd.disable), - # ("multinomial", multinomial, Autograd.disable), - # ("mv", mv, Autograd.disable), - # ("ne.Tensor", ne, Autograd.disable), - # ("ne.Scalar", ne_scalar, Autograd.disable), - # ("neg", neg, Autograd.disable), - # ("pow.Scalar", pow_scalar, Autograd.disable), - # ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable), - # ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable), - # ("reciprocal", reciprocal, Autograd.disable), - # ("relu", relu, Autograd.enable), - # ("rsqrt", rsqrt, Autograd.disable), - # ("sigmoid", sigmoid, Autograd.enable), - # ("silu", silu, Autograd.enable), - # ("sin", sin, Autograd.disable), - # ("softmax.int", softmax, Autograd.enable), - # ("sort", sort, Autograd.disable), - # ("sub.Tensor", sub, Autograd.disable), - # ("tanh", tanh, Autograd.enable), - # ("triu", triu, Autograd.disable), - # ("topk", topk, Autograd.disable), - # ("var_mean.correction", var_mean, Autograd.disable), - # ("linalg_vector_norm", vector_norm, Autograd.disable), - # ("where.self_out", where_self_out, Autograd.disable), - # ("where.self", where_self, Autograd.disable), - # ("where.ScalarSelf", where_scalar_self, Autograd.disable), - # ("where.ScalarOther", where_scalar_other, Autograd.disable), - # ("max", max, Autograd.disable), - # ("max.dim", max_dim, Autograd.disable), - # ("min", min, Autograd.disable), - # ("min.dim", min_dim, Autograd.disable), - # ("amax", amax, Autograd.disable), - # ("argmax", argmax, Autograd.disable), - # ("argmin", argmin, Autograd.disable), - # ("prod", prod, Autograd.disable), - # ("prod.dim_int", prod_dim, Autograd.disable), - # ("sum", sum, Autograd.disable), - # ("sum.dim_IntList", sum_dim, Autograd.disable), - # ( - # "scaled_dot_product_attention", - # scaled_dot_product_attention, - # Autograd.disable, - # ), + ("abs", abs, Autograd.disable), + ("add.Tensor", add, Autograd.disable), + ("addmm", addmm, Autograd.disable), + ("arange.start_step", arange_start, Autograd.disable), + ("arange.start", arange_start, Autograd.disable), + ("arange", arange, Autograd.disable), + ("batch_norm", batch_norm, Autograd.enable), + ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable), + ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable), + ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable), + ("bitwise_not", bitwise_not, Autograd.disable), + ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable), + ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable), + ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable), + # ("bmm", bmm, Autograd.disable), + ("clamp", clamp, Autograd.disable), + ("clamp.Tensor", clamp_tensor, Autograd.disable), + ("cos", cos, Autograd.disable), + ("pad", pad, Autograd.disable), + ("constant_pad_nd", constant_pad_nd, Autograd.disable), + ("cumsum", cumsum, Autograd.disable), + ("cummin", cummin, Autograd.disable), + ("div.Tensor", true_divide, Autograd.disable), + ("div.Scalar", true_divide, Autograd.disable), + ("div.Tensor_mode", div_mode, Autograd.disable), + ("div.Scalar_mode", div_mode, Autograd.disable), + ( + "divide.Tensor", + true_divide, + Autograd.disable, + ), # divide, an alias for div + ("divide.Scalar", true_divide, Autograd.disable), + ("divide.Tensor_mode", div_mode, Autograd.disable), + ("divide.Scalar_mode", div_mode, Autograd.disable), + ( + "true_divide.Tensor", + true_divide, + Autograd.disable, + ), # true_divide, an alias for div + ("true_divide.Scalar", true_divide, Autograd.disable), + ("floor_divide", floor_divide, Autograd.disable), + ("floor_divide.Scalar", floor_divide, Autograd.disable), + ("remainder.Tensor", remainder, Autograd.disable), + ("native_dropout", native_dropout, Autograd.enable), + ("erf", erf, Autograd.disable), + ("embedding", embedding, Autograd.enable), + ("eq.Tensor", eq, Autograd.disable), + ("eq.Scalar", eq_scalar, Autograd.disable), + ("exp", exp, Autograd.disable), + ("exponential_", exponential_, Autograd.disable), + ("ge.Tensor", ge, Autograd.disable), + ("ge.Scalar", ge_scalar, Autograd.disable), + ("gelu", gelu, Autograd.enable), + ("native_group_norm", group_norm, Autograd.enable), + ("_weight_norm_interface", weight_norm_interface, Autograd.enable), + ("_weight_norm", weight_norm, Autograd.enable), + ("gt.Tensor", gt, Autograd.disable), + ("gt.Scalar", gt_scalar, Autograd.disable), + ("instance_norm", instance_norm, Autograd.enable), + ("isfinite", isfinite, Autograd.disable), + ("isin.Tensor_Tensor", isin, Autograd.disable), + ("isin.Scalar_Tensor", isin, Autograd.disable), + ("isin.Tensor_Scalar", isin, Autograd.disable), + ("isinf", isinf, Autograd.disable), + ("isnan", isnan, Autograd.disable), + ("minimum", minimum, Autograd.disable), + ("maximum", maximum, Autograd.disable), + ("native_layer_norm", layer_norm, Autograd.enable), + ("le.Tensor", le, Autograd.disable), + ("le.Scalar", le_scalar, Autograd.disable), + ("lt.Tensor", lt, Autograd.disable), + ("lt.Scalar", lt_scalar, Autograd.disable), + ("rms_norm", rms_norm, Autograd.disable), + ("rand", rand, Autograd.disable), + ("randn", randn, Autograd.disable), + ("rand_like", rand_like, Autograd.disable), + ("randn_like", randn_like, Autograd.disable), + ("zeros", zeros, Autograd.disable), + ("ones", ones, Autograd.disable), + ("full", full, Autograd.disable), + ("zeros_like", zeros_like, Autograd.disable), + ("ones_like", ones_like, Autograd.disable), + ("full_like", full_like, Autograd.disable), + ("resolve_neg", resolve_neg, Autograd.disable), + ("resolve_conj", resolve_conj, Autograd.disable), + ("normal.Tensor_float", normal_tensor_float, Autograd.disable), + ("normal.float_Tensor", normal_float_tensor, Autograd.disable), + ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable), + ("uniform_", uniform_, Autograd.disable), + ("mean", mean, Autograd.disable), + ("mean.dim", mean_dim, Autograd.disable), + ("mm", mm, Autograd.disable), + ("mul.Tensor", mul, Autograd.disable), + ("multinomial", multinomial, Autograd.disable), + ("mv", mv, Autograd.disable), + ("ne.Tensor", ne, Autograd.disable), + ("ne.Scalar", ne_scalar, Autograd.disable), + ("neg", neg, Autograd.disable), + ("pow.Scalar", pow_scalar, Autograd.disable), + ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable), + ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable), + ("reciprocal", reciprocal, Autograd.disable), + ("relu", relu, Autograd.enable), + ("rsqrt", rsqrt, Autograd.disable), + ("sigmoid", sigmoid, Autograd.enable), + ("silu", silu, Autograd.enable), + ("sin", sin, Autograd.disable), + ("softmax.int", softmax, Autograd.enable), + ("sort", sort, Autograd.disable), + ("sub.Tensor", sub, Autograd.disable), + ("tanh", tanh, Autograd.enable), + ("triu", triu, Autograd.disable), + ("topk", topk, Autograd.disable), + ("var_mean.correction", var_mean, Autograd.disable), + ("linalg_vector_norm", vector_norm, Autograd.disable), + ("where.self_out", where_self_out, Autograd.disable), + ("where.self", where_self, Autograd.disable), + ("where.ScalarSelf", where_scalar_self, Autograd.disable), + ("where.ScalarOther", where_scalar_other, Autograd.disable), + ("max", max, Autograd.disable), + ("max.dim", max_dim, Autograd.disable), + ("min", min, Autograd.disable), + ("min.dim", min_dim, Autograd.disable), + ("amax", amax, Autograd.disable), + ("argmax", argmax, Autograd.disable), + ("argmin", argmin, Autograd.disable), + ("prod", prod, Autograd.disable), + ("prod.dim_int", prod_dim, Autograd.disable), + ("sum", sum, Autograd.disable), + ("sum.dim_IntList", sum_dim, Autograd.disable), ( "_flash_attention_forward", flash_attention_forward, Autograd.disable, ), - # ("all", all, Autograd.disable), - # ("all.dim", all_dim, Autograd.disable), - # ("all.dims", all_dims, Autograd.disable), - # ("any", any, Autograd.disable), - # ("any.dim", any_dim, Autograd.disable), - # ("any.dims", any_dims, Autograd.disable), - # ("quantile", quantile, Autograd.disable), - # ("log_softmax.int", log_softmax, Autograd.enable), - # ("outer", outer, Autograd.enable), - # ("cross_entropy_loss", cross_entropy_loss, Autograd.enable), - # ("nll_loss_forward", nll_loss_forward, Autograd.disable), - # ("nll_loss_backward", nll_loss_backward, Autograd.disable), - # ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable), - # ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable), - # ("scatter.src", scatter, Autograd.disable), - # ("scatter.reduce", scatter, Autograd.disable), - # ("gather", gather, Autograd.disable), - # ("gather_backward", gather_backward, Autograd.disable), - # ("isclose", isclose, Autograd.disable), - # ("allclose", allclose, Autograd.disable), - # ("fill.Scalar", fill_scalar, Autograd.disable), - # ("fill.Tensor", fill_tensor, Autograd.disable), - # ("flip", flip, Autograd.disable), - # ("slice_scatter", slice_scatter, Autograd.disable), - # ("select_scatter", select_scatter, Autograd.disable), - # ("index_select", index_select, Autograd.disable), - # ("tile", tile, Autograd.disable), - # ("masked_fill.Tensor", masked_fill, Autograd.disable), - # ("masked_fill.Scalar", masked_fill, Autograd.disable), - # ("masked_fill_.Tensor", masked_fill_, Autograd.disable), - # ("masked_fill_.Scalar", masked_fill_, Autograd.disable), - # ("_unique2", _unique2, Autograd.disable), - # ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable), - # ("upsample_nearest2d", upsample_nearest2d, Autograd.disable), - # ("nonzero", nonzero, Autograd.disable), - # ("repeat", repeat, Autograd.disable), - # ("masked_select", masked_select, Autograd.disable), - # ("stack", stack, Autograd.disable), - # ("hstack", hstack, Autograd.disable), - # ("cat", cat, Autograd.disable), - # ( - # "repeat_interleave.self_int", - # repeat_interleave_self_int, - # Autograd.disable, - # ), - # ("vstack", vstack, Autograd.disable), - # ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable), - # ( - # "repeat_interleave.self_Tensor", - # repeat_interleave_self_tensor, - # Autograd.disable, - # ), - # ("randperm", randperm, Autograd.disable), - # ("diag", diag, Autograd.disable), - # ("diag_embed", diag_embed, Autograd.disable), - # ("diagonal_backward", diagonal_backward, Autograd.disable), - # ("index_add", index_add, Autograd.disable), - # ("count_nonzero", count_nonzero, Autograd.disable), - # ("logical_or", logical_or, Autograd.disable), - # ("logical_and", logical_and, Autograd.disable), - # ("logical_xor", logical_xor, Autograd.disable), - # ("logical_not", logical_not, Autograd.disable), - # ("log_sigmoid", log_sigmoid, Autograd.disable), - # ("vdot", vdot, Autograd.disable), - # ("mse_loss", mse_loss, Autograd.disable), + ("all", all, Autograd.disable), + ("all.dim", all_dim, Autograd.disable), + ("all.dims", all_dims, Autograd.disable), + ("any", any, Autograd.disable), + ("any.dim", any_dim, Autograd.disable), + ("any.dims", any_dims, Autograd.disable), + ("quantile", quantile, Autograd.disable), + ("log_softmax.int", log_softmax, Autograd.enable), + ("outer", outer, Autograd.enable), + ("cross_entropy_loss", cross_entropy_loss, Autograd.enable), + ("nll_loss_forward", nll_loss_forward, Autograd.disable), + ("nll_loss_backward", nll_loss_backward, Autograd.disable), + ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable), + ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable), + ("scatter.src", scatter, Autograd.disable), + ("scatter.reduce", scatter, Autograd.disable), + ("gather", gather, Autograd.disable), + ("gather_backward", gather_backward, Autograd.disable), + ("isclose", isclose, Autograd.disable), + ("allclose", allclose, Autograd.disable), + ("fill.Scalar", fill_scalar, Autograd.disable), + ("fill.Tensor", fill_tensor, Autograd.disable), + ("flip", flip, Autograd.disable), + ("slice_scatter", slice_scatter, Autograd.disable), + ("select_scatter", select_scatter, Autograd.disable), + ("index_select", index_select, Autograd.disable), + ("tile", tile, Autograd.disable), + ("masked_fill.Tensor", masked_fill, Autograd.disable), + ("masked_fill.Scalar", masked_fill, Autograd.disable), + ("masked_fill_.Tensor", masked_fill_, Autograd.disable), + ("masked_fill_.Scalar", masked_fill_, Autograd.disable), + ("_unique2", _unique2, Autograd.disable), + ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable), + ("upsample_nearest2d", upsample_nearest2d, Autograd.disable), + ("nonzero", nonzero, Autograd.disable), + ("repeat", repeat, Autograd.disable), + ("masked_select", masked_select, Autograd.disable), + ("stack", stack, Autograd.disable), + ("hstack", hstack, Autograd.disable), + ("cat", cat, Autograd.disable), + ( + "repeat_interleave.self_int", + repeat_interleave_self_int, + Autograd.disable, + ), + ("vstack", vstack, Autograd.disable), + ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable), + ( + "repeat_interleave.self_Tensor", + repeat_interleave_self_tensor, + Autograd.disable, + ), + ("randperm", randperm, Autograd.disable), + ("diag", diag, Autograd.disable), + ("diag_embed", diag_embed, Autograd.disable), + ("diagonal_backward", diagonal_backward, Autograd.disable), + ("index_add", index_add, Autograd.disable), + ("count_nonzero", count_nonzero, Autograd.disable), + ("logical_or", logical_or, Autograd.disable), + ("logical_and", logical_and, Autograd.disable), + ("logical_xor", logical_xor, Autograd.disable), + ("logical_not", logical_not, Autograd.disable), + ("log_sigmoid", log_sigmoid, Autograd.disable), + ("vdot", vdot, Autograd.disable), + ("mse_loss", mse_loss, Autograd.disable), ), user_unused_ops_list=[] if unused is None else unused, lib=lib, diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 01a31e633..2802e11f6 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,312 +1,312 @@ -# from .abs import abs -# from .add import add -# from .addmm import addmm -# from .all import all, all_dim, all_dims -# from .amax import amax -# from .any import any, any_dim, any_dims -# from .arange import arange, arange_start -# from .argmax import argmax -# from .argmin import argmin +from .abs import abs +from .add import add +from .addmm import addmm +from .all import all, all_dim, all_dims +from .amax import amax +from .any import any, any_dim, any_dims +from .arange import arange, arange_start +from .argmax import argmax +from .argmin import argmin from .attention import flash_attention_forward, scaled_dot_product_attention -# from .batch_norm import batch_norm -# from .bitwise_and import ( -# bitwise_and_scalar, -# bitwise_and_scalar_tensor, -# bitwise_and_tensor, -# ) -# from .bitwise_not import bitwise_not -# from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor -# from .bmm import bmm -# from .cat import cat -# from .clamp import clamp, clamp_tensor -# from .conv1d import conv1d -# from .conv2d import conv2d -# from .conv_depthwise2d import _conv_depthwise2d -# from .cos import cos -# from .count_nonzero import count_nonzero -# from .cross_entropy_loss import cross_entropy_loss -# from .cummin import cummin -# from .cumsum import cumsum, normed_cumsum -# from .diag import diag -# from .diag_embed import diag_embed -# from .diagonal import diagonal_backward -# from .div import div_mode, floor_divide, remainder, true_divide -# from .dropout import native_dropout -# from .embedding import embedding -# from .eq import eq, eq_scalar -# from .erf import erf -# from .exp import exp -# from .exponential_ import exponential_ -# from .fill import fill_scalar, fill_tensor -# from .flip import flip -# from .full import full -# from .full_like import full_like -# from .gather import gather, gather_backward -# from .ge import ge, ge_scalar -# from .gelu import gelu -# from .groupnorm import group_norm -# from .gt import gt, gt_scalar -# from .hstack import hstack -# from .index_add import index_add -# from .index_select import index_select -# from .instancenorm import instance_norm -# from .isclose import allclose, isclose -# from .isfinite import isfinite -# from .isin import isin -# from .isinf import isinf -# from .isnan import isnan -# from .layernorm import layer_norm -# from .le import le, le_scalar -# from .log_sigmoid import log_sigmoid -# from .log_softmax import log_softmax -# from .logical_and import logical_and -# from .logical_not import logical_not -# from .logical_or import logical_or -# from .logical_xor import logical_xor -# from .lt import lt, lt_scalar -# from .masked_fill import masked_fill, masked_fill_ -# from .masked_select import masked_select -# from .max import max, max_dim -# from .maximum import maximum -# from .mean import mean, mean_dim -# from .min import min, min_dim -# from .minimum import minimum -# from .mm import mm -# from .mse_loss import mse_loss -# from .mul import mul -# from .multinomial import multinomial -# from .mv import mv -# from .ne import ne, ne_scalar -# from .neg import neg -# from .nllloss import ( -# nll_loss2d_backward, -# nll_loss2d_forward, -# nll_loss_backward, -# nll_loss_forward, -# ) -# from .nonzero import nonzero -# from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor -# from .ones import ones -# from .ones_like import ones_like -# from .outer import outer -# from .pad import constant_pad_nd, pad -# from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor -# from .prod import prod, prod_dim -# from .quantile import quantile -# from .rand import rand -# from .rand_like import rand_like -# from .randn import randn -# from .randn_like import randn_like -# from .randperm import randperm -# from .reciprocal import reciprocal -# from .relu import relu -# from .repeat import repeat -# from .repeat_interleave import ( -# repeat_interleave_self_int, -# repeat_interleave_self_tensor, -# repeat_interleave_tensor, -# ) -# from .resolve_conj import resolve_conj -# from .resolve_neg import resolve_neg -# from .rms_norm import rms_norm -# from .rsqrt import rsqrt -# from .scatter import scatter -# from .select_scatter import select_scatter -# from .sigmoid import sigmoid -# from .silu import silu -# from .sin import sin -# from .slice_scatter import slice_scatter -# from .softmax import softmax -# from .sort import sort -# from .stack import stack -# from .sub import sub -# from .sum import sum, sum_dim -# from .tanh import tanh -# from .tile import tile -# from .topk import topk -# from .triu import triu -# from .uniform import uniform_ -# from .unique import _unique2 -# from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa -# from .upsample_nearest2d import upsample_nearest2d -# from .var_mean import var_mean -# from .vdot import vdot -# from .vector_norm import vector_norm -# from .vstack import vstack -# from .weightnorm import weight_norm, weight_norm_interface -# from .where import where_scalar_other, where_scalar_self, where_self, where_self_out -# from .zeros import zeros -# from .zeros_like import zeros_like +from .batch_norm import batch_norm +from .bitwise_and import ( + bitwise_and_scalar, + bitwise_and_scalar_tensor, + bitwise_and_tensor, +) +from .bitwise_not import bitwise_not +from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor +from .bmm import bmm +from .cat import cat +from .clamp import clamp, clamp_tensor +from .conv1d import conv1d +from .conv2d import conv2d +from .conv_depthwise2d import _conv_depthwise2d +from .cos import cos +from .count_nonzero import count_nonzero +from .cross_entropy_loss import cross_entropy_loss +from .cummin import cummin +from .cumsum import cumsum, normed_cumsum +from .diag import diag +from .diag_embed import diag_embed +from .diagonal import diagonal_backward +from .div import div_mode, floor_divide, remainder, true_divide +from .dropout import native_dropout +from .embedding import embedding +from .eq import eq, eq_scalar +from .erf import erf +from .exp import exp +from .exponential_ import exponential_ +from .fill import fill_scalar, fill_tensor +from .flip import flip +from .full import full +from .full_like import full_like +from .gather import gather, gather_backward +from .ge import ge, ge_scalar +from .gelu import gelu +from .groupnorm import group_norm +from .gt import gt, gt_scalar +from .hstack import hstack +from .index_add import index_add +from .index_select import index_select +from .instancenorm import instance_norm +from .isclose import allclose, isclose +from .isfinite import isfinite +from .isin import isin +from .isinf import isinf +from .isnan import isnan +from .layernorm import layer_norm +from .le import le, le_scalar +from .log_sigmoid import log_sigmoid +from .log_softmax import log_softmax +from .logical_and import logical_and +from .logical_not import logical_not +from .logical_or import logical_or +from .logical_xor import logical_xor +from .lt import lt, lt_scalar +from .masked_fill import masked_fill, masked_fill_ +from .masked_select import masked_select +from .max import max, max_dim +from .maximum import maximum +from .mean import mean, mean_dim +from .min import min, min_dim +from .minimum import minimum +from .mm import mm +from .mse_loss import mse_loss +from .mul import mul +from .multinomial import multinomial +from .mv import mv +from .ne import ne, ne_scalar +from .neg import neg +from .nllloss import ( + nll_loss2d_backward, + nll_loss2d_forward, + nll_loss_backward, + nll_loss_forward, +) +from .nonzero import nonzero +from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor +from .ones import ones +from .ones_like import ones_like +from .outer import outer +from .pad import constant_pad_nd, pad +from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor +from .prod import prod, prod_dim +from .quantile import quantile +from .rand import rand +from .rand_like import rand_like +from .randn import randn +from .randn_like import randn_like +from .randperm import randperm +from .reciprocal import reciprocal +from .relu import relu +from .repeat import repeat +from .repeat_interleave import ( + repeat_interleave_self_int, + repeat_interleave_self_tensor, + repeat_interleave_tensor, +) +from .resolve_conj import resolve_conj +from .resolve_neg import resolve_neg +from .rms_norm import rms_norm +from .rsqrt import rsqrt +from .scatter import scatter +from .select_scatter import select_scatter +from .sigmoid import sigmoid +from .silu import silu +from .sin import sin +from .slice_scatter import slice_scatter +from .softmax import softmax +from .sort import sort +from .stack import stack +from .sub import sub +from .sum import sum, sum_dim +from .tanh import tanh +from .tile import tile +from .topk import topk +from .triu import triu +from .uniform import uniform_ +from .unique import _unique2 +from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa +from .upsample_nearest2d import upsample_nearest2d +from .var_mean import var_mean +from .vdot import vdot +from .vector_norm import vector_norm +from .vstack import vstack +from .weightnorm import weight_norm, weight_norm_interface +from .where import where_scalar_other, where_scalar_self, where_self, where_self_out +from .zeros import zeros +from .zeros_like import zeros_like __all__ = [ -# "log_sigmoid", -# "all", -# "all_dim", -# "all_dims", -# "allclose", -# "any", -# "any_dim", -# "any_dims", -# "add", -# "abs", -# "addmm", -# "arange", -# "arange_start", -# "batch_norm", -# "bitwise_and_tensor", -# "bitwise_and_scalar", -# "bitwise_and_scalar_tensor", -# "bitwise_not", -# "bitwise_or_tensor", -# "bitwise_or_scalar", -# "bitwise_or_scalar_tensor", -# "bmm", -# "clamp", -# "clamp_tensor", -# "cos", -# "count_nonzero", -# "diag", -# "diag_embed", -# "diagonal_backward", -# "pad", -# "constant_pad_nd", -# "cummin", -# "cumsum", -# "normed_cumsum", -# "true_divide", -# "div_mode", -# "floor_divide", -# "remainder", -# "zeros", -# "ones", -# "full", -# "native_dropout", -# "erf", -# "embedding", -# "eq", -# "eq_scalar", -# "exp", -# "fill_scalar", -# "fill_tensor", -# "exponential_", -# "gather", -# "gather_backward", -# "flip", -# "ones_like", -# "full_like", -# "zeros_like", -# "ge", -# "ge_scalar", -# "gelu", -# "group_norm", -# "gt", -# "gt_scalar", -# "index_select", -# "instance_norm", -# "isclose", -# "isfinite", -# "isin", -# "isinf", -# "isnan", -# "layer_norm", -# "weight_norm_interface", -# "weight_norm", -# "le", -# "le_scalar", -# "lt", -# "lt_scalar", -# "rms_norm", -# "mean", -# "mean_dim", -# "mm", -# "mul", -# "multinomial", -# "maximum", -# "minimum", -# "rand", -# "randn", -# "randperm", -# "rand_like", -# "randn_like", -# "resolve_neg", -# "resolve_conj", -# "normal_tensor_float", -# "normal_float_tensor", -# "normal_tensor_tensor", -# "uniform_", -# "mv", -# "ne", -# "ne_scalar", -# "neg", -# "pow_scalar", -# "pow_tensor_scalar", -# "pow_tensor_tensor", -# "reciprocal", -# "relu", -# "rsqrt", -# "scatter", -# "sigmoid", -# "silu", -# "sin", -# "softmax", -# "sub", -# "tanh", -# "tile", -# "triu", -# "topk", -# "max", -# "max_dim", -# "min", -# "min_dim", -# "sum", -# "sum_dim", -# "amax", -# "argmax", -# "argmin", -# "prod", -# "prod_dim", -# "quantile", -# "var_mean", -# "vector_norm", -# "log_softmax", -# "outer", -# "cross_entropy_loss", -# "where_self_out", -# "where_self", -# "where_scalar_self", -# "where_scalar_other", -# "index_add", -# "select_scatter", -# "slice_scatter", -# "masked_fill", -# "masked_fill_", -# "_unique2", -# "_upsample_bicubic2d_aa", -# "upsample_nearest2d", -# "nonzero", -# "repeat", -# "masked_select", -# "stack", -# "hstack", -# "cat", -# "repeat_interleave_self_int", -# "vstack", -# "repeat_interleave_tensor", + "log_sigmoid", + "all", + "all_dim", + "all_dims", + "allclose", + "any", + "any_dim", + "any_dims", + "add", + "abs", + "addmm", + "arange", + "arange_start", + "batch_norm", + "bitwise_and_tensor", + "bitwise_and_scalar", + "bitwise_and_scalar_tensor", + "bitwise_not", + "bitwise_or_tensor", + "bitwise_or_scalar", + "bitwise_or_scalar_tensor", + "bmm", + "clamp", + "clamp_tensor", + "cos", + "count_nonzero", + "diag", + "diag_embed", + "diagonal_backward", + "pad", + "constant_pad_nd", + "cummin", + "cumsum", + "normed_cumsum", + "true_divide", + "div_mode", + "floor_divide", + "remainder", + "zeros", + "ones", + "full", + "native_dropout", + "erf", + "embedding", + "eq", + "eq_scalar", + "exp", + "fill_scalar", + "fill_tensor", + "exponential_", + "gather", + "gather_backward", + "flip", + "ones_like", + "full_like", + "zeros_like", + "ge", + "ge_scalar", + "gelu", + "group_norm", + "gt", + "gt_scalar", + "index_select", + "instance_norm", + "isclose", + "isfinite", + "isin", + "isinf", + "isnan", + "layer_norm", + "weight_norm_interface", + "weight_norm", + "le", + "le_scalar", + "lt", + "lt_scalar", + "rms_norm", + "mean", + "mean_dim", + "mm", + "mul", + "multinomial", + "maximum", + "minimum", + "rand", + "randn", + "randperm", + "rand_like", + "randn_like", + "resolve_neg", + "resolve_conj", + "normal_tensor_float", + "normal_float_tensor", + "normal_tensor_tensor", + "uniform_", + "mv", + "ne", + "ne_scalar", + "neg", + "pow_scalar", + "pow_tensor_scalar", + "pow_tensor_tensor", + "reciprocal", + "relu", + "rsqrt", + "scatter", + "sigmoid", + "silu", + "sin", + "softmax", + "sub", + "tanh", + "tile", + "triu", + "topk", + "max", + "max_dim", + "min", + "min_dim", + "sum", + "sum_dim", + "amax", + "argmax", + "argmin", + "prod", + "prod_dim", + "quantile", + "var_mean", + "vector_norm", + "log_softmax", + "outer", + "cross_entropy_loss", + "where_self_out", + "where_self", + "where_scalar_self", + "where_scalar_other", + "index_add", + "select_scatter", + "slice_scatter", + "masked_fill", + "masked_fill_", + "_unique2", + "_upsample_bicubic2d_aa", + "upsample_nearest2d", + "nonzero", + "repeat", + "masked_select", + "stack", + "hstack", + "cat", + "repeat_interleave_self_int", + "vstack", + "repeat_interleave_tensor", "flash_attention_forward", -# "scaled_dot_product_attention", -# "conv2d", -# "conv1d", -# "_conv_depthwise2d", -# "repeat_interleave_self_tensor", -# "logical_or", -# "logical_and", -# "logical_xor", -# "logical_not", -# "sort", -# "nll_loss_forward", -# "nll_loss_backward", -# "nll_loss2d_forward", -# "nll_loss2d_backward", -# "vdot", -# "mse_loss", + "scaled_dot_product_attention", + "conv2d", + "conv1d", + "_conv_depthwise2d", + "repeat_interleave_self_tensor", + "logical_or", + "logical_and", + "logical_xor", + "logical_not", + "sort", + "nll_loss_forward", + "nll_loss_backward", + "nll_loss2d_forward", + "nll_loss2d_backward", + "vdot", + "mse_loss", ] diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 025270edd..1c6b98460 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -521,7 +521,7 @@ def apply_dropout( return P -@triton.jit(do_not_specialize=['max_seqlen_q', 'max_seqlen_k']) +@triton.jit def apply_mask( S, col_idx, @@ -587,37 +587,23 @@ def softmax_rescale( def block_m_heuristic(headdim, is_dropout): block_m = 128 if headdim <= 128 else 64 - print('block_m:', block_m) return block_m def block_n_heuristic(headdim, is_dropout): block_n = 64 if headdim <= 64 else 32 - print('block_n:', block_n) return block_n +def block_m_splitkv_heuristic(headdim): + return 128 if headdim <= 128 else 64 + +def block_n_splitkv_heuristic(headdim): + block_n = 64 if headdim <= 64 else 32 + def is_even_mn(args): even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) - print('is_even_mn:', even_mn) return even_mn -def block_m_splitkv_heuristic(headdim): - return 128 if headdim <= 128 else 64 -def block_n_splitkv_heuristic(headdim): - if headdim <= 64: - return 64 - else: - return 32 - -# @triton.autotune( -# configs=runtime.get_tuned_config("attention"), -# key=["HEAD_DIM"], -# prune_configs_by={ -# "early_config_prune": early_config_prune, -# "perf_model": None, -# "top_k": 1.0, -# }, -# ) @triton.heuristics( values={ 'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]), @@ -833,7 +819,7 @@ def flash_fwd_kernel( O_ = tl.dot(P, V, O_, allow_tf32=False) for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages): - # col_start = tl.multiple_of(col_start, BLOCK_N) + col_start = tl.multiple_of(col_start, BLOCK_N) off = col_start * k_s_stride K = tl.load(p_bk0 + off, cache_modifier=".cg") if PRE_LOAD_V: @@ -946,7 +932,7 @@ def flash_fwd_kernel( @triton.heuristics( values={ 'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), - 'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'BLOCK_N': lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]), 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, @@ -1023,8 +1009,6 @@ def flash_fwd_splitkv_kernel_v2( alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale - else: - alibi_slope = 0.0 if not is_causal: if IS_EVEN_MN: @@ -1151,7 +1135,7 @@ def flash_fwd_splitkv_kernel_v2( @triton.heuristics( values={ 'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), - 'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]), + 'BLOCK_N': lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]), 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, @@ -1225,8 +1209,6 @@ def flash_fwd_splitkv_kernel( alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale - else: - alibi_slope = 0.0 if not is_causal: if IS_EVEN_MN: @@ -1307,7 +1289,7 @@ def flash_fwd_splitkv_kernel( if PRE_LOAD_V: V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") - S = tl.dot(Q, K, allow_tf32=False) + S = tl.dot(Q, K) S = apply_mask( S, @@ -1339,7 +1321,7 @@ def flash_fwd_splitkv_kernel( else: V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg") P = P.to(Q_ptr.type.element_ty) - O_ = tl.dot(P, V, O_, allow_tf32=False) + O_ = tl.dot(P, V, O_) # LSE lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) @@ -1427,7 +1409,9 @@ def flash_fwd_splitkv_combine_kernel( # Write back output out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size) tl.store(out_ptr + out_offset, out, mask=out_mask[:, None]) - + + +_debug = False def mha_fwd( q, @@ -1500,17 +1484,20 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): if eff > 0.8 or n_waves > 1: return 1 - best_splits = 1 - best_eff = eff - min_blocks_per_split = 1 - max_blocks_per_split = triton.cdiv(n_blocks, 2) - for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]: - n_splits = triton.cdiv(n_blocks, blocks_per_split) - n_waves = triton.cdiv(n_splits * num_tasks, num_sms) - eff = (n_splits * num_tasks / num_sms) / n_waves - if eff > 0.85: - best_splits = n_splits - break + min_blocks_per_split = 2 + best_splits = min(triton.cdiv(n_blocks, min_blocks_per_split), int(math.floor(1. / eff)), num_sms) + + # best_splits = 1 + # best_eff = eff + # min_blocks_per_split = 1 + # max_blocks_per_split = triton.cdiv(n_blocks, 2) + # for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]: + # n_splits = triton.cdiv(n_blocks, blocks_per_split) + # n_waves = triton.cdiv(n_splits * num_tasks, num_sms) + # eff = (n_splits * num_tasks / num_sms) / n_waves + # if eff > 0.85: + # best_splits = n_splits + # break return best_splits with torch_device_fn.device(q_device): @@ -1576,13 +1563,14 @@ def try_split_kv(): else: n_splits, blocks_per_split = 1, None -# n_splits = 1 - block_n = block_n_splitkv_heuristic(head_size) - n_blocks = triton.cdiv(seqlen_k, block_n) - blocks_per_split = triton.cdiv(n_blocks, n_splits) - print('block_n:', block_n) - print('n_splits:', n_splits) - print('blocks_per_split', blocks_per_split) + if _debug: + n_splits = 32 + block_n = block_n_splitkv_heuristic(head_size) + n_blocks = triton.cdiv(seqlen_k, block_n) + blocks_per_split = triton.cdiv(n_blocks, n_splits) + print('block_n:', block_n) + print('n_splits:', n_splits) + print('blocks_per_split', blocks_per_split) if n_splits > 1: lse_splits = torch.empty( @@ -1659,10 +1647,11 @@ def try_split_kv(): NUM_HEADS=num_heads, NUM_HEADS_K=num_heads_k, ) - print(f'{kernel.name} shared memory:', kernel.metadata.shared) - print(f'{kernel.name} num_warps:', kernel.metadata.num_warps) - print(f'{kernel.name} num_stages:', kernel.metadata.num_stages) - # print(kernel.asm['ttgir']) + if debug: + print(f'{kernel.name} shared memory:', kernel.metadata.shared) + print(f'{kernel.name} num_warps:', kernel.metadata.num_warps) + print(f'{kernel.name} num_stages:', kernel.metadata.num_stages) + print(kernel.asm['ttgir']) if n_splits > 1: if head_size % 128 == 0: @@ -1713,7 +1702,7 @@ def flash_attention_forward( seqused_k=None, alibi_slopes=None ): - logging.debug("GEMS FLASH_ATTENTION") + logging.debug("GEMS FLASH_ATTENTION_FORWARD") assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] From 298d5e8782b355b6257575b33ac363bb408aaf56 Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Mon, 24 Mar 2025 10:04:11 +0000 Subject: [PATCH 23/25] Polish code. --- src/flag_gems/ops/attention.py | 420 ++++++++++++--------------------- 1 file changed, 150 insertions(+), 270 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 1c6b98460..fcff69394 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -530,11 +530,11 @@ def apply_mask( max_seqlen_k, ws_left, ws_right, - alibi_slope, is_even_mn: tl.constexpr, is_causal: tl.constexpr, is_local: tl.constexpr, has_alibi: tl.constexpr, + alibi_slope: tl.constexpr=None, ): need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn) if need_mask: @@ -597,7 +597,7 @@ def block_m_splitkv_heuristic(headdim): return 128 if headdim <= 128 else 64 def block_n_splitkv_heuristic(headdim): - block_n = 64 if headdim <= 64 else 32 + return 64 if headdim <= 64 else 32 def is_even_mn(args): even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) @@ -755,11 +755,11 @@ def flash_fwd_kernel( seqlen_k, ws_left, ws_right, - alibi_slope, is_even_mn=IS_EVEN_MN, is_causal=is_causal, is_local=is_local, - has_alibi=has_alibi + has_alibi=has_alibi, + alibi_slope=alibi_slope, ) O_, P, rowmax_, rowsum_ = softmax_rescale( @@ -836,11 +836,11 @@ def flash_fwd_kernel( seqlen_k, ws_left, ws_right, - alibi_slope, is_even_mn=True, is_causal=False, is_local=is_local, - has_alibi=has_alibi + has_alibi=has_alibi, + alibi_slope=alibi_slope, ) O_, P, rowmax_, rowsum_ = softmax_rescale( @@ -940,7 +940,7 @@ def flash_fwd_kernel( } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) -def flash_fwd_splitkv_kernel_v2( +def flash_fwd_bh_parallel_kernel( Q_ptr, K_ptr, V_ptr, @@ -989,148 +989,9 @@ def flash_fwd_splitkv_kernel_v2( num_warps: tl.constexpr, num_stages: tl.constexpr ): - m_block = tl.program_id(0) - split_id = tl.program_id(1) - bid = tl.program_id(2) // NUM_HEADS - hid = tl.program_id(2) % NUM_HEADS - - split_col_min = split_id * blocks_per_split * BLOCK_N - split_col_max = split_col_min + blocks_per_split * BLOCK_N - - col_min = 0 - - col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N - if is_causal: - col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right) - - split_col_max = min(split_col_max, col_max) - - if has_alibi: - alibi_offset = bid * slopes_batch_stride + hid - alibi_slope = tl.load(pSlopes + alibi_offset) - alibi_slope /= scale - - if not is_causal: - if IS_EVEN_MN: - masking_cols: tl.constexpr = 0 - else: - masking_cols: tl.constexpr = BLOCK_N - elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N - else: - # local and not causal, - masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N - - Q_ptr += bid * q_b_stride - Q_ptr += hid * q_h_stride - row_start = m_block * BLOCK_M - row_idx = row_start + tl.arange(0, BLOCK_M) - Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :] - p_qm = Q_ptr + Q_off - qmask = row_idx[:, None] < seqlen_q - if IS_EVEN_MN: - Q = tl.load(p_qm, cache_modifier=".cg") - else: - Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") - - h_hk_ratio = h // hk - K_ptr += bid * k_b_stride - K_ptr += (hid // h_hk_ratio) * k_h_stride - V_ptr += bid * k_b_stride - V_ptr += (hid // h_hk_ratio) * k_h_stride - - K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None] - p_bk0 = K_ptr + K_offset - - V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :] - p_bv0 = V_ptr + V_offset - - O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32) - rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) - - for col_start in tl.range(split_col_min, split_col_max, step=BLOCK_N): - col_start = tl.multiple_of(col_start, BLOCK_N) - off = col_start * k_s_stride - if IS_EVEN_MN: - K = tl.load(p_bk0 + off, cache_modifier=".cg") - if PRE_LOAD_V: - V = tl.load(p_bv0 + off, cache_modifier=".cg") - else: - col_idx = col_start + tl.arange(0, BLOCK_N) - kvmask = col_idx < seqlen_k - K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") - if PRE_LOAD_V: - V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") - S = tl.dot(Q, K, allow_tf32=False) - col_idx = col_start + tl.arange(0, BLOCK_N) - row_idx = row_start + tl.arange(0, BLOCK_M) - S = apply_mask( - S, - col_idx, - row_idx, - seqlen_q, - seqlen_k, - ws_left, - ws_right, - alibi_slope, - is_even_mn=IS_EVEN_MN, - is_causal=is_causal, - is_local=False, - has_alibi=has_alibi - ) - - O_, P, rowmax_, rowsum_ = softmax_rescale( - O_, - S, - rowmax_, - rowsum_, - softmax_scale_log2e=softmax_scale_log2e, - is_border=(is_causal or is_local), - ) - P = P.to(V_ptr.type.element_ty) - - if not PRE_LOAD_V: - off = col_start * k_s_stride - if IS_EVEN_MN: - V = tl.load(p_bv0 + off, cache_modifier=".cg") - else: - V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") - O_ = tl.dot(P, V, O_, allow_tf32=False) - - # LSE - lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) - inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) - - # Rescale output - O_ *= inv_sum[:, None] - - # Write back output - # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) - # grid = (seq_block, split, batch * head) - O_split_ptr = O_ptr - # + split, batch, head offsets, seq_block offsets are already added in row_idx - O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM - O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM) - O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM) - p_om = O_split_ptr + O_split_offset - - if IS_EVEN_MN: - tl.store(p_om, O_, cache_modifier=".cg") - else: - tl.store(p_om, O_, mask=qmask, cache_modifier=".cg") + # (TODO) + pass - # Write back lse - # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) - lse_split_ptr = lse_ptr - # + split, batch, head, seq_block offsets - lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M - - if IS_EVEN_MN: - tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") - else: - tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg") - @triton.heuristics( values={ @@ -1209,6 +1070,8 @@ def flash_fwd_splitkv_kernel( alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) alibi_slope /= scale + else: + alibi_slope = 0 if not is_causal: if IS_EVEN_MN: @@ -1299,11 +1162,11 @@ def flash_fwd_splitkv_kernel( seqlen_k, ws_left, ws_right, - alibi_slope, is_even_mn=IS_EVEN_MN, is_causal=is_causal, is_local=False, - has_alibi=has_alibi + has_alibi=has_alibi, + alibi_slope=alibi_slope, ) O_, P, rowmax_, rowsum_ = softmax_rescale( @@ -1547,134 +1410,151 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): # ONLY EVEN_K IS SUPPORTED assert head_size == head_size_rounded - # Check splitkv - def try_split_kv(): - block_m = block_m_splitkv_heuristic(head_size) - n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m) + # Do kernel dispatching + def dispatch(B, H, Q, K, D): num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count - block_n = block_n_splitkv_heuristic(head_size) - n_blocks = triton.cdiv(seqlen_k, block_n) - n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) - blocks_per_split = triton.cdiv(n_blocks, n_splits) - return n_splits, blocks_per_split - - if not is_dropout and not is_local: - n_splits, blocks_per_split = try_split_kv() - else: - n_splits, blocks_per_split = 1, None - if _debug: - n_splits = 32 - block_n = block_n_splitkv_heuristic(head_size) - n_blocks = triton.cdiv(seqlen_k, block_n) - blocks_per_split = triton.cdiv(n_blocks, n_splits) - print('block_n:', block_n) - print('n_splits:', n_splits) - print('blocks_per_split', blocks_per_split) - - if n_splits > 1: - lse_splits = torch.empty( - (n_splits, batch_size, num_heads, seqlen_q), - dtype=torch.float, - device=q_device - ) - out_splits = torch.empty( - (n_splits, batch_size, num_heads, seqlen_q, head_size), - dtype=torch.float, - device=q_device - ) - - # Launch kernel - if n_splits > 1: - grid = lambda args: ( - triton.cdiv(seqlen_q, args["BLOCK_M"]), - n_splits, - batch_size * num_heads - ) - kernel = flash_fwd_splitkv_kernel_v2[grid] - tmp_lse = lse_splits - tmp_out = out_splits - else: - grid = lambda args: ( - triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks - batch_size, - num_heads, - ) + default_args = {} + + # Try bh parallel + # if B * H > 0.8 * num_sms: + # kernel = flash_fwd_bh_parallel_kernel[(H, B)] + # # Yield kernel and prefilled args + # return kernel, default_args, None, None + + # Try splitkv + if not is_dropout and not is_local: + BM = block_m_splitkv_heuristic(D) + n_tasks = B * H * triton.cdiv(seqlen_q, BM) + BN = block_n_splitkv_heuristic(D) + n_blocks = triton.cdiv(seqlen_k, BN) + n_splits = splits_heuristics(n_tasks, num_sms, n_blocks) + + if _debug: + n_splits = 32 + n_blocks = triton.cdiv(K, BN) + blocks_per_split = triton.cdiv(n_blocks, n_splits) + print('block_n:', block_n) + print('n_splits:', n_splits) + print('blocks_per_split', blocks_per_split) + + if n_splits > 1: + lse_splits = torch.empty( + (n_splits, B, H, Q), + dtype=torch.float, + device=q_device + ) + out_splits = torch.empty( + (n_splits, B, H, Q, D), + dtype=torch.float, + device=q_device + ) + grid = lambda args: ( + triton.cdiv(Q, args["BLOCK_M"]), + n_splits, + B * H + ) + splitkv_kernel = flash_fwd_splitkv_kernel[grid] + blocks_per_split = triton.cdiv(n_blocks, n_splits) + splitkv_args = default_args.copy() + splitkv_args['blocks_per_split'] = blocks_per_split + splitkv_args['O_ptr'] = out_splits + splitkv_args['lse_ptr'] = lse_splits + # kernel = yield kernel, args + + if D % 128 == 0: + BLOCK_M = 4 + elif D % 64 == 0: + BLOCK_M = 8 + else: + BLOCK_M = 16 + grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M), ) + combine_kernel = flash_fwd_splitkv_combine_kernel[grid] + combine_args = { + 'out_splits_ptr': out_splits, + 'lse_splits_ptr': lse_splits, + 'n_splits': n_splits, + 'BLOCK_M': BLOCK_M, + 'q_total': B * H * Q, + 'MAX_N_SPLITS': triton.next_power_of_2(n_splits) + } + return splitkv_kernel, splitkv_args, combine_kernel, combine_args + + # Last option: flash_fwd + grid = lambda args: (triton.cdiv(Q, args["BLOCK_M"]), B, H, ) kernel = flash_fwd_kernel[grid] - tmp_lse = lse - tmp_out = out + return kernel, default_args, None, None + + kernel1, kernel1_args, kernel2, kernel2_args = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size) + + prefilled_args = { + "Q_ptr": q, + "K_ptr": k, + "V_ptr": v, + "P_ptr": p, + "O_ptr": out, + "lse_ptr": lse, + "seqlen_q": seqlen_q, + "seqlen_k": seqlen_k, + "seqlen_q_rounded": seqlen_q_rounded, + "seqlen_k_rounded": seqlen_k_rounded, + "q_b_stride": q.stride(0), + "q_s_stride": q.stride(-3), + "q_h_stride": q.stride(-2), + "k_b_stride": k.stride(0), + "k_s_stride": k.stride(-3), + "k_h_stride": k.stride(-2), + "o_b_stride": out.stride(0), + "o_s_stride": out.stride(-3), + "o_h_stride": out.stride(-2), + "h": num_heads, + "hk": num_heads_k, + "pSlopes": alibi_slopes, + "philox_seed": philox_seed, + "philox_offset": philox_offset, + "pdrop_u8": pdrop_u8, + "rpdrop": rpdrop, + "slopes_batch_stride": alibi_slopes_batch_stride, + "HEAD_DIM": head_size, + "is_dropout": is_dropout, + "is_causal": is_causal, + "is_local": is_local, + "has_alibi": has_alibi, + "softmax_scale": softmax_scale, + "softmax_scale_log2e": softmax_scale_log2e, + "ws_left": window_size_left, + "ws_right": window_size_right, + "return_P": return_softmax, + "BATCH_SIZE": batch_size, + "blocks_per_split": None, + "NUM_HEADS": num_heads, + "NUM_HEADS_K": num_heads_k, + } + + args_copy = prefilled_args.copy() + args_copy.update(kernel1_args) - kernel = kernel( - q, - k, - v, - p, - tmp_out, - tmp_lse, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - seqlen_k_rounded, - q.stride(0), - q.stride(-3), - q.stride(-2), - k.stride(0), - k.stride(-3), - k.stride(-2), - out.stride(0), - out.stride(-3), - out.stride(-2), - num_heads, - num_heads_k, - alibi_slopes, - philox_seed, - philox_offset, - pdrop_u8, - rpdrop, - alibi_slopes_batch_stride, - head_size, - is_dropout=is_dropout, - is_causal=is_causal, - is_local=is_local, - has_alibi=has_alibi, - softmax_scale=softmax_scale, - softmax_scale_log2e=softmax_scale_log2e, - ws_left=window_size_left, - ws_right=window_size_right, - return_P=return_softmax, - BATCH_SIZE=batch_size, - blocks_per_split=blocks_per_split, - NUM_HEADS=num_heads, - NUM_HEADS_K=num_heads_k, - ) - if debug: + kernel = kernel1(**args_copy) + if _debug: print(f'{kernel.name} shared memory:', kernel.metadata.shared) print(f'{kernel.name} num_warps:', kernel.metadata.num_warps) print(f'{kernel.name} num_stages:', kernel.metadata.num_stages) print(kernel.asm['ttgir']) - - if n_splits > 1: - if head_size % 128 == 0: - BLOCK_M = 4 - elif head_size % 64 == 0: - BLOCK_M = 8 - else: - BLOCK_M = 16 - grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), ) - kernel = flash_fwd_splitkv_combine_kernel[grid]( - out, - lse, - tmp_out, - tmp_lse, - head_size, - out.stride(0), - out.stride(-3), - out.stride(-1), - n_splits, - BLOCK_M, - q_total=batch_size * num_heads * seqlen_q, - MAX_N_SPLITS=triton.next_power_of_2(n_splits), - ) + + # Combine + if kernel2 is not None: + prefilled_args = { + "out_ptr": out, + "lse_ptr": lse, + "head_size": head_size, + "out_b_stride": out.stride(0), + "out_s_stride": out.stride(-3), + "out_h_stride": out.stride(-1), + } + args_copy = prefilled_args.copy() + args_copy.update(kernel2_args) + kernel2(**args_copy) + if swap_seq_and_group: out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) From 6a741dd3470354246e41372ac34a718b0b4da4de Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Thu, 27 Mar 2025 04:30:10 +0000 Subject: [PATCH 24/25] fixed numerous bugs. --- src/flag_gems/ops/attention.py | 216 +++++++++++++++++++-------------- tests/test_attention_ops.py | 177 ++++++++++++++++++++++++--- 2 files changed, 280 insertions(+), 113 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index fcff69394..6bc3d1b99 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -548,7 +548,7 @@ def apply_mask( S = tl.where(col_idx[None, :] > col_rb[:, None], float('-inf'), S) if is_local: - S = tl.where(col_idx[None, :] > col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S) + S = tl.where((col_idx[None, :] > col_rb[:, None]) | (col_idx[None, :] < col_lb[:, None]), float('-inf'), S) if (not is_local) & (not is_causal) & (not is_even_mn): S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S) @@ -676,7 +676,20 @@ def flash_fwd_kernel( col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N if is_causal or is_local: - col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right) + rounded_kv_end = max(0, tl.cdiv(seqlen_k - seqlen_q + ws_right, BLOCK_N) * BLOCK_N + (m_block + 1) * BLOCK_M) + col_max = min(col_max, rounded_kv_end) + + if (not is_causal) and (not is_local): + if IS_EVEN_MN: + masking_cols: tl.constexpr = 0 + else: + masking_cols: tl.constexpr = BLOCK_N + elif (is_causal | is_local) and IS_EVEN_MN: # causal implies ws_right is zero + masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N + else: + # local + masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N + if has_alibi: alibi_offset = bid * slopes_batch_stride + hid @@ -719,65 +732,78 @@ def flash_fwd_kernel( p_bk0 = K_ptr + K_offset p_bv0 = V_ptr + V_offset - if (not is_causal) and (not is_local): - if IS_EVEN_MN: - masking_cols: tl.constexpr = 0 - else: - masking_cols: tl.constexpr = BLOCK_N - elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero - masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N - else: - # local and not causal, - masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N - - for col_shift in tl.range(0, masking_cols, step=BLOCK_N): - col_start = col_max - col_shift - BLOCK_N - col_start = tl.multiple_of(col_start, BLOCK_N) - off = col_start * k_s_stride - if IS_EVEN_MN: - K = tl.load(p_bk0 + off, cache_modifier=".cg") - if PRE_LOAD_V: - V = tl.load(p_bv0 + off, cache_modifier=".cg") - else: + if is_causal | is_local | (not IS_EVEN_MN): + # Cut short masking cols if there's not enough cols out there + masking_cols = min(col_max, masking_cols) + for col_shift in tl.range(0, masking_cols, step=BLOCK_N): + col_start = col_max - col_shift - BLOCK_N + col_start = tl.multiple_of(col_start, BLOCK_N) + off = col_start * k_s_stride + if IS_EVEN_MN: + K = tl.load(p_bk0 + off, cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_bv0 + off, cache_modifier=".cg") + else: + col_idx = col_start + tl.arange(0, BLOCK_N) + kvmask = col_idx < seqlen_k + K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") + if PRE_LOAD_V: + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") + S = tl.dot(Q, K, allow_tf32=False) col_idx = col_start + tl.arange(0, BLOCK_N) - kvmask = col_idx < seqlen_k - K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") - if PRE_LOAD_V: - V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") - S = tl.dot(Q, K, allow_tf32=False) - col_idx = col_start + tl.arange(0, BLOCK_N) - row_idx = row_start + tl.arange(0, BLOCK_M) - S = apply_mask( - S, - col_idx, - row_idx, - seqlen_q, - seqlen_k, - ws_left, - ws_right, - is_even_mn=IS_EVEN_MN, - is_causal=is_causal, - is_local=is_local, - has_alibi=has_alibi, - alibi_slope=alibi_slope, - ) + row_idx = row_start + tl.arange(0, BLOCK_M) - O_, P, rowmax_, rowsum_ = softmax_rescale( - O_, - S, - rowmax_, - rowsum_, - softmax_scale_log2e=softmax_scale_log2e, - is_border=(is_causal or is_local), - ) - P = P.to(V_ptr.type.element_ty) + # tl.store(p_bp0 + col_start, S) + S = apply_mask( + S, + col_idx, + row_idx, + seqlen_q, + seqlen_k, + ws_left, + ws_right, + is_even_mn=IS_EVEN_MN, + is_causal=is_causal, + is_local=is_local, + has_alibi=has_alibi, + alibi_slope=alibi_slope, + ) - if is_dropout: - if return_P: - P_drop = P + O_, P, rowmax_, rowsum_ = softmax_rescale( + O_, + S, + rowmax_, + rowsum_, + softmax_scale_log2e=softmax_scale_log2e, + is_border=(is_causal or is_local), + ) + P = P.to(V_ptr.type.element_ty) + + if is_dropout: + if return_P: + P_drop = P + + P_drop = apply_dropout( + P_drop, + row_start, + col_start, + bid, + hid, + philox_seed, + philox_offset, + pdrop_u8, + encode_dropout_in_sign_bit=True, + NUM_HEADS=NUM_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + if IS_EVEN_MN: + tl.store(p_bp0 + col_start, P_drop) + else: + tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) - P_drop = apply_dropout( - P_drop, + P = apply_dropout( + P, row_start, col_start, bid, @@ -785,38 +811,19 @@ def flash_fwd_kernel( philox_seed, philox_offset, pdrop_u8, - encode_dropout_in_sign_bit=True, + encode_dropout_in_sign_bit=False, NUM_HEADS=NUM_HEADS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) + + if not PRE_LOAD_V: + off = col_start * k_s_stride if IS_EVEN_MN: - tl.store(p_bp0 + col_start, P_drop) + V = tl.load(p_bv0 + off, cache_modifier=".cg") else: - tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) - - P = apply_dropout( - P, - row_start, - col_start, - bid, - hid, - philox_seed, - philox_offset, - pdrop_u8, - encode_dropout_in_sign_bit=False, - NUM_HEADS=NUM_HEADS, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - - if not PRE_LOAD_V: - off = col_start * k_s_stride - if IS_EVEN_MN: - V = tl.load(p_bv0 + off, cache_modifier=".cg") - else: - V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") - O_ = tl.dot(P, V, O_, allow_tf32=False) + V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") + O_ = tl.dot(P, V, O_, allow_tf32=False) for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages): col_start = tl.multiple_of(col_start, BLOCK_N) @@ -899,7 +906,7 @@ def flash_fwd_kernel( # LSE # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only. - lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) + lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) # Rescale output inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) @@ -922,11 +929,13 @@ def flash_fwd_kernel( tl.store(O_ptr + O_offset, O, mask=qmask) # Write back lse - lse_ptr += bid * hid * seqlen_q + p_lse = lse_ptr + (bid * h + hid) * seqlen_q + row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) + if IS_EVEN_MN: - tl.store(lse_ptr + row_idx, lse) + tl.store(p_lse + row_idx, lse) else: - tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q) + tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) @triton.heuristics( @@ -1288,6 +1297,7 @@ def mha_fwd( window_size_left, window_size_right, return_softmax, + disable_splitkv=False ): q_dtype = q.dtype q_device = q.device @@ -1405,7 +1415,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks): has_alibi = False # Set SWA params - is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + is_causal = window_size_left < 0 and window_size_right == 0 + is_local = (window_size_left >= 0 and window_size_right >= 0) # ONLY EVEN_K IS SUPPORTED assert head_size == head_size_rounded @@ -1423,7 +1434,7 @@ def dispatch(B, H, Q, K, D): # return kernel, default_args, None, None # Try splitkv - if not is_dropout and not is_local: + if not is_dropout and not is_local and not disable_splitkv: BM = block_m_splitkv_heuristic(D) n_tasks = B * H * triton.cdiv(seqlen_q, BM) BN = block_n_splitkv_heuristic(D) @@ -1434,7 +1445,7 @@ def dispatch(B, H, Q, K, D): n_splits = 32 n_blocks = triton.cdiv(K, BN) blocks_per_split = triton.cdiv(n_blocks, n_splits) - print('block_n:', block_n) + print('block_n:', BN) print('n_splits:', n_splits) print('blocks_per_split', blocks_per_split) @@ -1487,6 +1498,14 @@ def dispatch(B, H, Q, K, D): kernel1, kernel1_args, kernel2, kernel2_args = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size) + if _debug: + p = torch.empty( + (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), + dtype=torch.float32, + device=q_device + ) + return_softmax = True + prefilled_args = { "Q_ptr": q, "K_ptr": k, @@ -1539,7 +1558,8 @@ def dispatch(B, H, Q, K, D): print(f'{kernel.name} shared memory:', kernel.metadata.shared) print(f'{kernel.name} num_warps:', kernel.metadata.num_warps) print(f'{kernel.name} num_stages:', kernel.metadata.num_stages) - print(kernel.asm['ttgir']) + # print(kernel.asm['ttgir']) + print('p:', p) # Combine if kernel2 is not None: @@ -1580,7 +1600,8 @@ def flash_attention_forward( window_size_left=None, window_size_right=None, seqused_k=None, - alibi_slopes=None + alibi_slopes=None, + disable_splitkv=False ): logging.debug("GEMS FLASH_ATTENTION_FORWARD") assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet." @@ -1591,8 +1612,14 @@ def flash_attention_forward( assert HEAD_DIM_K in {16, 32, 64, 128, 256} softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5) - non_null_window_left = window_size_left or -1 - non_null_window_right = window_size_right or -1 + if window_size_left is not None: + non_null_window_left = window_size_left + else: + non_null_window_left = -1 + if window_size_right is not None: + non_null_window_right = window_size_right + else: + non_null_window_right = -1 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( query, @@ -1606,6 +1633,7 @@ def flash_attention_forward( non_null_window_left, non_null_window_right, return_debug_mask, + disable_splitkv=disable_splitkv ) return (out, lse, philox_seed, philox_offset, p) diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py index 5af90fe85..ca5e07350 100644 --- a/tests/test_attention_ops.py +++ b/tests/test_attention_ops.py @@ -8,13 +8,16 @@ @pytest.mark.scaled_dot_product_attention -@pytest.mark.parametrize("batch", [8, 16]) -@pytest.mark.parametrize("num_head", [1, 8]) -@pytest.mark.parametrize("q_seq_len", [17, 64, 128]) -@pytest.mark.parametrize("kv_seq_len", [128, 2048]) -@pytest.mark.parametrize("head_size", [64, 128]) -@pytest.mark.parametrize("add_bias", [True, False]) -@pytest.mark.parametrize("is_causal", [True, False]) +@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (1, 1, 128, 2048), + (4, 8, 1024, 1024), + (4, 8, 1024, 128), + (4, 8, 17, 1030) + ]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("add_bias", [False]) +@pytest.mark.parametrize("is_causal", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_dot_product_attention( batch, num_head, q_seq_len, kv_seq_len, head_size, add_bias, is_causal, dtype @@ -48,15 +51,57 @@ def test_scaled_dot_product_attention( scale = float(1.0 / np.sqrt(head_size)) - if is_causal: - torch_result = torch.nn.functional.scaled_dot_product_attention( - ref_query, - ref_key, - ref_value, + if is_causal and q_seq_len != kv_seq_len: + # Pytorch treats non-square causal as a special case where the sdp attention + # does not route to flash attn and instead uses mem-eff attn. + # In this case, we directly compare on the lower level _flash_attention_forward + q = ref_query.transpose(1, 2) + k = ref_key.transpose(1, 2) + v = ref_value.transpose(1, 2) + out, *_ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + 0, + is_causal, + False, scale=scale, - is_causal=is_causal, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None ) + torch_result = out.transpose(1, 2) + + with flag_gems.use_gems(): + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + out, *_ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + 0, + is_causal, + False, + scale=scale, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None + ) + flaggem_result = out.transpose(1, 2) else: + attn_mask = None if is_causal else ref_attn_bias + torch_result = torch.nn.functional.scaled_dot_product_attention( ref_query, ref_key, @@ -66,13 +111,107 @@ def test_scaled_dot_product_attention( is_causal=is_causal, ) - with flag_gems.use_gems(): - if is_causal: - flaggem_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, scale=scale, is_causal=is_causal - ) - else: + with flag_gems.use_gems(): + attn_mask = None if is_causal else attn_bias flaggem_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attn_bias, scale=scale, is_causal=is_causal ) + + gems_assert_close(flaggem_result, torch_result, dtype) + + +@pytest.mark.flash_attention_forward +@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (1, 1, 128, 2048), + (8, 32, 1024, 1024), + (8, 32, 1024, 128), + (8, 32, 17, 1030) + ]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize(["window_size_left", "window_size_right"], + [ + (None, None), + (256, 0), + (128, 128) + ]) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_attention_forward( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, window_size_left, window_size_right, dtype +): + np.random.seed(0) + np_query = np.random.uniform( + -0.05, 0.05, (batch, num_head, q_seq_len, head_size) + ).astype(np.float32) + np_key = np.random.uniform( + -0.05, 0.05, (batch, num_head, kv_seq_len, head_size) + ).astype(np.float32) + np_value = np.random.uniform( + -0.05, 0.05, (batch, num_head, kv_seq_len, head_size) + ).astype(np.float32) + np_attn_bias = np.random.uniform( + -0.05, 0.05, (batch, num_head, q_seq_len, kv_seq_len) + ).astype(np.float32) + + query = torch.tensor(np_query, device="cuda", dtype=dtype) + key = torch.tensor(np_key, device="cuda", dtype=dtype) + value = torch.tensor(np_value, device="cuda", dtype=dtype) + + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + dropout_p = 0 + return_debug_mask = False + + q = ref_query.transpose(1, 2) + k = ref_key.transpose(1, 2) + v = ref_value.transpose(1, 2) + out, lse, _, _, _ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + seqused_k=None, + alibi_slopes=None + ) + torch_result = out.transpose(1, 2) + torch_lse = lse.transpose(1, 2) + + with flag_gems.use_gems(): + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + out, lse, _, _, _ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + seqused_k=None, + alibi_slopes=None + ) + flaggem_result = out.transpose(1, 2) + flaggem_lse = lse.transpose(1, 2) + gems_assert_close(flaggem_result, torch_result, dtype) + gems_assert_close(flaggem_lse, torch_lse, torch.float) \ No newline at end of file From 2aad10d6414a14aa5edb910c73cb1f25363a5e7d Mon Sep 17 00:00:00 2001 From: Tongxin Bai <waffle.bai@gmail.com> Date: Fri, 28 Mar 2025 08:24:44 +0000 Subject: [PATCH 25/25] Non-square, causal, swa, all pass. --- src/flag_gems/ops/attention.py | 52 +++++---- tests/test_attention_ops.py | 204 +++++++++++++++++++-------------- 2 files changed, 149 insertions(+), 107 deletions(-) diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 6bc3d1b99..c6e0f5b4a 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -538,8 +538,9 @@ def apply_mask( ): need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn) if need_mask: + # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left) - col_rb = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + ws_right) + col_rb = min(max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + ws_right) if has_alibi: S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None]) @@ -579,8 +580,8 @@ def softmax_rescale( O_acc *= p_scale[:, None] max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e) - P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) + cur_rowsum = tl.sum(P, 1) row_sum = row_sum + tl.sum(P, 1) return O_acc, P, row_max, row_sum @@ -599,9 +600,12 @@ def block_m_splitkv_heuristic(headdim): def block_n_splitkv_heuristic(headdim): return 64 if headdim <= 64 else 32 -def is_even_mn(args): - even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0) - return even_mn +def is_even_mn(M, N, BM, BN, WL, WR): + if M % BM == 0 and N % BN == 0: + if M % N == 0 or N % M == 0: + if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0): + return True + return False @triton.heuristics( @@ -611,7 +615,7 @@ def is_even_mn(args): 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: False, - 'IS_EVEN_MN': lambda args: is_even_mn(args), + 'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) @@ -667,17 +671,28 @@ def flash_fwd_kernel( m_block = tl.program_id(0) bid = tl.program_id(1) hid = tl.program_id(2) + num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) + + # We draw a minimum covering frame on the attention map that this CTA is assigned to process. + # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. + col_min = 0 if is_local: - col_min = m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left - col_min = max(col_min, 0) - else: - col_min = 0 + col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) + if not IS_EVEN_MN: + # round left + col_min = (col_min // BLOCK_N) * BLOCK_N - col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N + col_max = seqlen_k if is_causal or is_local: - rounded_kv_end = max(0, tl.cdiv(seqlen_k - seqlen_q + ws_right, BLOCK_N) * BLOCK_N + (m_block + 1) * BLOCK_M) - col_max = min(col_max, rounded_kv_end) + col_max += (m_block - num_m_blocks + 1) * BLOCK_M + if is_local: + col_max += ws_right + col_max = min(seqlen_k, col_max) + + if not IS_EVEN_MN: + # round right + col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N if (not is_causal) and (not is_local): if IS_EVEN_MN: @@ -690,7 +705,6 @@ def flash_fwd_kernel( # local masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N - if has_alibi: alibi_offset = bid * slopes_batch_stride + hid alibi_slope = tl.load(pSlopes + alibi_offset) @@ -734,7 +748,7 @@ def flash_fwd_kernel( if is_causal | is_local | (not IS_EVEN_MN): # Cut short masking cols if there's not enough cols out there - masking_cols = min(col_max, masking_cols) + masking_cols = min(col_max - col_min, masking_cols) for col_shift in tl.range(0, masking_cols, step=BLOCK_N): col_start = col_max - col_shift - BLOCK_N col_start = tl.multiple_of(col_start, BLOCK_N) @@ -907,9 +921,8 @@ def flash_fwd_kernel( # LSE # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only. lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_)) - - # Rescale output inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) + if is_dropout: O_ *= inv_sum[:, None] * rpdrop else: @@ -945,7 +958,7 @@ def flash_fwd_kernel( 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, - 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), + 'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) @@ -1009,7 +1022,7 @@ def flash_fwd_bh_parallel_kernel( 'num_warps': lambda args: 4, 'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2, 'PRE_LOAD_V': lambda args: True, - 'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0), + 'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]), } ) @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"]) @@ -1575,7 +1588,6 @@ def dispatch(B, H, Q, K, D): args_copy.update(kernel2_args) kernel2(**args_copy) - if swap_seq_and_group: out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size)) diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py index ca5e07350..458fc8d31 100644 --- a/tests/test_attention_ops.py +++ b/tests/test_attention_ops.py @@ -7,21 +7,7 @@ from .accuracy_utils import gems_assert_close, to_reference -@pytest.mark.scaled_dot_product_attention -@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], - [ - (1, 1, 128, 2048), - (4, 8, 1024, 1024), - (4, 8, 1024, 128), - (4, 8, 17, 1030) - ]) -@pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("add_bias", [False]) -@pytest.mark.parametrize("is_causal", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_scaled_dot_product_attention( - batch, num_head, q_seq_len, kv_seq_len, head_size, add_bias, is_causal, dtype -): +def make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype): np.random.seed(0) np_query = np.random.uniform( -0.05, 0.05, (batch, num_head, q_seq_len, head_size) @@ -39,25 +25,126 @@ def test_scaled_dot_product_attention( query = torch.tensor(np_query, device="cuda", dtype=dtype) key = torch.tensor(np_key, device="cuda", dtype=dtype) value = torch.tensor(np_value, device="cuda", dtype=dtype) - if add_bias: - attn_bias = torch.tensor(np_attn_bias, device="cuda", dtype=dtype) - else: - attn_bias = None + + return query, key, value + + +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (4, 8, 1024, 1024), + ]) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_square_qk_even_mn( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype) + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + torch_result = torch.nn.functional.scaled_dot_product_attention( + ref_query, + ref_key, + ref_value, + attn_mask=None, + scale=scale, + is_causal=is_causal, + ) + + with flag_gems.use_gems(): + flaggem_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, scale=scale, is_causal=is_causal + ) + + gems_assert_close(flaggem_result, torch_result, dtype) + +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (1, 1, 128, 2048), + (4, 8, 1024, 128), + (4, 8, 17, 1030) + ]) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_nonsquare_qk( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype) ref_query = to_reference(query, False) ref_key = to_reference(key, False) ref_value = to_reference(value, False) - ref_attn_bias = to_reference(attn_bias, False) if add_bias else None scale = float(1.0 / np.sqrt(head_size)) + torch_result = torch.nn.functional.scaled_dot_product_attention( + ref_query, + ref_key, + ref_value, + attn_mask=None, + scale=scale, + is_causal=is_causal, + ) + + with flag_gems.use_gems(): + flaggem_result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, scale=scale, is_causal=is_causal + ) + + gems_assert_close(flaggem_result, torch_result, dtype) - if is_causal and q_seq_len != kv_seq_len: - # Pytorch treats non-square causal as a special case where the sdp attention - # does not route to flash attn and instead uses mem-eff attn. - # In this case, we directly compare on the lower level _flash_attention_forward - q = ref_query.transpose(1, 2) - k = ref_key.transpose(1, 2) - v = ref_value.transpose(1, 2) + +@pytest.mark.scaled_dot_product_attention +@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"], + [ + (1, 1, 128, 2048), + (4, 8, 1024, 128), + (4, 8, 17, 1030) + ]) +@pytest.mark.parametrize("head_size", [64, 128, 256]) +@pytest.mark.parametrize("is_causal", [True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_nonsquare_qk_causal( + batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype +): + query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype) + ref_query = to_reference(query, False) + ref_key = to_reference(key, False) + ref_value = to_reference(value, False) + + scale = float(1.0 / np.sqrt(head_size)) + + q = ref_query.transpose(1, 2) + k = ref_key.transpose(1, 2) + v = ref_value.transpose(1, 2) + out, *_ = torch.ops.aten._flash_attention_forward( + q, + k, + v, + None, + None, + q.shape[-3], + k.shape[-3], + 0, + is_causal, + False, + scale=scale, + window_size_left=None, + window_size_right=None, + seqused_k=None, + alibi_slopes=None + ) + torch_result = out.transpose(1, 2) + + with flag_gems.use_gems(): + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) out, *_ = torch.ops.aten._flash_attention_forward( q, k, @@ -75,47 +162,7 @@ def test_scaled_dot_product_attention( seqused_k=None, alibi_slopes=None ) - torch_result = out.transpose(1, 2) - - with flag_gems.use_gems(): - q = query.transpose(1, 2) - k = key.transpose(1, 2) - v = value.transpose(1, 2) - out, *_ = torch.ops.aten._flash_attention_forward( - q, - k, - v, - None, - None, - q.shape[-3], - k.shape[-3], - 0, - is_causal, - False, - scale=scale, - window_size_left=None, - window_size_right=None, - seqused_k=None, - alibi_slopes=None - ) - flaggem_result = out.transpose(1, 2) - else: - attn_mask = None if is_causal else ref_attn_bias - - torch_result = torch.nn.functional.scaled_dot_product_attention( - ref_query, - ref_key, - ref_value, - attn_mask=ref_attn_bias, - scale=scale, - is_causal=is_causal, - ) - - with flag_gems.use_gems(): - attn_mask = None if is_causal else attn_bias - flaggem_result = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attn_bias, scale=scale, is_causal=is_causal - ) + flaggem_result = out.transpose(1, 2) gems_assert_close(flaggem_result, torch_result, dtype) @@ -131,32 +178,15 @@ def test_scaled_dot_product_attention( @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize(["window_size_left", "window_size_right"], [ - (None, None), (256, 0), (128, 128) ]) @pytest.mark.parametrize("is_causal", [False]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_flash_attention_forward( +def test_flash_fwd_swa( batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, window_size_left, window_size_right, dtype ): - np.random.seed(0) - np_query = np.random.uniform( - -0.05, 0.05, (batch, num_head, q_seq_len, head_size) - ).astype(np.float32) - np_key = np.random.uniform( - -0.05, 0.05, (batch, num_head, kv_seq_len, head_size) - ).astype(np.float32) - np_value = np.random.uniform( - -0.05, 0.05, (batch, num_head, kv_seq_len, head_size) - ).astype(np.float32) - np_attn_bias = np.random.uniform( - -0.05, 0.05, (batch, num_head, q_seq_len, kv_seq_len) - ).astype(np.float32) - - query = torch.tensor(np_query, device="cuda", dtype=dtype) - key = torch.tensor(np_key, device="cuda", dtype=dtype) - value = torch.tensor(np_value, device="cuda", dtype=dtype) + query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype) ref_query = to_reference(query, False) ref_key = to_reference(key, False)