From 4a63a7e7498b73bd557f4bfd4da9934c40606aaa Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Sat, 1 Mar 2025 09:55:37 +0000
Subject: [PATCH 01/25] Make flash attn compatible with flash_attn v2 api. WIP.

---
 OperatorList.md                     |   2 +-
 src/flag_gems/__init__.py           | 394 ++++++++++++++--------------
 src/flag_gems/ops/__init__.py       |   4 +-
 src/flag_gems/ops/attention.py      | 280 ++++++++++++++++++++
 src/flag_gems/ops/dropout.py        |   7 +-
 src/flag_gems/utils/random_utils.py |   2 +-
 6 files changed, 483 insertions(+), 206 deletions(-)

diff --git a/OperatorList.md b/OperatorList.md
index 1f6702a5c..5471e7aaf 100644
--- a/OperatorList.md
+++ b/OperatorList.md
@@ -86,7 +86,7 @@
 - nll_loss
 - nll_loss_forward
 - nll_loss_nd
-- scaled_dot_product_attention
+- _flash_attention_forward
 - upsample_nearest2d
 - _fft_c2r
 - _fft_r2c
diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py
index c50d66bea..8cf202e3e 100644
--- a/src/flag_gems/__init__.py
+++ b/src/flag_gems/__init__.py
@@ -19,206 +19,206 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
     global current_work_registrar
     current_work_registrar = registrar(
         (
-            ("abs", abs, Autograd.disable),
-            ("add.Tensor", add, Autograd.disable),
-            ("addmm", addmm, Autograd.disable),
-            ("arange.start_step", arange_start, Autograd.disable),
-            ("arange.start", arange_start, Autograd.disable),
-            ("arange", arange, Autograd.disable),
-            ("batch_norm", batch_norm, Autograd.enable),
-            ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
-            ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
-            ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
-            ("bitwise_not", bitwise_not, Autograd.disable),
-            ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable),
-            ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable),
-            ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable),
-            ("bmm", bmm, Autograd.disable),
-            ("clamp", clamp, Autograd.disable),
-            ("clamp.Tensor", clamp_tensor, Autograd.disable),
-            ("cos", cos, Autograd.disable),
-            ("pad", pad, Autograd.disable),
-            ("constant_pad_nd", constant_pad_nd, Autograd.disable),
-            ("cumsum", cumsum, Autograd.disable),
-            ("cummin", cummin, Autograd.disable),
-            ("div.Tensor", true_divide, Autograd.disable),
-            ("div.Scalar", true_divide, Autograd.disable),
-            ("div.Tensor_mode", div_mode, Autograd.disable),
-            ("div.Scalar_mode", div_mode, Autograd.disable),
+            # ("abs", abs, Autograd.disable),
+            # ("add.Tensor", add, Autograd.disable),
+            # ("addmm", addmm, Autograd.disable),
+            # ("arange.start_step", arange_start, Autograd.disable),
+            # ("arange.start", arange_start, Autograd.disable),
+            # ("arange", arange, Autograd.disable),
+            # ("batch_norm", batch_norm, Autograd.enable),
+            # ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
+            # ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
+            # ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
+            # ("bitwise_not", bitwise_not, Autograd.disable),
+            # ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable),
+            # ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable),
+            # ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable),
+            # # ("bmm", bmm, Autograd.disable),
+            # ("clamp", clamp, Autograd.disable),
+            # ("clamp.Tensor", clamp_tensor, Autograd.disable),
+            # ("cos", cos, Autograd.disable),
+            # ("pad", pad, Autograd.disable),
+            # ("constant_pad_nd", constant_pad_nd, Autograd.disable),
+            # ("cumsum", cumsum, Autograd.disable),
+            # ("cummin", cummin, Autograd.disable),
+            # ("div.Tensor", true_divide, Autograd.disable),
+            # ("div.Scalar", true_divide, Autograd.disable),
+            # ("div.Tensor_mode", div_mode, Autograd.disable),
+            # ("div.Scalar_mode", div_mode, Autograd.disable),
+            # (
+            #     "divide.Tensor",
+            #     true_divide,
+            #     Autograd.disable,
+            # ),  # divide, an alias for div
+            # ("divide.Scalar", true_divide, Autograd.disable),
+            # ("divide.Tensor_mode", div_mode, Autograd.disable),
+            # ("divide.Scalar_mode", div_mode, Autograd.disable),
+            # (
+            #     "true_divide.Tensor",
+            #     true_divide,
+            #     Autograd.disable,
+            # ),  # true_divide, an alias for div
+            # ("true_divide.Scalar", true_divide, Autograd.disable),
+            # ("floor_divide", floor_divide, Autograd.disable),
+            # ("floor_divide.Scalar", floor_divide, Autograd.disable),
+            # ("remainder.Tensor", remainder, Autograd.disable),
+            # ("native_dropout", native_dropout, Autograd.enable),
+            # ("erf", erf, Autograd.disable),
+            # ("embedding", embedding, Autograd.enable),
+            # ("eq.Tensor", eq, Autograd.disable),
+            # ("eq.Scalar", eq_scalar, Autograd.disable),
+            # ("exp", exp, Autograd.disable),
+            # ("exponential_", exponential_, Autograd.disable),
+            # ("ge.Tensor", ge, Autograd.disable),
+            # ("ge.Scalar", ge_scalar, Autograd.disable),
+            # ("gelu", gelu, Autograd.enable),
+            # ("native_group_norm", group_norm, Autograd.enable),
+            # ("_weight_norm_interface", weight_norm_interface, Autograd.enable),
+            # ("_weight_norm", weight_norm, Autograd.enable),
+            # ("gt.Tensor", gt, Autograd.disable),
+            # ("gt.Scalar", gt_scalar, Autograd.disable),
+            # ("instance_norm", instance_norm, Autograd.enable),
+            # ("isfinite", isfinite, Autograd.disable),
+            # ("isin.Tensor_Tensor", isin, Autograd.disable),
+            # ("isin.Scalar_Tensor", isin, Autograd.disable),
+            # ("isin.Tensor_Scalar", isin, Autograd.disable),
+            # ("isinf", isinf, Autograd.disable),
+            # ("isnan", isnan, Autograd.disable),
+            # ("minimum", minimum, Autograd.disable),
+            # ("maximum", maximum, Autograd.disable),
+            # ("native_layer_norm", layer_norm, Autograd.enable),
+            # ("le.Tensor", le, Autograd.disable),
+            # ("le.Scalar", le_scalar, Autograd.disable),
+            # ("lt.Tensor", lt, Autograd.disable),
+            # ("lt.Scalar", lt_scalar, Autograd.disable),
+            # ("rms_norm", rms_norm, Autograd.disable),
+            # ("rand", rand, Autograd.disable),
+            # ("randn", randn, Autograd.disable),
+            # ("rand_like", rand_like, Autograd.disable),
+            # ("randn_like", randn_like, Autograd.disable),
+            # ("zeros", zeros, Autograd.disable),
+            # ("ones", ones, Autograd.disable),
+            # ("full", full, Autograd.disable),
+            # ("zeros_like", zeros_like, Autograd.disable),
+            # ("ones_like", ones_like, Autograd.disable),
+            # ("full_like", full_like, Autograd.disable),
+            # ("resolve_neg", resolve_neg, Autograd.disable),
+            # ("resolve_conj", resolve_conj, Autograd.disable),
+            # ("normal.Tensor_float", normal_tensor_float, Autograd.disable),
+            # ("normal.float_Tensor", normal_float_tensor, Autograd.disable),
+            # ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable),
+            # ("uniform_", uniform_, Autograd.disable),
+            # ("mean", mean, Autograd.disable),
+            # ("mean.dim", mean_dim, Autograd.disable),
+            # ("mm", mm, Autograd.disable),
+            # ("mul.Tensor", mul, Autograd.disable),
+            # ("multinomial", multinomial, Autograd.disable),
+            # ("mv", mv, Autograd.disable),
+            # ("ne.Tensor", ne, Autograd.disable),
+            # ("ne.Scalar", ne_scalar, Autograd.disable),
+            # ("neg", neg, Autograd.disable),
+            # ("pow.Scalar", pow_scalar, Autograd.disable),
+            # ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable),
+            # ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable),
+            # ("reciprocal", reciprocal, Autograd.disable),
+            # ("relu", relu, Autograd.enable),
+            # ("rsqrt", rsqrt, Autograd.disable),
+            # ("sigmoid", sigmoid, Autograd.enable),
+            # ("silu", silu, Autograd.enable),
+            # ("sin", sin, Autograd.disable),
+            # ("softmax.int", softmax, Autograd.enable),
+            # ("sort", sort, Autograd.disable),
+            # ("sub.Tensor", sub, Autograd.disable),
+            # ("tanh", tanh, Autograd.enable),
+            # ("triu", triu, Autograd.disable),
+            # ("topk", topk, Autograd.disable),
+            # ("var_mean.correction", var_mean, Autograd.disable),
+            # ("linalg_vector_norm", vector_norm, Autograd.disable),
+            # ("where.self_out", where_self_out, Autograd.disable),
+            # ("where.self", where_self, Autograd.disable),
+            # ("where.ScalarSelf", where_scalar_self, Autograd.disable),
+            # ("where.ScalarOther", where_scalar_other, Autograd.disable),
+            # ("max", max, Autograd.disable),
+            # ("max.dim", max_dim, Autograd.disable),
+            # ("min", min, Autograd.disable),
+            # ("min.dim", min_dim, Autograd.disable),
+            # ("amax", amax, Autograd.disable),
+            # ("argmax", argmax, Autograd.disable),
+            # ("argmin", argmin, Autograd.disable),
+            # ("prod", prod, Autograd.disable),
+            # ("prod.dim_int", prod_dim, Autograd.disable),
+            # ("sum", sum, Autograd.disable),
+            # ("sum.dim_IntList", sum_dim, Autograd.disable),
             (
-                "divide.Tensor",
-                true_divide,
-                Autograd.disable,
-            ),  # divide, an alias for div
-            ("divide.Scalar", true_divide, Autograd.disable),
-            ("divide.Tensor_mode", div_mode, Autograd.disable),
-            ("divide.Scalar_mode", div_mode, Autograd.disable),
-            (
-                "true_divide.Tensor",
-                true_divide,
-                Autograd.disable,
-            ),  # true_divide, an alias for div
-            ("true_divide.Scalar", true_divide, Autograd.disable),
-            ("floor_divide", floor_divide, Autograd.disable),
-            ("floor_divide.Scalar", floor_divide, Autograd.disable),
-            ("remainder.Tensor", remainder, Autograd.disable),
-            ("native_dropout", native_dropout, Autograd.enable),
-            ("erf", erf, Autograd.disable),
-            ("embedding", embedding, Autograd.enable),
-            ("eq.Tensor", eq, Autograd.disable),
-            ("eq.Scalar", eq_scalar, Autograd.disable),
-            ("exp", exp, Autograd.disable),
-            ("exponential_", exponential_, Autograd.disable),
-            ("ge.Tensor", ge, Autograd.disable),
-            ("ge.Scalar", ge_scalar, Autograd.disable),
-            ("gelu", gelu, Autograd.enable),
-            ("native_group_norm", group_norm, Autograd.enable),
-            ("_weight_norm_interface", weight_norm_interface, Autograd.enable),
-            ("_weight_norm", weight_norm, Autograd.enable),
-            ("gt.Tensor", gt, Autograd.disable),
-            ("gt.Scalar", gt_scalar, Autograd.disable),
-            ("instance_norm", instance_norm, Autograd.enable),
-            ("isfinite", isfinite, Autograd.disable),
-            ("isin.Tensor_Tensor", isin, Autograd.disable),
-            ("isin.Scalar_Tensor", isin, Autograd.disable),
-            ("isin.Tensor_Scalar", isin, Autograd.disable),
-            ("isinf", isinf, Autograd.disable),
-            ("isnan", isnan, Autograd.disable),
-            ("minimum", minimum, Autograd.disable),
-            ("maximum", maximum, Autograd.disable),
-            ("native_layer_norm", layer_norm, Autograd.enable),
-            ("le.Tensor", le, Autograd.disable),
-            ("le.Scalar", le_scalar, Autograd.disable),
-            ("lt.Tensor", lt, Autograd.disable),
-            ("lt.Scalar", lt_scalar, Autograd.disable),
-            ("rms_norm", rms_norm, Autograd.disable),
-            ("rand", rand, Autograd.disable),
-            ("randn", randn, Autograd.disable),
-            ("rand_like", rand_like, Autograd.disable),
-            ("randn_like", randn_like, Autograd.disable),
-            ("zeros", zeros, Autograd.disable),
-            ("ones", ones, Autograd.disable),
-            ("full", full, Autograd.disable),
-            ("zeros_like", zeros_like, Autograd.disable),
-            ("ones_like", ones_like, Autograd.disable),
-            ("full_like", full_like, Autograd.disable),
-            ("resolve_neg", resolve_neg, Autograd.disable),
-            ("resolve_conj", resolve_conj, Autograd.disable),
-            ("normal.Tensor_float", normal_tensor_float, Autograd.disable),
-            ("normal.float_Tensor", normal_float_tensor, Autograd.disable),
-            ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable),
-            ("uniform_", uniform_, Autograd.disable),
-            ("mean", mean, Autograd.disable),
-            ("mean.dim", mean_dim, Autograd.disable),
-            ("mm", mm, Autograd.disable),
-            ("mul.Tensor", mul, Autograd.disable),
-            ("multinomial", multinomial, Autograd.disable),
-            ("mv", mv, Autograd.disable),
-            ("ne.Tensor", ne, Autograd.disable),
-            ("ne.Scalar", ne_scalar, Autograd.disable),
-            ("neg", neg, Autograd.disable),
-            ("pow.Scalar", pow_scalar, Autograd.disable),
-            ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable),
-            ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable),
-            ("reciprocal", reciprocal, Autograd.disable),
-            ("relu", relu, Autograd.enable),
-            ("rsqrt", rsqrt, Autograd.disable),
-            ("sigmoid", sigmoid, Autograd.enable),
-            ("silu", silu, Autograd.enable),
-            ("sin", sin, Autograd.disable),
-            ("softmax.int", softmax, Autograd.enable),
-            ("sort", sort, Autograd.disable),
-            ("sub.Tensor", sub, Autograd.disable),
-            ("tanh", tanh, Autograd.enable),
-            ("triu", triu, Autograd.disable),
-            ("topk", topk, Autograd.disable),
-            ("var_mean.correction", var_mean, Autograd.disable),
-            ("linalg_vector_norm", vector_norm, Autograd.disable),
-            ("where.self_out", where_self_out, Autograd.disable),
-            ("where.self", where_self, Autograd.disable),
-            ("where.ScalarSelf", where_scalar_self, Autograd.disable),
-            ("where.ScalarOther", where_scalar_other, Autograd.disable),
-            ("max", max, Autograd.disable),
-            ("max.dim", max_dim, Autograd.disable),
-            ("min", min, Autograd.disable),
-            ("min.dim", min_dim, Autograd.disable),
-            ("amax", amax, Autograd.disable),
-            ("argmax", argmax, Autograd.disable),
-            ("argmin", argmin, Autograd.disable),
-            ("prod", prod, Autograd.disable),
-            ("prod.dim_int", prod_dim, Autograd.disable),
-            ("sum", sum, Autograd.disable),
-            ("sum.dim_IntList", sum_dim, Autograd.disable),
-            (
-                "scaled_dot_product_attention",
-                scaled_dot_product_attention,
-                Autograd.disable,
-            ),
-            ("all", all, Autograd.disable),
-            ("all.dim", all_dim, Autograd.disable),
-            ("all.dims", all_dims, Autograd.disable),
-            ("any", any, Autograd.disable),
-            ("any.dim", any_dim, Autograd.disable),
-            ("any.dims", any_dims, Autograd.disable),
-            ("quantile", quantile, Autograd.disable),
-            ("log_softmax.int", log_softmax, Autograd.enable),
-            ("outer", outer, Autograd.enable),
-            ("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
-            ("nll_loss_forward", nll_loss_forward, Autograd.disable),
-            ("nll_loss_backward", nll_loss_backward, Autograd.disable),
-            ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
-            ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
-            ("scatter.src", scatter, Autograd.disable),
-            ("scatter.reduce", scatter, Autograd.disable),
-            ("gather", gather, Autograd.disable),
-            ("gather_backward", gather_backward, Autograd.disable),
-            ("isclose", isclose, Autograd.disable),
-            ("allclose", allclose, Autograd.disable),
-            ("fill.Scalar", fill_scalar, Autograd.disable),
-            ("fill.Tensor", fill_tensor, Autograd.disable),
-            ("flip", flip, Autograd.disable),
-            ("slice_scatter", slice_scatter, Autograd.disable),
-            ("select_scatter", select_scatter, Autograd.disable),
-            ("index_select", index_select, Autograd.disable),
-            ("tile", tile, Autograd.disable),
-            ("masked_fill.Tensor", masked_fill, Autograd.disable),
-            ("masked_fill.Scalar", masked_fill, Autograd.disable),
-            ("masked_fill_.Tensor", masked_fill_, Autograd.disable),
-            ("masked_fill_.Scalar", masked_fill_, Autograd.disable),
-            ("_unique2", _unique2, Autograd.disable),
-            ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable),
-            ("upsample_nearest2d", upsample_nearest2d, Autograd.disable),
-            ("nonzero", nonzero, Autograd.disable),
-            ("repeat", repeat, Autograd.disable),
-            ("masked_select", masked_select, Autograd.disable),
-            ("stack", stack, Autograd.disable),
-            ("hstack", hstack, Autograd.disable),
-            ("cat", cat, Autograd.disable),
-            (
-                "repeat_interleave.self_int",
-                repeat_interleave_self_int,
-                Autograd.disable,
-            ),
-            ("vstack", vstack, Autograd.disable),
-            ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable),
-            (
-                "repeat_interleave.self_Tensor",
-                repeat_interleave_self_tensor,
+                "_flash_attention_forward",
+                flash_attention_forward,
                 Autograd.disable,
             ),
-            ("randperm", randperm, Autograd.disable),
-            ("diag", diag, Autograd.disable),
-            ("diag_embed", diag_embed, Autograd.disable),
-            ("diagonal_backward", diagonal_backward, Autograd.disable),
-            ("index_add", index_add, Autograd.disable),
-            ("count_nonzero", count_nonzero, Autograd.disable),
-            ("logical_or", logical_or, Autograd.disable),
-            ("logical_and", logical_and, Autograd.disable),
-            ("logical_xor", logical_xor, Autograd.disable),
-            ("logical_not", logical_not, Autograd.disable),
-            ("log_sigmoid", log_sigmoid, Autograd.disable),
-            ("vdot", vdot, Autograd.disable),
-            ("mse_loss", mse_loss, Autograd.disable),
+            # ("all", all, Autograd.disable),
+            # ("all.dim", all_dim, Autograd.disable),
+            # ("all.dims", all_dims, Autograd.disable),
+            # ("any", any, Autograd.disable),
+            # ("any.dim", any_dim, Autograd.disable),
+            # ("any.dims", any_dims, Autograd.disable),
+            # ("quantile", quantile, Autograd.disable),
+            # ("log_softmax.int", log_softmax, Autograd.enable),
+            # ("outer", outer, Autograd.enable),
+            # ("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
+            # ("nll_loss_forward", nll_loss_forward, Autograd.disable),
+            # ("nll_loss_backward", nll_loss_backward, Autograd.disable),
+            # ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
+            # ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
+            # ("scatter.src", scatter, Autograd.disable),
+            # ("scatter.reduce", scatter, Autograd.disable),
+            # ("gather", gather, Autograd.disable),
+            # ("gather_backward", gather_backward, Autograd.disable),
+            # ("isclose", isclose, Autograd.disable),
+            # ("allclose", allclose, Autograd.disable),
+            # ("fill.Scalar", fill_scalar, Autograd.disable),
+            # ("fill.Tensor", fill_tensor, Autograd.disable),
+            # ("flip", flip, Autograd.disable),
+            # ("slice_scatter", slice_scatter, Autograd.disable),
+            # ("select_scatter", select_scatter, Autograd.disable),
+            # ("index_select", index_select, Autograd.disable),
+            # ("tile", tile, Autograd.disable),
+            # ("masked_fill.Tensor", masked_fill, Autograd.disable),
+            # ("masked_fill.Scalar", masked_fill, Autograd.disable),
+            # ("masked_fill_.Tensor", masked_fill_, Autograd.disable),
+            # ("masked_fill_.Scalar", masked_fill_, Autograd.disable),
+            # ("_unique2", _unique2, Autograd.disable),
+            # ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable),
+            # ("upsample_nearest2d", upsample_nearest2d, Autograd.disable),
+            # ("nonzero", nonzero, Autograd.disable),
+            # ("repeat", repeat, Autograd.disable),
+            # ("masked_select", masked_select, Autograd.disable),
+            # ("stack", stack, Autograd.disable),
+            # ("hstack", hstack, Autograd.disable),
+            # ("cat", cat, Autograd.disable),
+            # (
+            #     "repeat_interleave.self_int",
+            #     repeat_interleave_self_int,
+            #     Autograd.disable,
+            # ),
+            # ("vstack", vstack, Autograd.disable),
+            # ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable),
+            # (
+            #     "repeat_interleave.self_Tensor",
+            #     repeat_interleave_self_tensor,
+            #     Autograd.disable,
+            # ),
+            # ("randperm", randperm, Autograd.disable),
+            # ("diag", diag, Autograd.disable),
+            # ("diag_embed", diag_embed, Autograd.disable),
+            # ("diagonal_backward", diagonal_backward, Autograd.disable),
+            # ("index_add", index_add, Autograd.disable),
+            # ("count_nonzero", count_nonzero, Autograd.disable),
+            # ("logical_or", logical_or, Autograd.disable),
+            # ("logical_and", logical_and, Autograd.disable),
+            # ("logical_xor", logical_xor, Autograd.disable),
+            # ("logical_not", logical_not, Autograd.disable),
+            # ("log_sigmoid", log_sigmoid, Autograd.disable),
+            # ("vdot", vdot, Autograd.disable),
+            # ("mse_loss", mse_loss, Autograd.disable),
         ),
         user_unused_ops_list=[] if unused is None else unused,
         lib=lib,
diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py
index f0a1cc9fe..c4051d6d8 100755
--- a/src/flag_gems/ops/__init__.py
+++ b/src/flag_gems/ops/__init__.py
@@ -7,7 +7,7 @@
 from .arange import arange, arange_start
 from .argmax import argmax
 from .argmin import argmin
-from .attention import scaled_dot_product_attention
+from .attention import flash_attention_forward
 from .batch_norm import batch_norm
 from .bitwise_and import (
     bitwise_and_scalar,
@@ -292,7 +292,7 @@
     "repeat_interleave_self_int",
     "vstack",
     "repeat_interleave_tensor",
-    "scaled_dot_product_attention",
+    "flash_attention_forward",
     "conv2d",
     "conv1d",
     "_conv_depthwise2d",
diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 7e16189e8..3c114fcf1 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -5,6 +5,7 @@
 import triton.language as tl
 
 from flag_gems.runtime import torch_device_fn
+from flag_gems.utils.random_utils import update_philox_state
 
 from .. import runtime
 
@@ -305,6 +306,285 @@ def _attn_fwd(
     tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
 
 
+@triton.jit
+def philox_offset_one_warp(b, h, nh: tl.constexpr):
+    # To align with TriDao's implementation, philox_offset linearly determined by
+    # a 3d dense tensor (batch_id, head_id, thread_id) with shape (batch_size, num_heads, 32)
+    # and stride ( num_heads * 32, 32, 1 )
+    return (b * nh + h) * 32 + tl.arange(0, 32)
+
+
+@triton.jit
+def to_lohi(x):
+    return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
+
+
+@triton.jit
+def from_lohi(lo, hi):
+    return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
+
+
+@triton.jit
+def philox_(seed, subsequence, offset):
+    kPhilox10A: tl.constexpr = 0x9E3779B9
+    kPhilox10B: tl.constexpr = 0xBB67AE85
+    k0, k1 = to_lohi(seed.to(tl.uint64))
+    c0, c1 = to_lohi(offset.to(tl.uint64))
+    c2, c3 = to_lohi(subsequence(tl.uint64))
+
+    # pragma unroll
+    kPhiloxSA: tl.constexpr = 0xD2511F53
+    kPhiloxSB: tl.constexpr = 0xCD9E8D57
+    for _ in range(6):
+        res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
+        res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
+        res0_x, res0_y = to_lohi(res0)
+        res1_x, res1_y = to_lohi(res1)
+        c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
+        k0 += kPhilox10A
+        k1 += kPhilox10B
+
+    res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
+    res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
+    res0_x.res0_y = to_lohi(res0)
+    res1_x.res1_y = to_lohi(res1)
+    c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
+
+    return c0, c1, c2, c3
+
+
+@triton.jit
+def apply_mask(
+    P,
+    mask,
+    encode_dropout_in_sign_bit: tl.constexpr,
+):
+    if encode_dropout_in_sign_bit:
+        P = tl.where(mask, -P, P)
+    else:
+        P = tl.where(mask, 0, P)
+    return P
+
+
+@triton.jit
+def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr):
+    r = uint32_r
+    p = uint8_p
+    m0 = tl.where(r & 0xFF < p, 0, 1)
+    r >>= 8
+    m1 = tl.where(r & 0xFF < p, 0, 1)
+    m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
+
+    r >>= 8
+    m0 = tl.where(r & 0xFF < p, 0, 1)
+    r >>= 8
+    m1 = tl.where(r & 0xFF < p, 0, 1)
+    m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
+
+    m = tl.join(m0, m1).trans(2, 0, 1).reshape(4 * M, N)
+    return m
+
+
+@triton.jit(
+    do_not_specialize=[
+        "b",
+        "h",
+        "row_start",
+        "col_start",
+        "philox_seed",
+        "philox_offset",
+    ]
+)
+def apply_dropout(
+    P,
+    row_start,
+    col_start,
+    philox_seed,
+    philox_offset,
+    p_dropout_uint8: tl.constexpr,
+    encode_dropout_in_sign_bit: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+):
+    # P is of size (BLOCK_M, BLOCK_N) and its scalar bitsize is 32
+    # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32
+    M: tl.constexpr = BLOCK_M // 16
+    N: tl.constexpr = BLOCK_N // 32
+    row = row_start + tl.arange(0, M)[:, None]
+    col = col_start + tl.arange(0, BLOCK_N)[None, :] // 32
+
+    philox_offset = philox_offset + tl.arange(0, BLOCK_N)[None, :] % 32
+
+    subsequence = from_lohi(row * 32, col)
+    r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset)
+
+    # Fully unrolled due to triton's inability to concat 2d tensor
+    m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N)
+    m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N)
+    m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N)
+
+    m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N)
+    m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N)
+    m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N)
+
+    m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N)
+    P = apply_mask(P, m)
+    return P
+
+
+@triton.jit
+def flash_fwd_kernel(
+    pQ,
+    pK,
+    pV,
+    pP,
+    pO,
+    pSlopes,
+    philox_seed,
+    philox_offset,
+    pdrop_int8,
+    drop: tl.constexpr,
+    causal: tl.constexpr,
+    scale: tl.constexp,
+    wl: tl.constexpr,
+    wr: tl.constexpr,
+    return_P: tl.constexpr,
+):
+    pass
+
+
+def flash_attention_forward(
+    query,
+    key,
+    value,
+    cum_seq_q,
+    cum_seq_k,
+    max_q,
+    max_k,
+    dropout_p,
+    is_causal,
+    return_debug_mask,
+    *,
+    scale=None,
+    window_size_left=None,
+    window_size_right=None,
+    seqused_k=None,
+    alibi_slopes=None
+):
+    logging.debug("GEMS FLASH_ATTENTION")
+    assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet."
+
+    HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
+    HEAD_DIM_V = value.shape[-1]
+    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
+    assert HEAD_DIM_K in {16, 32, 64, 128, 256}
+
+    softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5)
+    non_null_window_left = window_size_left or -1
+    non_null_window_right = window_size_right or -1
+    out = torch.empty_like(query, dtype=value.dtype)
+
+    mha_out = mha_fwd(
+        query,
+        key,
+        value,
+        out,
+        alibi_slopes,
+        dropout_p,
+        softmax_scale,
+        is_causal,
+        non_null_window_left,
+        non_null_window_right,
+        return_debug_mask,
+    )
+    (
+        output,
+        q_padded,
+        k_padded,
+        v_padded,
+        logsumexp,
+        philox_seed,
+        philox_offset,
+        debug_attn_mask,
+    ) = mha_out
+
+
+def mha_fwd(
+    q,
+    k,
+    v,
+    out,
+    alibi_slopes,
+    p_dropout,
+    softmax_scale,
+    is_causal,
+    window_size_left,
+    window_size_right,
+    return_softmax,
+):
+    q_dtype = q.dtype
+    q_device = q.device
+    assert q_dtype in (
+        torch.float16,
+        torch.bfloat16,
+    ), "FlashAttention only support fp16 and bf16 data type"
+    assert q_dtype == k.dtype
+    assert q_dtype == v.dtype
+    assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
+    assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
+    assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
+    batch_size, seqlen_q, num_heads, head_size = q.size()
+    _, seqlen_k, num_heads_k, _ = k.size()
+    assert (
+        head_size % 8 == 0
+    ), "head_size must be a multiple of 8, this is ensured by padding!"
+    assert (
+        num_heads % num_heads_k == 0
+    ), "Number of heads in key/value must divide number of heads in query"
+    if window_size_left >= seqlen_k:
+        window_size_left = -1
+    if window_size_right >= seqlen_k:
+        window_size_right = -1
+    if seqlen_q == 1 and alibi_slopes is None:
+        is_causal = False
+    if is_causal:
+        window_size_right = 0
+
+    if out:
+        assert out.stride(-1) == 1
+        assert out.dtype == q.dtype
+        assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
+    else:
+        out = torch.empty_like(q)
+
+    round_multiple = lambda x, m: (x + m - 1) // m * m
+    head_size_rounded = round_multiple(head_size, 32)
+    seqlen_q_rounded = round_multiple(seqlen_q, 128)
+    seqlen_k_rounded = round_multiple(seqlen_k, 128)
+
+    lse = torch.empty(
+        (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
+    )
+    if return_softmax:
+        assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0"
+        p = torch.empty(
+            (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
+            dtype=q_dtype,
+            device=q_device,
+        )
+
+    if p_dropout > 0:
+        increment = triton.cdiv(batch_size * num_heads * 32)
+        philox_seed, philox_offset = update_philox_state(increment)
+
+    with torch_device_fn.device(q.device):
+        grid = lambda args: (
+            triton.cdiv(seqlen_q, args["BLOCK_M"]),
+            batch_size * num_heads,
+            1,
+        )
+
+
 def scaled_dot_product_attention(
     query,
     key,
diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py
index 2bcb31a81..7625c76ba 100644
--- a/src/flag_gems/ops/dropout.py
+++ b/src/flag_gems/ops/dropout.py
@@ -4,10 +4,7 @@
 import triton
 import triton.language as tl
 
-from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
-    uint_to_uniform_float,
-)
+from flag_gems.utils.random_utils import uint_to_uniform_float, update_philox_state
 
 from .. import runtime
 from ..runtime import torch_device_fn
@@ -133,7 +130,7 @@ def forward(ctx, x, p, train):
         # hence we cannot obtain the per thread offset as in Pytorch.
         increment = triton.cdiv(N, UNROLL)
         with torch_device_fn.device(device):
-            philox_seed, philox_offset = philox_backend_seed_offset(increment)
+            philox_seed, philox_offset = update_philox_state(increment)
             dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset)
         ctx.p = p
         ctx.philox_seed = philox_seed
diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py
index 22f7155ce..5ca61a8f4 100644
--- a/src/flag_gems/utils/random_utils.py
+++ b/src/flag_gems/utils/random_utils.py
@@ -34,7 +34,7 @@ def uint_to_uniform_float(x):
 # https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452
 # It returns the current state of the default Philox RNG in seed and offset and
 # updates the next offset by adding `increment`.
-def philox_backend_seed_offset(increment, device=None):
+def update_philox_state(increment, device=None):
     device = device or torch_device_fn.current_device()
     gen = torch_device_fn.default_generators[device]
     state_copy = gen.get_state()

From f43a2e04ea7a5ba5f3a391a8176ac7b1b7855107 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Mon, 3 Mar 2025 04:00:38 +0000
Subject: [PATCH 02/25] update kernel wrapper.

---
 src/flag_gems/ops/attention.py | 395 +++++++++++++++++++--------------
 1 file changed, 224 insertions(+), 171 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 3c114fcf1..1804ddcf2 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -306,6 +306,96 @@ def _attn_fwd(
     tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
 
 
+def scaled_dot_product_attention(
+    query,
+    key,
+    value,
+    attn_mask=None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+):
+    logging.debug("GEMS SCALED DOT PRODUCT ATTENTION")
+    # shape constraints
+    HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
+    # when v is in float8_e5m2 it is transposed.
+    HEAD_DIM_V = value.shape[-1]
+    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
+    assert HEAD_DIM_K in {16, 32, 64, 128, 256}
+    assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
+
+    o = torch.empty_like(query, dtype=value.dtype)
+
+    stage = 3 if is_causal else 1
+
+    if scale is None:
+        sm_scale = 1.0 / (HEAD_DIM_K**0.5)
+    else:
+        sm_scale = scale
+
+    kv_head_num = key.shape[1]
+
+    grid = lambda args: (
+        triton.cdiv(query.shape[2], args["BLOCK_M"]),
+        query.shape[0] * query.shape[1],
+        1,
+    )
+
+    if attn_mask is not None:
+        HAS_ATTN_MASK = True
+        stride_attn_mask_batch = attn_mask.stride(0)
+        stride_attn_mask_head = attn_mask.stride(1)
+        stride_attn_mask_q_seqlen = attn_mask.stride(2)
+        stride_attn_mask_kv_seqlen = attn_mask.stride(3)
+    else:
+        HAS_ATTN_MASK = False
+        stride_attn_mask_batch = 1
+        stride_attn_mask_head = 1
+        stride_attn_mask_q_seqlen = 1
+        stride_attn_mask_kv_seqlen = 1
+
+    with torch_device_fn.device(query.device):
+        _attn_fwd[grid](
+            query,
+            key,
+            value,
+            attn_mask,
+            sm_scale,
+            o,  #
+            query.stride(0),
+            query.stride(1),
+            query.stride(2),
+            query.stride(3),  #
+            key.stride(0),
+            key.stride(1),
+            key.stride(2),
+            key.stride(3),  #
+            value.stride(0),
+            value.stride(1),
+            value.stride(2),
+            value.stride(3),  #
+            stride_attn_mask_batch,
+            stride_attn_mask_head,
+            stride_attn_mask_q_seqlen,
+            stride_attn_mask_kv_seqlen,  #
+            o.stride(0),
+            o.stride(1),
+            o.stride(2),
+            o.stride(3),  #
+            query.shape[0],
+            query.shape[1],
+            kv_head_num,  #
+            query.shape[2],  #
+            key.shape[2],  #
+            HEAD_DIM_K,  #
+            STAGE=stage,  #
+            HAS_ATTN_MASK=HAS_ATTN_MASK,  #
+        )
+        return o
+
+
+
 @triton.jit
 def philox_offset_one_warp(b, h, nh: tl.constexpr):
     # To align with TriDao's implementation, philox_offset linearly determined by
@@ -315,12 +405,12 @@ def philox_offset_one_warp(b, h, nh: tl.constexpr):
 
 
 @triton.jit
-def to_lohi(x):
+def u64_to_lohi(x):
     return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
 
 
 @triton.jit
-def from_lohi(lo, hi):
+def u64_from_lohi(lo, hi):
     return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
 
 
@@ -328,9 +418,9 @@ def from_lohi(lo, hi):
 def philox_(seed, subsequence, offset):
     kPhilox10A: tl.constexpr = 0x9E3779B9
     kPhilox10B: tl.constexpr = 0xBB67AE85
-    k0, k1 = to_lohi(seed.to(tl.uint64))
-    c0, c1 = to_lohi(offset.to(tl.uint64))
-    c2, c3 = to_lohi(subsequence(tl.uint64))
+    k0, k1 = u64_to_lohi(seed.to(tl.uint64))
+    c0, c1 = u64_to_lohi(offset.to(tl.uint64))
+    c2, c3 = u64_to_lohi(subsequence(tl.uint64))
 
     # pragma unroll
     kPhiloxSA: tl.constexpr = 0xD2511F53
@@ -338,16 +428,16 @@ def philox_(seed, subsequence, offset):
     for _ in range(6):
         res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
         res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
-        res0_x, res0_y = to_lohi(res0)
-        res1_x, res1_y = to_lohi(res1)
+        res0_x, res0_y = u64_to_lohi(res0)
+        res1_x, res1_y = u64_to_lohi(res1)
         c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
         k0 += kPhilox10A
         k1 += kPhilox10B
 
     res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
     res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
-    res0_x.res0_y = to_lohi(res0)
-    res1_x.res1_y = to_lohi(res1)
+    res0_x.res0_y = u64_to_lohi(res0)
+    res1_x.res1_y = u64_to_lohi(res1)
     c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
 
     return c0, c1, c2, c3
@@ -367,9 +457,9 @@ def apply_mask(
 
 
 @triton.jit
-def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr):
-    r = uint32_r
-    p = uint8_p
+def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr):
+    r = r_u32
+    p = p_u8
     m0 = tl.where(r & 0xFF < p, 0, 1)
     r >>= 8
     m1 = tl.where(r & 0xFF < p, 0, 1)
@@ -397,12 +487,15 @@ def make_4x_dropout_mask(uint32_r, uint8_p, M: tl.constexpr, N: tl.constexpr):
 )
 def apply_dropout(
     P,
-    row_start,
-    col_start,
+    sor,
+    soc,
     philox_seed,
     philox_offset,
     p_dropout_uint8: tl.constexpr,
     encode_dropout_in_sign_bit: tl.constexpr,
+    bid: tl.constexpr,
+    hid: tl.constexpr,
+    nheads: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
 ):
@@ -410,12 +503,13 @@ def apply_dropout(
     # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32
     M: tl.constexpr = BLOCK_M // 16
     N: tl.constexpr = BLOCK_N // 32
-    row = row_start + tl.arange(0, M)[:, None]
-    col = col_start + tl.arange(0, BLOCK_N)[None, :] // 32
+    row = sor + tl.arange(0, M)[:, None]
+    col = soc + tl.arange(0, BLOCK_N)[None, :] // 32
 
-    philox_offset = philox_offset + tl.arange(0, BLOCK_N)[None, :] % 32
+    tid = tl.arange(0, BLOCK_N)[None, :] % 32
+    philox_offset += (bid * nheads + hid) * 32 + tid
 
-    subsequence = from_lohi(row * 32, col)
+    subsequence = u64_from_lohi(row * 32, col)
     r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset)
 
     # Fully unrolled due to triton's inability to concat 2d tensor
@@ -432,7 +526,7 @@ def apply_dropout(
     return P
 
 
-@triton.jit
+@triton.jit(do_not_specialize=[])
 def flash_fwd_kernel(
     pQ,
     pK,
@@ -443,70 +537,17 @@ def flash_fwd_kernel(
     philox_seed,
     philox_offset,
     pdrop_int8,
-    drop: tl.constexpr,
-    causal: tl.constexpr,
+    is_dropout: tl.constexpr,
+    is_causal: tl.constexpr,
     scale: tl.constexp,
-    wl: tl.constexpr,
-    wr: tl.constexpr,
+    ws_left: tl.constexpr,
+    ws_right: tl.constexpr,
     return_P: tl.constexpr,
+    is_local: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    num_warps: tl.constexpr,
 ):
-    pass
-
-
-def flash_attention_forward(
-    query,
-    key,
-    value,
-    cum_seq_q,
-    cum_seq_k,
-    max_q,
-    max_k,
-    dropout_p,
-    is_causal,
-    return_debug_mask,
-    *,
-    scale=None,
-    window_size_left=None,
-    window_size_right=None,
-    seqused_k=None,
-    alibi_slopes=None
-):
-    logging.debug("GEMS FLASH_ATTENTION")
-    assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet."
-
-    HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
-    HEAD_DIM_V = value.shape[-1]
-    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
-    assert HEAD_DIM_K in {16, 32, 64, 128, 256}
-
-    softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5)
-    non_null_window_left = window_size_left or -1
-    non_null_window_right = window_size_right or -1
-    out = torch.empty_like(query, dtype=value.dtype)
-
-    mha_out = mha_fwd(
-        query,
-        key,
-        value,
-        out,
-        alibi_slopes,
-        dropout_p,
-        softmax_scale,
-        is_causal,
-        non_null_window_left,
-        non_null_window_right,
-        return_debug_mask,
-    )
-    (
-        output,
-        q_padded,
-        k_padded,
-        v_padded,
-        logsumexp,
-        philox_seed,
-        philox_offset,
-        debug_attn_mask,
-    ) = mha_out
 
 
 def mha_fwd(
@@ -562,113 +603,125 @@ def mha_fwd(
     seqlen_q_rounded = round_multiple(seqlen_q, 128)
     seqlen_k_rounded = round_multiple(seqlen_k, 128)
 
-    lse = torch.empty(
-        (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
-    )
-    if return_softmax:
-        assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0"
-        p = torch.empty(
-            (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
-            dtype=q_dtype,
-            device=q_device,
-        )
-
-    if p_dropout > 0:
-        increment = triton.cdiv(batch_size * num_heads * 32)
-        philox_seed, philox_offset = update_philox_state(increment)
+    with torch_device_fn.device(q_device):
+        # Set softmax params
+        lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float)
+        if return_softmax:
+            assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0"
+            p = torch.empty(
+                (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
+                dtype=q_dtype,
+            )
 
-    with torch_device_fn.device(q.device):
+        # Set dropout params
+        if p_dropout > 0:
+            increment = triton.cdiv(batch_size * num_heads * 32)
+            philox_seed, philox_offset = update_philox_state(increment)
+            is_dropout = True
+        else:
+            is_dropout = False
+
+        p_dropout = 1 - p_dropout
+        pdrop_u8 = math.floor(p_dropout * 255.0)
+
+        # Set alibi params
+        if alibi_slopes is not None:
+            assert alibi_slopes.device == q_device
+            assert alibi_slopes.dtype in (torch.float, )
+            assert alibi_slopes.stride(-1) == 1
+            assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads)
+            alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
+        
+        # Set SWA params
+        is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
+
+        # ONLY EVEN_K IS SUPPORTED
+        assert head_size == head_size_rounded
+        
         grid = lambda args: (
-            triton.cdiv(seqlen_q, args["BLOCK_M"]),
-            batch_size * num_heads,
-            1,
+            triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
+            batch_size,
+            num_heads,
         )
 
+        flash_fwd_kernel[grid](
+            q,
+            k,
+            v,
+            p,
+            out,
+            alibi_slopes,
+            philox_seed,
+            philox_offset,
+            pdrop_u8,
+            is_dropout,
+            is_causal,
+            softmax_scale,
+            window_size_left,
+            window_size_right,
+            return_softmax,
+            is_local,
+            BLOCK_M,
+            BLOCK_N,
+            num_warps
+        )
+        
 
-def scaled_dot_product_attention(
+
+def flash_attention_forward(
     query,
     key,
     value,
-    attn_mask=None,
-    dropout_p=0.0,
-    is_causal=False,
+    cum_seq_q,
+    cum_seq_k,
+    max_q,
+    max_k,
+    dropout_p,
+    is_causal,
+    return_debug_mask,
+    *,
     scale=None,
-    enable_gqa=False,
+    window_size_left=None,
+    window_size_right=None,
+    seqused_k=None,
+    alibi_slopes=None
 ):
-    logging.debug("GEMS SCALED DOT PRODUCT ATTENTION")
-    # shape constraints
+    logging.debug("GEMS FLASH_ATTENTION")
+    assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet."
+
     HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
-    # when v is in float8_e5m2 it is transposed.
     HEAD_DIM_V = value.shape[-1]
     assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
     assert HEAD_DIM_K in {16, 32, 64, 128, 256}
-    assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
-
-    o = torch.empty_like(query, dtype=value.dtype)
-
-    stage = 3 if is_causal else 1
-
-    if scale is None:
-        sm_scale = 1.0 / (HEAD_DIM_K**0.5)
-    else:
-        sm_scale = scale
 
-    kv_head_num = key.shape[1]
+    softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5)
+    non_null_window_left = window_size_left or -1
+    non_null_window_right = window_size_right or -1
+    out = torch.empty_like(query, dtype=value.dtype)
 
-    grid = lambda args: (
-        triton.cdiv(query.shape[2], args["BLOCK_M"]),
-        query.shape[0] * query.shape[1],
-        1,
+    mha_out = mha_fwd(
+        query,
+        key,
+        value,
+        out,
+        alibi_slopes,
+        dropout_p,
+        softmax_scale,
+        is_causal,
+        non_null_window_left,
+        non_null_window_right,
+        return_debug_mask,
     )
+    (
+        output,
+        q_padded,
+        k_padded,
+        v_padded,
+        logsumexp,
+        philox_seed,
+        philox_offset,
+        debug_attn_mask,
+    ) = mha_out
+
 
-    if attn_mask is not None:
-        HAS_ATTN_MASK = True
-        stride_attn_mask_batch = attn_mask.stride(0)
-        stride_attn_mask_head = attn_mask.stride(1)
-        stride_attn_mask_q_seqlen = attn_mask.stride(2)
-        stride_attn_mask_kv_seqlen = attn_mask.stride(3)
-    else:
-        HAS_ATTN_MASK = False
-        stride_attn_mask_batch = 1
-        stride_attn_mask_head = 1
-        stride_attn_mask_q_seqlen = 1
-        stride_attn_mask_kv_seqlen = 1
 
-    with torch_device_fn.device(query.device):
-        _attn_fwd[grid](
-            query,
-            key,
-            value,
-            attn_mask,
-            sm_scale,
-            o,  #
-            query.stride(0),
-            query.stride(1),
-            query.stride(2),
-            query.stride(3),  #
-            key.stride(0),
-            key.stride(1),
-            key.stride(2),
-            key.stride(3),  #
-            value.stride(0),
-            value.stride(1),
-            value.stride(2),
-            value.stride(3),  #
-            stride_attn_mask_batch,
-            stride_attn_mask_head,
-            stride_attn_mask_q_seqlen,
-            stride_attn_mask_kv_seqlen,  #
-            o.stride(0),
-            o.stride(1),
-            o.stride(2),
-            o.stride(3),  #
-            query.shape[0],
-            query.shape[1],
-            kv_head_num,  #
-            query.shape[2],  #
-            key.shape[2],  #
-            HEAD_DIM_K,  #
-            STAGE=stage,  #
-            HAS_ATTN_MASK=HAS_ATTN_MASK,  #
-        )
-        return o

From b041f6b0edb0264f8a230abcca8ad6b880cbed6f Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Mon, 3 Mar 2025 09:08:53 +0000
Subject: [PATCH 03/25] update masking.

---
 src/flag_gems/ops/attention.py | 88 +++++++++++++++++++++++++++++-----
 1 file changed, 77 insertions(+), 11 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 1804ddcf2..364caa832 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -444,7 +444,7 @@ def philox_(seed, subsequence, offset):
 
 
 @triton.jit
-def apply_mask(
+def apply_dropout_mask(
     P,
     mask,
     encode_dropout_in_sign_bit: tl.constexpr,
@@ -522,10 +522,43 @@ def apply_dropout(
     m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N)
 
     m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N)
-    P = apply_mask(P, m)
+    P = apply_dropout_mask(P, m)
     return P
 
 
+@triton.jit
+def apply_mask(
+    P,
+    col_idx,
+    row_idx,
+    warp_row_stride,
+    max_seqlen_q,
+    max_seqlen_k,
+    ws_left,
+    ws_right,
+    alibi_slope,
+    is_causal: tl.constexpr,
+    is_local: tl.constexpr,
+    has_alibi: tl.constexpr,
+):
+    if has_alibi or is_causal or is_local:
+        col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
+        col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
+
+        if not has_alibi:
+            alibi_slope = .0
+            
+        P -= alibi_slope * tl.abs(col_idx - row_idx)
+        
+        if is_causal:
+            P = tl.where(col_idx >= col_rb, float('-inf'), P)
+        
+        if is_local:
+            P = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), P)
+
+    return P
+    
+
 @triton.jit(do_not_specialize=[])
 def flash_fwd_kernel(
     pQ,
@@ -533,21 +566,46 @@ def flash_fwd_kernel(
     pV,
     pP,
     pO,
+    seqlen_q,
+    seqlen_k,
     pSlopes,
     philox_seed,
     philox_offset,
     pdrop_int8,
+    slopes_batch_stride,
     is_dropout: tl.constexpr,
     is_causal: tl.constexpr,
+    is_local: tl.constexpr,
+    has_alibi: tl.constexpr,
     scale: tl.constexp,
     ws_left: tl.constexpr,
     ws_right: tl.constexpr,
     return_P: tl.constexpr,
-    is_local: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
     num_warps: tl.constexpr,
 ):
+    m_block = tl.program_id(0)
+    bid = tl.program_id(1)
+    hid = tl.program_id(2)
+
+    if is_local:
+        n_block_min: tl.constexpr = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
+    else:
+        n_block_min: tl.constexpr = 0 
+
+    n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
+    
+    if is_causal or is_local:
+        n_block_max = min(n_block_max,
+                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, BLOCK_N))
+
+    if has_alibi:
+        alibi_offset = bid * slopes_batch_stride + hid
+        alibi_slope = tl.load(pSlopes + alibi_offset)
+        alibi_slope /= scale
+
+    
 
 
 def mha_fwd(
@@ -631,7 +689,11 @@ def mha_fwd(
             assert alibi_slopes.stride(-1) == 1
             assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads)
             alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
-        
+            has_alibi = True
+        else:
+            alibi_slopes_batch_stride = 0
+            has_alibi = False
+
         # Set SWA params
         is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
 
@@ -650,17 +712,21 @@ def mha_fwd(
             v,
             p,
             out,
+            seqlen_q,
+            seqlen_k,
             alibi_slopes,
             philox_seed,
             philox_offset,
             pdrop_u8,
-            is_dropout,
-            is_causal,
-            softmax_scale,
-            window_size_left,
-            window_size_right,
-            return_softmax,
-            is_local,
+            alibi_slopes_batch_stride,
+            is_dropout=is_dropout,
+            is_causal=is_causal,
+            is_local=is_local,
+            has_alibi=has_alibi,
+            scale=softmax_scale,
+            ws_left=window_size_left,
+            ws_right=window_size_right,
+            return_P=return_softmax,
             BLOCK_M,
             BLOCK_N,
             num_warps

From 7ab8dab042203660393ce41ca1c542ce269bebff Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Wed, 5 Mar 2025 01:31:14 +0000
Subject: [PATCH 04/25] done all masking steps.

---
 src/flag_gems/ops/attention.py | 192 +++++++++++++++++++++++++++++----
 1 file changed, 173 insertions(+), 19 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 364caa832..28b8bad26 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -395,6 +395,7 @@ def scaled_dot_product_attention(
         return o
 
 
+# The following implementation is a fundamentally a triton rewrite of TriDao's Flash Attention in Cuda.
 
 @triton.jit
 def philox_offset_one_warp(b, h, nh: tl.constexpr):
@@ -489,13 +490,13 @@ def apply_dropout(
     P,
     sor,
     soc,
+    bid,
+    hid,
     philox_seed,
     philox_offset,
     p_dropout_uint8: tl.constexpr,
     encode_dropout_in_sign_bit: tl.constexpr,
-    bid: tl.constexpr,
-    hid: tl.constexpr,
-    nheads: tl.constexpr,
+    NUM_HEADS: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
 ):
@@ -507,7 +508,7 @@ def apply_dropout(
     col = soc + tl.arange(0, BLOCK_N)[None, :] // 32
 
     tid = tl.arange(0, BLOCK_N)[None, :] % 32
-    philox_offset += (bid * nheads + hid) * 32 + tid
+    philox_offset += (bid * NUM_HEADS + hid) * 32 + tid
 
     subsequence = u64_from_lohi(row * 32, col)
     r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset)
@@ -526,12 +527,11 @@ def apply_dropout(
     return P
 
 
-@triton.jit
+@triton.jit(do_not_specialize=['max_seqlen_q', 'max_seqlen_k'])
 def apply_mask(
-    P,
+    S,
     col_idx,
     row_idx,
-    warp_row_stride,
     max_seqlen_q,
     max_seqlen_k,
     ws_left,
@@ -547,17 +547,36 @@ def apply_mask(
 
         if not has_alibi:
             alibi_slope = .0
-            
-        P -= alibi_slope * tl.abs(col_idx - row_idx)
-        
+
+        S -= alibi_slope * tl.abs(col_idx - row_idx)
+
         if is_causal:
-            P = tl.where(col_idx >= col_rb, float('-inf'), P)
-        
+            S = tl.where(col_idx >= col_rb, float('-inf'), S)
+
         if is_local:
-            P = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), P)
+            S = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), S)
+
+    return S
+
+
+@triton.jit
+def softmax_rescale(
+    O_acc,
+    S,
+    row_max,
+    row_sum,
+    softmax_scale_log2: tl.constexpr
+):
+    prev_row_max = row_max
+    row_max = tl.maximum(row_max, tl.max(S, 1))
+    row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2
+    row_sum *= row_sum_scale
+    O_acc *= row_sum_scale[:, None]
+    max_scaled = tl.where(rowmax == float('-inf'), 0, rowmax * softmax_scale_log2)
+    P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None])
+    row_sum = row_sum + tl.sum(exp_S, 1)
+    return O_acc, P, row_max, row_sum
 
-    return P
-    
 
 @triton.jit(do_not_specialize=[])
 def flash_fwd_kernel(
@@ -568,19 +587,35 @@ def flash_fwd_kernel(
     pO,
     seqlen_q,
     seqlen_k,
+    seqlen_q_rounded,
+    seqlen_k_rounded,
+    q_b_stride,
+    q_s_stride,
+    q_h_stride,
+    k_b_stride,
+    k_s_stride,
+    k_h_stride,
+    o_b_stride,
+    o_s_stride,
+    o_h_stride,
     pSlopes,
     philox_seed,
     philox_offset,
-    pdrop_int8,
+    pdrop_u8,
     slopes_batch_stride,
     is_dropout: tl.constexpr,
     is_causal: tl.constexpr,
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
-    scale: tl.constexp,
+    softmax_scale: tl.constexp,
+    softmax_scale_log2: tl.constexpr,
     ws_left: tl.constexpr,
     ws_right: tl.constexpr,
     return_P: tl.constexpr,
+    BATCH_SIZE: tl.constexpr,
+    NUM_HEADS: tl.constexpr,
+    NUM_HEADS_K: tl.constexpr,
+    HEADSIZE: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
     num_warps: tl.constexpr,
@@ -605,7 +640,107 @@ def flash_fwd_kernel(
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
 
-    
+    if (not is_causal) and (not is_local):
+        n_masking_steps = 1 
+    elif is_causal:
+        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
+    else:
+        # local and not causal, 
+        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+
+
+    Q_ptr += bid * q_b_stride
+    Q_ptr += hid * q_h_stride
+    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEADSIZE)[None, :]
+    qmask = row_idx < seqlen_q
+    Q = tl.load(pQ + Q_off, mask=qmask)
+
+    # Start from the right most block
+    n_block = n_block_max - 1
+
+    h_hk_ratio = h // hk
+    K_ptr += bid * k_b_stride
+    K_ptr += (hid // h_hk_ratio) * k_h_stride
+    V_ptr += bid * k_b_stride
+    V_ptr += (hid // h_hk_ratio) * k_h_stride
+    col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+    KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
+
+    P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
+    P_ptr += n_block * BLOCK_N
+    P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
+
+    O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
+    for _ in range(n_masking_steps):
+        kvmask = col_idx < seqlen_k
+        K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg")        
+        V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
+        S = tl.dot(Q_block, K_block)
+        KV_offset += BLOCK_N * k_s_stride
+        S = apply_mask(
+            S,
+            col_idx,
+            row_idx,
+            seqlen_q,
+            seqlen_k,
+            ws_left,
+            ws_right,
+            alibi_slope,
+            is_causal,
+            is_local,
+            has_alibi
+        )
+
+        O_acc, P, rowmax_, rowsum_ = softmax_rescale(O_, S, rowmax_, rowsum_, softmax_scale_log2)
+        P = P.to(pO.type.element_ty)
+
+        row_start = m_block * (BLOCK_M // 16)
+        col_start = n_block * (BLOCK_N // 32)
+        if return_P:
+            P_drop = P
+            P_drop = apply_dropout(
+                P_drop,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=True,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+            tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :])
+            P_offset += BLOCK_N
+
+        if is_dropout:
+            P = apply_dropout(
+                P,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=False,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+
+        tl.dot(P, V, acc=O_)
+
+        if n_masking_steps > 1 and n_block <= n_block_min:
+            n_block -= 1
+            break
+
+
 
 
 def mha_fwd(
@@ -700,6 +835,9 @@ def mha_fwd(
         # ONLY EVEN_K IS SUPPORTED
         assert head_size == head_size_rounded
         
+        M_LOG2E	= 1.4426950408889634074
+        scale_softmax_log2 = softmax_scale * M_LOG2E
+
         grid = lambda args: (
             triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
             batch_size,
@@ -714,6 +852,17 @@ def mha_fwd(
             out,
             seqlen_q,
             seqlen_k,
+            seqlen_q_rounded,
+            seqlen_k_rounded,
+            q.stride(0),
+            q.stride(-3),
+            q.stride(-2),
+            k.stride(0),
+            k.stride(-3),
+            k.stride(-2),
+            out.stride(0),
+            out.stride(-3),
+            out.stride(-2),
             alibi_slopes,
             philox_seed,
             philox_offset,
@@ -723,10 +872,15 @@ def mha_fwd(
             is_causal=is_causal,
             is_local=is_local,
             has_alibi=has_alibi,
-            scale=softmax_scale,
+            softmax_scale=softmax_scale,
+            softmax_scale_log2=scale_softmax_log2
             ws_left=window_size_left,
             ws_right=window_size_right,
             return_P=return_softmax,
+            BATCH_SIZE=batch_size,
+            NUM_HEADS=num_heads,
+            NUM_HEADS_K=num_heads_k,
+            HEADSIZE=head_size_rounded,
             BLOCK_M,
             BLOCK_N,
             num_warps

From 4909d10e058657a639d25c6c98f5c1bbb60d2a77 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Wed, 5 Mar 2025 10:12:25 +0000
Subject: [PATCH 05/25] fwd kernel almost done.

---
 src/flag_gems/ops/attention.py | 135 ++++++++++++++++++++++++++++-----
 1 file changed, 117 insertions(+), 18 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 28b8bad26..780ff12de 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -537,11 +537,12 @@ def apply_mask(
     ws_left,
     ws_right,
     alibi_slope,
+    is_even_mn: tl.constexpr,
     is_causal: tl.constexpr,
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
 ):
-    if has_alibi or is_causal or is_local:
+    if has_alibi or is_causal or is_local or not is_even_mn:
         col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
         col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
 
@@ -551,10 +552,13 @@ def apply_mask(
         S -= alibi_slope * tl.abs(col_idx - row_idx)
 
         if is_causal:
-            S = tl.where(col_idx >= col_rb, float('-inf'), S)
+            S = tl.where(col_idx[None, :] >= col_rb[None, :], float('-inf'), S)
 
         if is_local:
-            S = tl.where(col_idx >= col_rb | col_idx < col_lb, float('-inf'), S)
+            S = tl.where(col_idx[None, :] >= col_rb[None, :] | col_idx[None, :] < col_lb[None, :], float('-inf'), S)
+        
+        if (not local) and (not is_causal) and (not is_even_mn):
+            S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
 
     return S
 
@@ -565,7 +569,8 @@ def softmax_rescale(
     S,
     row_max,
     row_sum,
-    softmax_scale_log2: tl.constexpr
+    softmax_scale_log2: tl.constexpr,
+    is_border: tl.constexpr
 ):
     prev_row_max = row_max
     row_max = tl.maximum(row_max, tl.max(S, 1))
@@ -578,6 +583,11 @@ def softmax_rescale(
     return O_acc, P, row_max, row_sum
 
 
+@triton.heuristics(
+    values={
+        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
+    }
+)
 @triton.jit(do_not_specialize=[])
 def flash_fwd_kernel(
     pQ,
@@ -616,9 +626,11 @@ def flash_fwd_kernel(
     NUM_HEADS: tl.constexpr,
     NUM_HEADS_K: tl.constexpr,
     HEADSIZE: tl.constexpr,
+    IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
     num_warps: tl.constexpr,
+    num_stages: tl.constexpr
 ):
     m_block = tl.program_id(0)
     bid = tl.program_id(1)
@@ -641,8 +653,11 @@ def flash_fwd_kernel(
         alibi_slope /= scale
 
     if (not is_causal) and (not is_local):
-        n_masking_steps = 1 
-    elif is_causal:
+        if IS_EVEN_MN:
+            n_masking_steps = 0
+        else:
+            n_masking_steps = 1
+    elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
         n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
     else:
         # local and not causal, 
@@ -665,7 +680,6 @@ def flash_fwd_kernel(
     V_ptr += bid * k_b_stride
     V_ptr += (hid // h_hk_ratio) * k_h_stride
     col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-    KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
 
     P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
     P_ptr += n_block * BLOCK_N
@@ -674,12 +688,17 @@ def flash_fwd_kernel(
     O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
+    
     for _ in range(n_masking_steps):
-        kvmask = col_idx < seqlen_k
-        K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg")        
-        V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
+        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
+        if IS_EVEN_MN:
+            K = tl.load(pK + KV_offset, cache_modifier="cg")
+            V = tl.load(pV + KV_offset, cache_modifier="cg")
+        else:
+            kvmask = col_idx < seqlen_k
+            K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg")
+            V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
         S = tl.dot(Q_block, K_block)
-        KV_offset += BLOCK_N * k_s_stride
         S = apply_mask(
             S,
             col_idx,
@@ -689,12 +708,21 @@ def flash_fwd_kernel(
             ws_left,
             ws_right,
             alibi_slope,
-            is_causal,
-            is_local,
-            has_alibi
+            is_even_mn=IS_EVEN_MN,
+            is_causal=is_causal,
+            is_local=is_local,
+            has_alibi=has_alibi
         )
+        col_idx -= BLOCK_N
 
-        O_acc, P, rowmax_, rowsum_ = softmax_rescale(O_, S, rowmax_, rowsum_, softmax_scale_log2)
+        O_acc, P, rowmax_, rowsum_ = softmax_rescale(
+            O_,
+            S,
+            rowmax_,
+            rowsum_,
+            softmax_scale_log2=softmax_scale_log2,
+            is_border=(is_causal or is_local)
+        )
         P = P.to(pO.type.element_ty)
 
         row_start = m_block * (BLOCK_M // 16)
@@ -736,11 +764,81 @@ def flash_fwd_kernel(
 
         tl.dot(P, V, acc=O_)
 
+        n_block -= 1
         if n_masking_steps > 1 and n_block <= n_block_min:
-            n_block -= 1
             break
 
+    for _ in range(n_block_min, n_block + 1):
+        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
+        K = tl.load(pK + KV_offset, cache_modifier="cg")        
+        V = tl.load(pV + KV_offset, cache_modifier="cg")
+        S = tl.dot(Q_block, K_block)
+        S = apply_mask(
+            S,
+            col_idx,
+            row_idx,
+            seqlen_q,
+            seqlen_k,
+            ws_left,
+            ws_right,
+            alibi_slope,
+            is_even_mn=True,
+            is_causal=False,
+            is_local=is_local,
+            has_alibi=has_alibi
+        )
+        col_idx -= BLOCK_N
 
+        O_acc, P, rowmax_, rowsum_ = softmax_rescale(
+            O_,
+            S,
+            rowmax_,
+            rowsum_,
+            softmax_scale_log2=softmax_scale_log2,
+            is_border=is_local
+        )
+        P = P.to(pO.type.element_ty)
+
+        row_start = m_block * (BLOCK_M // 16)
+        col_start = n_block * (BLOCK_N // 32)
+        if return_P:
+            P_drop = P
+            P_drop = apply_dropout(
+                P_drop,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=True,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+            tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :])
+            P_offset += BLOCK_N
+
+        if is_dropout:
+            P = apply_dropout(
+                P,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=False,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+
+        tl.dot(P, V, acc=O_)
+        
+        n_block -=1
 
 
 def mha_fwd(
@@ -881,8 +979,9 @@ def mha_fwd(
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
             HEADSIZE=head_size_rounded,
-            BLOCK_M,
-            BLOCK_N,
+            IS_EVEN_MN=is_even_mn,
+            BLOCK_M=BLOCK_M,
+            BLOCK_N=BLOCK_N,
             num_warps
         )
         

From 14e5d4c0eab25ed126d2ca44b5613bb0aa92b57d Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 6 Mar 2025 04:46:51 +0000
Subject: [PATCH 06/25] fwd kernel done.

---
 src/flag_gems/ops/attention.py | 134 +++++++++++++++++++++++----------
 1 file changed, 96 insertions(+), 38 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 780ff12de..edd3cc31b 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -583,6 +583,19 @@ def softmax_rescale(
     return O_acc, P, row_max, row_sum
 
 
+@triton.jit
+def 
+
+
+@triton.autotune(
+    configs=runtime.get_tuned_config("attention"),
+    key=["HEAD_DIM"],
+    prune_configs_by={
+        "early_config_prune": early_config_prune,
+        "perf_model": None,
+        "top_k": 1.0,
+    },
+)
 @triton.heuristics(
     values={
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
@@ -590,11 +603,12 @@ def softmax_rescale(
 )
 @triton.jit(do_not_specialize=[])
 def flash_fwd_kernel(
-    pQ,
-    pK,
-    pV,
-    pP,
-    pO,
+    Q_ptr,
+    K_ptr,
+    V_ptr,
+    P_ptr,
+    O_ptr,
+    lse_ptr,
     seqlen_q,
     seqlen_k,
     seqlen_q_rounded,
@@ -612,6 +626,7 @@ def flash_fwd_kernel(
     philox_seed,
     philox_offset,
     pdrop_u8,
+    rpdrop,
     slopes_batch_stride,
     is_dropout: tl.constexpr,
     is_causal: tl.constexpr,
@@ -622,10 +637,11 @@ def flash_fwd_kernel(
     ws_left: tl.constexpr,
     ws_right: tl.constexpr,
     return_P: tl.constexpr,
+    PRE_LOAD_V: tl.constexpr,
     BATCH_SIZE: tl.constexpr,
     NUM_HEADS: tl.constexpr,
     NUM_HEADS_K: tl.constexpr,
-    HEADSIZE: tl.constexpr,
+    HEAD_DIM: tl.constexpr,
     IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
@@ -690,14 +706,16 @@ def flash_fwd_kernel(
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
     
     for _ in range(n_masking_steps):
-        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
+        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         if IS_EVEN_MN:
             K = tl.load(pK + KV_offset, cache_modifier="cg")
-            V = tl.load(pV + KV_offset, cache_modifier="cg")
+            if PRE_LOAD_V:
+                V = tl.load(pV + KV_offset, cache_modifier="cg")
         else:
             kvmask = col_idx < seqlen_k
             K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg")
-            V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
+            if PRE_LOAD_V:
+                V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
         S = tl.dot(Q_block, K_block)
         S = apply_mask(
             S,
@@ -762,14 +780,19 @@ def flash_fwd_kernel(
                 BLOCK_N=BLOCK_N,
             )
 
+        if not PRE_LOAD_V:
+            if IS_EVEN_MN:
+                V = tl.load(pV + KV_offset, cache_modifier="cg")
+            else:
+                V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
         tl.dot(P, V, acc=O_)
 
         n_block -= 1
         if n_masking_steps > 1 and n_block <= n_block_min:
             break
 
-    for _ in range(n_block_min, n_block + 1):
-        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEADSIZE)[:, None]
+    for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages):
+        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         K = tl.load(pK + KV_offset, cache_modifier="cg")        
         V = tl.load(pV + KV_offset, cache_modifier="cg")
         S = tl.dot(Q_block, K_block)
@@ -837,8 +860,34 @@ def flash_fwd_kernel(
             )
 
         tl.dot(P, V, acc=O_)
-        
-        n_block -=1
+
+    # Final LSE
+    lse = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, float('inf'), rowmax_ * softmax_scale +tl.log(rowsum_))
+    inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_)
+    
+    # Rescale output
+    if is_dropout
+        O_ *= inv_sum * rpdrop
+    else:
+        O_ *= inv_sum
+    
+    O = O_.to(pO.type.element_ty)
+
+    # Write back output
+    O_ptr += bid * o_b_stride
+    O_ptr += hid * o_h_stride
+    O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM)
+    if IS_EVEN_MN:
+        tl.store(O_ptr + O_offset, O)
+    else:
+        tl.store(O_ptr + O_offset, O, mask=qmask)
+    
+    # Write back lse
+    lse_ptr += bid * hid * seqlen_q
+    if IS_EVEN_MN:
+        tl.store(lse_ptr + row_idx, lse)
+    else:
+        tl.store(lse_ptr + row_idx, lse, mask=qmask)
 
 
 def mha_fwd(
@@ -882,6 +931,17 @@ def mha_fwd(
     if is_causal:
         window_size_right = 0
 
+    if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes
+        swap_seq_and_group = True
+    else:
+        swap_seq_and_group = False
+
+    ngroups = num_heads // num_heads_k
+    if swap_seq_and_group:
+        q = q.reshape((batch_size, num_heads_k, ngroups, head_size)).transpose(1, 2)
+        seqlen_q = ngroups
+        num_heads = num_heads_k
+
     if out:
         assert out.stride(-1) == 1
         assert out.dtype == q.dtype
@@ -903,6 +963,8 @@ def mha_fwd(
                 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
                 dtype=q_dtype,
             )
+        else:
+            p = torch.empty(())
 
         # Set dropout params
         if p_dropout > 0:
@@ -914,6 +976,11 @@ def mha_fwd(
 
         p_dropout = 1 - p_dropout
         pdrop_u8 = math.floor(p_dropout * 255.0)
+        rpdrop = 1. / p_dropout
+
+        M_LOG2E	= 1.4426950408889634074
+        scale_softmax_log2 = softmax_scale * M_LOG2E
+        scale_softmax_rp_dropout = rpdrop * softmax_scale
 
         # Set alibi params
         if alibi_slopes is not None:
@@ -932,9 +999,7 @@ def mha_fwd(
 
         # ONLY EVEN_K IS SUPPORTED
         assert head_size == head_size_rounded
-        
-        M_LOG2E	= 1.4426950408889634074
-        scale_softmax_log2 = softmax_scale * M_LOG2E
+
 
         grid = lambda args: (
             triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
@@ -948,6 +1013,7 @@ def mha_fwd(
             v,
             p,
             out,
+            lse,
             seqlen_q,
             seqlen_k,
             seqlen_q_rounded,
@@ -965,26 +1031,30 @@ def mha_fwd(
             philox_seed,
             philox_offset,
             pdrop_u8,
+            rpdrop,
             alibi_slopes_batch_stride,
             is_dropout=is_dropout,
             is_causal=is_causal,
             is_local=is_local,
             has_alibi=has_alibi,
             softmax_scale=softmax_scale,
-            softmax_scale_log2=scale_softmax_log2
+            softmax_scale_log2=scale_softmax_log2,
             ws_left=window_size_left,
             ws_right=window_size_right,
             return_P=return_softmax,
             BATCH_SIZE=batch_size,
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
-            HEADSIZE=head_size_rounded,
+            HEAD_DIM=head_size_rounded,
             IS_EVEN_MN=is_even_mn,
-            BLOCK_M=BLOCK_M,
-            BLOCK_N=BLOCK_N,
-            num_warps
         )
-        
+    
+    if swap_seq_and_group:
+        out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))
+        q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))
+        lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
+
+    return out, q, k, v, lse, philox_seed, philox_offset, p
 
 
 def flash_attention_forward(
@@ -1016,13 +1086,12 @@ def flash_attention_forward(
     softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5)
     non_null_window_left = window_size_left or -1
     non_null_window_right = window_size_right or -1
-    out = torch.empty_like(query, dtype=value.dtype)
 
-    mha_out = mha_fwd(
+    out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
         query,
         key,
         value,
-        out,
+        None,
         alibi_slopes,
         dropout_p,
         softmax_scale,
@@ -1031,16 +1100,5 @@ def flash_attention_forward(
         non_null_window_right,
         return_debug_mask,
     )
-    (
-        output,
-        q_padded,
-        k_padded,
-        v_padded,
-        logsumexp,
-        philox_seed,
-        philox_offset,
-        debug_attn_mask,
-    ) = mha_out
-
-
-
+    
+    return (out, lse, philox_seed, philox_offset, p)

From 62c92f2faa3a4bb2c6f4ce156acd019b84400956 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 6 Mar 2025 05:05:45 +0000
Subject: [PATCH 07/25] fix syntax errors.

---
 src/flag_gems/ops/attention.py                         | 10 +++-------
 src/flag_gems/ops/exponential_.py                      |  4 ++--
 src/flag_gems/ops/multinomial.py                       |  4 ++--
 src/flag_gems/ops/normal.py                            |  4 ++--
 src/flag_gems/ops/rand.py                              |  4 ++--
 src/flag_gems/ops/rand_like.py                         |  4 ++--
 src/flag_gems/ops/randn.py                             |  4 ++--
 src/flag_gems/ops/randn_like.py                        |  4 ++--
 src/flag_gems/ops/randperm.py                          |  4 ++--
 src/flag_gems/ops/uniform.py                           |  4 ++--
 .../runtime/backend/_metax/ops/exponential_.py         |  4 ++--
 11 files changed, 23 insertions(+), 27 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index edd3cc31b..9c13485d6 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -583,10 +583,6 @@ def softmax_rescale(
     return O_acc, P, row_max, row_sum
 
 
-@triton.jit
-def 
-
-
 @triton.autotune(
     configs=runtime.get_tuned_config("attention"),
     key=["HEAD_DIM"],
@@ -632,7 +628,7 @@ def flash_fwd_kernel(
     is_causal: tl.constexpr,
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
-    softmax_scale: tl.constexp,
+    softmax_scale: tl.constexpr,
     softmax_scale_log2: tl.constexpr,
     ws_left: tl.constexpr,
     ws_right: tl.constexpr,
@@ -866,7 +862,7 @@ def flash_fwd_kernel(
     inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_)
     
     # Rescale output
-    if is_dropout
+    if is_dropout:
         O_ *= inv_sum * rpdrop
     else:
         O_ *= inv_sum
@@ -931,7 +927,7 @@ def mha_fwd(
     if is_causal:
         window_size_right = 0
 
-    if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes
+    if seqlen_q == 1 and num_heads > num_heads_k and window_size_left < 0 and window_size_right < 0 and p_dropout == 0 and not alibi_slopes:
         swap_seq_and_group = True
     else:
         swap_seq_and_group = False
diff --git a/src/flag_gems/ops/exponential_.py b/src/flag_gems/ops/exponential_.py
index 8d36ce7eb..0ab5796ee 100644
--- a/src/flag_gems/ops/exponential_.py
+++ b/src/flag_gems/ops/exponential_.py
@@ -5,7 +5,7 @@
 import triton.language as tl
 
 from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
+    update_philox_state,
     uint_to_uniform_float,
 )
 
@@ -94,7 +94,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None):
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     eps = torch.finfo(dtype).eps
     x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
     with torch_device_fn.device(device):
diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py
index d1f061cda..8824426cd 100644
--- a/src/flag_gems/ops/multinomial.py
+++ b/src/flag_gems/ops/multinomial.py
@@ -5,7 +5,7 @@
 import triton.language as tl
 
 from flag_gems.utils import libentry
-from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform
+from flag_gems.utils.random_utils import update_philox_state, uniform
 
 
 @libentry()
@@ -84,7 +84,7 @@ def multinomial(prob, n_samples, with_replacement=False, *, gen=None):
     # The CTA level parallelism is framed in a 2d grid of blocks with grid.y
     # indexing into distributions and grid.x output sample batches
     increment = n_dist * n_samples
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist)
     multinomial_with_replacement[grid](
         cum_prob, out, n_categories, n_samples, philox_seed, philox_offset
diff --git a/src/flag_gems/ops/normal.py b/src/flag_gems/ops/normal.py
index b24b4398e..d35bbba8d 100644
--- a/src/flag_gems/ops/normal.py
+++ b/src/flag_gems/ops/normal.py
@@ -5,7 +5,7 @@
 
 from ..runtime import torch_device_fn
 from ..utils import pointwise_dynamic
-from ..utils.random_utils import philox_backend_seed_offset
+from ..utils.random_utils import update_philox_state
 from ..utils.shape_utils import broadcast_shapes, volume
 from .randn import randn_kernel
 
@@ -50,7 +50,7 @@ def normal_distribution(shape, device, *, generator=None):
     grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
 
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(device):
         randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
     return out
diff --git a/src/flag_gems/ops/rand.py b/src/flag_gems/ops/rand.py
index 9cab927ac..3dc2a51de 100644
--- a/src/flag_gems/ops/rand.py
+++ b/src/flag_gems/ops/rand.py
@@ -5,7 +5,7 @@
 import triton.language as tl
 
 from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
+    update_philox_state,
     uint_to_uniform_float,
 )
 from flag_gems.utils.shape_utils import volume
@@ -63,7 +63,7 @@ def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(device):
         rand_kernel[grid_fn](out, N, philox_seed, philox_offset)
     return out
diff --git a/src/flag_gems/ops/rand_like.py b/src/flag_gems/ops/rand_like.py
index 85338f595..6b5b9b706 100644
--- a/src/flag_gems/ops/rand_like.py
+++ b/src/flag_gems/ops/rand_like.py
@@ -4,7 +4,7 @@
 import triton
 
 from flag_gems.ops.rand import rand_kernel
-from flag_gems.utils.random_utils import philox_backend_seed_offset
+from flag_gems.utils.random_utils import update_philox_state
 
 from ..runtime import torch_device_fn
 
@@ -25,7 +25,7 @@ def rand_like(
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(x.device):
         rand_kernel[grid_fn](out, N, philox_seed, philox_offset)
     return out
diff --git a/src/flag_gems/ops/randn.py b/src/flag_gems/ops/randn.py
index 3ee2932e5..581247a73 100644
--- a/src/flag_gems/ops/randn.py
+++ b/src/flag_gems/ops/randn.py
@@ -5,7 +5,7 @@
 import triton.language as tl
 
 from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
+    update_philox_state,
     uint_to_uniform_float,
 )
 from flag_gems.utils.shape_utils import volume
@@ -77,7 +77,7 @@ def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(device):
         randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
     return out
diff --git a/src/flag_gems/ops/randn_like.py b/src/flag_gems/ops/randn_like.py
index 0458328dc..b8c87b7df 100644
--- a/src/flag_gems/ops/randn_like.py
+++ b/src/flag_gems/ops/randn_like.py
@@ -4,7 +4,7 @@
 import triton
 
 from flag_gems.ops.randn import randn_kernel
-from flag_gems.utils.random_utils import philox_backend_seed_offset
+from flag_gems.utils.random_utils import update_philox_state
 
 from ..runtime import torch_device_fn
 
@@ -25,7 +25,7 @@ def randn_like(
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(x.device):
         randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
     return out
diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py
index 9757a1adc..1e7a07bba 100644
--- a/src/flag_gems/ops/randperm.py
+++ b/src/flag_gems/ops/randperm.py
@@ -4,7 +4,7 @@
 import triton
 import triton.language as tl
 
-from flag_gems.utils.random_utils import philox_backend_seed_offset
+from flag_gems.utils.random_utils import update_philox_state
 
 from .. import runtime
 from ..runtime import device, torch_device_fn
@@ -384,7 +384,7 @@ def sort_by_key(key, value, valid_bits):
         # last step, shuffle inner-block data
         BLOCK_SIZE_SHUFFLE = 512
         grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),)
-        philox_seed, philox_offset = philox_backend_seed_offset(n_elements)
+        philox_seed, philox_offset = update_philox_state(n_elements)
         with torch_device_fn.device(key.device):
             duplicate_keys_shuffle_kernel[grid_shuffle](
                 v_out,
diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py
index 114a552bf..8b2d1f8f5 100644
--- a/src/flag_gems/ops/uniform.py
+++ b/src/flag_gems/ops/uniform.py
@@ -4,7 +4,7 @@
 import triton.language as tl
 
 from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
+    update_philox_state,
     uint_to_uniform_float,
 )
 from flag_gems.utils.shape_utils import volume
@@ -55,7 +55,7 @@ def uniform_(self, from_=0.0, to=1.0, *, generator=None):
     grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
 
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     with torch_device_fn.device(self.device):
         uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to)
     return self
diff --git a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py
index 76be32499..899b3e49e 100644
--- a/src/flag_gems/runtime/backend/_metax/ops/exponential_.py
+++ b/src/flag_gems/runtime/backend/_metax/ops/exponential_.py
@@ -6,7 +6,7 @@
 
 from flag_gems.runtime import torch_device_fn
 from flag_gems.utils.random_utils import (
-    philox_backend_seed_offset,
+    update_philox_state,
     uint_to_uniform_float,
 )
 
@@ -238,7 +238,7 @@ def exponential_(x, lambd: float = 1.0, *, gen=None):
     # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
     # hence we cannot obtain the per thread offset as in Pytorch.
     increment = triton.cdiv(N, UNROLL)
-    philox_seed, philox_offset = philox_backend_seed_offset(increment)
+    philox_seed, philox_offset = update_philox_state(increment)
     eps = torch.finfo(dtype).eps
     x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
     type_index = lst.index(dtype)

From af83c2f5e4e2550f40d2299955c0cff493648cb2 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 6 Mar 2025 11:55:59 +0000
Subject: [PATCH 08/25] rowmax inf needs to be handled.

---
 src/flag_gems/ops/attention.py | 268 ++++++++++++++++++---------------
 1 file changed, 148 insertions(+), 120 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 9c13485d6..78789162a 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -1,4 +1,5 @@
 import logging
+import math
 
 import torch
 import triton
@@ -542,22 +543,27 @@ def apply_mask(
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
 ):
-    if has_alibi or is_causal or is_local or not is_even_mn:
+    # need_mask = has_alibi or is_causal
+    # need_mask |= is_local
+    # need_mask |= not is_even_mn
+    need_mask: tl.constexpr = has_alibi | is_local | (not is_even_mn)
+    if need_mask:
         col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
         col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
 
         if not has_alibi:
             alibi_slope = .0
 
-        S -= alibi_slope * tl.abs(col_idx - row_idx)
+        S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
 
         if is_causal:
-            S = tl.where(col_idx[None, :] >= col_rb[None, :], float('-inf'), S)
+            S = tl.where(col_idx[None, :] >= col_rb[:, None], float('-inf'), S)
 
         if is_local:
-            S = tl.where(col_idx[None, :] >= col_rb[None, :] | col_idx[None, :] < col_lb[None, :], float('-inf'), S)
+            S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
         
-        if (not local) and (not is_causal) and (not is_even_mn):
+        
+        if (not local) & (not is_causal) & (not is_even_mn):
             S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
 
     return S
@@ -575,11 +581,12 @@ def softmax_rescale(
     prev_row_max = row_max
     row_max = tl.maximum(row_max, tl.max(S, 1))
     row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2
+    tl.device_print('row_sum_scale', row_max - prev_row_max)
     row_sum *= row_sum_scale
     O_acc *= row_sum_scale[:, None]
-    max_scaled = tl.where(rowmax == float('-inf'), 0, rowmax * softmax_scale_log2)
+    max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2)
     P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None])
-    row_sum = row_sum + tl.sum(exp_S, 1)
+    row_sum = row_sum + tl.sum(S, 1)
     return O_acc, P, row_max, row_sum
 
 
@@ -597,7 +604,7 @@ def softmax_rescale(
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
     }
 )
-@triton.jit(do_not_specialize=[])
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
 def flash_fwd_kernel(
     Q_ptr,
     K_ptr,
@@ -618,12 +625,15 @@ def flash_fwd_kernel(
     o_b_stride,
     o_s_stride,
     o_h_stride,
+    h,
+    hk,
     pSlopes,
     philox_seed,
     philox_offset,
     pdrop_u8,
     rpdrop,
     slopes_batch_stride,
+    HEAD_DIM: tl.constexpr,
     is_dropout: tl.constexpr,
     is_causal: tl.constexpr,
     is_local: tl.constexpr,
@@ -637,7 +647,6 @@ def flash_fwd_kernel(
     BATCH_SIZE: tl.constexpr,
     NUM_HEADS: tl.constexpr,
     NUM_HEADS_K: tl.constexpr,
-    HEAD_DIM: tl.constexpr,
     IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
@@ -663,6 +672,8 @@ def flash_fwd_kernel(
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
+    else:
+        alibi_slope = 0.0
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
@@ -679,9 +690,12 @@ def flash_fwd_kernel(
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
     row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
-    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEADSIZE)[None, :]
-    qmask = row_idx < seqlen_q
-    Q = tl.load(pQ + Q_off, mask=qmask)
+    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    qmask = row_idx[:, None] < seqlen_q
+    if IS_EVEN_MN:
+        Q = tl.load(Q_ptr + Q_off)
+    else:
+        Q = tl.load(Q_ptr + Q_off, mask=qmask)
 
     # Start from the right most block
     n_block = n_block_max - 1
@@ -697,101 +711,106 @@ def flash_fwd_kernel(
     P_ptr += n_block * BLOCK_N
     P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
 
-    O_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
     
-    for _ in range(n_masking_steps):
-        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        if IS_EVEN_MN:
-            K = tl.load(pK + KV_offset, cache_modifier="cg")
-            if PRE_LOAD_V:
-                V = tl.load(pV + KV_offset, cache_modifier="cg")
-        else:
-            kvmask = col_idx < seqlen_k
-            K = tl.load(pK + KV_offset, mask=kvmask, cache_modifier="cg")
-            if PRE_LOAD_V:
-                V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
-        S = tl.dot(Q_block, K_block)
-        S = apply_mask(
-            S,
-            col_idx,
-            row_idx,
-            seqlen_q,
-            seqlen_k,
-            ws_left,
-            ws_right,
-            alibi_slope,
-            is_even_mn=IS_EVEN_MN,
-            is_causal=is_causal,
-            is_local=is_local,
-            has_alibi=has_alibi
-        )
-        col_idx -= BLOCK_N
-
-        O_acc, P, rowmax_, rowsum_ = softmax_rescale(
-            O_,
-            S,
-            rowmax_,
-            rowsum_,
-            softmax_scale_log2=softmax_scale_log2,
-            is_border=(is_causal or is_local)
-        )
-        P = P.to(pO.type.element_ty)
-
-        row_start = m_block * (BLOCK_M // 16)
-        col_start = n_block * (BLOCK_N // 32)
-        if return_P:
-            P_drop = P
-            P_drop = apply_dropout(
-                P_drop,
-                row_start,
-                col_start,
-                bid,
-                hid,
-                philox_seed,
-                philox_offset,
-                pdrop_u8,
-                encode_dropout_in_sign_bit=True,
-                NUM_HEADS=NUM_HEADS,
-                BLOCK_M=BLOCK_M,
-                BLOCK_N=BLOCK_N,
-            )
-            tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :])
-            P_offset += BLOCK_N
-
-        if is_dropout:
-            P = apply_dropout(
-                P,
-                row_start,
-                col_start,
-                bid,
-                hid,
-                philox_seed,
-                philox_offset,
-                pdrop_u8,
-                encode_dropout_in_sign_bit=False,
-                NUM_HEADS=NUM_HEADS,
-                BLOCK_M=BLOCK_M,
-                BLOCK_N=BLOCK_N,
-            )
-
-        if not PRE_LOAD_V:
-            if IS_EVEN_MN:
-                V = tl.load(pV + KV_offset, cache_modifier="cg")
-            else:
-                V = tl.load(pV + KV_offset, mask=kvmask, cache_modifier="cg")
-        tl.dot(P, V, acc=O_)
-
-        n_block -= 1
-        if n_masking_steps > 1 and n_block <= n_block_min:
-            break
-
-    for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages):
-        KV_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        K = tl.load(pK + KV_offset, cache_modifier="cg")        
-        V = tl.load(pV + KV_offset, cache_modifier="cg")
-        S = tl.dot(Q_block, K_block)
+    # for _ in range(n_masking_steps):
+    #     K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+    #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    #     if IS_EVEN_MN:
+    #         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+    #         if PRE_LOAD_V:
+    #             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+    #     else:
+    #         kvmask = col_idx[None, :] < seqlen_k
+    #         K = tl.load(K_ptr + K_offset, mask=kvmask, cache_modifier=".cg")
+    #         if PRE_LOAD_V:
+    #             V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg")
+    #     S = tl.dot(Q, K, allow_tf32=False)
+    #     S = apply_mask(
+    #         S,
+    #         col_idx,
+    #         row_idx,
+    #         seqlen_q,
+    #         seqlen_k,
+    #         ws_left,
+    #         ws_right,
+    #         alibi_slope,
+    #         is_even_mn=IS_EVEN_MN,
+    #         is_causal=is_causal,
+    #         is_local=is_local,
+    #         has_alibi=has_alibi
+    #     )
+    #     col_idx -= BLOCK_N
+
+    #     O_acc, P, rowmax_, rowsum_ = softmax_rescale(
+    #         O_,
+    #         S,
+    #         rowmax_,
+    #         rowsum_,
+    #         softmax_scale_log2=softmax_scale_log2,
+    #         is_border=(is_causal or is_local)
+    #     )
+    #     P = P.to(O_ptr.type.element_ty)
+
+    #     row_start = m_block * (BLOCK_M // 16)
+    #     col_start = n_block * (BLOCK_N // 32)
+    #     if return_P:
+    #         P_drop = P
+    #         P_drop = apply_dropout(
+    #             P_drop,
+    #             row_start,
+    #             col_start,
+    #             bid,
+    #             hid,
+    #             philox_seed,
+    #             philox_offset,
+    #             pdrop_u8,
+    #             encode_dropout_in_sign_bit=True,
+    #             NUM_HEADS=NUM_HEADS,
+    #             BLOCK_M=BLOCK_M,
+    #             BLOCK_N=BLOCK_N,
+    #         )
+    #         tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
+    #         P_offset += BLOCK_N
+
+    #     if is_dropout:
+    #         P = apply_dropout(
+    #             P,
+    #             row_start,
+    #             col_start,
+    #             bid,
+    #             hid,
+    #             philox_seed,
+    #             philox_offset,
+    #             pdrop_u8,
+    #             encode_dropout_in_sign_bit=False,
+    #             NUM_HEADS=NUM_HEADS,
+    #             BLOCK_M=BLOCK_M,
+    #             BLOCK_N=BLOCK_N,
+    #         )
+
+    #     if not PRE_LOAD_V:
+    #         if IS_EVEN_MN:
+    #             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+    #         else:
+    #             V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg")
+    #     tl.dot(P, V, acc=O_, allow_tf32=False)
+
+    #     n_block -= 1
+    #     # if n_masking_steps > 1 and n_block <= n_block_min:
+    #     #     break
+
+
+    # for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages):
+    if True:
+        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+        K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+        if PRE_LOAD_V:
+            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        S = tl.dot(Q, K)
         S = apply_mask(
             S,
             col_idx,
@@ -808,7 +827,7 @@ def flash_fwd_kernel(
         )
         col_idx -= BLOCK_N
 
-        O_acc, P, rowmax_, rowsum_ = softmax_rescale(
+        O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
             rowmax_,
@@ -816,7 +835,8 @@ def flash_fwd_kernel(
             softmax_scale_log2=softmax_scale_log2,
             is_border=is_local
         )
-        P = P.to(pO.type.element_ty)
+
+        P = P.to(O_ptr.type.element_ty)
 
         row_start = m_block * (BLOCK_M // 16)
         col_start = n_block * (BLOCK_N // 32)
@@ -836,7 +856,7 @@ def flash_fwd_kernel(
                 BLOCK_M=BLOCK_M,
                 BLOCK_N=BLOCK_N,
             )
-            tl.store(P_ptr + P_offset, P_drop, mask=qmask[:, None] & kvmask[None, :])
+            tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
             P_offset += BLOCK_N
 
         if is_dropout:
@@ -855,24 +875,29 @@ def flash_fwd_kernel(
                 BLOCK_N=BLOCK_N,
             )
 
+        if not PRE_LOAD_V:
+            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+
         tl.dot(P, V, acc=O_)
 
     # Final LSE
-    lse = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, float('inf'), rowmax_ * softmax_scale +tl.log(rowsum_))
-    inv_sum = tl.where(rowsum_ == 0 | rowsum_ != rowsum_, 1.0, 1.0 / rowsum_)
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     
-    # Rescale output
-    if is_dropout:
-        O_ *= inv_sum * rpdrop
-    else:
-        O_ *= inv_sum
+    # # Rescale output
+    # if is_dropout:
+    #     O_ *= inv_sum[:, None] * rpdrop
+    # else:
+    #     O_ *= inv_sum[:, None]
     
-    O = O_.to(pO.type.element_ty)
+    O = O_.to(O_ptr.type.element_ty)
 
     # Write back output
     O_ptr += bid * o_b_stride
     O_ptr += hid * o_h_stride
     O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM)
+
     if IS_EVEN_MN:
         tl.store(O_ptr + O_offset, O)
     else:
@@ -943,7 +968,7 @@ def mha_fwd(
         assert out.dtype == q.dtype
         assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
     else:
-        out = torch.empty_like(q)
+        out = torch.empty_like(q, dtype=v.dtype)
 
     round_multiple = lambda x, m: (x + m - 1) // m * m
     head_size_rounded = round_multiple(head_size, 32)
@@ -952,15 +977,16 @@ def mha_fwd(
 
     with torch_device_fn.device(q_device):
         # Set softmax params
-        lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float)
+        lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device)
         if return_softmax:
             assert p_dropout > 0, "return_softmax is only supported when p_dropout > 0.0"
             p = torch.empty(
                 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
                 dtype=q_dtype,
+                device=q_device
             )
         else:
-            p = torch.empty(())
+            p = torch.empty((), device=q_device)
 
         # Set dropout params
         if p_dropout > 0:
@@ -968,6 +994,7 @@ def mha_fwd(
             philox_seed, philox_offset = update_philox_state(increment)
             is_dropout = True
         else:
+            philox_seed, philox_offset = None, None
             is_dropout = False
 
         p_dropout = 1 - p_dropout
@@ -1023,12 +1050,15 @@ def mha_fwd(
             out.stride(0),
             out.stride(-3),
             out.stride(-2),
+            num_heads,
+            num_heads_k,
             alibi_slopes,
             philox_seed,
             philox_offset,
             pdrop_u8,
             rpdrop,
             alibi_slopes_batch_stride,
+            head_size,
             is_dropout=is_dropout,
             is_causal=is_causal,
             is_local=is_local,
@@ -1041,8 +1071,6 @@ def mha_fwd(
             BATCH_SIZE=batch_size,
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
-            HEAD_DIM=head_size_rounded,
-            IS_EVEN_MN=is_even_mn,
         )
     
     if swap_seq_and_group:

From 1962a965372e161fb41c337cd6290a56521e3df2 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Fri, 7 Mar 2025 12:10:52 +0000
Subject: [PATCH 09/25] passed noncausal, nonlocal and no bias.

---
 src/flag_gems/ops/attention.py | 255 +++++++++++++++++----------------
 1 file changed, 134 insertions(+), 121 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 78789162a..26efbf69f 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -562,8 +562,7 @@ def apply_mask(
         if is_local:
             S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
         
-        
-        if (not local) & (not is_causal) & (not is_even_mn):
+        if (not is_local) & (not is_causal) & (not is_even_mn):
             S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
 
     return S
@@ -576,32 +575,44 @@ def softmax_rescale(
     row_max,
     row_sum,
     softmax_scale_log2: tl.constexpr,
-    is_border: tl.constexpr
+    is_border: tl.constexpr,
+    is_init: tl.constexpr
 ):
-    prev_row_max = row_max
+    prev_max = row_max
     row_max = tl.maximum(row_max, tl.max(S, 1))
-    row_sum_scale = tl.math.exp2(row_max - prev_row_max) * softmax_scale_log2
-    tl.device_print('row_sum_scale', row_max - prev_row_max)
-    row_sum *= row_sum_scale
-    O_acc *= row_sum_scale[:, None]
+
+    if not is_init:
+        if is_border:
+            cur_max = tl.where(row_max == float('-inf'), 0, row_max)
+        else:
+            cur_max = row_max
+        p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2)
+        row_sum *= p_scale
+        O_acc *= p_scale[:, None]
+
     max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2)
     P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None])
-    row_sum = row_sum + tl.sum(S, 1)
+    row_sum = row_sum + tl.sum(P, 1)
     return O_acc, P, row_max, row_sum
 
 
-@triton.autotune(
-    configs=runtime.get_tuned_config("attention"),
-    key=["HEAD_DIM"],
-    prune_configs_by={
-        "early_config_prune": early_config_prune,
-        "perf_model": None,
-        "top_k": 1.0,
-    },
-)
+# @triton.autotune(
+#     configs=runtime.get_tuned_config("attention"),
+#     key=["HEAD_DIM"],
+#     prune_configs_by={
+#         "early_config_prune": early_config_prune,
+#         "perf_model": None,
+#         "top_k": 1.0,
+#     },
+# )
 @triton.heuristics(
     values={
-        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
+        'BLOCK_M': lambda args: 64,
+        'BLOCK_N': lambda args: 64,
+        'num_warps': lambda args: 4,
+        'num_stages': lambda args: 2,
+        'PRE_LOAD_V': lambda args: False,
+        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
@@ -686,7 +697,6 @@ def flash_fwd_kernel(
         # local and not causal, 
         n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
 
-
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
     row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -705,8 +715,7 @@ def flash_fwd_kernel(
     K_ptr += (hid // h_hk_ratio) * k_h_stride
     V_ptr += bid * k_b_stride
     V_ptr += (hid // h_hk_ratio) * k_h_stride
-    col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-
+    
     P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
     P_ptr += n_block * BLOCK_N
     P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
@@ -715,96 +724,98 @@ def flash_fwd_kernel(
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
     
-    # for _ in range(n_masking_steps):
-    #     K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-    #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-    #     if IS_EVEN_MN:
-    #         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
-    #         if PRE_LOAD_V:
-    #             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
-    #     else:
-    #         kvmask = col_idx[None, :] < seqlen_k
-    #         K = tl.load(K_ptr + K_offset, mask=kvmask, cache_modifier=".cg")
-    #         if PRE_LOAD_V:
-    #             V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg")
-    #     S = tl.dot(Q, K, allow_tf32=False)
-    #     S = apply_mask(
-    #         S,
-    #         col_idx,
-    #         row_idx,
-    #         seqlen_q,
-    #         seqlen_k,
-    #         ws_left,
-    #         ws_right,
-    #         alibi_slope,
-    #         is_even_mn=IS_EVEN_MN,
-    #         is_causal=is_causal,
-    #         is_local=is_local,
-    #         has_alibi=has_alibi
-    #     )
-    #     col_idx -= BLOCK_N
-
-    #     O_acc, P, rowmax_, rowsum_ = softmax_rescale(
-    #         O_,
-    #         S,
-    #         rowmax_,
-    #         rowsum_,
-    #         softmax_scale_log2=softmax_scale_log2,
-    #         is_border=(is_causal or is_local)
-    #     )
-    #     P = P.to(O_ptr.type.element_ty)
-
-    #     row_start = m_block * (BLOCK_M // 16)
-    #     col_start = n_block * (BLOCK_N // 32)
-    #     if return_P:
-    #         P_drop = P
-    #         P_drop = apply_dropout(
-    #             P_drop,
-    #             row_start,
-    #             col_start,
-    #             bid,
-    #             hid,
-    #             philox_seed,
-    #             philox_offset,
-    #             pdrop_u8,
-    #             encode_dropout_in_sign_bit=True,
-    #             NUM_HEADS=NUM_HEADS,
-    #             BLOCK_M=BLOCK_M,
-    #             BLOCK_N=BLOCK_N,
-    #         )
-    #         tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
-    #         P_offset += BLOCK_N
-
-    #     if is_dropout:
-    #         P = apply_dropout(
-    #             P,
-    #             row_start,
-    #             col_start,
-    #             bid,
-    #             hid,
-    #             philox_seed,
-    #             philox_offset,
-    #             pdrop_u8,
-    #             encode_dropout_in_sign_bit=False,
-    #             NUM_HEADS=NUM_HEADS,
-    #             BLOCK_M=BLOCK_M,
-    #             BLOCK_N=BLOCK_N,
-    #         )
-
-    #     if not PRE_LOAD_V:
-    #         if IS_EVEN_MN:
-    #             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
-    #         else:
-    #             V = tl.load(V_ptr + V_offset, mask=kvmask, cache_modifier=".cg")
-    #     tl.dot(P, V, acc=O_, allow_tf32=False)
-
-    #     n_block -= 1
-    #     # if n_masking_steps > 1 and n_block <= n_block_min:
-    #     #     break
-
-
-    # for n_block in tl.range(n_block, n_block_min - 1, num_stages=num_stages):
-    if True:
+    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1):
+        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+        V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+        if IS_EVEN_MN:
+            K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+            if PRE_LOAD_V:
+                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        else:
+            kvmask = col_idx < seqlen_k
+            K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg")
+            if PRE_LOAD_V:
+                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+        S = tl.dot(Q, K, allow_tf32=False)
+        S = apply_mask(
+            S,
+            col_idx,
+            row_idx,
+            seqlen_q,
+            seqlen_k,
+            ws_left,
+            ws_right,
+            alibi_slope,
+            is_even_mn=IS_EVEN_MN,
+            is_causal=is_causal,
+            is_local=is_local,
+            has_alibi=has_alibi
+        )
+        # col_idx -= BLOCK_N
+
+        is_init = (n_block == n_block_max - 1).to(tl.int1)
+        O_, P, rowmax_, rowsum_ = softmax_rescale(
+            O_,
+            S,
+            rowmax_,
+            rowsum_,
+            softmax_scale_log2=softmax_scale_log2,
+            is_border=(is_causal or is_local),
+            is_init=is_init
+        )
+        P = P.to(O_ptr.type.element_ty)
+
+        row_start = m_block * (BLOCK_M // 16)
+        col_start = n_block * (BLOCK_N // 32)
+        if return_P:
+            P_drop = P
+            P_drop = apply_dropout(
+                P_drop,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=True,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+            tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :])
+            P_offset += BLOCK_N
+
+        if is_dropout:
+            P = apply_dropout(
+                P,
+                row_start,
+                col_start,
+                bid,
+                hid,
+                philox_seed,
+                philox_offset,
+                pdrop_u8,
+                encode_dropout_in_sign_bit=False,
+                NUM_HEADS=NUM_HEADS,
+                BLOCK_M=BLOCK_M,
+                BLOCK_N=BLOCK_N,
+            )
+
+        if not PRE_LOAD_V:
+            if IS_EVEN_MN:
+                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            else:
+                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+        O_ = tl.dot(P, V, O_, allow_tf32=False)
+
+        # if n_masking_steps > 1 and n_block <= n_block_min:
+        #     break
+
+
+    for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages):
+        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
         if PRE_LOAD_V:
@@ -825,15 +836,17 @@ def flash_fwd_kernel(
             is_local=is_local,
             has_alibi=has_alibi
         )
-        col_idx -= BLOCK_N
+        # col_idx -= BLOCK_N
 
+        is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
             rowmax_,
             rowsum_,
             softmax_scale_log2=softmax_scale_log2,
-            is_border=is_local
+            is_border=is_local,
+            is_init=is_init
         )
 
         P = P.to(O_ptr.type.element_ty)
@@ -879,17 +892,17 @@ def flash_fwd_kernel(
             V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
 
-        tl.dot(P, V, acc=O_)
+        O_ = tl.dot(P, V, O_)
 
     # Final LSE
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     
-    # # Rescale output
-    # if is_dropout:
-    #     O_ *= inv_sum[:, None] * rpdrop
-    # else:
-    #     O_ *= inv_sum[:, None]
+    # Rescale output
+    if is_dropout:
+        O_ *= inv_sum[:, None] * rpdrop
+    else:
+        O_ *= inv_sum[:, None]
     
     O = O_.to(O_ptr.type.element_ty)
 
@@ -908,7 +921,7 @@ def flash_fwd_kernel(
     if IS_EVEN_MN:
         tl.store(lse_ptr + row_idx, lse)
     else:
-        tl.store(lse_ptr + row_idx, lse, mask=qmask)
+        tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q)
 
 
 def mha_fwd(

From afc6108b599ea8d39a830638e2c1328a6632ba97 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Tue, 11 Mar 2025 11:32:33 +0000
Subject: [PATCH 10/25] added splitkv, perf still lags.

---
 src/flag_gems/__init__.py      |   5 +
 src/flag_gems/ops/__init__.py  |   3 +-
 src/flag_gems/ops/attention.py | 468 ++++++++++++++++++++++++++++++---
 3 files changed, 444 insertions(+), 32 deletions(-)

diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py
index 8cf202e3e..85cc04941 100644
--- a/src/flag_gems/__init__.py
+++ b/src/flag_gems/__init__.py
@@ -149,6 +149,11 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
             # ("prod.dim_int", prod_dim, Autograd.disable),
             # ("sum", sum, Autograd.disable),
             # ("sum.dim_IntList", sum_dim, Autograd.disable),
+            # (
+            #     "scaled_dot_product_attention",
+            #     scaled_dot_product_attention,
+            #     Autograd.disable,
+            # ),
             (
                 "_flash_attention_forward",
                 flash_attention_forward,
diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py
index c4051d6d8..2802e11f6 100755
--- a/src/flag_gems/ops/__init__.py
+++ b/src/flag_gems/ops/__init__.py
@@ -7,7 +7,7 @@
 from .arange import arange, arange_start
 from .argmax import argmax
 from .argmin import argmin
-from .attention import flash_attention_forward
+from .attention import flash_attention_forward, scaled_dot_product_attention
 from .batch_norm import batch_norm
 from .bitwise_and import (
     bitwise_and_scalar,
@@ -293,6 +293,7 @@
     "vstack",
     "repeat_interleave_tensor",
     "flash_attention_forward",
+    "scaled_dot_product_attention",
     "conv2d",
     "conv1d",
     "_conv_depthwise2d",
diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 26efbf69f..6a917b104 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -543,10 +543,7 @@ def apply_mask(
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
 ):
-    # need_mask = has_alibi or is_causal
-    # need_mask |= is_local
-    # need_mask |= not is_even_mn
-    need_mask: tl.constexpr = has_alibi | is_local | (not is_even_mn)
+    need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn)
     if need_mask:
         col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
         col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
@@ -574,7 +571,7 @@ def softmax_rescale(
     S,
     row_max,
     row_sum,
-    softmax_scale_log2: tl.constexpr,
+    softmax_scale_log2e: tl.constexpr,
     is_border: tl.constexpr,
     is_init: tl.constexpr
 ):
@@ -586,16 +583,33 @@ def softmax_rescale(
             cur_max = tl.where(row_max == float('-inf'), 0, row_max)
         else:
             cur_max = row_max
-        p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2)
+        p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
         row_sum *= p_scale
         O_acc *= p_scale[:, None]
 
-    max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2)
-    P = tl.math.exp2(S * softmax_scale_log2 - max_scaled[:, None])
+    max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e)
+    P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
     row_sum = row_sum + tl.sum(P, 1)
     return O_acc, P, row_max, row_sum
 
 
+def block_m_heuristic(headdim, is_dropout):
+    return 64 if headdim <= 128 else 32
+
+def block_n_heuristic(headdim, is_dropout):
+    return 64 if headdim <= 128 else 32
+
+def block_m_splitkv_heuristic(headdim):
+    return 64 if headdim <= 128 else 32
+
+def block_n_splitkv_heuristic(headdim):
+    if headdim <= 64:
+        return 128
+    elif headdim <= 128:
+        return 64
+    else:
+        return 32
+
 # @triton.autotune(
 #     configs=runtime.get_tuned_config("attention"),
 #     key=["HEAD_DIM"],
@@ -607,11 +621,11 @@ def softmax_rescale(
 # )
 @triton.heuristics(
     values={
-        'BLOCK_M': lambda args: 64,
-        'BLOCK_N': lambda args: 64,
+        'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]),
+        'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]),
         'num_warps': lambda args: 4,
-        'num_stages': lambda args: 2,
-        'PRE_LOAD_V': lambda args: False,
+        'num_stages': lambda args: 3,
+        'PRE_LOAD_V': lambda args: True,
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
@@ -650,7 +664,7 @@ def flash_fwd_kernel(
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
     softmax_scale: tl.constexpr,
-    softmax_scale_log2: tl.constexpr,
+    softmax_scale_log2e: tl.constexpr,
     ws_left: tl.constexpr,
     ws_right: tl.constexpr,
     return_P: tl.constexpr,
@@ -661,6 +675,7 @@ def flash_fwd_kernel(
     IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
+    num_splits: tl.constexpr,
     num_warps: tl.constexpr,
     num_stages: tl.constexpr
 ):
@@ -677,7 +692,7 @@ def flash_fwd_kernel(
     
     if is_causal or is_local:
         n_block_max = min(n_block_max,
-                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, BLOCK_N))
+                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
 
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
@@ -691,7 +706,7 @@ def flash_fwd_kernel(
             n_masking_steps = 0
         else:
             n_masking_steps = 1
-    elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
+    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
         n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
     else:
         # local and not causal, 
@@ -760,7 +775,7 @@ def flash_fwd_kernel(
             S,
             rowmax_,
             rowsum_,
-            softmax_scale_log2=softmax_scale_log2,
+            softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
             is_init=is_init
         )
@@ -818,10 +833,15 @@ def flash_fwd_kernel(
         col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+        # if PRE_LOAD_V:
+        #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+        #     V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        S = tl.dot(Q, K)
+
         if PRE_LOAD_V:
             V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
-        S = tl.dot(Q, K)
+
         S = apply_mask(
             S,
             col_idx,
@@ -844,7 +864,7 @@ def flash_fwd_kernel(
             S,
             rowmax_,
             rowsum_,
-            softmax_scale_log2=softmax_scale_log2,
+            softmax_scale_log2e=softmax_scale_log2e,
             is_border=is_local,
             is_init=is_init
         )
@@ -894,7 +914,8 @@ def flash_fwd_kernel(
 
         O_ = tl.dot(P, V, O_)
 
-    # Final LSE
+    # LSE
+    # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     
@@ -924,6 +945,312 @@ def flash_fwd_kernel(
         tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q)
 
 
+@triton.heuristics(
+    values={
+        'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'num_warps': lambda args: 4,
+        'num_stages': lambda args: 3,
+        'PRE_LOAD_V': lambda args: True,
+        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
+    }
+)
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
+def flash_fwd_splitkv_kernel(
+    Q_ptr,
+    K_ptr,
+    V_ptr,
+    P_ptr,
+    O_ptr,
+    lse_ptr,
+    seqlen_q,
+    seqlen_k,
+    seqlen_q_rounded,
+    seqlen_k_rounded,
+    q_b_stride,
+    q_s_stride,
+    q_h_stride,
+    k_b_stride,
+    k_s_stride,
+    k_h_stride,
+    o_b_stride,
+    o_s_stride,
+    o_h_stride,
+    h,
+    hk,
+    pSlopes,
+    philox_seed,
+    philox_offset,
+    pdrop_u8,
+    rpdrop,
+    slopes_batch_stride,
+    HEAD_DIM: tl.constexpr,
+    is_dropout: tl.constexpr,
+    is_causal: tl.constexpr,
+    is_local: tl.constexpr,
+    has_alibi: tl.constexpr,
+    softmax_scale: tl.constexpr,
+    softmax_scale_log2e: tl.constexpr,
+    ws_left: tl.constexpr,
+    ws_right: tl.constexpr,
+    return_P: tl.constexpr,
+    PRE_LOAD_V: tl.constexpr,
+    BATCH_SIZE: tl.constexpr,
+    NUM_HEADS: tl.constexpr,
+    NUM_HEADS_K: tl.constexpr,
+    IS_EVEN_MN: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    num_splits: tl.constexpr,
+    num_warps: tl.constexpr,
+    num_stages: tl.constexpr
+):
+    m_block = tl.program_id(0)
+    split_id = tl.program_id(1)
+    bid = tl.program_id(2) // NUM_HEADS
+    hid = tl.program_id(2) % NUM_HEADS
+    
+    blocks_per_split = tl.cdiv(tl.cdiv(seqlen_k, BLOCK_N), num_splits)
+
+    if is_local:
+        n_block_min = max(split_id * blocks_per_split, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
+    else:
+        n_block_min = split_id * blocks_per_split
+
+    n_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N))
+    if is_causal or is_local:
+        n_block_max = min(n_block_max,
+                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
+
+    if has_alibi:
+        alibi_offset = bid * slopes_batch_stride + hid
+        alibi_slope = tl.load(pSlopes + alibi_offset)
+        alibi_slope /= scale
+    else:
+        alibi_slope = 0.0
+
+    if (not is_causal) and (not is_local):
+        if IS_EVEN_MN:
+            n_masking_steps = 0
+        else:
+            n_masking_steps = 1
+    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
+        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
+    else:
+        # local and not causal, 
+        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+
+    Q_ptr += bid * q_b_stride
+    Q_ptr += hid * q_h_stride
+    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    qmask = row_idx[:, None] < seqlen_q
+    if IS_EVEN_MN:
+        Q = tl.load(Q_ptr + Q_off)
+    else:
+        Q = tl.load(Q_ptr + Q_off, mask=qmask)
+
+    # Start from the right most block
+    n_block = n_block_max - 1
+
+    h_hk_ratio = h // hk
+    K_ptr += bid * k_b_stride
+    K_ptr += (hid // h_hk_ratio) * k_h_stride
+    V_ptr += bid * k_b_stride
+    V_ptr += (hid // h_hk_ratio) * k_h_stride
+    
+    P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
+    P_ptr += n_block * BLOCK_N
+    P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
+
+    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
+    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
+    
+    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1):
+        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+        V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+        if IS_EVEN_MN:
+            K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+            if PRE_LOAD_V:
+                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        else:
+            kvmask = col_idx < seqlen_k
+            K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg")
+            if PRE_LOAD_V:
+                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+        S = tl.dot(Q, K, allow_tf32=False)
+        S = apply_mask(
+            S,
+            col_idx,
+            row_idx,
+            seqlen_q,
+            seqlen_k,
+            ws_left,
+            ws_right,
+            alibi_slope,
+            is_even_mn=IS_EVEN_MN,
+            is_causal=is_causal,
+            is_local=is_local,
+            has_alibi=has_alibi
+        )
+        # col_idx -= BLOCK_N
+
+        is_init = (n_block == n_block_max - 1).to(tl.int1)
+        O_, P, rowmax_, rowsum_ = softmax_rescale(
+            O_,
+            S,
+            rowmax_,
+            rowsum_,
+            softmax_scale_log2e=softmax_scale_log2e,
+            is_border=(is_causal or is_local),
+            is_init=is_init
+        )
+        P = P.to(Q_ptr.type.element_ty)
+
+        if not PRE_LOAD_V:
+            if IS_EVEN_MN:
+                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            else:
+                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+        O_ = tl.dot(P, V, O_, allow_tf32=False)
+        # if n_masking_steps > 1 and n_block <= n_block_min:
+        #     break
+
+    for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages):
+        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+        K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+        # if PRE_LOAD_V:
+        #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+        #     V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        S = tl.dot(Q, K)
+
+        if PRE_LOAD_V:
+            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+
+        S = apply_mask(
+            S,
+            col_idx,
+            row_idx,
+            seqlen_q,
+            seqlen_k,
+            ws_left,
+            ws_right,
+            alibi_slope,
+            is_even_mn=True,
+            is_causal=False,
+            is_local=is_local,
+            has_alibi=has_alibi
+        )
+        # col_idx -= BLOCK_N
+
+        is_init = (n_block == n_block_max - 1).to(tl.int1)
+        O_, P, rowmax_, rowsum_ = softmax_rescale(
+            O_,
+            S,
+            rowmax_,
+            rowsum_,
+            softmax_scale_log2e=softmax_scale_log2e,
+            is_border=is_local,
+            is_init=is_init
+        )
+
+        P = P.to(Q_ptr.type.element_ty)
+
+        if not PRE_LOAD_V:
+            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+
+        O_ = tl.dot(P, V, O_)
+
+    # LSE
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
+    
+    # Rescale output
+    O_ *= inv_sum[:, None]
+
+    # Write back output
+    O_split_ptr = O_ptr
+    # (n_splits, batch_size, num_heads, seqlen_q, head_size)
+    # grid = (seq_block, split, batch * head)
+    O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM
+    O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
+
+    if IS_EVEN_MN:
+        tl.store(O_split_ptr + O_split_offset, O_)
+    else:
+        tl.store(O_split_ptr + O_split_offset, O_, mask=qmask)
+    
+    # Write back lse
+    lse_split_ptr = lse_ptr
+    lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q
+    lse_split_ptr += m_block * BLOCK_M
+
+    if IS_EVEN_MN:
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse)
+    else:
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q)
+
+
+@triton.jit
+def flash_fwd_splitkv_combine_kernel(
+    out_ptr,
+    lse_ptr,
+    out_splits_ptr,
+    lse_splits_ptr,
+    head_size: tl.constexpr,
+    out_b_stride,
+    out_s_stride,
+    out_h_stride,
+    n_splits,
+    BLOCK_M: tl.constexpr,
+    q_total,
+    MAX_N_SPLITS: tl.constexpr,
+):
+    pid = tl.program_id(0)
+    lse_splits_ptr += pid * BLOCK_M
+    lse_ptr += pid * BLOCK_M
+    out_splits_ptr += pid * BLOCK_M * head_size
+    out_ptr += pid * BLOCK_M * head_size
+    lse_split_stride = tl.num_programs(0) * BLOCK_M
+    out_split_stride = tl.num_programs(0) * BLOCK_M * head_size
+
+    # Subtracting maximum from each of the split lse's for better numerical stability
+    lse_split_offset = tl.arange(0, BLOCK_M)[:, None] + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
+    lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits)
+    lse_splits = tl.load(lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float('-inf'))
+    max_lse = tl.max(lse_splits, 1)
+    
+    # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
+    Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
+    Z_scaled = tl.sum(Zi_scaled, 1)
+    Zi_Z = Zi_scaled / Z_scaled[:, None]
+
+    # Write back LSE
+    lse = tl.log(Z_scaled) + max_lse
+    out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
+    tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
+
+    out_split_offset = (
+        tl.arange(0, BLOCK_M)[:, None, None] * head_size
+        + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
+        + tl.arange(0, head_size)[None, None, :]
+    )
+    out_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total) & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
+    out_splits = tl.load(out_splits_ptr + out_split_offset, mask=out_split_mask, other=0)
+    out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
+    out = out.to(out_ptr.type.element_ty)
+    
+    # tl.device_print('O', out)
+    # Write back output
+    out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size)
+    tl.store(out_ptr + out_offset, out, mask=out_mask[:, None])
+    
+
 def mha_fwd(
     q,
     k,
@@ -988,6 +1315,26 @@ def mha_fwd(
     seqlen_q_rounded = round_multiple(seqlen_q, 128)
     seqlen_k_rounded = round_multiple(seqlen_k, 128)
 
+    def splits_heuristics(num_tasks, num_sms, n_blocks):
+        # splits only number of waves and wave efficiency are both low
+        n_waves = triton.cdiv(num_tasks, num_sms)
+        eff = (num_tasks / num_sms) / n_waves
+        if eff > 0.85 or n_waves > 10:
+            return 1
+        max_eff = eff
+        best_splits = 1
+        for w in range(n_waves, 10):
+            n_splits = min(num_sms, n_blocks, w * num_sms // num_tasks)
+            blocks_per_split = triton.cdiv(n_blocks, n_splits)
+            if blocks_per_split < 4:
+                continue
+            n_splits = triton.cdiv(n_blocks, blocks_per_split)
+            eff = (n_splits * num_tasks / num_sms) / w
+            if eff > max_eff:
+                max_eff = eff
+                best_splits = n_splits
+        return best_splits
+
     with torch_device_fn.device(q_device):
         # Set softmax params
         lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device)
@@ -1014,9 +1361,31 @@ def mha_fwd(
         pdrop_u8 = math.floor(p_dropout * 255.0)
         rpdrop = 1. / p_dropout
 
+        # Check splitkv
+        if not is_dropout:
+            n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m_splitkv_heuristic(head_size))
+            num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count
+            n_blocks = triton.cdiv(seqlen_k, block_n_splitkv_heuristic(head_size))
+            n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
+            print('n_blocks:', n_blocks)
+            print('n_splits:', n_splits)
+        else:
+            n_splits = 1
+        
+        if n_splits > 1:
+            lse_splits = torch.empty(
+                (n_splits, batch_size, num_heads, seqlen_q),
+                dtype=torch.float,
+                device=q_device
+            )
+            out_splits = torch.empty(
+                (n_splits, batch_size, num_heads, seqlen_q, head_size),
+                dtype=torch.float,
+                device=q_device
+            )
+
         M_LOG2E	= 1.4426950408889634074
-        scale_softmax_log2 = softmax_scale * M_LOG2E
-        scale_softmax_rp_dropout = rpdrop * softmax_scale
+        softmax_scale_log2e = softmax_scale * M_LOG2E
 
         # Set alibi params
         if alibi_slopes is not None:
@@ -1036,20 +1405,33 @@ def mha_fwd(
         # ONLY EVEN_K IS SUPPORTED
         assert head_size == head_size_rounded
 
+        # Launch kernel
+        if n_splits > 1:
+            grid = lambda args: (
+                triton.cdiv(seqlen_q, args["BLOCK_M"]),
+                n_splits,
+                batch_size * num_heads
+            )
+            kernel = flash_fwd_splitkv_kernel[grid]
+            tmp_lse = lse_splits
+            tmp_out = out_splits
+        else:
+            grid = lambda args: (
+                triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
+                batch_size,
+                num_heads,
+            )
+            kernel = flash_fwd_kernel[grid]
+            tmp_lse = lse
+            tmp_out = out
 
-        grid = lambda args: (
-            triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
-            batch_size,
-            num_heads,
-        )
-
-        flash_fwd_kernel[grid](
+        kernel(
             q,
             k,
             v,
             p,
-            out,
-            lse,
+            tmp_out,
+            tmp_lse,
             seqlen_q,
             seqlen_k,
             seqlen_q_rounded,
@@ -1077,15 +1459,39 @@ def mha_fwd(
             is_local=is_local,
             has_alibi=has_alibi,
             softmax_scale=softmax_scale,
-            softmax_scale_log2=scale_softmax_log2,
+            softmax_scale_log2e=softmax_scale_log2e,
             ws_left=window_size_left,
             ws_right=window_size_right,
             return_P=return_softmax,
             BATCH_SIZE=batch_size,
+            num_splits=n_splits,
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
         )
     
+    if n_splits > 1:
+        if head_size % 128 == 0:
+            BLOCK_M = 4
+        elif head_size % 64 == 0:
+            BLOCK_M = 8
+        else:
+            BLOCK_M = 16
+        grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), )
+        flash_fwd_splitkv_combine_kernel[grid](
+            out,
+            lse,
+            tmp_out,
+            tmp_lse,
+            head_size,
+            out.stride(0),
+            out.stride(-3),
+            out.stride(-1),
+            n_splits,
+            BLOCK_M,
+            q_total=batch_size * num_heads * seqlen_q,
+            MAX_N_SPLITS=triton.next_power_of_2(n_splits),
+        )
+
     if swap_seq_and_group:
         out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))
         q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))

From 88f33d4138a29d3ad01ce55da27a8e0303c4799f Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Tue, 18 Mar 2025 03:51:45 +0000
Subject: [PATCH 11/25] nuked dynamic cf in loop, but still failed pipelining.

---
 src/flag_gems/ops/__init__.py  | 614 ++++++++++++++++-----------------
 src/flag_gems/ops/attention.py | 267 +++++++-------
 2 files changed, 455 insertions(+), 426 deletions(-)

diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py
index 2802e11f6..01a31e633 100755
--- a/src/flag_gems/ops/__init__.py
+++ b/src/flag_gems/ops/__init__.py
@@ -1,312 +1,312 @@
-from .abs import abs
-from .add import add
-from .addmm import addmm
-from .all import all, all_dim, all_dims
-from .amax import amax
-from .any import any, any_dim, any_dims
-from .arange import arange, arange_start
-from .argmax import argmax
-from .argmin import argmin
+# from .abs import abs
+# from .add import add
+# from .addmm import addmm
+# from .all import all, all_dim, all_dims
+# from .amax import amax
+# from .any import any, any_dim, any_dims
+# from .arange import arange, arange_start
+# from .argmax import argmax
+# from .argmin import argmin
 from .attention import flash_attention_forward, scaled_dot_product_attention
-from .batch_norm import batch_norm
-from .bitwise_and import (
-    bitwise_and_scalar,
-    bitwise_and_scalar_tensor,
-    bitwise_and_tensor,
-)
-from .bitwise_not import bitwise_not
-from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
-from .bmm import bmm
-from .cat import cat
-from .clamp import clamp, clamp_tensor
-from .conv1d import conv1d
-from .conv2d import conv2d
-from .conv_depthwise2d import _conv_depthwise2d
-from .cos import cos
-from .count_nonzero import count_nonzero
-from .cross_entropy_loss import cross_entropy_loss
-from .cummin import cummin
-from .cumsum import cumsum, normed_cumsum
-from .diag import diag
-from .diag_embed import diag_embed
-from .diagonal import diagonal_backward
-from .div import div_mode, floor_divide, remainder, true_divide
-from .dropout import native_dropout
-from .embedding import embedding
-from .eq import eq, eq_scalar
-from .erf import erf
-from .exp import exp
-from .exponential_ import exponential_
-from .fill import fill_scalar, fill_tensor
-from .flip import flip
-from .full import full
-from .full_like import full_like
-from .gather import gather, gather_backward
-from .ge import ge, ge_scalar
-from .gelu import gelu
-from .groupnorm import group_norm
-from .gt import gt, gt_scalar
-from .hstack import hstack
-from .index_add import index_add
-from .index_select import index_select
-from .instancenorm import instance_norm
-from .isclose import allclose, isclose
-from .isfinite import isfinite
-from .isin import isin
-from .isinf import isinf
-from .isnan import isnan
-from .layernorm import layer_norm
-from .le import le, le_scalar
-from .log_sigmoid import log_sigmoid
-from .log_softmax import log_softmax
-from .logical_and import logical_and
-from .logical_not import logical_not
-from .logical_or import logical_or
-from .logical_xor import logical_xor
-from .lt import lt, lt_scalar
-from .masked_fill import masked_fill, masked_fill_
-from .masked_select import masked_select
-from .max import max, max_dim
-from .maximum import maximum
-from .mean import mean, mean_dim
-from .min import min, min_dim
-from .minimum import minimum
-from .mm import mm
-from .mse_loss import mse_loss
-from .mul import mul
-from .multinomial import multinomial
-from .mv import mv
-from .ne import ne, ne_scalar
-from .neg import neg
-from .nllloss import (
-    nll_loss2d_backward,
-    nll_loss2d_forward,
-    nll_loss_backward,
-    nll_loss_forward,
-)
-from .nonzero import nonzero
-from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
-from .ones import ones
-from .ones_like import ones_like
-from .outer import outer
-from .pad import constant_pad_nd, pad
-from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
-from .prod import prod, prod_dim
-from .quantile import quantile
-from .rand import rand
-from .rand_like import rand_like
-from .randn import randn
-from .randn_like import randn_like
-from .randperm import randperm
-from .reciprocal import reciprocal
-from .relu import relu
-from .repeat import repeat
-from .repeat_interleave import (
-    repeat_interleave_self_int,
-    repeat_interleave_self_tensor,
-    repeat_interleave_tensor,
-)
-from .resolve_conj import resolve_conj
-from .resolve_neg import resolve_neg
-from .rms_norm import rms_norm
-from .rsqrt import rsqrt
-from .scatter import scatter
-from .select_scatter import select_scatter
-from .sigmoid import sigmoid
-from .silu import silu
-from .sin import sin
-from .slice_scatter import slice_scatter
-from .softmax import softmax
-from .sort import sort
-from .stack import stack
-from .sub import sub
-from .sum import sum, sum_dim
-from .tanh import tanh
-from .tile import tile
-from .topk import topk
-from .triu import triu
-from .uniform import uniform_
-from .unique import _unique2
-from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa
-from .upsample_nearest2d import upsample_nearest2d
-from .var_mean import var_mean
-from .vdot import vdot
-from .vector_norm import vector_norm
-from .vstack import vstack
-from .weightnorm import weight_norm, weight_norm_interface
-from .where import where_scalar_other, where_scalar_self, where_self, where_self_out
-from .zeros import zeros
-from .zeros_like import zeros_like
+# from .batch_norm import batch_norm
+# from .bitwise_and import (
+#     bitwise_and_scalar,
+#     bitwise_and_scalar_tensor,
+#     bitwise_and_tensor,
+# )
+# from .bitwise_not import bitwise_not
+# from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
+# from .bmm import bmm
+# from .cat import cat
+# from .clamp import clamp, clamp_tensor
+# from .conv1d import conv1d
+# from .conv2d import conv2d
+# from .conv_depthwise2d import _conv_depthwise2d
+# from .cos import cos
+# from .count_nonzero import count_nonzero
+# from .cross_entropy_loss import cross_entropy_loss
+# from .cummin import cummin
+# from .cumsum import cumsum, normed_cumsum
+# from .diag import diag
+# from .diag_embed import diag_embed
+# from .diagonal import diagonal_backward
+# from .div import div_mode, floor_divide, remainder, true_divide
+# from .dropout import native_dropout
+# from .embedding import embedding
+# from .eq import eq, eq_scalar
+# from .erf import erf
+# from .exp import exp
+# from .exponential_ import exponential_
+# from .fill import fill_scalar, fill_tensor
+# from .flip import flip
+# from .full import full
+# from .full_like import full_like
+# from .gather import gather, gather_backward
+# from .ge import ge, ge_scalar
+# from .gelu import gelu
+# from .groupnorm import group_norm
+# from .gt import gt, gt_scalar
+# from .hstack import hstack
+# from .index_add import index_add
+# from .index_select import index_select
+# from .instancenorm import instance_norm
+# from .isclose import allclose, isclose
+# from .isfinite import isfinite
+# from .isin import isin
+# from .isinf import isinf
+# from .isnan import isnan
+# from .layernorm import layer_norm
+# from .le import le, le_scalar
+# from .log_sigmoid import log_sigmoid
+# from .log_softmax import log_softmax
+# from .logical_and import logical_and
+# from .logical_not import logical_not
+# from .logical_or import logical_or
+# from .logical_xor import logical_xor
+# from .lt import lt, lt_scalar
+# from .masked_fill import masked_fill, masked_fill_
+# from .masked_select import masked_select
+# from .max import max, max_dim
+# from .maximum import maximum
+# from .mean import mean, mean_dim
+# from .min import min, min_dim
+# from .minimum import minimum
+# from .mm import mm
+# from .mse_loss import mse_loss
+# from .mul import mul
+# from .multinomial import multinomial
+# from .mv import mv
+# from .ne import ne, ne_scalar
+# from .neg import neg
+# from .nllloss import (
+#     nll_loss2d_backward,
+#     nll_loss2d_forward,
+#     nll_loss_backward,
+#     nll_loss_forward,
+# )
+# from .nonzero import nonzero
+# from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
+# from .ones import ones
+# from .ones_like import ones_like
+# from .outer import outer
+# from .pad import constant_pad_nd, pad
+# from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
+# from .prod import prod, prod_dim
+# from .quantile import quantile
+# from .rand import rand
+# from .rand_like import rand_like
+# from .randn import randn
+# from .randn_like import randn_like
+# from .randperm import randperm
+# from .reciprocal import reciprocal
+# from .relu import relu
+# from .repeat import repeat
+# from .repeat_interleave import (
+#     repeat_interleave_self_int,
+#     repeat_interleave_self_tensor,
+#     repeat_interleave_tensor,
+# )
+# from .resolve_conj import resolve_conj
+# from .resolve_neg import resolve_neg
+# from .rms_norm import rms_norm
+# from .rsqrt import rsqrt
+# from .scatter import scatter
+# from .select_scatter import select_scatter
+# from .sigmoid import sigmoid
+# from .silu import silu
+# from .sin import sin
+# from .slice_scatter import slice_scatter
+# from .softmax import softmax
+# from .sort import sort
+# from .stack import stack
+# from .sub import sub
+# from .sum import sum, sum_dim
+# from .tanh import tanh
+# from .tile import tile
+# from .topk import topk
+# from .triu import triu
+# from .uniform import uniform_
+# from .unique import _unique2
+# from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa
+# from .upsample_nearest2d import upsample_nearest2d
+# from .var_mean import var_mean
+# from .vdot import vdot
+# from .vector_norm import vector_norm
+# from .vstack import vstack
+# from .weightnorm import weight_norm, weight_norm_interface
+# from .where import where_scalar_other, where_scalar_self, where_self, where_self_out
+# from .zeros import zeros
+# from .zeros_like import zeros_like
 
 __all__ = [
-    "log_sigmoid",
-    "all",
-    "all_dim",
-    "all_dims",
-    "allclose",
-    "any",
-    "any_dim",
-    "any_dims",
-    "add",
-    "abs",
-    "addmm",
-    "arange",
-    "arange_start",
-    "batch_norm",
-    "bitwise_and_tensor",
-    "bitwise_and_scalar",
-    "bitwise_and_scalar_tensor",
-    "bitwise_not",
-    "bitwise_or_tensor",
-    "bitwise_or_scalar",
-    "bitwise_or_scalar_tensor",
-    "bmm",
-    "clamp",
-    "clamp_tensor",
-    "cos",
-    "count_nonzero",
-    "diag",
-    "diag_embed",
-    "diagonal_backward",
-    "pad",
-    "constant_pad_nd",
-    "cummin",
-    "cumsum",
-    "normed_cumsum",
-    "true_divide",
-    "div_mode",
-    "floor_divide",
-    "remainder",
-    "zeros",
-    "ones",
-    "full",
-    "native_dropout",
-    "erf",
-    "embedding",
-    "eq",
-    "eq_scalar",
-    "exp",
-    "fill_scalar",
-    "fill_tensor",
-    "exponential_",
-    "gather",
-    "gather_backward",
-    "flip",
-    "ones_like",
-    "full_like",
-    "zeros_like",
-    "ge",
-    "ge_scalar",
-    "gelu",
-    "group_norm",
-    "gt",
-    "gt_scalar",
-    "index_select",
-    "instance_norm",
-    "isclose",
-    "isfinite",
-    "isin",
-    "isinf",
-    "isnan",
-    "layer_norm",
-    "weight_norm_interface",
-    "weight_norm",
-    "le",
-    "le_scalar",
-    "lt",
-    "lt_scalar",
-    "rms_norm",
-    "mean",
-    "mean_dim",
-    "mm",
-    "mul",
-    "multinomial",
-    "maximum",
-    "minimum",
-    "rand",
-    "randn",
-    "randperm",
-    "rand_like",
-    "randn_like",
-    "resolve_neg",
-    "resolve_conj",
-    "normal_tensor_float",
-    "normal_float_tensor",
-    "normal_tensor_tensor",
-    "uniform_",
-    "mv",
-    "ne",
-    "ne_scalar",
-    "neg",
-    "pow_scalar",
-    "pow_tensor_scalar",
-    "pow_tensor_tensor",
-    "reciprocal",
-    "relu",
-    "rsqrt",
-    "scatter",
-    "sigmoid",
-    "silu",
-    "sin",
-    "softmax",
-    "sub",
-    "tanh",
-    "tile",
-    "triu",
-    "topk",
-    "max",
-    "max_dim",
-    "min",
-    "min_dim",
-    "sum",
-    "sum_dim",
-    "amax",
-    "argmax",
-    "argmin",
-    "prod",
-    "prod_dim",
-    "quantile",
-    "var_mean",
-    "vector_norm",
-    "log_softmax",
-    "outer",
-    "cross_entropy_loss",
-    "where_self_out",
-    "where_self",
-    "where_scalar_self",
-    "where_scalar_other",
-    "index_add",
-    "select_scatter",
-    "slice_scatter",
-    "masked_fill",
-    "masked_fill_",
-    "_unique2",
-    "_upsample_bicubic2d_aa",
-    "upsample_nearest2d",
-    "nonzero",
-    "repeat",
-    "masked_select",
-    "stack",
-    "hstack",
-    "cat",
-    "repeat_interleave_self_int",
-    "vstack",
-    "repeat_interleave_tensor",
+#     "log_sigmoid",
+#     "all",
+#     "all_dim",
+#     "all_dims",
+#     "allclose",
+#     "any",
+#     "any_dim",
+#     "any_dims",
+#     "add",
+#     "abs",
+#     "addmm",
+#     "arange",
+#     "arange_start",
+#     "batch_norm",
+#     "bitwise_and_tensor",
+#     "bitwise_and_scalar",
+#     "bitwise_and_scalar_tensor",
+#     "bitwise_not",
+#     "bitwise_or_tensor",
+#     "bitwise_or_scalar",
+#     "bitwise_or_scalar_tensor",
+#     "bmm",
+#     "clamp",
+#     "clamp_tensor",
+#     "cos",
+#     "count_nonzero",
+#     "diag",
+#     "diag_embed",
+#     "diagonal_backward",
+#     "pad",
+#     "constant_pad_nd",
+#     "cummin",
+#     "cumsum",
+#     "normed_cumsum",
+#     "true_divide",
+#     "div_mode",
+#     "floor_divide",
+#     "remainder",
+#     "zeros",
+#     "ones",
+#     "full",
+#     "native_dropout",
+#     "erf",
+#     "embedding",
+#     "eq",
+#     "eq_scalar",
+#     "exp",
+#     "fill_scalar",
+#     "fill_tensor",
+#     "exponential_",
+#     "gather",
+#     "gather_backward",
+#     "flip",
+#     "ones_like",
+#     "full_like",
+#     "zeros_like",
+#     "ge",
+#     "ge_scalar",
+#     "gelu",
+#     "group_norm",
+#     "gt",
+#     "gt_scalar",
+#     "index_select",
+#     "instance_norm",
+#     "isclose",
+#     "isfinite",
+#     "isin",
+#     "isinf",
+#     "isnan",
+#     "layer_norm",
+#     "weight_norm_interface",
+#     "weight_norm",
+#     "le",
+#     "le_scalar",
+#     "lt",
+#     "lt_scalar",
+#     "rms_norm",
+#     "mean",
+#     "mean_dim",
+#     "mm",
+#     "mul",
+#     "multinomial",
+#     "maximum",
+#     "minimum",
+#     "rand",
+#     "randn",
+#     "randperm",
+#     "rand_like",
+#     "randn_like",
+#     "resolve_neg",
+#     "resolve_conj",
+#     "normal_tensor_float",
+#     "normal_float_tensor",
+#     "normal_tensor_tensor",
+#     "uniform_",
+#     "mv",
+#     "ne",
+#     "ne_scalar",
+#     "neg",
+#     "pow_scalar",
+#     "pow_tensor_scalar",
+#     "pow_tensor_tensor",
+#     "reciprocal",
+#     "relu",
+#     "rsqrt",
+#     "scatter",
+#     "sigmoid",
+#     "silu",
+#     "sin",
+#     "softmax",
+#     "sub",
+#     "tanh",
+#     "tile",
+#     "triu",
+#     "topk",
+#     "max",
+#     "max_dim",
+#     "min",
+#     "min_dim",
+#     "sum",
+#     "sum_dim",
+#     "amax",
+#     "argmax",
+#     "argmin",
+#     "prod",
+#     "prod_dim",
+#     "quantile",
+#     "var_mean",
+#     "vector_norm",
+#     "log_softmax",
+#     "outer",
+#     "cross_entropy_loss",
+#     "where_self_out",
+#     "where_self",
+#     "where_scalar_self",
+#     "where_scalar_other",
+#     "index_add",
+#     "select_scatter",
+#     "slice_scatter",
+#     "masked_fill",
+#     "masked_fill_",
+#     "_unique2",
+#     "_upsample_bicubic2d_aa",
+#     "upsample_nearest2d",
+#     "nonzero",
+#     "repeat",
+#     "masked_select",
+#     "stack",
+#     "hstack",
+#     "cat",
+#     "repeat_interleave_self_int",
+#     "vstack",
+#     "repeat_interleave_tensor",
     "flash_attention_forward",
-    "scaled_dot_product_attention",
-    "conv2d",
-    "conv1d",
-    "_conv_depthwise2d",
-    "repeat_interleave_self_tensor",
-    "logical_or",
-    "logical_and",
-    "logical_xor",
-    "logical_not",
-    "sort",
-    "nll_loss_forward",
-    "nll_loss_backward",
-    "nll_loss2d_forward",
-    "nll_loss2d_backward",
-    "vdot",
-    "mse_loss",
+#     "scaled_dot_product_attention",
+#     "conv2d",
+#     "conv1d",
+#     "_conv_depthwise2d",
+#     "repeat_interleave_self_tensor",
+#     "logical_or",
+#     "logical_and",
+#     "logical_xor",
+#     "logical_not",
+#     "sort",
+#     "nll_loss_forward",
+#     "nll_loss_backward",
+#     "nll_loss2d_forward",
+#     "nll_loss2d_backward",
+#     "vdot",
+#     "mse_loss",
 ]
diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 6a917b104..756ca2571 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -573,19 +573,27 @@ def softmax_rescale(
     row_sum,
     softmax_scale_log2e: tl.constexpr,
     is_border: tl.constexpr,
-    is_init: tl.constexpr
+    # is_init: tl.constexpr
 ):
     prev_max = row_max
     row_max = tl.maximum(row_max, tl.max(S, 1))
 
-    if not is_init:
-        if is_border:
-            cur_max = tl.where(row_max == float('-inf'), 0, row_max)
-        else:
-            cur_max = row_max
-        p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
-        row_sum *= p_scale
-        O_acc *= p_scale[:, None]
+    # if not is_init:
+    #     if is_border:
+    #         cur_max = tl.where(row_max == float('-inf'), 0, row_max)
+    #     else:
+    #         cur_max = row_max
+    #     p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
+    #     row_sum *= p_scale
+    #     O_acc *= p_scale[:, None]
+
+    if is_border:
+        cur_max = tl.where(row_max == float('-inf'), 0, row_max)
+    else:
+        cur_max = row_max
+    p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
+    row_sum *= p_scale
+    O_acc *= p_scale[:, None]
 
     max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e)
     P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
@@ -594,21 +602,23 @@ def softmax_rescale(
 
 
 def block_m_heuristic(headdim, is_dropout):
-    return 64 if headdim <= 128 else 32
+    # return 128 if headdim <= 128 else 64
+    return 64
 
 def block_n_heuristic(headdim, is_dropout):
-    return 64 if headdim <= 128 else 32
+    # return 128 if headdim <= 128 else 64
+    return 64
 
 def block_m_splitkv_heuristic(headdim):
-    return 64 if headdim <= 128 else 32
+    return 128 if headdim <= 128 else 64
 
 def block_n_splitkv_heuristic(headdim):
     if headdim <= 64:
-        return 128
+        return 256
     elif headdim <= 128:
-        return 64
+        return 128
     else:
-        return 32
+        return 64
 
 # @triton.autotune(
 #     configs=runtime.get_tuned_config("attention"),
@@ -629,7 +639,7 @@ def block_n_splitkv_heuristic(headdim):
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
-@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "k_b_stride", "o_b_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
 def flash_fwd_kernel(
     Q_ptr,
     K_ptr,
@@ -684,15 +694,25 @@ def flash_fwd_kernel(
     hid = tl.program_id(2)
 
     if is_local:
-        n_block_min: tl.constexpr = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
+        col_min = m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left
+        col_min = max(col_min, 0)
     else:
-        n_block_min: tl.constexpr = 0 
-
-    n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
+        col_min = 0
     
+    col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N
     if is_causal or is_local:
-        n_block_max = min(n_block_max,
-                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
+        col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right)
+
+    # if is_local:
+    #     n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
+    # else:
+    #     n_block_min = 0 
+
+    # n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
+    
+    # if is_causal or is_local:
+    #     n_block_max = min(n_block_max,
+    #                       tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
 
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
@@ -703,24 +723,36 @@ def flash_fwd_kernel(
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
-            n_masking_steps = 0
+            n_masking_blocks: tl.constexpr = 0
         else:
-            n_masking_steps = 1
+            n_masking_blocks: tl.constexpr = 1
     elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
+        n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N)
     else:
         # local and not causal, 
-        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+        n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+    
+    masking_cols = n_masking_blocks * BLOCK_N
 
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
-    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    row_start = m_block * BLOCK_M
+    row_idx = row_start + tl.arange(0, BLOCK_M)
     Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
     qmask = row_idx[:, None] < seqlen_q
     if IS_EVEN_MN:
-        Q = tl.load(Q_ptr + Q_off)
+        Q = tl.load(Q_ptr + Q_off, cache_modifier='.cg')
     else:
-        Q = tl.load(Q_ptr + Q_off, mask=qmask)
+        Q = tl.load(Q_ptr + Q_off, mask=qmask, cache_modifier='.cg')
+    
+    if return_P:
+        P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
+        P_ptr += (n_block_max - 1) * BLOCK_N
+        P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
+
+    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
+    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
 
     # Start from the right most block
     n_block = n_block_max - 1
@@ -730,29 +762,28 @@ def flash_fwd_kernel(
     K_ptr += (hid // h_hk_ratio) * k_h_stride
     V_ptr += bid * k_b_stride
     V_ptr += (hid // h_hk_ratio) * k_h_stride
-    
-    P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
-    P_ptr += n_block * BLOCK_N
-    P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
 
-    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
-    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
-    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
-    
-    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1):
-        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+    V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+
+    p_bk0 = K_ptr + K_offset
+    p_bv0 = V_ptr + V_offset
+
+    for col_start in tl.range(max(col_min, col_max - masking_cols), col_max, step=BLOCK_N):
+    # for r_blk_idx in tl.range(0, min(n_masking_blocks, n_blocks_max - n_blocks_min)):
+        off = col_start * k_s_stride
         if IS_EVEN_MN:
-            K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+            K = tl.load(p_bk0 + off, cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, cache_modifier=".cg")
         else:
-            kvmask = col_idx < seqlen_k
-            K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg")
+            kvmask = col < seqlen_k
+            K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
         S = tl.dot(Q, K, allow_tf32=False)
+        col_idx = col_start + tl.arange(0, BLOCK_N)
+        row_idx = row_start + tl.arange(0, BLOCK_M)
         S = apply_mask(
             S,
             col_idx,
@@ -767,9 +798,8 @@ def flash_fwd_kernel(
             is_local=is_local,
             has_alibi=has_alibi
         )
-        # col_idx -= BLOCK_N
 
-        is_init = (n_block == n_block_max - 1).to(tl.int1)
+        # is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -777,32 +807,33 @@ def flash_fwd_kernel(
             rowsum_,
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
-            is_init=is_init
+            # is_init=is_init
         )
         P = P.to(O_ptr.type.element_ty)
 
-        row_start = m_block * (BLOCK_M // 16)
-        col_start = n_block * (BLOCK_N // 32)
-        if return_P:
-            P_drop = P
-            P_drop = apply_dropout(
-                P_drop,
-                row_start,
-                col_start,
-                bid,
-                hid,
-                philox_seed,
-                philox_offset,
-                pdrop_u8,
-                encode_dropout_in_sign_bit=True,
-                NUM_HEADS=NUM_HEADS,
-                BLOCK_M=BLOCK_M,
-                BLOCK_N=BLOCK_N,
-            )
-            tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :])
-            P_offset += BLOCK_N
-
         if is_dropout:
+            row_start = m_block * (BLOCK_M // 16)
+            col_start = n_block * (BLOCK_N // 32)
+
+            if return_P:
+                P_drop = P
+                P_drop = apply_dropout(
+                    P_drop,
+                    row_start,
+                    col_start,
+                    bid,
+                    hid,
+                    philox_seed,
+                    philox_offset,
+                    pdrop_u8,
+                    encode_dropout_in_sign_bit=True,
+                    NUM_HEADS=NUM_HEADS,
+                    BLOCK_M=BLOCK_M,
+                    BLOCK_N=BLOCK_N,
+                )
+                tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :])
+                P_offset += BLOCK_N
+
             P = apply_dropout(
                 P,
                 row_start,
@@ -819,29 +850,26 @@ def flash_fwd_kernel(
             )
 
         if not PRE_LOAD_V:
+            off = col_start * k_s_stride
             if IS_EVEN_MN:
-                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, cache_modifier=".cg")
             else:
-                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
-        # if n_masking_steps > 1 and n_block <= n_block_min:
-        #     break
-
-
-    for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages):
-        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+    for col_start in tl.range(min_col, max_col - masking_cols, step=BLOCK_N, num_stages=num_stages):
+    # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
+        off = col_start * k_s_stride
+        K = tl.load(p_bk0 + off, cache_modifier=".cg")
         # if PRE_LOAD_V:
-        #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
         #     V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
         S = tl.dot(Q, K)
 
         if PRE_LOAD_V:
-            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            V = tl.load(p_bv0 + off, cache_modifier=".cg")
 
+        col_idx = col_start + tl.arange(0, BLOCK_N)
+        row_idx = row_start + tl.arange(0, BLOCK_M)
         S = apply_mask(
             S,
             col_idx,
@@ -856,9 +884,8 @@ def flash_fwd_kernel(
             is_local=is_local,
             has_alibi=has_alibi
         )
-        # col_idx -= BLOCK_N
 
-        is_init = (n_block == n_block_max - 1).to(tl.int1)
+        # is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -866,33 +893,32 @@ def flash_fwd_kernel(
             rowsum_,
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=is_local,
-            is_init=is_init
         )
-
         P = P.to(O_ptr.type.element_ty)
 
-        row_start = m_block * (BLOCK_M // 16)
-        col_start = n_block * (BLOCK_N // 32)
-        if return_P:
-            P_drop = P
-            P_drop = apply_dropout(
-                P_drop,
-                row_start,
-                col_start,
-                bid,
-                hid,
-                philox_seed,
-                philox_offset,
-                pdrop_u8,
-                encode_dropout_in_sign_bit=True,
-                NUM_HEADS=NUM_HEADS,
-                BLOCK_M=BLOCK_M,
-                BLOCK_N=BLOCK_N,
-            )
-            tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
-            P_offset += BLOCK_N
-
         if is_dropout:
+            row_start = m_block * (BLOCK_M // 16)
+            col_start = n_block * (BLOCK_N // 32)
+        
+            if return_P:
+                P_drop = P
+                P_drop = apply_dropout(
+                    P_drop,
+                    row_start,
+                    col_start,
+                    bid,
+                    hid,
+                    philox_seed,
+                    philox_offset,
+                    pdrop_u8,
+                    encode_dropout_in_sign_bit=True,
+                    NUM_HEADS=NUM_HEADS,
+                    BLOCK_M=BLOCK_M,
+                    BLOCK_N=BLOCK_N,
+                )
+                tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
+                P_offset += BLOCK_N
+
             P = apply_dropout(
                 P,
                 row_start,
@@ -909,8 +935,8 @@ def flash_fwd_kernel(
             )
 
         if not PRE_LOAD_V:
-            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            off = col_start * k_s_stride
+            V = tl.load(p_bv0 + off, cache_modifier=".cg")
 
         O_ = tl.dot(P, V, O_)
 
@@ -1031,14 +1057,14 @@ def flash_fwd_splitkv_kernel(
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
-            n_masking_steps = 0
+            n_masking_blocks = 0
         else:
-            n_masking_steps = 1
+            n_masking_blocks = 1
     elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
+        n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N)
     else:
         # local and not causal, 
-        n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+        n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) + 1
 
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
@@ -1067,7 +1093,7 @@ def flash_fwd_splitkv_kernel(
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
     
-    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_steps - 1, step=-1):
+    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_blocks - 1, step=-1):
         col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
@@ -1115,10 +1141,10 @@ def flash_fwd_splitkv_kernel(
             else:
                 V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
-        # if n_masking_steps > 1 and n_block <= n_block_min:
+        # if n_masking_blocks > 1 and n_block <= n_block_min:
         #     break
 
-    for n_block in tl.range(n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1, num_stages=num_stages):
+    for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
         col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
@@ -1425,7 +1451,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             tmp_lse = lse
             tmp_out = out
 
-        kernel(
+        kernel = kernel(
             q,
             k,
             v,
@@ -1468,6 +1494,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
         )
+        print(f'{kernel.name} shared memory:', kernel.metadata.shared)
+        # print(kernel.asm['ttgir'])
     
     if n_splits > 1:
         if head_size % 128 == 0:
@@ -1477,7 +1505,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         else:
             BLOCK_M = 16
         grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), )
-        flash_fwd_splitkv_combine_kernel[grid](
+        kernel = flash_fwd_splitkv_combine_kernel[grid](
             out,
             lse,
             tmp_out,
@@ -1491,6 +1519,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             q_total=batch_size * num_heads * seqlen_q,
             MAX_N_SPLITS=triton.next_power_of_2(n_splits),
         )
+        print(f'{kernel.name} shared memory:', kernel.metadata.shared)
 
     if swap_seq_and_group:
         out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))

From af24ad2baaee60f483d090bc0a3c2e79377fa81f Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Tue, 18 Mar 2025 04:08:19 +0000
Subject: [PATCH 12/25] bug fix.

---
 src/flag_gems/ops/attention.py | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 756ca2571..c8bfdc09a 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -744,7 +744,7 @@ def flash_fwd_kernel(
         Q = tl.load(Q_ptr + Q_off, cache_modifier='.cg')
     else:
         Q = tl.load(Q_ptr + Q_off, mask=qmask, cache_modifier='.cg')
-    
+
     if return_P:
         P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
         P_ptr += (n_block_max - 1) * BLOCK_N
@@ -754,9 +754,6 @@ def flash_fwd_kernel(
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
 
-    # Start from the right most block
-    n_block = n_block_max - 1
-
     h_hk_ratio = h // hk
     K_ptr += bid * k_b_stride
     K_ptr += (hid // h_hk_ratio) * k_h_stride
@@ -857,7 +854,7 @@ def flash_fwd_kernel(
                 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
-    for col_start in tl.range(min_col, max_col - masking_cols, step=BLOCK_N, num_stages=num_stages):
+    for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages):
     # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
         off = col_start * k_s_stride
         K = tl.load(p_bk0 + off, cache_modifier=".cg")

From 1025a3d768e5806504faa5a10de606e39d6a3adc Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Wed, 19 Mar 2025 12:34:35 +0000
Subject: [PATCH 13/25] Pipeline works, causal passes.

---
 src/flag_gems/ops/attention.py                | 67 +++++++++----------
 .../runtime/backend/_nvidia/tune_configs.yaml |  3 +-
 2 files changed, 33 insertions(+), 37 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index c8bfdc09a..07c48b87e 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -602,12 +602,10 @@ def softmax_rescale(
 
 
 def block_m_heuristic(headdim, is_dropout):
-    # return 128 if headdim <= 128 else 64
-    return 64
+    return 128 if headdim <= 128 else 64
 
 def block_n_heuristic(headdim, is_dropout):
-    # return 128 if headdim <= 128 else 64
-    return 64
+    return 64 if headdim <= 64 else 32 
 
 def block_m_splitkv_heuristic(headdim):
     return 128 if headdim <= 128 else 64
@@ -634,12 +632,12 @@ def block_n_splitkv_heuristic(headdim):
         'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]),
         'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]),
         'num_warps': lambda args: 4,
-        'num_stages': lambda args: 3,
+        'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
-@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "k_b_stride", "o_b_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
 def flash_fwd_kernel(
     Q_ptr,
     K_ptr,
@@ -660,8 +658,8 @@ def flash_fwd_kernel(
     o_b_stride,
     o_s_stride,
     o_h_stride,
-    h,
-    hk,
+    h: tl.constexpr,
+    hk: tl.constexpr,
     pSlopes,
     philox_seed,
     philox_offset,
@@ -703,37 +701,16 @@ def flash_fwd_kernel(
     if is_causal or is_local:
         col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right)
 
-    # if is_local:
-    #     n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
-    # else:
-    #     n_block_min = 0 
-
-    # n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
-    
-    # if is_causal or is_local:
-    #     n_block_max = min(n_block_max,
-    #                       tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
-
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
     else:
         alibi_slope = 0.0
-
-    if (not is_causal) and (not is_local):
-        if IS_EVEN_MN:
-            n_masking_blocks: tl.constexpr = 0
-        else:
-            n_masking_blocks: tl.constexpr = 1
-    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N)
-    else:
-        # local and not causal, 
-        n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1
     
-    masking_cols = n_masking_blocks * BLOCK_N
+    # masking_cols: tl.constexpr = n_masking_blocks * BLOCK_N
 
+    q_b_stride = tl.multiple_of(q_b_stride, HEAD_DIM * h)
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
     row_start = m_block * BLOCK_M
@@ -754,6 +731,7 @@ def flash_fwd_kernel(
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
 
+    k_b_stride = tl.multiple_of(k_b_stride, HEAD_DIM * hk)
     h_hk_ratio = h // hk
     K_ptr += bid * k_b_stride
     K_ptr += (hid // h_hk_ratio) * k_h_stride
@@ -766,8 +744,24 @@ def flash_fwd_kernel(
     p_bk0 = K_ptr + K_offset
     p_bv0 = V_ptr + V_offset
 
-    for col_start in tl.range(max(col_min, col_max - masking_cols), col_max, step=BLOCK_N):
-    # for r_blk_idx in tl.range(0, min(n_masking_blocks, n_blocks_max - n_blocks_min)):
+    if (not is_causal) and (not is_local):
+        if IS_EVEN_MN:
+            # n_masking_blocks: tl.constexpr = 0
+            masking_cols: tl.constexpr = 0
+        else:
+            # n_masking_blocks: tl.constexpr = 1
+            masking_cols: tl.constexpr = BLOCK_N
+    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
+        # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N)
+        masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
+    else:
+        # local and not causal, 
+        # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+        masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
+
+    for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
+        col_start = col_max - col_shift - BLOCK_N
+        col_start = tl.multiple_of(col_start, BLOCK_N)
         off = col_start * k_s_stride
         if IS_EVEN_MN:
             K = tl.load(p_bk0 + off, cache_modifier=".cg")
@@ -796,7 +790,6 @@ def flash_fwd_kernel(
             has_alibi=has_alibi
         )
 
-        # is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -804,7 +797,6 @@ def flash_fwd_kernel(
             rowsum_,
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
-            # is_init=is_init
         )
         P = P.to(O_ptr.type.element_ty)
 
@@ -856,6 +848,7 @@ def flash_fwd_kernel(
 
     for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages):
     # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
+        # col_start = tl.multiple_of(col_start, BLOCK_N)
         off = col_start * k_s_stride
         K = tl.load(p_bk0 + off, cache_modifier=".cg")
         # if PRE_LOAD_V:
@@ -882,7 +875,6 @@ def flash_fwd_kernel(
             has_alibi=has_alibi
         )
 
-        # is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -951,6 +943,7 @@ def flash_fwd_kernel(
     O = O_.to(O_ptr.type.element_ty)
 
     # Write back output
+    o_b_stride = tl.multiple_of(o_b_stride, HEAD_DIM * h)
     O_ptr += bid * o_b_stride
     O_ptr += hid * o_h_stride
     O_offset = row_idx[:, None] * o_s_stride + tl.arange(0, HEAD_DIM)
@@ -1492,6 +1485,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             NUM_HEADS_K=num_heads_k,
         )
         print(f'{kernel.name} shared memory:', kernel.metadata.shared)
+        print(f'{kernel.name} num_warps:', kernel.metadata.num_warps)
+        print(f'{kernel.name} num_stages:', kernel.metadata.num_stages)
         # print(kernel.asm['ttgir'])
     
     if n_splits > 1:
diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
index b4865b3c1..3753f67d2 100644
--- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
+++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
@@ -8,12 +8,13 @@ attention:
     num_warps: warps
     num_stages: stages
   block_m:
+  - 32
   - 64
   - 128
   block_n:
   - 32
   - 64
-  - 128
+  # - 128
   pre_load_v:
   - true
   - false

From 12f105dbd5e91e1c707da6b21527981276b879a2 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 20 Mar 2025 04:34:50 +0000
Subject: [PATCH 14/25] a couple of fixes, nonequal q k seqlens still broken.

---
 src/flag_gems/ops/attention.py | 18 ++----------------
 1 file changed, 2 insertions(+), 16 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 07c48b87e..b3363ddcf 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -707,8 +707,6 @@ def flash_fwd_kernel(
         alibi_slope /= scale
     else:
         alibi_slope = 0.0
-    
-    # masking_cols: tl.constexpr = n_masking_blocks * BLOCK_N
 
     q_b_stride = tl.multiple_of(q_b_stride, HEAD_DIM * h)
     Q_ptr += bid * q_b_stride
@@ -746,17 +744,13 @@ def flash_fwd_kernel(
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
-            # n_masking_blocks: tl.constexpr = 0
             masking_cols: tl.constexpr = 0
         else:
-            # n_masking_blocks: tl.constexpr = 1
             masking_cols: tl.constexpr = BLOCK_N
     elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N)
         masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
     else:
         # local and not causal, 
-        # n_masking_blocks: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) + 1
         masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
 
     for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
@@ -847,16 +841,12 @@ def flash_fwd_kernel(
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
     for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages):
-    # for r_blk_idx in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
         # col_start = tl.multiple_of(col_start, BLOCK_N)
         off = col_start * k_s_stride
         K = tl.load(p_bk0 + off, cache_modifier=".cg")
-        # if PRE_LOAD_V:
-        #     V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
-        S = tl.dot(Q, K)
-
         if PRE_LOAD_V:
-            V = tl.load(p_bv0 + off, cache_modifier=".cg")
+            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        S = tl.dot(Q, K)
 
         col_idx = col_start + tl.arange(0, BLOCK_N)
         row_idx = row_start + tl.arange(0, BLOCK_M)
@@ -1113,7 +1103,6 @@ def flash_fwd_splitkv_kernel(
         )
         # col_idx -= BLOCK_N
 
-        is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -1121,7 +1110,6 @@ def flash_fwd_splitkv_kernel(
             rowsum_,
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
-            is_init=is_init
         )
         P = P.to(Q_ptr.type.element_ty)
 
@@ -1163,7 +1151,6 @@ def flash_fwd_splitkv_kernel(
         )
         # col_idx -= BLOCK_N
 
-        is_init = (n_block == n_block_max - 1).to(tl.int1)
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
             S,
@@ -1171,7 +1158,6 @@ def flash_fwd_splitkv_kernel(
             rowsum_,
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=is_local,
-            is_init=is_init
         )
 
         P = P.to(Q_ptr.type.element_ty)

From a33ff44275784588d446b3a81a7c715188c0ed2a Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 20 Mar 2025 15:57:47 +0000
Subject: [PATCH 15/25] Causal results are stable now. Consistent with
 aten._flash_attention_forward.

---
 src/flag_gems/ops/attention.py | 82 ++++++++++++++++------------------
 1 file changed, 39 insertions(+), 43 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index b3363ddcf..54248c4a1 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -546,18 +546,16 @@ def apply_mask(
     need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn)
     if need_mask:
         col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
-        col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
+        col_rb = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + ws_right)
 
-        if not has_alibi:
-            alibi_slope = .0
-
-        S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
+        if has_alibi:
+            S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
 
         if is_causal:
-            S = tl.where(col_idx[None, :] >= col_rb[:, None], float('-inf'), S)
+            S = tl.where(col_idx[None, :] > col_rb[:, None], float('-inf'), S)
 
         if is_local:
-            S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
+            S = tl.where(col_idx[None, :] > col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
         
         if (not is_local) & (not is_causal) & (not is_even_mn):
             S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
@@ -578,34 +576,36 @@ def softmax_rescale(
     prev_max = row_max
     row_max = tl.maximum(row_max, tl.max(S, 1))
 
-    # if not is_init:
-    #     if is_border:
-    #         cur_max = tl.where(row_max == float('-inf'), 0, row_max)
-    #     else:
-    #         cur_max = row_max
-    #     p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
-    #     row_sum *= p_scale
-    #     O_acc *= p_scale[:, None]
-
     if is_border:
         cur_max = tl.where(row_max == float('-inf'), 0, row_max)
     else:
         cur_max = row_max
+
     p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
     row_sum *= p_scale
     O_acc *= p_scale[:, None]
 
     max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e)
+
     P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
     row_sum = row_sum + tl.sum(P, 1)
     return O_acc, P, row_max, row_sum
 
 
 def block_m_heuristic(headdim, is_dropout):
-    return 128 if headdim <= 128 else 64
+    block_m = 128 if headdim <= 128 else 64
+    print('block_m:', block_m)
+    return block_m
 
 def block_n_heuristic(headdim, is_dropout):
-    return 64 if headdim <= 64 else 32 
+    block_n = 64 if headdim <= 64 else 32
+    print('block_n:', block_n)
+    return block_n
+
+def is_even_mn(args):
+    even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
+    print('is_even_mn:', even_mn)
+    return even_mn
 
 def block_m_splitkv_heuristic(headdim):
     return 128 if headdim <= 128 else 64
@@ -633,8 +633,8 @@ def block_n_splitkv_heuristic(headdim):
         'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]),
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
-        'PRE_LOAD_V': lambda args: True,
-        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
+        'PRE_LOAD_V': lambda args: False,
+        'IS_EVEN_MN': lambda args: is_even_mn(args),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
@@ -762,7 +762,8 @@ def flash_fwd_kernel(
             if PRE_LOAD_V:
                 V = tl.load(p_bv0 + off, cache_modifier=".cg")
         else:
-            kvmask = col < seqlen_k
+            col_idx = col_start + tl.arange(0, BLOCK_N)
+            kvmask = col_idx < seqlen_k
             K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
             if PRE_LOAD_V:
                 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
@@ -845,7 +846,7 @@ def flash_fwd_kernel(
         off = col_start * k_s_stride
         K = tl.load(p_bk0 + off, cache_modifier=".cg")
         if PRE_LOAD_V:
-            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            V = tl.load(p_bv0 + off, cache_modifier=".cg")
         S = tl.dot(Q, K)
 
         col_idx = col_start + tl.arange(0, BLOCK_N)
@@ -919,12 +920,13 @@ def flash_fwd_kernel(
 
         O_ = tl.dot(P, V, O_)
 
+
     # LSE
     # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
-    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
-    
+
     # Rescale output
+    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     if is_dropout:
         O_ *= inv_sum[:, None] * rpdrop
     else:
@@ -1101,7 +1103,6 @@ def flash_fwd_splitkv_kernel(
             is_local=is_local,
             has_alibi=has_alibi
         )
-        # col_idx -= BLOCK_N
 
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
@@ -1119,21 +1120,16 @@ def flash_fwd_splitkv_kernel(
             else:
                 V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
-        # if n_masking_blocks > 1 and n_block <= n_block_min:
-        #     break
+
 
     for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
         col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
         K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
-        # if PRE_LOAD_V:
-        #     V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-        #     V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
-        S = tl.dot(Q, K)
-
         if PRE_LOAD_V:
             V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
             V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+        S = tl.dot(Q, K)
 
         S = apply_mask(
             S,
@@ -1149,7 +1145,6 @@ def flash_fwd_splitkv_kernel(
             is_local=is_local,
             has_alibi=has_alibi
         )
-        # col_idx -= BLOCK_N
 
         O_, P, rowmax_, rowsum_ = softmax_rescale(
             O_,
@@ -1246,8 +1241,7 @@ def flash_fwd_splitkv_combine_kernel(
     out_splits = tl.load(out_splits_ptr + out_split_offset, mask=out_split_mask, other=0)
     out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
     out = out.to(out_ptr.type.element_ty)
-    
-    # tl.device_print('O', out)
+
     # Write back output
     out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size)
     tl.store(out_ptr + out_offset, out, mask=out_mask[:, None])
@@ -1364,16 +1358,18 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         rpdrop = 1. / p_dropout
 
         # Check splitkv
-        if not is_dropout:
-            n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m_splitkv_heuristic(head_size))
+        def try_split_kv():
+            block_m = block_m_splitkv_heuristic(head_size)
+            n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m)
             num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count
-            n_blocks = triton.cdiv(seqlen_k, block_n_splitkv_heuristic(head_size))
+            block_n = block_n_splitkv_heuristic(head_size)
+            n_blocks = triton.cdiv(seqlen_k, block_n)
             n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
-            print('n_blocks:', n_blocks)
-            print('n_splits:', n_splits)
-        else:
-            n_splits = 1
-        
+            return n_splits
+            
+        n_splits = try_split_kv() if is_dropout else 1
+        print('n_splits:', n_splits)
+
         if n_splits > 1:
             lse_splits = torch.empty(
                 (n_splits, batch_size, num_heads, seqlen_q),

From 04e49e00b458ed16fd1179187101e5ffb9a5e188 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Fri, 21 Mar 2025 16:41:07 +0000
Subject: [PATCH 16/25] Dropout passes.

---
 src/flag_gems/ops/attention.py      | 93 +++++++++++++----------------
 src/flag_gems/utils/random_utils.py |  7 +++
 2 files changed, 49 insertions(+), 51 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 54248c4a1..b0f866ca1 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -422,24 +422,24 @@ def philox_(seed, subsequence, offset):
     kPhilox10B: tl.constexpr = 0xBB67AE85
     k0, k1 = u64_to_lohi(seed.to(tl.uint64))
     c0, c1 = u64_to_lohi(offset.to(tl.uint64))
-    c2, c3 = u64_to_lohi(subsequence(tl.uint64))
+    c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
 
     # pragma unroll
     kPhiloxSA: tl.constexpr = 0xD2511F53
     kPhiloxSB: tl.constexpr = 0xCD9E8D57
     for _ in range(6):
-        res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
-        res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
+        res0 = kPhiloxSA * c0.to(tl.uint64)
+        res1 = kPhiloxSB * c2.to(tl.uint64)
         res0_x, res0_y = u64_to_lohi(res0)
         res1_x, res1_y = u64_to_lohi(res1)
         c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
         k0 += kPhilox10A
         k1 += kPhilox10B
 
-    res0 = kPhiloxSA.to(tl.uint64) * c0.to(tl.uint64)
-    res1 = kPhiloxSB.to(tl.uint64) * c2.to(tl.uint64)
-    res0_x.res0_y = u64_to_lohi(res0)
-    res1_x.res1_y = u64_to_lohi(res1)
+    res0 = kPhiloxSA * c0.to(tl.uint64)
+    res1 = kPhiloxSB * c2.to(tl.uint64)
+    res0_x, res0_y = u64_to_lohi(res0)
+    res1_x, res1_y = u64_to_lohi(res1)
     c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
 
     return c0, c1, c2, c3
@@ -462,35 +462,30 @@ def apply_dropout_mask(
 def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr):
     r = r_u32
     p = p_u8
-    m0 = tl.where(r & 0xFF < p, 0, 1)
+    # m0 = tl.where(r & 0xFF < p, 0, 1)
+    m0 = ~(r & 0xFF < p)
     r >>= 8
-    m1 = tl.where(r & 0xFF < p, 0, 1)
-    m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
+    # m1 = tl.where(r & 0xFF < p, 0, 1)
+    m1 = ~(r & 0xFF < p)
+    m = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
 
     r >>= 8
-    m0 = tl.where(r & 0xFF < p, 0, 1)
+    # n0 = tl.where(r & 0xFF < p, 0, 1)
+    n0 = ~(r & 0xFF < p)
     r >>= 8
-    m1 = tl.where(r & 0xFF < p, 0, 1)
-    m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
-
-    m = tl.join(m0, m1).trans(2, 0, 1).reshape(4 * M, N)
-    return m
-
-
-@triton.jit(
-    do_not_specialize=[
-        "b",
-        "h",
-        "row_start",
-        "col_start",
-        "philox_seed",
-        "philox_offset",
-    ]
-)
+    # n1 = tl.where(r & 0xFF < p, 0, 1)
+    n1 = ~(r & 0xFF < p)
+    n = tl.join(n0, n1).trans(2, 0, 1).reshape(2 * M, N)
+
+    mn = tl.join(m, n).trans(2, 0, 1).reshape(4 * M, N)
+    return mn
+
+
+@triton.jit
 def apply_dropout(
     P,
-    sor,
-    soc,
+    row_start,
+    col_start,
     bid,
     hid,
     philox_seed,
@@ -501,30 +496,31 @@ def apply_dropout(
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
 ):
-    # P is of size (BLOCK_M, BLOCK_N) and its scalar bitsize is 32
-    # BLOCK_M is ensured to be a multiple of 16, BLOCK_N a multiple of 32
+    # We only need one philox call for every 16 rows because a single philox call
+    # generates 4 random uints, which are casted for 16 random draws in uint8's.
     M: tl.constexpr = BLOCK_M // 16
     N: tl.constexpr = BLOCK_N // 32
-    row = sor + tl.arange(0, M)[:, None]
-    col = soc + tl.arange(0, BLOCK_N)[None, :] // 32
+    row = row_start // 16 + tl.arange(0, M)[:, None]
+    col = col_start + tl.arange(0, BLOCK_N)[None, :]
+
+    subsequence = u64_from_lohi(row, col // 32)
 
     tid = tl.arange(0, BLOCK_N)[None, :] % 32
     philox_offset += (bid * NUM_HEADS + hid) * 32 + tid
-
-    subsequence = u64_from_lohi(row * 32, col)
+    philox_offset += subsequence * 0 
     r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset)
 
     # Fully unrolled due to triton's inability to concat 2d tensor
-    m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N)
-    m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N)
-    m0 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N)
+    m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N)
+    m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N)
+    m = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, BLOCK_N)
 
-    m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, N)
-    m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, N)
-    m1 = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, N)
+    n0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N)
+    n1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N)
+    n = tl.join(n0, n1).trans(2, 0, 1).reshape(8 * M, BLOCK_N)
 
-    m = tl.join(m0, m1).trans(2, 0, 1).reshape(16 * M, N)
-    P = apply_dropout_mask(P, m)
+    mn = tl.join(m, n).trans(2, 0, 1).reshape(16 * M, BLOCK_N)
+    P = apply_dropout_mask(P, mn, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit)
     return P
 
 
@@ -796,11 +792,9 @@ def flash_fwd_kernel(
         P = P.to(O_ptr.type.element_ty)
 
         if is_dropout:
-            row_start = m_block * (BLOCK_M // 16)
-            col_start = n_block * (BLOCK_N // 32)
-
             if return_P:
                 P_drop = P
+
                 P_drop = apply_dropout(
                     P_drop,
                     row_start,
@@ -877,9 +871,6 @@ def flash_fwd_kernel(
         P = P.to(O_ptr.type.element_ty)
 
         if is_dropout:
-            row_start = m_block * (BLOCK_M // 16)
-            col_start = n_block * (BLOCK_N // 32)
-        
             if return_P:
                 P_drop = P
                 P_drop = apply_dropout(
@@ -1346,7 +1337,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
 
         # Set dropout params
         if p_dropout > 0:
-            increment = triton.cdiv(batch_size * num_heads * 32)
+            increment = batch_size * num_heads * 32
             philox_seed, philox_offset = update_philox_state(increment)
             is_dropout = True
         else:
diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py
index 5ca61a8f4..ec0f771b5 100644
--- a/src/flag_gems/utils/random_utils.py
+++ b/src/flag_gems/utils/random_utils.py
@@ -46,6 +46,13 @@ def update_philox_state(increment, device=None):
     gen.set_state(state_copy)
     return seed, offset
 
+def set_philox_state(seed, offset, device=None):
+    device = device or torch_device_fn.current_device()
+    gen = torch_device_fn.default_generators[device]
+    assert offset % 4 == 0
+    new_state = torch.tensor((seed, offset), dtype=torch.int64)
+    gen = get.set_state(new_state.view(torch.uint8))
+    return
 
 def per_thread_offset(N, num_blocks, num_warps, warp_threads=32):
     block_threads = num_warps * warp_threads

From 2ee91ea76cf6bd3c6634b7cae76be1c0dbbe5e55 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Fri, 21 Mar 2025 16:45:36 +0000
Subject: [PATCH 17/25] dropout disables  splitkv.

---
 src/flag_gems/ops/attention.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index b0f866ca1..c686d6230 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -1358,7 +1358,7 @@ def try_split_kv():
             n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
             return n_splits
             
-        n_splits = try_split_kv() if is_dropout else 1
+        n_splits = try_split_kv() if not is_dropout else 1
         print('n_splits:', n_splits)
 
         if n_splits > 1:

From 1b7e6c0e2577da13d43363775c509f3a6c0ef977 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Sun, 23 Mar 2025 05:34:31 +0000
Subject: [PATCH 18/25] Working on splitkv..

---
 src/flag_gems/ops/attention.py | 217 +++++++++++++++++----------------
 1 file changed, 114 insertions(+), 103 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index c686d6230..8601e0bd7 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -427,7 +427,7 @@ def philox_(seed, subsequence, offset):
     # pragma unroll
     kPhiloxSA: tl.constexpr = 0xD2511F53
     kPhiloxSB: tl.constexpr = 0xCD9E8D57
-    for _ in range(6):
+    for _ in tl.static_range(6):
         res0 = kPhiloxSA * c0.to(tl.uint64)
         res1 = kPhiloxSB * c2.to(tl.uint64)
         res0_x, res0_y = u64_to_lohi(res0)
@@ -462,22 +462,17 @@ def apply_dropout_mask(
 def make_4x_dropout_mask(r_u32, p_u8, M: tl.constexpr, N: tl.constexpr):
     r = r_u32
     p = p_u8
-    # m0 = tl.where(r & 0xFF < p, 0, 1)
     m0 = ~(r & 0xFF < p)
     r >>= 8
-    # m1 = tl.where(r & 0xFF < p, 0, 1)
     m1 = ~(r & 0xFF < p)
-    m = tl.join(m0, m1).trans(2, 0, 1).reshape(2 * M, N)
+    m = tl.join(m0, m1)
 
     r >>= 8
-    # n0 = tl.where(r & 0xFF < p, 0, 1)
     n0 = ~(r & 0xFF < p)
     r >>= 8
-    # n1 = tl.where(r & 0xFF < p, 0, 1)
     n1 = ~(r & 0xFF < p)
-    n = tl.join(n0, n1).trans(2, 0, 1).reshape(2 * M, N)
-
-    mn = tl.join(m, n).trans(2, 0, 1).reshape(4 * M, N)
+    n = tl.join(n0, n1)
+    mn = tl.join(m, n)
     return mn
 
 
@@ -500,26 +495,28 @@ def apply_dropout(
     # generates 4 random uints, which are casted for 16 random draws in uint8's.
     M: tl.constexpr = BLOCK_M // 16
     N: tl.constexpr = BLOCK_N // 32
+    row_start = tl.multiple_of(row_start, BLOCK_M)
+    col_start = tl.multiple_of(col_start, BLOCK_N)
     row = row_start // 16 + tl.arange(0, M)[:, None]
     col = col_start + tl.arange(0, BLOCK_N)[None, :]
 
     subsequence = u64_from_lohi(row, col // 32)
 
     tid = tl.arange(0, BLOCK_N)[None, :] % 32
-    philox_offset += (bid * NUM_HEADS + hid) * 32 + tid
-    philox_offset += subsequence * 0 
-    r0, r1, r2, r3 = philox_(philox_seed, subsequence, philox_offset)
+    offset = philox_offset + (bid * NUM_HEADS + hid) * 32 + tid
+    offset += subsequence * 0
+    r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
 
     # Fully unrolled due to triton's inability to concat 2d tensor
     m0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N)
     m1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N)
-    m = tl.join(m0, m1).trans(2, 0, 1).reshape(8 * M, BLOCK_N)
+    m = tl.join(m0, m1)
 
     n0 = make_4x_dropout_mask(r0, p_dropout_uint8, M, BLOCK_N)
     n1 = make_4x_dropout_mask(r1, p_dropout_uint8, M, BLOCK_N)
-    n = tl.join(n0, n1).trans(2, 0, 1).reshape(8 * M, BLOCK_N)
+    n = tl.join(n0, n1)
 
-    mn = tl.join(m, n).trans(2, 0, 1).reshape(16 * M, BLOCK_N)
+    mn = tl.join(m, n).reshape(16 * M, BLOCK_N)
     P = apply_dropout_mask(P, mn, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit)
     return P
 
@@ -608,11 +605,9 @@ def block_m_splitkv_heuristic(headdim):
 
 def block_n_splitkv_heuristic(headdim):
     if headdim <= 64:
-        return 256
-    elif headdim <= 128:
-        return 128
-    else:
         return 64
+    else:
+        return 32
 
 # @triton.autotune(
 #     configs=runtime.get_tuned_config("attention"),
@@ -718,8 +713,8 @@ def flash_fwd_kernel(
 
     if return_P:
         P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
-        P_ptr += (n_block_max - 1) * BLOCK_N
         P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
+        p_bp0 = P_ptr + P_offset
 
     O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
@@ -809,8 +804,10 @@ def flash_fwd_kernel(
                     BLOCK_M=BLOCK_M,
                     BLOCK_N=BLOCK_N,
                 )
-                tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask[None, :])
-                P_offset += BLOCK_N
+                if IS_EVEN_MN:
+                    tl.store(p_bp0 + col_start, P_drop)
+                else:
+                    tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
 
             P = apply_dropout(
                 P,
@@ -887,8 +884,10 @@ def flash_fwd_kernel(
                     BLOCK_M=BLOCK_M,
                     BLOCK_N=BLOCK_N,
                 )
-                tl.store(P_ptr + P_offset, P_drop, mask=qmask & kvmask)
-                P_offset += BLOCK_N
+                if IS_EVEN_MN:
+                    tl.store(p_bp0 + col_start, P_drop)
+                else:
+                    tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
 
             P = apply_dropout(
                 P,
@@ -954,7 +953,7 @@ def flash_fwd_kernel(
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
-@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "q_b_stride", "q_s_stride", "q_h_stride", "k_b_stride", "k_s_stride", "k_h_stride", "o_b_stride", "o_h_stride", "o_s_stride", "philox_seed", "philox_offset", "pdrop_u8", "slopes_batch_stride"])
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
 def flash_fwd_splitkv_kernel(
     Q_ptr,
     K_ptr,
@@ -1000,7 +999,7 @@ def flash_fwd_splitkv_kernel(
     IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
-    num_splits: tl.constexpr,
+    blocks_per_split: tl.constexpr,
     num_warps: tl.constexpr,
     num_stages: tl.constexpr
 ):
@@ -1008,15 +1007,16 @@ def flash_fwd_splitkv_kernel(
     split_id = tl.program_id(1)
     bid = tl.program_id(2) // NUM_HEADS
     hid = tl.program_id(2) % NUM_HEADS
-    
-    blocks_per_split = tl.cdiv(tl.cdiv(seqlen_k, BLOCK_N), num_splits)
+
+    split_block_min = split_id * blocks_per_split
+    split_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N))
 
     if is_local:
-        n_block_min = max(split_id * blocks_per_split, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
+        n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
     else:
-        n_block_min = split_id * blocks_per_split
+        n_block_min = 0
 
-    n_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N))
+    n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
     if is_causal or is_local:
         n_block_max = min(n_block_max,
                           tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
@@ -1030,56 +1030,59 @@ def flash_fwd_splitkv_kernel(
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
-            n_masking_blocks = 0
+            masking_block_min = n_block_max
         else:
-            n_masking_blocks = 1
+            masking_block_min = n_block_max - 1
     elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N)
+        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
     else:
         # local and not causal, 
-        n_masking_blocks = tl.cdiv(BLOCK_M, BLOCK_N) + 1
+        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
 
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
     row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
     Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    p_qm = Q_ptr + Q_off
     qmask = row_idx[:, None] < seqlen_q
     if IS_EVEN_MN:
-        Q = tl.load(Q_ptr + Q_off)
+        Q = tl.load(p_qm)
     else:
-        Q = tl.load(Q_ptr + Q_off, mask=qmask)
-
-    # Start from the right most block
-    n_block = n_block_max - 1
+        Q = tl.load(p_qm, mask=qmask)
 
     h_hk_ratio = h // hk
     K_ptr += bid * k_b_stride
     K_ptr += (hid // h_hk_ratio) * k_h_stride
     V_ptr += bid * k_b_stride
     V_ptr += (hid // h_hk_ratio) * k_h_stride
+
+    K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+    p_k0 = K_ptr + K_offset
     
-    P_ptr += ((bid * NUM_HEADS + hid) * seqlen_q_rounded + m_block * BLOCK_M) * seqlen_k_rounded
-    P_ptr += n_block * BLOCK_N
-    P_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(0, BLOCK_N)
+    V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    p_v0 = V_ptr + V_offset
 
     O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
-    
-    for n_block in tl.range(n_block_max - 1, n_block_max - n_masking_blocks - 1, step=-1):
+
+    split_masking_block = max(masking_block_min, split_block_min)
+    for n_block in tl.range(split_block_max - 1, split_masking_block - 1, step=-1):
+        kv_off = n_block * BLOCK_N * k_s_stride
         col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+        row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
         if IS_EVEN_MN:
-            K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+            K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+                V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
         else:
             kvmask = col_idx < seqlen_k
-            K = tl.load(K_ptr + K_offset, mask=kvmask[None, :], cache_modifier=".cg")
+            K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+                V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
+
         S = tl.dot(Q, K, allow_tf32=False)
+
         S = apply_mask(
             S,
             col_idx,
@@ -1112,16 +1115,15 @@ def flash_fwd_splitkv_kernel(
                 V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
-
-    for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
-        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-        K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
+    nomasking_max = min(split_block_max, masking_block_min)
+    for n_block in tl.range(split_block_min, nomasking_max, num_stages=num_stages):
+        kv_off = n_block * BLOCK_N * k_s_stride
+        K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
         if PRE_LOAD_V:
-            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
         S = tl.dot(Q, K)
 
+        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
         S = apply_mask(
             S,
             col_idx,
@@ -1149,8 +1151,7 @@ def flash_fwd_splitkv_kernel(
         P = P.to(Q_ptr.type.element_ty)
 
         if not PRE_LOAD_V:
-            V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-            V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+            V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
 
         O_ = tl.dot(P, V, O_)
 
@@ -1162,21 +1163,24 @@ def flash_fwd_splitkv_kernel(
     O_ *= inv_sum[:, None]
 
     # Write back output
-    O_split_ptr = O_ptr
-    # (n_splits, batch_size, num_heads, seqlen_q, head_size)
+    # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
     # grid = (seq_block, split, batch * head)
+    O_split_ptr = O_ptr
+    # + split, batch, head offsets, seq_block offsets are already added in row_idx
     O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM
     O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
+    p_om = O_split_ptr + O_split_offset
 
     if IS_EVEN_MN:
-        tl.store(O_split_ptr + O_split_offset, O_)
+        tl.store(p_om, O_)
     else:
-        tl.store(O_split_ptr + O_split_offset, O_, mask=qmask)
+        tl.store(p_om, O_, mask=qmask)
     
     # Write back lse
+    # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
     lse_split_ptr = lse_ptr
-    lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q
-    lse_split_ptr += m_block * BLOCK_M
+    # + split, batch, head, seq_block offsets
+    lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M
 
     if IS_EVEN_MN:
         tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse)
@@ -1303,22 +1307,25 @@ def mha_fwd(
     seqlen_k_rounded = round_multiple(seqlen_k, 128)
 
     def splits_heuristics(num_tasks, num_sms, n_blocks):
-        # splits only number of waves and wave efficiency are both low
+        # splits when wave efficiency is low
         n_waves = triton.cdiv(num_tasks, num_sms)
         eff = (num_tasks / num_sms) / n_waves
-        if eff > 0.85 or n_waves > 10:
+        if eff > 0.85:
             return 1
-        max_eff = eff
+
         best_splits = 1
-        for w in range(n_waves, 10):
-            n_splits = min(num_sms, n_blocks, w * num_sms // num_tasks)
-            blocks_per_split = triton.cdiv(n_blocks, n_splits)
-            if blocks_per_split < 4:
-                continue
+        best_eff = eff
+        min_blocks_per_split = 1
+        max_blocks_per_split = triton.cdiv(n_blocks, 2)
+        for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]:
             n_splits = triton.cdiv(n_blocks, blocks_per_split)
-            eff = (n_splits * num_tasks / num_sms) / w
-            if eff > max_eff:
-                max_eff = eff
+            n_waves = triton.cdiv(n_splits * num_tasks, num_sms)
+            eff = (n_splits * num_tasks / num_sms) / n_waves
+            if eff > 0.85:
+                best_splits = n_splits
+                break
+            if eff > best_eff:
+                best_eff = eff
                 best_splits = n_splits
         return best_splits
 
@@ -1348,6 +1355,27 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         pdrop_u8 = math.floor(p_dropout * 255.0)
         rpdrop = 1. / p_dropout
 
+        M_LOG2E	= 1.4426950408889634074
+        softmax_scale_log2e = softmax_scale * M_LOG2E
+
+        # Set alibi params
+        if alibi_slopes is not None:
+            assert alibi_slopes.device == q_device
+            assert alibi_slopes.dtype in (torch.float, )
+            assert alibi_slopes.stride(-1) == 1
+            assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads)
+            alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
+            has_alibi = True
+        else:
+            alibi_slopes_batch_stride = 0
+            has_alibi = False
+
+        # Set SWA params
+        is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
+
+        # ONLY EVEN_K IS SUPPORTED
+        assert head_size == head_size_rounded
+
         # Check splitkv
         def try_split_kv():
             block_m = block_m_splitkv_heuristic(head_size)
@@ -1356,10 +1384,15 @@ def try_split_kv():
             block_n = block_n_splitkv_heuristic(head_size)
             n_blocks = triton.cdiv(seqlen_k, block_n)
             n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
-            return n_splits
-            
-        n_splits = try_split_kv() if not is_dropout else 1
+            blocks_per_split = triton.cdiv(n_blocks, n_splits)
+            return n_splits, blocks_per_split
+        
+        if not is_dropout:
+            n_splits, blocks_per_split = try_split_kv()
+        else:
+            n_splits, blocks_per_split = 1, None
         print('n_splits:', n_splits)
+        print('blocks_per_split', blocks_per_split)
 
         if n_splits > 1:
             lse_splits = torch.empty(
@@ -1373,27 +1406,6 @@ def try_split_kv():
                 device=q_device
             )
 
-        M_LOG2E	= 1.4426950408889634074
-        softmax_scale_log2e = softmax_scale * M_LOG2E
-
-        # Set alibi params
-        if alibi_slopes is not None:
-            assert alibi_slopes.device == q_device
-            assert alibi_slopes.dtype in (torch.float, )
-            assert alibi_slopes.stride(-1) == 1
-            assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (batch_size, num_heads)
-            alibi_slopes_batch_stride = alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
-            has_alibi = True
-        else:
-            alibi_slopes_batch_stride = 0
-            has_alibi = False
-
-        # Set SWA params
-        is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
-
-        # ONLY EVEN_K IS SUPPORTED
-        assert head_size == head_size_rounded
-
         # Launch kernel
         if n_splits > 1:
             grid = lambda args: (
@@ -1453,7 +1465,7 @@ def try_split_kv():
             ws_right=window_size_right,
             return_P=return_softmax,
             BATCH_SIZE=batch_size,
-            num_splits=n_splits,
+            blocks_per_split=blocks_per_split,
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
         )
@@ -1484,7 +1496,6 @@ def try_split_kv():
             q_total=batch_size * num_heads * seqlen_q,
             MAX_N_SPLITS=triton.next_power_of_2(n_splits),
         )
-        print(f'{kernel.name} shared memory:', kernel.metadata.shared)
 
     if swap_seq_and_group:
         out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))

From a1c93be491cb403a823ba6190dbaf6fd3a83d3be Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Sun, 23 Mar 2025 06:05:42 +0000
Subject: [PATCH 19/25] Working on splitkv..

---
 src/flag_gems/ops/attention.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 8601e0bd7..2e34de23e 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -1156,7 +1156,9 @@ def flash_fwd_splitkv_kernel(
         O_ = tl.dot(P, V, O_)
 
     # LSE
-    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    # if (split_block_max <= n_block_min) or (split_block_min >= n_block_max):
+
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     
     # Rescale output

From 24cb70e15ad75b940f221b2c4b42cc0ae58bfd84 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Sun, 23 Mar 2025 06:33:19 +0000
Subject: [PATCH 20/25] Splitkv passes but requires solid perf opt.

---
 src/flag_gems/ops/attention.py | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 2e34de23e..4d07179dd 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -674,7 +674,7 @@ def flash_fwd_kernel(
     IS_EVEN_MN: tl.constexpr,
     BLOCK_M: tl.constexpr,
     BLOCK_N: tl.constexpr,
-    num_splits: tl.constexpr,
+    blocks_per_split: tl.constexpr,
     num_warps: tl.constexpr,
     num_stages: tl.constexpr
 ):
@@ -913,7 +913,7 @@ def flash_fwd_kernel(
 
     # LSE
     # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
-    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
 
     # Rescale output
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
@@ -1156,8 +1156,6 @@ def flash_fwd_splitkv_kernel(
         O_ = tl.dot(P, V, O_)
 
     # LSE
-    # if (split_block_max <= n_block_min) or (split_block_min >= n_block_max):
-
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
     
@@ -1312,7 +1310,7 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         # splits when wave efficiency is low
         n_waves = triton.cdiv(num_tasks, num_sms)
         eff = (num_tasks / num_sms) / n_waves
-        if eff > 0.85:
+        if eff > 0.8 or n_waves > 1:
             return 1
 
         best_splits = 1
@@ -1326,9 +1324,6 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             if eff > 0.85:
                 best_splits = n_splits
                 break
-            if eff > best_eff:
-                best_eff = eff
-                best_splits = n_splits
         return best_splits
 
     with torch_device_fn.device(q_device):

From 901cfdc626734a9aa84217ea559b392c72cbd0dd Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Sun, 23 Mar 2025 15:08:58 +0000
Subject: [PATCH 21/25] Dirty hacking for debugging splitkv.

---
 src/flag_gems/ops/attention.py | 351 +++++++++++++++++++++++++--------
 1 file changed, 272 insertions(+), 79 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 4d07179dd..025270edd 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -784,7 +784,7 @@ def flash_fwd_kernel(
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
         )
-        P = P.to(O_ptr.type.element_ty)
+        P = P.to(V_ptr.type.element_ty)
 
         if is_dropout:
             if return_P:
@@ -865,7 +865,7 @@ def flash_fwd_kernel(
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=is_local,
         )
-        P = P.to(O_ptr.type.element_ty)
+        P = P.to(V_ptr.type.element_ty)
 
         if is_dropout:
             if return_P:
@@ -948,13 +948,13 @@ def flash_fwd_kernel(
         'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
         'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
         'num_warps': lambda args: 4,
-        'num_stages': lambda args: 3,
+        'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
         'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
-def flash_fwd_splitkv_kernel(
+def flash_fwd_splitkv_kernel_v2(
     Q_ptr,
     K_ptr,
     V_ptr,
@@ -1008,18 +1008,16 @@ def flash_fwd_splitkv_kernel(
     bid = tl.program_id(2) // NUM_HEADS
     hid = tl.program_id(2) % NUM_HEADS
 
-    split_block_min = split_id * blocks_per_split
-    split_block_max = min((split_id + 1) * blocks_per_split, tl.cdiv(seqlen_k, BLOCK_N))
+    split_col_min = split_id * blocks_per_split * BLOCK_N
+    split_col_max = split_col_min + blocks_per_split * BLOCK_N
 
-    if is_local:
-        n_block_min = max(0, (m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left) / BLOCK_N)
-    else:
-        n_block_min = 0
+    col_min = 0
+    
+    col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N
+    if is_causal:
+        col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right)
 
-    n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
-    if is_causal or is_local:
-        n_block_max = min(n_block_max,
-                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
+    split_col_max = min(split_col_max, col_max)
 
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
@@ -1028,27 +1026,28 @@ def flash_fwd_splitkv_kernel(
     else:
         alibi_slope = 0.0
 
-    if (not is_causal) and (not is_local):
+    if not is_causal:
         if IS_EVEN_MN:
-            masking_block_min = n_block_max
+            masking_cols: tl.constexpr = 0
         else:
-            masking_block_min = n_block_max - 1
+            masking_cols: tl.constexpr = BLOCK_N
     elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
+        masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
     else:
         # local and not causal, 
-        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
+        masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
 
     Q_ptr += bid * q_b_stride
     Q_ptr += hid * q_h_stride
-    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    row_start = m_block * BLOCK_M
+    row_idx = row_start + tl.arange(0, BLOCK_M)
     Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
     p_qm = Q_ptr + Q_off
     qmask = row_idx[:, None] < seqlen_q
     if IS_EVEN_MN:
-        Q = tl.load(p_qm)
+        Q = tl.load(p_qm, cache_modifier=".cg")
     else:
-        Q = tl.load(p_qm, mask=qmask)
+        Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
 
     h_hk_ratio = h // hk
     K_ptr += bid * k_b_stride
@@ -1057,32 +1056,31 @@ def flash_fwd_splitkv_kernel(
     V_ptr += (hid // h_hk_ratio) * k_h_stride
 
     K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-    p_k0 = K_ptr + K_offset
+    p_bk0 = K_ptr + K_offset
     
     V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-    p_v0 = V_ptr + V_offset
+    p_bv0 = V_ptr + V_offset
 
     O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
     rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
     rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
 
-    split_masking_block = max(masking_block_min, split_block_min)
-    for n_block in tl.range(split_block_max - 1, split_masking_block - 1, step=-1):
-        kv_off = n_block * BLOCK_N * k_s_stride
-        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    for col_start in tl.range(split_col_min, split_col_max, step=BLOCK_N):
+        col_start = tl.multiple_of(col_start, BLOCK_N)
+        off = col_start * k_s_stride
         if IS_EVEN_MN:
-            K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
+            K = tl.load(p_bk0 + off, cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, cache_modifier=".cg")
         else:
+            col_idx = col_start + tl.arange(0, BLOCK_N)
             kvmask = col_idx < seqlen_k
-            K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
+            K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
             if PRE_LOAD_V:
-                V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
-
+                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
         S = tl.dot(Q, K, allow_tf32=False)
-
+        col_idx = col_start + tl.arange(0, BLOCK_N)
+        row_idx = row_start + tl.arange(0, BLOCK_M)
         S = apply_mask(
             S,
             col_idx,
@@ -1094,7 +1092,7 @@ def flash_fwd_splitkv_kernel(
             alibi_slope,
             is_even_mn=IS_EVEN_MN,
             is_causal=is_causal,
-            is_local=is_local,
+            is_local=False,
             has_alibi=has_alibi
         )
 
@@ -1106,54 +1104,242 @@ def flash_fwd_splitkv_kernel(
             softmax_scale_log2e=softmax_scale_log2e,
             is_border=(is_causal or is_local),
         )
-        P = P.to(Q_ptr.type.element_ty)
+        P = P.to(V_ptr.type.element_ty)
 
         if not PRE_LOAD_V:
+            off = col_start * k_s_stride
             if IS_EVEN_MN:
-                V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, cache_modifier=".cg")
             else:
-                V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
+                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
-    nomasking_max = min(split_block_max, masking_block_min)
-    for n_block in tl.range(split_block_min, nomasking_max, num_stages=num_stages):
-        kv_off = n_block * BLOCK_N * k_s_stride
-        K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
-        if PRE_LOAD_V:
-            V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
-        S = tl.dot(Q, K)
+    # LSE
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
+    
+    # Rescale output
+    O_ *= inv_sum[:, None]
 
-        col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
-        S = apply_mask(
-            S,
-            col_idx,
-            row_idx,
-            seqlen_q,
-            seqlen_k,
-            ws_left,
-            ws_right,
-            alibi_slope,
-            is_even_mn=True,
-            is_causal=False,
-            is_local=is_local,
-            has_alibi=has_alibi
-        )
+    # Write back output
+    # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
+    # grid = (seq_block, split, batch * head)
+    O_split_ptr = O_ptr
+    # + split, batch, head offsets, seq_block offsets are already added in row_idx
+    O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM
+    O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
+    O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM)
+    p_om = O_split_ptr + O_split_offset
 
-        O_, P, rowmax_, rowsum_ = softmax_rescale(
-            O_,
-            S,
-            rowmax_,
-            rowsum_,
-            softmax_scale_log2e=softmax_scale_log2e,
-            is_border=is_local,
-        )
+    if IS_EVEN_MN:
+        tl.store(p_om, O_, cache_modifier=".cg")
+    else:
+        tl.store(p_om, O_, mask=qmask, cache_modifier=".cg")
+    
+    # Write back lse
+    # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
+    lse_split_ptr = lse_ptr
+    # + split, batch, head, seq_block offsets
+    lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M
 
-        P = P.to(Q_ptr.type.element_ty)
+    if IS_EVEN_MN:
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
+    else:
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg")
 
-        if not PRE_LOAD_V:
-            V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
 
-        O_ = tl.dot(P, V, O_)
+@triton.heuristics(
+    values={
+        'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'num_warps': lambda args: 4,
+        'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
+        'PRE_LOAD_V': lambda args: True,
+        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
+    }
+)
+@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
+def flash_fwd_splitkv_kernel(
+    Q_ptr,
+    K_ptr,
+    V_ptr,
+    P_ptr,
+    O_ptr,
+    lse_ptr,
+    seqlen_q,
+    seqlen_k,
+    seqlen_q_rounded,
+    seqlen_k_rounded,
+    q_b_stride,
+    q_s_stride,
+    q_h_stride,
+    k_b_stride,
+    k_s_stride,
+    k_h_stride,
+    o_b_stride,
+    o_s_stride,
+    o_h_stride,
+    h,
+    hk,
+    pSlopes,
+    philox_seed,
+    philox_offset,
+    pdrop_u8,
+    rpdrop,
+    slopes_batch_stride,
+    HEAD_DIM: tl.constexpr,
+    is_dropout: tl.constexpr,
+    is_causal: tl.constexpr,
+    is_local: tl.constexpr,
+    has_alibi: tl.constexpr,
+    softmax_scale: tl.constexpr,
+    softmax_scale_log2e: tl.constexpr,
+    ws_left: tl.constexpr,
+    ws_right: tl.constexpr,
+    return_P: tl.constexpr,
+    PRE_LOAD_V: tl.constexpr,
+    BATCH_SIZE: tl.constexpr,
+    NUM_HEADS: tl.constexpr,
+    NUM_HEADS_K: tl.constexpr,
+    IS_EVEN_MN: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    blocks_per_split: tl.constexpr,
+    num_warps: tl.constexpr,
+    num_stages: tl.constexpr
+):
+    m_block = tl.program_id(0)
+    split_id = tl.program_id(1)
+    bid = tl.program_id(2) // NUM_HEADS
+    hid = tl.program_id(2) % NUM_HEADS
+
+    split_block_min = split_id * blocks_per_split
+    split_block_max = split_block_min + blocks_per_split
+
+    n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
+    if is_causal:
+        n_block_max = min(n_block_max,
+                          tl.cdiv((m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right, BLOCK_N))
+
+    if has_alibi:
+        alibi_offset = bid * slopes_batch_stride + hid
+        alibi_slope = tl.load(pSlopes + alibi_offset)
+        alibi_slope /= scale
+    else:
+        alibi_slope = 0.0
+
+    if not is_causal:
+        if IS_EVEN_MN:
+            masking_block_min = n_block_max
+        else:
+            masking_block_min = n_block_max - 1
+    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
+        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
+    else: 
+        masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
+
+    Q_ptr += bid * q_b_stride
+    Q_ptr += hid * q_h_stride
+    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    p_qm = Q_ptr + Q_off
+    qmask = row_idx[:, None] < seqlen_q
+    if IS_EVEN_MN:
+        Q = tl.load(p_qm, cache_modifier=".cg")
+    else:
+        Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
+
+    h_hk_ratio = h // hk
+    K_ptr += bid * k_b_stride
+    K_ptr += (hid // h_hk_ratio) * k_h_stride
+    V_ptr += bid * k_b_stride
+    V_ptr += (hid // h_hk_ratio) * k_h_stride
+
+    K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
+    p_k0 = K_ptr + K_offset
+    
+    V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
+    p_v0 = V_ptr + V_offset
+
+    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
+    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
+
+    if split_block_max <= masking_block_min:
+        # no masking needed
+        for n_block in tl.range(split_block_min, split_block_max, num_stages=num_stages):
+            kv_off = n_block * BLOCK_N * k_s_stride
+            K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
+            if PRE_LOAD_V:
+                V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
+            S = tl.dot(Q, K)
+
+            col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+
+            if has_alibi:
+                S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
+
+            O_, P, rowmax_, rowsum_ = softmax_rescale(
+                O_,
+                S,
+                rowmax_,
+                rowsum_,
+                softmax_scale_log2e=softmax_scale_log2e,
+                is_border=False,
+            )
+
+            if not PRE_LOAD_V:
+                V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
+            P = P.to(Q_ptr.type.element_ty)
+            O_ = tl.dot(P, V, O_)        
+    else:
+        for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
+            kv_off = n_block * BLOCK_N * k_s_stride
+            col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+            row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+            if IS_EVEN_MN:
+                K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
+                if PRE_LOAD_V:
+                    V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
+            else:
+                kvmask = col_idx < seqlen_k
+                K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
+                if PRE_LOAD_V:
+                    V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
+
+            S = tl.dot(Q, K, allow_tf32=False)
+
+            S = apply_mask(
+                S,
+                col_idx,
+                row_idx,
+                seqlen_q,
+                seqlen_k,
+                ws_left,
+                ws_right,
+                alibi_slope,
+                is_even_mn=IS_EVEN_MN,
+                is_causal=is_causal,
+                is_local=False,
+                has_alibi=has_alibi
+            )
+
+            O_, P, rowmax_, rowsum_ = softmax_rescale(
+                O_,
+                S,
+                rowmax_,
+                rowsum_,
+                softmax_scale_log2e=softmax_scale_log2e,
+                is_border=(is_causal or is_local),
+            )
+
+            if not PRE_LOAD_V:
+                if IS_EVEN_MN:
+                    V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
+                else:
+                    V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
+            P = P.to(Q_ptr.type.element_ty)
+            O_ = tl.dot(P, V, O_, allow_tf32=False)
 
     # LSE
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
@@ -1169,12 +1355,13 @@ def flash_fwd_splitkv_kernel(
     # + split, batch, head offsets, seq_block offsets are already added in row_idx
     O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM
     O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
+    O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM)
     p_om = O_split_ptr + O_split_offset
 
     if IS_EVEN_MN:
-        tl.store(p_om, O_)
+        tl.store(p_om, O_, cache_modifier=".cg")
     else:
-        tl.store(p_om, O_, mask=qmask)
+        tl.store(p_om, O_, mask=qmask, cache_modifier=".cg")
     
     # Write back lse
     # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
@@ -1183,9 +1370,9 @@ def flash_fwd_splitkv_kernel(
     lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M
 
     if IS_EVEN_MN:
-        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse)
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
     else:
-        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q)
+        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg")
 
 
 @triton.jit
@@ -1384,10 +1571,16 @@ def try_split_kv():
             blocks_per_split = triton.cdiv(n_blocks, n_splits)
             return n_splits, blocks_per_split
         
-        if not is_dropout:
+        if not is_dropout and not is_local:
             n_splits, blocks_per_split = try_split_kv()
         else:
             n_splits, blocks_per_split = 1, None
+
+#        n_splits = 1
+        block_n = block_n_splitkv_heuristic(head_size)
+        n_blocks = triton.cdiv(seqlen_k, block_n)
+        blocks_per_split = triton.cdiv(n_blocks, n_splits)
+        print('block_n:', block_n)
         print('n_splits:', n_splits)
         print('blocks_per_split', blocks_per_split)
 
@@ -1410,7 +1603,7 @@ def try_split_kv():
                 n_splits,
                 batch_size * num_heads
             )
-            kernel = flash_fwd_splitkv_kernel[grid]
+            kernel = flash_fwd_splitkv_kernel_v2[grid]
             tmp_lse = lse_splits
             tmp_out = out_splits
         else:

From 8e47b1d4b0f4a6d57d310f0ad9e7d45b3bbcf854 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Mon, 24 Mar 2025 04:27:19 +0000
Subject: [PATCH 22/25] fixed an error with splitkv block_n heuristics.

---
 src/flag_gems/__init__.py      | 395 +++++++++++----------
 src/flag_gems/ops/__init__.py  | 614 ++++++++++++++++-----------------
 src/flag_gems/ops/attention.py |  97 +++---
 3 files changed, 545 insertions(+), 561 deletions(-)

diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py
index 85cc04941..0bb7dc486 100644
--- a/src/flag_gems/__init__.py
+++ b/src/flag_gems/__init__.py
@@ -19,211 +19,206 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
     global current_work_registrar
     current_work_registrar = registrar(
         (
-            # ("abs", abs, Autograd.disable),
-            # ("add.Tensor", add, Autograd.disable),
-            # ("addmm", addmm, Autograd.disable),
-            # ("arange.start_step", arange_start, Autograd.disable),
-            # ("arange.start", arange_start, Autograd.disable),
-            # ("arange", arange, Autograd.disable),
-            # ("batch_norm", batch_norm, Autograd.enable),
-            # ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
-            # ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
-            # ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
-            # ("bitwise_not", bitwise_not, Autograd.disable),
-            # ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable),
-            # ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable),
-            # ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable),
-            # # ("bmm", bmm, Autograd.disable),
-            # ("clamp", clamp, Autograd.disable),
-            # ("clamp.Tensor", clamp_tensor, Autograd.disable),
-            # ("cos", cos, Autograd.disable),
-            # ("pad", pad, Autograd.disable),
-            # ("constant_pad_nd", constant_pad_nd, Autograd.disable),
-            # ("cumsum", cumsum, Autograd.disable),
-            # ("cummin", cummin, Autograd.disable),
-            # ("div.Tensor", true_divide, Autograd.disable),
-            # ("div.Scalar", true_divide, Autograd.disable),
-            # ("div.Tensor_mode", div_mode, Autograd.disable),
-            # ("div.Scalar_mode", div_mode, Autograd.disable),
-            # (
-            #     "divide.Tensor",
-            #     true_divide,
-            #     Autograd.disable,
-            # ),  # divide, an alias for div
-            # ("divide.Scalar", true_divide, Autograd.disable),
-            # ("divide.Tensor_mode", div_mode, Autograd.disable),
-            # ("divide.Scalar_mode", div_mode, Autograd.disable),
-            # (
-            #     "true_divide.Tensor",
-            #     true_divide,
-            #     Autograd.disable,
-            # ),  # true_divide, an alias for div
-            # ("true_divide.Scalar", true_divide, Autograd.disable),
-            # ("floor_divide", floor_divide, Autograd.disable),
-            # ("floor_divide.Scalar", floor_divide, Autograd.disable),
-            # ("remainder.Tensor", remainder, Autograd.disable),
-            # ("native_dropout", native_dropout, Autograd.enable),
-            # ("erf", erf, Autograd.disable),
-            # ("embedding", embedding, Autograd.enable),
-            # ("eq.Tensor", eq, Autograd.disable),
-            # ("eq.Scalar", eq_scalar, Autograd.disable),
-            # ("exp", exp, Autograd.disable),
-            # ("exponential_", exponential_, Autograd.disable),
-            # ("ge.Tensor", ge, Autograd.disable),
-            # ("ge.Scalar", ge_scalar, Autograd.disable),
-            # ("gelu", gelu, Autograd.enable),
-            # ("native_group_norm", group_norm, Autograd.enable),
-            # ("_weight_norm_interface", weight_norm_interface, Autograd.enable),
-            # ("_weight_norm", weight_norm, Autograd.enable),
-            # ("gt.Tensor", gt, Autograd.disable),
-            # ("gt.Scalar", gt_scalar, Autograd.disable),
-            # ("instance_norm", instance_norm, Autograd.enable),
-            # ("isfinite", isfinite, Autograd.disable),
-            # ("isin.Tensor_Tensor", isin, Autograd.disable),
-            # ("isin.Scalar_Tensor", isin, Autograd.disable),
-            # ("isin.Tensor_Scalar", isin, Autograd.disable),
-            # ("isinf", isinf, Autograd.disable),
-            # ("isnan", isnan, Autograd.disable),
-            # ("minimum", minimum, Autograd.disable),
-            # ("maximum", maximum, Autograd.disable),
-            # ("native_layer_norm", layer_norm, Autograd.enable),
-            # ("le.Tensor", le, Autograd.disable),
-            # ("le.Scalar", le_scalar, Autograd.disable),
-            # ("lt.Tensor", lt, Autograd.disable),
-            # ("lt.Scalar", lt_scalar, Autograd.disable),
-            # ("rms_norm", rms_norm, Autograd.disable),
-            # ("rand", rand, Autograd.disable),
-            # ("randn", randn, Autograd.disable),
-            # ("rand_like", rand_like, Autograd.disable),
-            # ("randn_like", randn_like, Autograd.disable),
-            # ("zeros", zeros, Autograd.disable),
-            # ("ones", ones, Autograd.disable),
-            # ("full", full, Autograd.disable),
-            # ("zeros_like", zeros_like, Autograd.disable),
-            # ("ones_like", ones_like, Autograd.disable),
-            # ("full_like", full_like, Autograd.disable),
-            # ("resolve_neg", resolve_neg, Autograd.disable),
-            # ("resolve_conj", resolve_conj, Autograd.disable),
-            # ("normal.Tensor_float", normal_tensor_float, Autograd.disable),
-            # ("normal.float_Tensor", normal_float_tensor, Autograd.disable),
-            # ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable),
-            # ("uniform_", uniform_, Autograd.disable),
-            # ("mean", mean, Autograd.disable),
-            # ("mean.dim", mean_dim, Autograd.disable),
-            # ("mm", mm, Autograd.disable),
-            # ("mul.Tensor", mul, Autograd.disable),
-            # ("multinomial", multinomial, Autograd.disable),
-            # ("mv", mv, Autograd.disable),
-            # ("ne.Tensor", ne, Autograd.disable),
-            # ("ne.Scalar", ne_scalar, Autograd.disable),
-            # ("neg", neg, Autograd.disable),
-            # ("pow.Scalar", pow_scalar, Autograd.disable),
-            # ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable),
-            # ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable),
-            # ("reciprocal", reciprocal, Autograd.disable),
-            # ("relu", relu, Autograd.enable),
-            # ("rsqrt", rsqrt, Autograd.disable),
-            # ("sigmoid", sigmoid, Autograd.enable),
-            # ("silu", silu, Autograd.enable),
-            # ("sin", sin, Autograd.disable),
-            # ("softmax.int", softmax, Autograd.enable),
-            # ("sort", sort, Autograd.disable),
-            # ("sub.Tensor", sub, Autograd.disable),
-            # ("tanh", tanh, Autograd.enable),
-            # ("triu", triu, Autograd.disable),
-            # ("topk", topk, Autograd.disable),
-            # ("var_mean.correction", var_mean, Autograd.disable),
-            # ("linalg_vector_norm", vector_norm, Autograd.disable),
-            # ("where.self_out", where_self_out, Autograd.disable),
-            # ("where.self", where_self, Autograd.disable),
-            # ("where.ScalarSelf", where_scalar_self, Autograd.disable),
-            # ("where.ScalarOther", where_scalar_other, Autograd.disable),
-            # ("max", max, Autograd.disable),
-            # ("max.dim", max_dim, Autograd.disable),
-            # ("min", min, Autograd.disable),
-            # ("min.dim", min_dim, Autograd.disable),
-            # ("amax", amax, Autograd.disable),
-            # ("argmax", argmax, Autograd.disable),
-            # ("argmin", argmin, Autograd.disable),
-            # ("prod", prod, Autograd.disable),
-            # ("prod.dim_int", prod_dim, Autograd.disable),
-            # ("sum", sum, Autograd.disable),
-            # ("sum.dim_IntList", sum_dim, Autograd.disable),
-            # (
-            #     "scaled_dot_product_attention",
-            #     scaled_dot_product_attention,
-            #     Autograd.disable,
-            # ),
+            ("abs", abs, Autograd.disable),
+            ("add.Tensor", add, Autograd.disable),
+            ("addmm", addmm, Autograd.disable),
+            ("arange.start_step", arange_start, Autograd.disable),
+            ("arange.start", arange_start, Autograd.disable),
+            ("arange", arange, Autograd.disable),
+            ("batch_norm", batch_norm, Autograd.enable),
+            ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
+            ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
+            ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
+            ("bitwise_not", bitwise_not, Autograd.disable),
+            ("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable),
+            ("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable),
+            ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable),
+            # ("bmm", bmm, Autograd.disable),
+            ("clamp", clamp, Autograd.disable),
+            ("clamp.Tensor", clamp_tensor, Autograd.disable),
+            ("cos", cos, Autograd.disable),
+            ("pad", pad, Autograd.disable),
+            ("constant_pad_nd", constant_pad_nd, Autograd.disable),
+            ("cumsum", cumsum, Autograd.disable),
+            ("cummin", cummin, Autograd.disable),
+            ("div.Tensor", true_divide, Autograd.disable),
+            ("div.Scalar", true_divide, Autograd.disable),
+            ("div.Tensor_mode", div_mode, Autograd.disable),
+            ("div.Scalar_mode", div_mode, Autograd.disable),
+            (
+                "divide.Tensor",
+                true_divide,
+                Autograd.disable,
+            ),  # divide, an alias for div
+            ("divide.Scalar", true_divide, Autograd.disable),
+            ("divide.Tensor_mode", div_mode, Autograd.disable),
+            ("divide.Scalar_mode", div_mode, Autograd.disable),
+            (
+                "true_divide.Tensor",
+                true_divide,
+                Autograd.disable,
+            ),  # true_divide, an alias for div
+            ("true_divide.Scalar", true_divide, Autograd.disable),
+            ("floor_divide", floor_divide, Autograd.disable),
+            ("floor_divide.Scalar", floor_divide, Autograd.disable),
+            ("remainder.Tensor", remainder, Autograd.disable),
+            ("native_dropout", native_dropout, Autograd.enable),
+            ("erf", erf, Autograd.disable),
+            ("embedding", embedding, Autograd.enable),
+            ("eq.Tensor", eq, Autograd.disable),
+            ("eq.Scalar", eq_scalar, Autograd.disable),
+            ("exp", exp, Autograd.disable),
+            ("exponential_", exponential_, Autograd.disable),
+            ("ge.Tensor", ge, Autograd.disable),
+            ("ge.Scalar", ge_scalar, Autograd.disable),
+            ("gelu", gelu, Autograd.enable),
+            ("native_group_norm", group_norm, Autograd.enable),
+            ("_weight_norm_interface", weight_norm_interface, Autograd.enable),
+            ("_weight_norm", weight_norm, Autograd.enable),
+            ("gt.Tensor", gt, Autograd.disable),
+            ("gt.Scalar", gt_scalar, Autograd.disable),
+            ("instance_norm", instance_norm, Autograd.enable),
+            ("isfinite", isfinite, Autograd.disable),
+            ("isin.Tensor_Tensor", isin, Autograd.disable),
+            ("isin.Scalar_Tensor", isin, Autograd.disable),
+            ("isin.Tensor_Scalar", isin, Autograd.disable),
+            ("isinf", isinf, Autograd.disable),
+            ("isnan", isnan, Autograd.disable),
+            ("minimum", minimum, Autograd.disable),
+            ("maximum", maximum, Autograd.disable),
+            ("native_layer_norm", layer_norm, Autograd.enable),
+            ("le.Tensor", le, Autograd.disable),
+            ("le.Scalar", le_scalar, Autograd.disable),
+            ("lt.Tensor", lt, Autograd.disable),
+            ("lt.Scalar", lt_scalar, Autograd.disable),
+            ("rms_norm", rms_norm, Autograd.disable),
+            ("rand", rand, Autograd.disable),
+            ("randn", randn, Autograd.disable),
+            ("rand_like", rand_like, Autograd.disable),
+            ("randn_like", randn_like, Autograd.disable),
+            ("zeros", zeros, Autograd.disable),
+            ("ones", ones, Autograd.disable),
+            ("full", full, Autograd.disable),
+            ("zeros_like", zeros_like, Autograd.disable),
+            ("ones_like", ones_like, Autograd.disable),
+            ("full_like", full_like, Autograd.disable),
+            ("resolve_neg", resolve_neg, Autograd.disable),
+            ("resolve_conj", resolve_conj, Autograd.disable),
+            ("normal.Tensor_float", normal_tensor_float, Autograd.disable),
+            ("normal.float_Tensor", normal_float_tensor, Autograd.disable),
+            ("normal.Tensor_Tensor", normal_tensor_tensor, Autograd.disable),
+            ("uniform_", uniform_, Autograd.disable),
+            ("mean", mean, Autograd.disable),
+            ("mean.dim", mean_dim, Autograd.disable),
+            ("mm", mm, Autograd.disable),
+            ("mul.Tensor", mul, Autograd.disable),
+            ("multinomial", multinomial, Autograd.disable),
+            ("mv", mv, Autograd.disable),
+            ("ne.Tensor", ne, Autograd.disable),
+            ("ne.Scalar", ne_scalar, Autograd.disable),
+            ("neg", neg, Autograd.disable),
+            ("pow.Scalar", pow_scalar, Autograd.disable),
+            ("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable),
+            ("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable),
+            ("reciprocal", reciprocal, Autograd.disable),
+            ("relu", relu, Autograd.enable),
+            ("rsqrt", rsqrt, Autograd.disable),
+            ("sigmoid", sigmoid, Autograd.enable),
+            ("silu", silu, Autograd.enable),
+            ("sin", sin, Autograd.disable),
+            ("softmax.int", softmax, Autograd.enable),
+            ("sort", sort, Autograd.disable),
+            ("sub.Tensor", sub, Autograd.disable),
+            ("tanh", tanh, Autograd.enable),
+            ("triu", triu, Autograd.disable),
+            ("topk", topk, Autograd.disable),
+            ("var_mean.correction", var_mean, Autograd.disable),
+            ("linalg_vector_norm", vector_norm, Autograd.disable),
+            ("where.self_out", where_self_out, Autograd.disable),
+            ("where.self", where_self, Autograd.disable),
+            ("where.ScalarSelf", where_scalar_self, Autograd.disable),
+            ("where.ScalarOther", where_scalar_other, Autograd.disable),
+            ("max", max, Autograd.disable),
+            ("max.dim", max_dim, Autograd.disable),
+            ("min", min, Autograd.disable),
+            ("min.dim", min_dim, Autograd.disable),
+            ("amax", amax, Autograd.disable),
+            ("argmax", argmax, Autograd.disable),
+            ("argmin", argmin, Autograd.disable),
+            ("prod", prod, Autograd.disable),
+            ("prod.dim_int", prod_dim, Autograd.disable),
+            ("sum", sum, Autograd.disable),
+            ("sum.dim_IntList", sum_dim, Autograd.disable),
             (
                 "_flash_attention_forward",
                 flash_attention_forward,
                 Autograd.disable,
             ),
-            # ("all", all, Autograd.disable),
-            # ("all.dim", all_dim, Autograd.disable),
-            # ("all.dims", all_dims, Autograd.disable),
-            # ("any", any, Autograd.disable),
-            # ("any.dim", any_dim, Autograd.disable),
-            # ("any.dims", any_dims, Autograd.disable),
-            # ("quantile", quantile, Autograd.disable),
-            # ("log_softmax.int", log_softmax, Autograd.enable),
-            # ("outer", outer, Autograd.enable),
-            # ("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
-            # ("nll_loss_forward", nll_loss_forward, Autograd.disable),
-            # ("nll_loss_backward", nll_loss_backward, Autograd.disable),
-            # ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
-            # ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
-            # ("scatter.src", scatter, Autograd.disable),
-            # ("scatter.reduce", scatter, Autograd.disable),
-            # ("gather", gather, Autograd.disable),
-            # ("gather_backward", gather_backward, Autograd.disable),
-            # ("isclose", isclose, Autograd.disable),
-            # ("allclose", allclose, Autograd.disable),
-            # ("fill.Scalar", fill_scalar, Autograd.disable),
-            # ("fill.Tensor", fill_tensor, Autograd.disable),
-            # ("flip", flip, Autograd.disable),
-            # ("slice_scatter", slice_scatter, Autograd.disable),
-            # ("select_scatter", select_scatter, Autograd.disable),
-            # ("index_select", index_select, Autograd.disable),
-            # ("tile", tile, Autograd.disable),
-            # ("masked_fill.Tensor", masked_fill, Autograd.disable),
-            # ("masked_fill.Scalar", masked_fill, Autograd.disable),
-            # ("masked_fill_.Tensor", masked_fill_, Autograd.disable),
-            # ("masked_fill_.Scalar", masked_fill_, Autograd.disable),
-            # ("_unique2", _unique2, Autograd.disable),
-            # ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable),
-            # ("upsample_nearest2d", upsample_nearest2d, Autograd.disable),
-            # ("nonzero", nonzero, Autograd.disable),
-            # ("repeat", repeat, Autograd.disable),
-            # ("masked_select", masked_select, Autograd.disable),
-            # ("stack", stack, Autograd.disable),
-            # ("hstack", hstack, Autograd.disable),
-            # ("cat", cat, Autograd.disable),
-            # (
-            #     "repeat_interleave.self_int",
-            #     repeat_interleave_self_int,
-            #     Autograd.disable,
-            # ),
-            # ("vstack", vstack, Autograd.disable),
-            # ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable),
-            # (
-            #     "repeat_interleave.self_Tensor",
-            #     repeat_interleave_self_tensor,
-            #     Autograd.disable,
-            # ),
-            # ("randperm", randperm, Autograd.disable),
-            # ("diag", diag, Autograd.disable),
-            # ("diag_embed", diag_embed, Autograd.disable),
-            # ("diagonal_backward", diagonal_backward, Autograd.disable),
-            # ("index_add", index_add, Autograd.disable),
-            # ("count_nonzero", count_nonzero, Autograd.disable),
-            # ("logical_or", logical_or, Autograd.disable),
-            # ("logical_and", logical_and, Autograd.disable),
-            # ("logical_xor", logical_xor, Autograd.disable),
-            # ("logical_not", logical_not, Autograd.disable),
-            # ("log_sigmoid", log_sigmoid, Autograd.disable),
-            # ("vdot", vdot, Autograd.disable),
-            # ("mse_loss", mse_loss, Autograd.disable),
+            ("all", all, Autograd.disable),
+            ("all.dim", all_dim, Autograd.disable),
+            ("all.dims", all_dims, Autograd.disable),
+            ("any", any, Autograd.disable),
+            ("any.dim", any_dim, Autograd.disable),
+            ("any.dims", any_dims, Autograd.disable),
+            ("quantile", quantile, Autograd.disable),
+            ("log_softmax.int", log_softmax, Autograd.enable),
+            ("outer", outer, Autograd.enable),
+            ("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
+            ("nll_loss_forward", nll_loss_forward, Autograd.disable),
+            ("nll_loss_backward", nll_loss_backward, Autograd.disable),
+            ("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
+            ("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
+            ("scatter.src", scatter, Autograd.disable),
+            ("scatter.reduce", scatter, Autograd.disable),
+            ("gather", gather, Autograd.disable),
+            ("gather_backward", gather_backward, Autograd.disable),
+            ("isclose", isclose, Autograd.disable),
+            ("allclose", allclose, Autograd.disable),
+            ("fill.Scalar", fill_scalar, Autograd.disable),
+            ("fill.Tensor", fill_tensor, Autograd.disable),
+            ("flip", flip, Autograd.disable),
+            ("slice_scatter", slice_scatter, Autograd.disable),
+            ("select_scatter", select_scatter, Autograd.disable),
+            ("index_select", index_select, Autograd.disable),
+            ("tile", tile, Autograd.disable),
+            ("masked_fill.Tensor", masked_fill, Autograd.disable),
+            ("masked_fill.Scalar", masked_fill, Autograd.disable),
+            ("masked_fill_.Tensor", masked_fill_, Autograd.disable),
+            ("masked_fill_.Scalar", masked_fill_, Autograd.disable),
+            ("_unique2", _unique2, Autograd.disable),
+            ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, Autograd.disable),
+            ("upsample_nearest2d", upsample_nearest2d, Autograd.disable),
+            ("nonzero", nonzero, Autograd.disable),
+            ("repeat", repeat, Autograd.disable),
+            ("masked_select", masked_select, Autograd.disable),
+            ("stack", stack, Autograd.disable),
+            ("hstack", hstack, Autograd.disable),
+            ("cat", cat, Autograd.disable),
+            (
+                "repeat_interleave.self_int",
+                repeat_interleave_self_int,
+                Autograd.disable,
+            ),
+            ("vstack", vstack, Autograd.disable),
+            ("repeat_interleave.Tensor", repeat_interleave_tensor, Autograd.disable),
+            (
+                "repeat_interleave.self_Tensor",
+                repeat_interleave_self_tensor,
+                Autograd.disable,
+            ),
+            ("randperm", randperm, Autograd.disable),
+            ("diag", diag, Autograd.disable),
+            ("diag_embed", diag_embed, Autograd.disable),
+            ("diagonal_backward", diagonal_backward, Autograd.disable),
+            ("index_add", index_add, Autograd.disable),
+            ("count_nonzero", count_nonzero, Autograd.disable),
+            ("logical_or", logical_or, Autograd.disable),
+            ("logical_and", logical_and, Autograd.disable),
+            ("logical_xor", logical_xor, Autograd.disable),
+            ("logical_not", logical_not, Autograd.disable),
+            ("log_sigmoid", log_sigmoid, Autograd.disable),
+            ("vdot", vdot, Autograd.disable),
+            ("mse_loss", mse_loss, Autograd.disable),
         ),
         user_unused_ops_list=[] if unused is None else unused,
         lib=lib,
diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py
index 01a31e633..2802e11f6 100755
--- a/src/flag_gems/ops/__init__.py
+++ b/src/flag_gems/ops/__init__.py
@@ -1,312 +1,312 @@
-# from .abs import abs
-# from .add import add
-# from .addmm import addmm
-# from .all import all, all_dim, all_dims
-# from .amax import amax
-# from .any import any, any_dim, any_dims
-# from .arange import arange, arange_start
-# from .argmax import argmax
-# from .argmin import argmin
+from .abs import abs
+from .add import add
+from .addmm import addmm
+from .all import all, all_dim, all_dims
+from .amax import amax
+from .any import any, any_dim, any_dims
+from .arange import arange, arange_start
+from .argmax import argmax
+from .argmin import argmin
 from .attention import flash_attention_forward, scaled_dot_product_attention
-# from .batch_norm import batch_norm
-# from .bitwise_and import (
-#     bitwise_and_scalar,
-#     bitwise_and_scalar_tensor,
-#     bitwise_and_tensor,
-# )
-# from .bitwise_not import bitwise_not
-# from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
-# from .bmm import bmm
-# from .cat import cat
-# from .clamp import clamp, clamp_tensor
-# from .conv1d import conv1d
-# from .conv2d import conv2d
-# from .conv_depthwise2d import _conv_depthwise2d
-# from .cos import cos
-# from .count_nonzero import count_nonzero
-# from .cross_entropy_loss import cross_entropy_loss
-# from .cummin import cummin
-# from .cumsum import cumsum, normed_cumsum
-# from .diag import diag
-# from .diag_embed import diag_embed
-# from .diagonal import diagonal_backward
-# from .div import div_mode, floor_divide, remainder, true_divide
-# from .dropout import native_dropout
-# from .embedding import embedding
-# from .eq import eq, eq_scalar
-# from .erf import erf
-# from .exp import exp
-# from .exponential_ import exponential_
-# from .fill import fill_scalar, fill_tensor
-# from .flip import flip
-# from .full import full
-# from .full_like import full_like
-# from .gather import gather, gather_backward
-# from .ge import ge, ge_scalar
-# from .gelu import gelu
-# from .groupnorm import group_norm
-# from .gt import gt, gt_scalar
-# from .hstack import hstack
-# from .index_add import index_add
-# from .index_select import index_select
-# from .instancenorm import instance_norm
-# from .isclose import allclose, isclose
-# from .isfinite import isfinite
-# from .isin import isin
-# from .isinf import isinf
-# from .isnan import isnan
-# from .layernorm import layer_norm
-# from .le import le, le_scalar
-# from .log_sigmoid import log_sigmoid
-# from .log_softmax import log_softmax
-# from .logical_and import logical_and
-# from .logical_not import logical_not
-# from .logical_or import logical_or
-# from .logical_xor import logical_xor
-# from .lt import lt, lt_scalar
-# from .masked_fill import masked_fill, masked_fill_
-# from .masked_select import masked_select
-# from .max import max, max_dim
-# from .maximum import maximum
-# from .mean import mean, mean_dim
-# from .min import min, min_dim
-# from .minimum import minimum
-# from .mm import mm
-# from .mse_loss import mse_loss
-# from .mul import mul
-# from .multinomial import multinomial
-# from .mv import mv
-# from .ne import ne, ne_scalar
-# from .neg import neg
-# from .nllloss import (
-#     nll_loss2d_backward,
-#     nll_loss2d_forward,
-#     nll_loss_backward,
-#     nll_loss_forward,
-# )
-# from .nonzero import nonzero
-# from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
-# from .ones import ones
-# from .ones_like import ones_like
-# from .outer import outer
-# from .pad import constant_pad_nd, pad
-# from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
-# from .prod import prod, prod_dim
-# from .quantile import quantile
-# from .rand import rand
-# from .rand_like import rand_like
-# from .randn import randn
-# from .randn_like import randn_like
-# from .randperm import randperm
-# from .reciprocal import reciprocal
-# from .relu import relu
-# from .repeat import repeat
-# from .repeat_interleave import (
-#     repeat_interleave_self_int,
-#     repeat_interleave_self_tensor,
-#     repeat_interleave_tensor,
-# )
-# from .resolve_conj import resolve_conj
-# from .resolve_neg import resolve_neg
-# from .rms_norm import rms_norm
-# from .rsqrt import rsqrt
-# from .scatter import scatter
-# from .select_scatter import select_scatter
-# from .sigmoid import sigmoid
-# from .silu import silu
-# from .sin import sin
-# from .slice_scatter import slice_scatter
-# from .softmax import softmax
-# from .sort import sort
-# from .stack import stack
-# from .sub import sub
-# from .sum import sum, sum_dim
-# from .tanh import tanh
-# from .tile import tile
-# from .topk import topk
-# from .triu import triu
-# from .uniform import uniform_
-# from .unique import _unique2
-# from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa
-# from .upsample_nearest2d import upsample_nearest2d
-# from .var_mean import var_mean
-# from .vdot import vdot
-# from .vector_norm import vector_norm
-# from .vstack import vstack
-# from .weightnorm import weight_norm, weight_norm_interface
-# from .where import where_scalar_other, where_scalar_self, where_self, where_self_out
-# from .zeros import zeros
-# from .zeros_like import zeros_like
+from .batch_norm import batch_norm
+from .bitwise_and import (
+    bitwise_and_scalar,
+    bitwise_and_scalar_tensor,
+    bitwise_and_tensor,
+)
+from .bitwise_not import bitwise_not
+from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
+from .bmm import bmm
+from .cat import cat
+from .clamp import clamp, clamp_tensor
+from .conv1d import conv1d
+from .conv2d import conv2d
+from .conv_depthwise2d import _conv_depthwise2d
+from .cos import cos
+from .count_nonzero import count_nonzero
+from .cross_entropy_loss import cross_entropy_loss
+from .cummin import cummin
+from .cumsum import cumsum, normed_cumsum
+from .diag import diag
+from .diag_embed import diag_embed
+from .diagonal import diagonal_backward
+from .div import div_mode, floor_divide, remainder, true_divide
+from .dropout import native_dropout
+from .embedding import embedding
+from .eq import eq, eq_scalar
+from .erf import erf
+from .exp import exp
+from .exponential_ import exponential_
+from .fill import fill_scalar, fill_tensor
+from .flip import flip
+from .full import full
+from .full_like import full_like
+from .gather import gather, gather_backward
+from .ge import ge, ge_scalar
+from .gelu import gelu
+from .groupnorm import group_norm
+from .gt import gt, gt_scalar
+from .hstack import hstack
+from .index_add import index_add
+from .index_select import index_select
+from .instancenorm import instance_norm
+from .isclose import allclose, isclose
+from .isfinite import isfinite
+from .isin import isin
+from .isinf import isinf
+from .isnan import isnan
+from .layernorm import layer_norm
+from .le import le, le_scalar
+from .log_sigmoid import log_sigmoid
+from .log_softmax import log_softmax
+from .logical_and import logical_and
+from .logical_not import logical_not
+from .logical_or import logical_or
+from .logical_xor import logical_xor
+from .lt import lt, lt_scalar
+from .masked_fill import masked_fill, masked_fill_
+from .masked_select import masked_select
+from .max import max, max_dim
+from .maximum import maximum
+from .mean import mean, mean_dim
+from .min import min, min_dim
+from .minimum import minimum
+from .mm import mm
+from .mse_loss import mse_loss
+from .mul import mul
+from .multinomial import multinomial
+from .mv import mv
+from .ne import ne, ne_scalar
+from .neg import neg
+from .nllloss import (
+    nll_loss2d_backward,
+    nll_loss2d_forward,
+    nll_loss_backward,
+    nll_loss_forward,
+)
+from .nonzero import nonzero
+from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
+from .ones import ones
+from .ones_like import ones_like
+from .outer import outer
+from .pad import constant_pad_nd, pad
+from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
+from .prod import prod, prod_dim
+from .quantile import quantile
+from .rand import rand
+from .rand_like import rand_like
+from .randn import randn
+from .randn_like import randn_like
+from .randperm import randperm
+from .reciprocal import reciprocal
+from .relu import relu
+from .repeat import repeat
+from .repeat_interleave import (
+    repeat_interleave_self_int,
+    repeat_interleave_self_tensor,
+    repeat_interleave_tensor,
+)
+from .resolve_conj import resolve_conj
+from .resolve_neg import resolve_neg
+from .rms_norm import rms_norm
+from .rsqrt import rsqrt
+from .scatter import scatter
+from .select_scatter import select_scatter
+from .sigmoid import sigmoid
+from .silu import silu
+from .sin import sin
+from .slice_scatter import slice_scatter
+from .softmax import softmax
+from .sort import sort
+from .stack import stack
+from .sub import sub
+from .sum import sum, sum_dim
+from .tanh import tanh
+from .tile import tile
+from .topk import topk
+from .triu import triu
+from .uniform import uniform_
+from .unique import _unique2
+from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa
+from .upsample_nearest2d import upsample_nearest2d
+from .var_mean import var_mean
+from .vdot import vdot
+from .vector_norm import vector_norm
+from .vstack import vstack
+from .weightnorm import weight_norm, weight_norm_interface
+from .where import where_scalar_other, where_scalar_self, where_self, where_self_out
+from .zeros import zeros
+from .zeros_like import zeros_like
 
 __all__ = [
-#     "log_sigmoid",
-#     "all",
-#     "all_dim",
-#     "all_dims",
-#     "allclose",
-#     "any",
-#     "any_dim",
-#     "any_dims",
-#     "add",
-#     "abs",
-#     "addmm",
-#     "arange",
-#     "arange_start",
-#     "batch_norm",
-#     "bitwise_and_tensor",
-#     "bitwise_and_scalar",
-#     "bitwise_and_scalar_tensor",
-#     "bitwise_not",
-#     "bitwise_or_tensor",
-#     "bitwise_or_scalar",
-#     "bitwise_or_scalar_tensor",
-#     "bmm",
-#     "clamp",
-#     "clamp_tensor",
-#     "cos",
-#     "count_nonzero",
-#     "diag",
-#     "diag_embed",
-#     "diagonal_backward",
-#     "pad",
-#     "constant_pad_nd",
-#     "cummin",
-#     "cumsum",
-#     "normed_cumsum",
-#     "true_divide",
-#     "div_mode",
-#     "floor_divide",
-#     "remainder",
-#     "zeros",
-#     "ones",
-#     "full",
-#     "native_dropout",
-#     "erf",
-#     "embedding",
-#     "eq",
-#     "eq_scalar",
-#     "exp",
-#     "fill_scalar",
-#     "fill_tensor",
-#     "exponential_",
-#     "gather",
-#     "gather_backward",
-#     "flip",
-#     "ones_like",
-#     "full_like",
-#     "zeros_like",
-#     "ge",
-#     "ge_scalar",
-#     "gelu",
-#     "group_norm",
-#     "gt",
-#     "gt_scalar",
-#     "index_select",
-#     "instance_norm",
-#     "isclose",
-#     "isfinite",
-#     "isin",
-#     "isinf",
-#     "isnan",
-#     "layer_norm",
-#     "weight_norm_interface",
-#     "weight_norm",
-#     "le",
-#     "le_scalar",
-#     "lt",
-#     "lt_scalar",
-#     "rms_norm",
-#     "mean",
-#     "mean_dim",
-#     "mm",
-#     "mul",
-#     "multinomial",
-#     "maximum",
-#     "minimum",
-#     "rand",
-#     "randn",
-#     "randperm",
-#     "rand_like",
-#     "randn_like",
-#     "resolve_neg",
-#     "resolve_conj",
-#     "normal_tensor_float",
-#     "normal_float_tensor",
-#     "normal_tensor_tensor",
-#     "uniform_",
-#     "mv",
-#     "ne",
-#     "ne_scalar",
-#     "neg",
-#     "pow_scalar",
-#     "pow_tensor_scalar",
-#     "pow_tensor_tensor",
-#     "reciprocal",
-#     "relu",
-#     "rsqrt",
-#     "scatter",
-#     "sigmoid",
-#     "silu",
-#     "sin",
-#     "softmax",
-#     "sub",
-#     "tanh",
-#     "tile",
-#     "triu",
-#     "topk",
-#     "max",
-#     "max_dim",
-#     "min",
-#     "min_dim",
-#     "sum",
-#     "sum_dim",
-#     "amax",
-#     "argmax",
-#     "argmin",
-#     "prod",
-#     "prod_dim",
-#     "quantile",
-#     "var_mean",
-#     "vector_norm",
-#     "log_softmax",
-#     "outer",
-#     "cross_entropy_loss",
-#     "where_self_out",
-#     "where_self",
-#     "where_scalar_self",
-#     "where_scalar_other",
-#     "index_add",
-#     "select_scatter",
-#     "slice_scatter",
-#     "masked_fill",
-#     "masked_fill_",
-#     "_unique2",
-#     "_upsample_bicubic2d_aa",
-#     "upsample_nearest2d",
-#     "nonzero",
-#     "repeat",
-#     "masked_select",
-#     "stack",
-#     "hstack",
-#     "cat",
-#     "repeat_interleave_self_int",
-#     "vstack",
-#     "repeat_interleave_tensor",
+    "log_sigmoid",
+    "all",
+    "all_dim",
+    "all_dims",
+    "allclose",
+    "any",
+    "any_dim",
+    "any_dims",
+    "add",
+    "abs",
+    "addmm",
+    "arange",
+    "arange_start",
+    "batch_norm",
+    "bitwise_and_tensor",
+    "bitwise_and_scalar",
+    "bitwise_and_scalar_tensor",
+    "bitwise_not",
+    "bitwise_or_tensor",
+    "bitwise_or_scalar",
+    "bitwise_or_scalar_tensor",
+    "bmm",
+    "clamp",
+    "clamp_tensor",
+    "cos",
+    "count_nonzero",
+    "diag",
+    "diag_embed",
+    "diagonal_backward",
+    "pad",
+    "constant_pad_nd",
+    "cummin",
+    "cumsum",
+    "normed_cumsum",
+    "true_divide",
+    "div_mode",
+    "floor_divide",
+    "remainder",
+    "zeros",
+    "ones",
+    "full",
+    "native_dropout",
+    "erf",
+    "embedding",
+    "eq",
+    "eq_scalar",
+    "exp",
+    "fill_scalar",
+    "fill_tensor",
+    "exponential_",
+    "gather",
+    "gather_backward",
+    "flip",
+    "ones_like",
+    "full_like",
+    "zeros_like",
+    "ge",
+    "ge_scalar",
+    "gelu",
+    "group_norm",
+    "gt",
+    "gt_scalar",
+    "index_select",
+    "instance_norm",
+    "isclose",
+    "isfinite",
+    "isin",
+    "isinf",
+    "isnan",
+    "layer_norm",
+    "weight_norm_interface",
+    "weight_norm",
+    "le",
+    "le_scalar",
+    "lt",
+    "lt_scalar",
+    "rms_norm",
+    "mean",
+    "mean_dim",
+    "mm",
+    "mul",
+    "multinomial",
+    "maximum",
+    "minimum",
+    "rand",
+    "randn",
+    "randperm",
+    "rand_like",
+    "randn_like",
+    "resolve_neg",
+    "resolve_conj",
+    "normal_tensor_float",
+    "normal_float_tensor",
+    "normal_tensor_tensor",
+    "uniform_",
+    "mv",
+    "ne",
+    "ne_scalar",
+    "neg",
+    "pow_scalar",
+    "pow_tensor_scalar",
+    "pow_tensor_tensor",
+    "reciprocal",
+    "relu",
+    "rsqrt",
+    "scatter",
+    "sigmoid",
+    "silu",
+    "sin",
+    "softmax",
+    "sub",
+    "tanh",
+    "tile",
+    "triu",
+    "topk",
+    "max",
+    "max_dim",
+    "min",
+    "min_dim",
+    "sum",
+    "sum_dim",
+    "amax",
+    "argmax",
+    "argmin",
+    "prod",
+    "prod_dim",
+    "quantile",
+    "var_mean",
+    "vector_norm",
+    "log_softmax",
+    "outer",
+    "cross_entropy_loss",
+    "where_self_out",
+    "where_self",
+    "where_scalar_self",
+    "where_scalar_other",
+    "index_add",
+    "select_scatter",
+    "slice_scatter",
+    "masked_fill",
+    "masked_fill_",
+    "_unique2",
+    "_upsample_bicubic2d_aa",
+    "upsample_nearest2d",
+    "nonzero",
+    "repeat",
+    "masked_select",
+    "stack",
+    "hstack",
+    "cat",
+    "repeat_interleave_self_int",
+    "vstack",
+    "repeat_interleave_tensor",
     "flash_attention_forward",
-#     "scaled_dot_product_attention",
-#     "conv2d",
-#     "conv1d",
-#     "_conv_depthwise2d",
-#     "repeat_interleave_self_tensor",
-#     "logical_or",
-#     "logical_and",
-#     "logical_xor",
-#     "logical_not",
-#     "sort",
-#     "nll_loss_forward",
-#     "nll_loss_backward",
-#     "nll_loss2d_forward",
-#     "nll_loss2d_backward",
-#     "vdot",
-#     "mse_loss",
+    "scaled_dot_product_attention",
+    "conv2d",
+    "conv1d",
+    "_conv_depthwise2d",
+    "repeat_interleave_self_tensor",
+    "logical_or",
+    "logical_and",
+    "logical_xor",
+    "logical_not",
+    "sort",
+    "nll_loss_forward",
+    "nll_loss_backward",
+    "nll_loss2d_forward",
+    "nll_loss2d_backward",
+    "vdot",
+    "mse_loss",
 ]
diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 025270edd..1c6b98460 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -521,7 +521,7 @@ def apply_dropout(
     return P
 
 
-@triton.jit(do_not_specialize=['max_seqlen_q', 'max_seqlen_k'])
+@triton.jit
 def apply_mask(
     S,
     col_idx,
@@ -587,37 +587,23 @@ def softmax_rescale(
 
 def block_m_heuristic(headdim, is_dropout):
     block_m = 128 if headdim <= 128 else 64
-    print('block_m:', block_m)
     return block_m
 
 def block_n_heuristic(headdim, is_dropout):
     block_n = 64 if headdim <= 64 else 32
-    print('block_n:', block_n)
     return block_n
 
+def block_m_splitkv_heuristic(headdim):
+    return 128 if headdim <= 128 else 64
+
+def block_n_splitkv_heuristic(headdim):
+    block_n = 64 if headdim <= 64 else 32
+
 def is_even_mn(args):
     even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
-    print('is_even_mn:', even_mn)
     return even_mn
 
-def block_m_splitkv_heuristic(headdim):
-    return 128 if headdim <= 128 else 64
 
-def block_n_splitkv_heuristic(headdim):
-    if headdim <= 64:
-        return 64
-    else:
-        return 32
-
-# @triton.autotune(
-#     configs=runtime.get_tuned_config("attention"),
-#     key=["HEAD_DIM"],
-#     prune_configs_by={
-#         "early_config_prune": early_config_prune,
-#         "perf_model": None,
-#         "top_k": 1.0,
-#     },
-# )
 @triton.heuristics(
     values={
         'BLOCK_M': lambda args: block_m_heuristic(args["HEAD_DIM"], args["is_dropout"]),
@@ -833,7 +819,7 @@ def flash_fwd_kernel(
         O_ = tl.dot(P, V, O_, allow_tf32=False)
 
     for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages):
-        # col_start = tl.multiple_of(col_start, BLOCK_N)
+        col_start = tl.multiple_of(col_start, BLOCK_N)
         off = col_start * k_s_stride
         K = tl.load(p_bk0 + off, cache_modifier=".cg")
         if PRE_LOAD_V:
@@ -946,7 +932,7 @@ def flash_fwd_kernel(
 @triton.heuristics(
     values={
         'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
-        'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'BLOCK_N': lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]),
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
@@ -1023,8 +1009,6 @@ def flash_fwd_splitkv_kernel_v2(
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
-    else:
-        alibi_slope = 0.0
 
     if not is_causal:
         if IS_EVEN_MN:
@@ -1151,7 +1135,7 @@ def flash_fwd_splitkv_kernel_v2(
 @triton.heuristics(
     values={
         'BLOCK_M': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
-        'BLOCK_N': lambda args: block_m_splitkv_heuristic(args["HEAD_DIM"]),
+        'BLOCK_N': lambda args: block_n_splitkv_heuristic(args["HEAD_DIM"]),
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
@@ -1225,8 +1209,6 @@ def flash_fwd_splitkv_kernel(
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
-    else:
-        alibi_slope = 0.0
 
     if not is_causal:
         if IS_EVEN_MN:
@@ -1307,7 +1289,7 @@ def flash_fwd_splitkv_kernel(
                 if PRE_LOAD_V:
                     V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
 
-            S = tl.dot(Q, K, allow_tf32=False)
+            S = tl.dot(Q, K)
 
             S = apply_mask(
                 S,
@@ -1339,7 +1321,7 @@ def flash_fwd_splitkv_kernel(
                 else:
                     V = tl.load(p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg")
             P = P.to(Q_ptr.type.element_ty)
-            O_ = tl.dot(P, V, O_, allow_tf32=False)
+            O_ = tl.dot(P, V, O_)
 
     # LSE
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
@@ -1427,7 +1409,9 @@ def flash_fwd_splitkv_combine_kernel(
     # Write back output
     out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size)
     tl.store(out_ptr + out_offset, out, mask=out_mask[:, None])
-    
+
+
+_debug = False
 
 def mha_fwd(
     q,
@@ -1500,17 +1484,20 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         if eff > 0.8 or n_waves > 1:
             return 1
 
-        best_splits = 1
-        best_eff = eff
-        min_blocks_per_split = 1
-        max_blocks_per_split = triton.cdiv(n_blocks, 2)
-        for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]:
-            n_splits = triton.cdiv(n_blocks, blocks_per_split)
-            n_waves = triton.cdiv(n_splits * num_tasks, num_sms)
-            eff = (n_splits * num_tasks / num_sms) / n_waves
-            if eff > 0.85:
-                best_splits = n_splits
-                break
+        min_blocks_per_split = 2
+        best_splits = min(triton.cdiv(n_blocks, min_blocks_per_split), int(math.floor(1. / eff)), num_sms)
+
+        # best_splits = 1
+        # best_eff = eff
+        # min_blocks_per_split = 1
+        # max_blocks_per_split = triton.cdiv(n_blocks, 2)
+        # for blocks_per_split in range(min_blocks_per_split, max_blocks_per_split + 1)[::-1]:
+        #     n_splits = triton.cdiv(n_blocks, blocks_per_split)
+        #     n_waves = triton.cdiv(n_splits * num_tasks, num_sms)
+        #     eff = (n_splits * num_tasks / num_sms) / n_waves
+        #     if eff > 0.85:
+        #         best_splits = n_splits
+        #         break
         return best_splits
 
     with torch_device_fn.device(q_device):
@@ -1576,13 +1563,14 @@ def try_split_kv():
         else:
             n_splits, blocks_per_split = 1, None
 
-#        n_splits = 1
-        block_n = block_n_splitkv_heuristic(head_size)
-        n_blocks = triton.cdiv(seqlen_k, block_n)
-        blocks_per_split = triton.cdiv(n_blocks, n_splits)
-        print('block_n:', block_n)
-        print('n_splits:', n_splits)
-        print('blocks_per_split', blocks_per_split)
+        if _debug:
+            n_splits = 32
+            block_n = block_n_splitkv_heuristic(head_size)
+            n_blocks = triton.cdiv(seqlen_k, block_n)
+            blocks_per_split = triton.cdiv(n_blocks, n_splits)
+            print('block_n:', block_n)
+            print('n_splits:', n_splits)
+            print('blocks_per_split', blocks_per_split)
 
         if n_splits > 1:
             lse_splits = torch.empty(
@@ -1659,10 +1647,11 @@ def try_split_kv():
             NUM_HEADS=num_heads,
             NUM_HEADS_K=num_heads_k,
         )
-        print(f'{kernel.name} shared memory:', kernel.metadata.shared)
-        print(f'{kernel.name} num_warps:', kernel.metadata.num_warps)
-        print(f'{kernel.name} num_stages:', kernel.metadata.num_stages)
-        # print(kernel.asm['ttgir'])
+        if debug:
+            print(f'{kernel.name} shared memory:', kernel.metadata.shared)
+            print(f'{kernel.name} num_warps:', kernel.metadata.num_warps)
+            print(f'{kernel.name} num_stages:', kernel.metadata.num_stages)
+            print(kernel.asm['ttgir'])
     
     if n_splits > 1:
         if head_size % 128 == 0:
@@ -1713,7 +1702,7 @@ def flash_attention_forward(
     seqused_k=None,
     alibi_slopes=None
 ):
-    logging.debug("GEMS FLASH_ATTENTION")
+    logging.debug("GEMS FLASH_ATTENTION_FORWARD")
     assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet."
 
     HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]

From 298d5e8782b355b6257575b33ac363bb408aaf56 Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Mon, 24 Mar 2025 10:04:11 +0000
Subject: [PATCH 23/25] Polish code.

---
 src/flag_gems/ops/attention.py | 420 ++++++++++++---------------------
 1 file changed, 150 insertions(+), 270 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 1c6b98460..fcff69394 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -530,11 +530,11 @@ def apply_mask(
     max_seqlen_k,
     ws_left,
     ws_right,
-    alibi_slope,
     is_even_mn: tl.constexpr,
     is_causal: tl.constexpr,
     is_local: tl.constexpr,
     has_alibi: tl.constexpr,
+    alibi_slope: tl.constexpr=None,
 ):
     need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn)
     if need_mask:
@@ -597,7 +597,7 @@ def block_m_splitkv_heuristic(headdim):
     return 128 if headdim <= 128 else 64
 
 def block_n_splitkv_heuristic(headdim):
-    block_n = 64 if headdim <= 64 else 32
+    return 64 if headdim <= 64 else 32
 
 def is_even_mn(args):
     even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
@@ -755,11 +755,11 @@ def flash_fwd_kernel(
             seqlen_k,
             ws_left,
             ws_right,
-            alibi_slope,
             is_even_mn=IS_EVEN_MN,
             is_causal=is_causal,
             is_local=is_local,
-            has_alibi=has_alibi
+            has_alibi=has_alibi,
+            alibi_slope=alibi_slope,
         )
 
         O_, P, rowmax_, rowsum_ = softmax_rescale(
@@ -836,11 +836,11 @@ def flash_fwd_kernel(
             seqlen_k,
             ws_left,
             ws_right,
-            alibi_slope,
             is_even_mn=True,
             is_causal=False,
             is_local=is_local,
-            has_alibi=has_alibi
+            has_alibi=has_alibi,
+            alibi_slope=alibi_slope,
         )
 
         O_, P, rowmax_, rowsum_ = softmax_rescale(
@@ -940,7 +940,7 @@ def flash_fwd_kernel(
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
-def flash_fwd_splitkv_kernel_v2(
+def flash_fwd_bh_parallel_kernel(
     Q_ptr,
     K_ptr,
     V_ptr,
@@ -989,148 +989,9 @@ def flash_fwd_splitkv_kernel_v2(
     num_warps: tl.constexpr,
     num_stages: tl.constexpr
 ):
-    m_block = tl.program_id(0)
-    split_id = tl.program_id(1)
-    bid = tl.program_id(2) // NUM_HEADS
-    hid = tl.program_id(2) % NUM_HEADS
-
-    split_col_min = split_id * blocks_per_split * BLOCK_N
-    split_col_max = split_col_min + blocks_per_split * BLOCK_N
-
-    col_min = 0
-    
-    col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N
-    if is_causal:
-        col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right)
-
-    split_col_max = min(split_col_max, col_max)
-
-    if has_alibi:
-        alibi_offset = bid * slopes_batch_stride + hid
-        alibi_slope = tl.load(pSlopes + alibi_offset)
-        alibi_slope /= scale
-
-    if not is_causal:
-        if IS_EVEN_MN:
-            masking_cols: tl.constexpr = 0
-        else:
-            masking_cols: tl.constexpr = BLOCK_N
-    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
-    else:
-        # local and not causal, 
-        masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
-
-    Q_ptr += bid * q_b_stride
-    Q_ptr += hid * q_h_stride
-    row_start = m_block * BLOCK_M
-    row_idx = row_start + tl.arange(0, BLOCK_M)
-    Q_off = row_idx[:, None] * q_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-    p_qm = Q_ptr + Q_off
-    qmask = row_idx[:, None] < seqlen_q
-    if IS_EVEN_MN:
-        Q = tl.load(p_qm, cache_modifier=".cg")
-    else:
-        Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
-
-    h_hk_ratio = h // hk
-    K_ptr += bid * k_b_stride
-    K_ptr += (hid // h_hk_ratio) * k_h_stride
-    V_ptr += bid * k_b_stride
-    V_ptr += (hid // h_hk_ratio) * k_h_stride
-
-    K_offset = tl.arange(0, BLOCK_N)[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
-    p_bk0 = K_ptr + K_offset
-    
-    V_offset = tl.arange(0, BLOCK_N)[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
-    p_bv0 = V_ptr + V_offset
-
-    O_ = tl.zeros((BLOCK_M, HEAD_DIM), dtype=tl.float32)
-    rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
-    rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
-
-    for col_start in tl.range(split_col_min, split_col_max, step=BLOCK_N):
-        col_start = tl.multiple_of(col_start, BLOCK_N)
-        off = col_start * k_s_stride
-        if IS_EVEN_MN:
-            K = tl.load(p_bk0 + off, cache_modifier=".cg")
-            if PRE_LOAD_V:
-                V = tl.load(p_bv0 + off, cache_modifier=".cg")
-        else:
-            col_idx = col_start + tl.arange(0, BLOCK_N)
-            kvmask = col_idx < seqlen_k
-            K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
-            if PRE_LOAD_V:
-                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
-        S = tl.dot(Q, K, allow_tf32=False)
-        col_idx = col_start + tl.arange(0, BLOCK_N)
-        row_idx = row_start + tl.arange(0, BLOCK_M)
-        S = apply_mask(
-            S,
-            col_idx,
-            row_idx,
-            seqlen_q,
-            seqlen_k,
-            ws_left,
-            ws_right,
-            alibi_slope,
-            is_even_mn=IS_EVEN_MN,
-            is_causal=is_causal,
-            is_local=False,
-            has_alibi=has_alibi
-        )
-
-        O_, P, rowmax_, rowsum_ = softmax_rescale(
-            O_,
-            S,
-            rowmax_,
-            rowsum_,
-            softmax_scale_log2e=softmax_scale_log2e,
-            is_border=(is_causal or is_local),
-        )
-        P = P.to(V_ptr.type.element_ty)
-
-        if not PRE_LOAD_V:
-            off = col_start * k_s_stride
-            if IS_EVEN_MN:
-                V = tl.load(p_bv0 + off, cache_modifier=".cg")
-            else:
-                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
-        O_ = tl.dot(P, V, O_, allow_tf32=False)
-
-    # LSE
-    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
-    inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
-    
-    # Rescale output
-    O_ *= inv_sum[:, None]
-
-    # Write back output
-    # O_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
-    # grid = (seq_block, split, batch * head)
-    O_split_ptr = O_ptr
-    # + split, batch, head offsets, seq_block offsets are already added in row_idx
-    O_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * HEAD_DIM
-    O_split_offset = row_idx[:, None] * HEAD_DIM + tl.arange(0, HEAD_DIM)
-    O_split_ptr = tl.multiple_of(O_split_ptr, HEAD_DIM)
-    p_om = O_split_ptr + O_split_offset
-
-    if IS_EVEN_MN:
-        tl.store(p_om, O_, cache_modifier=".cg")
-    else:
-        tl.store(p_om, O_, mask=qmask, cache_modifier=".cg")
+    # (TODO)
+    pass
     
-    # Write back lse
-    # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
-    lse_split_ptr = lse_ptr
-    # + split, batch, head, seq_block offsets
-    lse_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q + m_block * BLOCK_M
-
-    if IS_EVEN_MN:
-        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
-    else:
-        tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, mask=row_idx < seqlen_q, cache_modifier=".cg")
-
 
 @triton.heuristics(
     values={
@@ -1209,6 +1070,8 @@ def flash_fwd_splitkv_kernel(
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
         alibi_slope /= scale
+    else:
+        alibi_slope = 0
 
     if not is_causal:
         if IS_EVEN_MN:
@@ -1299,11 +1162,11 @@ def flash_fwd_splitkv_kernel(
                 seqlen_k,
                 ws_left,
                 ws_right,
-                alibi_slope,
                 is_even_mn=IS_EVEN_MN,
                 is_causal=is_causal,
                 is_local=False,
-                has_alibi=has_alibi
+                has_alibi=has_alibi,
+                alibi_slope=alibi_slope,
             )
 
             O_, P, rowmax_, rowsum_ = softmax_rescale(
@@ -1547,134 +1410,151 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
         # ONLY EVEN_K IS SUPPORTED
         assert head_size == head_size_rounded
 
-        # Check splitkv
-        def try_split_kv():
-            block_m = block_m_splitkv_heuristic(head_size)
-            n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m)
+        # Do kernel dispatching
+        def dispatch(B, H, Q, K, D):
             num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count
-            block_n = block_n_splitkv_heuristic(head_size)
-            n_blocks = triton.cdiv(seqlen_k, block_n)
-            n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
-            blocks_per_split = triton.cdiv(n_blocks, n_splits)
-            return n_splits, blocks_per_split
-        
-        if not is_dropout and not is_local:
-            n_splits, blocks_per_split = try_split_kv()
-        else:
-            n_splits, blocks_per_split = 1, None
 
-        if _debug:
-            n_splits = 32
-            block_n = block_n_splitkv_heuristic(head_size)
-            n_blocks = triton.cdiv(seqlen_k, block_n)
-            blocks_per_split = triton.cdiv(n_blocks, n_splits)
-            print('block_n:', block_n)
-            print('n_splits:', n_splits)
-            print('blocks_per_split', blocks_per_split)
-
-        if n_splits > 1:
-            lse_splits = torch.empty(
-                (n_splits, batch_size, num_heads, seqlen_q),
-                dtype=torch.float,
-                device=q_device
-            )
-            out_splits = torch.empty(
-                (n_splits, batch_size, num_heads, seqlen_q, head_size),
-                dtype=torch.float,
-                device=q_device
-            )
-
-        # Launch kernel
-        if n_splits > 1:
-            grid = lambda args: (
-                triton.cdiv(seqlen_q, args["BLOCK_M"]),
-                n_splits,
-                batch_size * num_heads
-            )
-            kernel = flash_fwd_splitkv_kernel_v2[grid]
-            tmp_lse = lse_splits
-            tmp_out = out_splits
-        else:
-            grid = lambda args: (
-                triton.cdiv(seqlen_q, args["BLOCK_M"]), # num_m_blocks
-                batch_size,
-                num_heads,
-            )
+            default_args = {}
+
+            # Try bh parallel
+            # if B * H > 0.8 * num_sms:
+            #     kernel = flash_fwd_bh_parallel_kernel[(H, B)]
+            #     # Yield kernel and prefilled args
+            #     return kernel, default_args, None, None
+
+            # Try splitkv
+            if not is_dropout and not is_local:
+                BM = block_m_splitkv_heuristic(D)
+                n_tasks = B * H * triton.cdiv(seqlen_q, BM)
+                BN = block_n_splitkv_heuristic(D)
+                n_blocks = triton.cdiv(seqlen_k, BN)
+                n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
+                
+                if _debug:
+                    n_splits = 32
+                    n_blocks = triton.cdiv(K, BN)
+                    blocks_per_split = triton.cdiv(n_blocks, n_splits)
+                    print('block_n:', block_n)
+                    print('n_splits:', n_splits)
+                    print('blocks_per_split', blocks_per_split)
+
+                if n_splits > 1:
+                    lse_splits = torch.empty(
+                        (n_splits, B, H, Q),
+                        dtype=torch.float,
+                        device=q_device
+                    )
+                    out_splits = torch.empty(
+                        (n_splits, B, H, Q, D),
+                        dtype=torch.float,
+                        device=q_device
+                    )
+                    grid = lambda args: (
+                        triton.cdiv(Q, args["BLOCK_M"]),
+                        n_splits,
+                        B * H
+                    )
+                    splitkv_kernel = flash_fwd_splitkv_kernel[grid]
+                    blocks_per_split = triton.cdiv(n_blocks, n_splits)
+                    splitkv_args = default_args.copy()
+                    splitkv_args['blocks_per_split'] = blocks_per_split
+                    splitkv_args['O_ptr'] = out_splits
+                    splitkv_args['lse_ptr'] = lse_splits
+                    # kernel = yield kernel, args
+                
+                    if D % 128 == 0:
+                        BLOCK_M = 4
+                    elif D % 64 == 0:
+                        BLOCK_M = 8
+                    else:
+                        BLOCK_M = 16
+                    grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M), )
+                    combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
+                    combine_args = {
+                        'out_splits_ptr': out_splits,
+                        'lse_splits_ptr': lse_splits,
+                        'n_splits': n_splits,
+                        'BLOCK_M': BLOCK_M,
+                        'q_total': B * H * Q,
+                        'MAX_N_SPLITS': triton.next_power_of_2(n_splits)
+                    }
+                    return splitkv_kernel, splitkv_args, combine_kernel, combine_args
+            
+            # Last option: flash_fwd
+            grid = lambda args: (triton.cdiv(Q, args["BLOCK_M"]), B, H, )
             kernel = flash_fwd_kernel[grid]
-            tmp_lse = lse
-            tmp_out = out
+            return kernel, default_args, None, None
+
+        kernel1, kernel1_args, kernel2, kernel2_args = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size)
+
+        prefilled_args = {
+            "Q_ptr": q,
+            "K_ptr": k,
+            "V_ptr": v,
+            "P_ptr": p,
+            "O_ptr": out,
+            "lse_ptr": lse,
+            "seqlen_q": seqlen_q,
+            "seqlen_k": seqlen_k,
+            "seqlen_q_rounded": seqlen_q_rounded,
+            "seqlen_k_rounded": seqlen_k_rounded,
+            "q_b_stride": q.stride(0),
+            "q_s_stride": q.stride(-3),
+            "q_h_stride": q.stride(-2),
+            "k_b_stride": k.stride(0),
+            "k_s_stride": k.stride(-3),
+            "k_h_stride": k.stride(-2),
+            "o_b_stride": out.stride(0),
+            "o_s_stride": out.stride(-3),
+            "o_h_stride": out.stride(-2),
+            "h": num_heads,
+            "hk": num_heads_k,
+            "pSlopes": alibi_slopes,
+            "philox_seed": philox_seed,
+            "philox_offset": philox_offset,
+            "pdrop_u8": pdrop_u8,
+            "rpdrop": rpdrop,
+            "slopes_batch_stride": alibi_slopes_batch_stride,
+            "HEAD_DIM": head_size,
+            "is_dropout": is_dropout,
+            "is_causal": is_causal,
+            "is_local": is_local,
+            "has_alibi": has_alibi,
+            "softmax_scale": softmax_scale,
+            "softmax_scale_log2e": softmax_scale_log2e,
+            "ws_left": window_size_left,
+            "ws_right": window_size_right,
+            "return_P": return_softmax,
+            "BATCH_SIZE": batch_size,
+            "blocks_per_split": None,
+            "NUM_HEADS": num_heads,
+            "NUM_HEADS_K": num_heads_k,
+        }
+        
+        args_copy = prefilled_args.copy()
+        args_copy.update(kernel1_args)
 
-        kernel = kernel(
-            q,
-            k,
-            v,
-            p,
-            tmp_out,
-            tmp_lse,
-            seqlen_q,
-            seqlen_k,
-            seqlen_q_rounded,
-            seqlen_k_rounded,
-            q.stride(0),
-            q.stride(-3),
-            q.stride(-2),
-            k.stride(0),
-            k.stride(-3),
-            k.stride(-2),
-            out.stride(0),
-            out.stride(-3),
-            out.stride(-2),
-            num_heads,
-            num_heads_k,
-            alibi_slopes,
-            philox_seed,
-            philox_offset,
-            pdrop_u8,
-            rpdrop,
-            alibi_slopes_batch_stride,
-            head_size,
-            is_dropout=is_dropout,
-            is_causal=is_causal,
-            is_local=is_local,
-            has_alibi=has_alibi,
-            softmax_scale=softmax_scale,
-            softmax_scale_log2e=softmax_scale_log2e,
-            ws_left=window_size_left,
-            ws_right=window_size_right,
-            return_P=return_softmax,
-            BATCH_SIZE=batch_size,
-            blocks_per_split=blocks_per_split,
-            NUM_HEADS=num_heads,
-            NUM_HEADS_K=num_heads_k,
-        )
-        if debug:
+        kernel = kernel1(**args_copy)
+        if _debug:
             print(f'{kernel.name} shared memory:', kernel.metadata.shared)
             print(f'{kernel.name} num_warps:', kernel.metadata.num_warps)
             print(f'{kernel.name} num_stages:', kernel.metadata.num_stages)
             print(kernel.asm['ttgir'])
-    
-    if n_splits > 1:
-        if head_size % 128 == 0:
-            BLOCK_M = 4
-        elif head_size % 64 == 0:
-            BLOCK_M = 8
-        else:
-            BLOCK_M = 16
-        grid = lambda args: (triton.cdiv(batch_size * num_heads * seqlen_q, BLOCK_M), )
-        kernel = flash_fwd_splitkv_combine_kernel[grid](
-            out,
-            lse,
-            tmp_out,
-            tmp_lse,
-            head_size,
-            out.stride(0),
-            out.stride(-3),
-            out.stride(-1),
-            n_splits,
-            BLOCK_M,
-            q_total=batch_size * num_heads * seqlen_q,
-            MAX_N_SPLITS=triton.next_power_of_2(n_splits),
-        )
+
+        # Combine
+        if kernel2 is not None:
+            prefilled_args = {
+                "out_ptr": out,
+                "lse_ptr": lse,
+                "head_size": head_size,
+                "out_b_stride": out.stride(0),
+                "out_s_stride": out.stride(-3),
+                "out_h_stride": out.stride(-1),
+            }
+            args_copy = prefilled_args.copy()
+            args_copy.update(kernel2_args)
+            kernel2(**args_copy)
+
 
     if swap_seq_and_group:
         out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))

From 6a741dd3470354246e41372ac34a718b0b4da4de Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Thu, 27 Mar 2025 04:30:10 +0000
Subject: [PATCH 24/25] fixed numerous bugs.

---
 src/flag_gems/ops/attention.py | 216 +++++++++++++++++++--------------
 tests/test_attention_ops.py    | 177 ++++++++++++++++++++++++---
 2 files changed, 280 insertions(+), 113 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index fcff69394..6bc3d1b99 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -548,7 +548,7 @@ def apply_mask(
             S = tl.where(col_idx[None, :] > col_rb[:, None], float('-inf'), S)
 
         if is_local:
-            S = tl.where(col_idx[None, :] > col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
+            S = tl.where((col_idx[None, :] > col_rb[:, None]) | (col_idx[None, :] < col_lb[:, None]), float('-inf'), S)
         
         if (not is_local) & (not is_causal) & (not is_even_mn):
             S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
@@ -676,7 +676,20 @@ def flash_fwd_kernel(
     
     col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N
     if is_causal or is_local:
-        col_max = min(col_max, (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + ws_right)
+        rounded_kv_end = max(0, tl.cdiv(seqlen_k - seqlen_q + ws_right, BLOCK_N) * BLOCK_N + (m_block + 1) * BLOCK_M)
+        col_max = min(col_max, rounded_kv_end)
+
+    if (not is_causal) and (not is_local):
+        if IS_EVEN_MN:
+            masking_cols: tl.constexpr = 0
+        else:
+            masking_cols: tl.constexpr = BLOCK_N
+    elif (is_causal | is_local) and IS_EVEN_MN: # causal implies ws_right is zero
+        masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
+    else:
+        # local  
+        masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
+
 
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
@@ -719,65 +732,78 @@ def flash_fwd_kernel(
     p_bk0 = K_ptr + K_offset
     p_bv0 = V_ptr + V_offset
 
-    if (not is_causal) and (not is_local):
-        if IS_EVEN_MN:
-            masking_cols: tl.constexpr = 0
-        else:
-            masking_cols: tl.constexpr = BLOCK_N
-    elif is_causal and IS_EVEN_MN: # causal implies ws_right is zero
-        masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
-    else:
-        # local and not causal, 
-        masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
-
-    for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
-        col_start = col_max - col_shift - BLOCK_N
-        col_start = tl.multiple_of(col_start, BLOCK_N)
-        off = col_start * k_s_stride
-        if IS_EVEN_MN:
-            K = tl.load(p_bk0 + off, cache_modifier=".cg")
-            if PRE_LOAD_V:
-                V = tl.load(p_bv0 + off, cache_modifier=".cg")
-        else:
+    if is_causal | is_local | (not IS_EVEN_MN):
+        # Cut short masking cols if there's not enough cols out there
+        masking_cols = min(col_max, masking_cols)
+        for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
+            col_start = col_max - col_shift - BLOCK_N
+            col_start = tl.multiple_of(col_start, BLOCK_N)
+            off = col_start * k_s_stride
+            if IS_EVEN_MN:
+                K = tl.load(p_bk0 + off, cache_modifier=".cg")
+                if PRE_LOAD_V:
+                    V = tl.load(p_bv0 + off, cache_modifier=".cg")
+            else:
+                col_idx = col_start + tl.arange(0, BLOCK_N)
+                kvmask = col_idx < seqlen_k
+                K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
+                if PRE_LOAD_V:
+                    V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
+            S = tl.dot(Q, K, allow_tf32=False)
             col_idx = col_start + tl.arange(0, BLOCK_N)
-            kvmask = col_idx < seqlen_k
-            K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
-            if PRE_LOAD_V:
-                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
-        S = tl.dot(Q, K, allow_tf32=False)
-        col_idx = col_start + tl.arange(0, BLOCK_N)
-        row_idx = row_start + tl.arange(0, BLOCK_M)
-        S = apply_mask(
-            S,
-            col_idx,
-            row_idx,
-            seqlen_q,
-            seqlen_k,
-            ws_left,
-            ws_right,
-            is_even_mn=IS_EVEN_MN,
-            is_causal=is_causal,
-            is_local=is_local,
-            has_alibi=has_alibi,
-            alibi_slope=alibi_slope,
-        )
+            row_idx = row_start + tl.arange(0, BLOCK_M)
 
-        O_, P, rowmax_, rowsum_ = softmax_rescale(
-            O_,
-            S,
-            rowmax_,
-            rowsum_,
-            softmax_scale_log2e=softmax_scale_log2e,
-            is_border=(is_causal or is_local),
-        )
-        P = P.to(V_ptr.type.element_ty)
+            # tl.store(p_bp0 + col_start, S)
+            S = apply_mask(
+                S,
+                col_idx,
+                row_idx,
+                seqlen_q,
+                seqlen_k,
+                ws_left,
+                ws_right,
+                is_even_mn=IS_EVEN_MN,
+                is_causal=is_causal,
+                is_local=is_local,
+                has_alibi=has_alibi,
+                alibi_slope=alibi_slope,
+            )
 
-        if is_dropout:
-            if return_P:
-                P_drop = P
+            O_, P, rowmax_, rowsum_ = softmax_rescale(
+                O_,
+                S,
+                rowmax_,
+                rowsum_,
+                softmax_scale_log2e=softmax_scale_log2e,
+                is_border=(is_causal or is_local),
+            )
+            P = P.to(V_ptr.type.element_ty)
+
+            if is_dropout:
+                if return_P:
+                    P_drop = P
+
+                    P_drop = apply_dropout(
+                        P_drop,
+                        row_start,
+                        col_start,
+                        bid,
+                        hid,
+                        philox_seed,
+                        philox_offset,
+                        pdrop_u8,
+                        encode_dropout_in_sign_bit=True,
+                        NUM_HEADS=NUM_HEADS,
+                        BLOCK_M=BLOCK_M,
+                        BLOCK_N=BLOCK_N,
+                    )
+                    if IS_EVEN_MN:
+                        tl.store(p_bp0 + col_start, P_drop)
+                    else:
+                        tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
 
-                P_drop = apply_dropout(
-                    P_drop,
+                P = apply_dropout(
+                    P,
                     row_start,
                     col_start,
                     bid,
@@ -785,38 +811,19 @@ def flash_fwd_kernel(
                     philox_seed,
                     philox_offset,
                     pdrop_u8,
-                    encode_dropout_in_sign_bit=True,
+                    encode_dropout_in_sign_bit=False,
                     NUM_HEADS=NUM_HEADS,
                     BLOCK_M=BLOCK_M,
                     BLOCK_N=BLOCK_N,
                 )
+
+            if not PRE_LOAD_V:
+                off = col_start * k_s_stride
                 if IS_EVEN_MN:
-                    tl.store(p_bp0 + col_start, P_drop)
+                    V = tl.load(p_bv0 + off, cache_modifier=".cg")
                 else:
-                    tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
-
-            P = apply_dropout(
-                P,
-                row_start,
-                col_start,
-                bid,
-                hid,
-                philox_seed,
-                philox_offset,
-                pdrop_u8,
-                encode_dropout_in_sign_bit=False,
-                NUM_HEADS=NUM_HEADS,
-                BLOCK_M=BLOCK_M,
-                BLOCK_N=BLOCK_N,
-            )
-
-        if not PRE_LOAD_V:
-            off = col_start * k_s_stride
-            if IS_EVEN_MN:
-                V = tl.load(p_bv0 + off, cache_modifier=".cg")
-            else:
-                V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
-        O_ = tl.dot(P, V, O_, allow_tf32=False)
+                    V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
+            O_ = tl.dot(P, V, O_, allow_tf32=False)
 
     for col_start in tl.range(col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages):
         col_start = tl.multiple_of(col_start, BLOCK_N)
@@ -899,7 +906,7 @@ def flash_fwd_kernel(
 
     # LSE
     # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
-    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('-inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
+    lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
 
     # Rescale output
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
@@ -922,11 +929,13 @@ def flash_fwd_kernel(
         tl.store(O_ptr + O_offset, O, mask=qmask)
     
     # Write back lse
-    lse_ptr += bid * hid * seqlen_q
+    p_lse = lse_ptr + (bid * h + hid) * seqlen_q
+    row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
+
     if IS_EVEN_MN:
-        tl.store(lse_ptr + row_idx, lse)
+        tl.store(p_lse + row_idx, lse)
     else:
-        tl.store(lse_ptr + row_idx, lse, mask=row_idx < seqlen_q)
+        tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
 
 
 @triton.heuristics(
@@ -1288,6 +1297,7 @@ def mha_fwd(
     window_size_left,
     window_size_right,
     return_softmax,
+    disable_splitkv=False
 ):
     q_dtype = q.dtype
     q_device = q.device
@@ -1405,7 +1415,8 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
             has_alibi = False
 
         # Set SWA params
-        is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
+        is_causal = window_size_left < 0 and window_size_right == 0
+        is_local = (window_size_left >= 0 and window_size_right >= 0)
 
         # ONLY EVEN_K IS SUPPORTED
         assert head_size == head_size_rounded
@@ -1423,7 +1434,7 @@ def dispatch(B, H, Q, K, D):
             #     return kernel, default_args, None, None
 
             # Try splitkv
-            if not is_dropout and not is_local:
+            if not is_dropout and not is_local and not disable_splitkv:
                 BM = block_m_splitkv_heuristic(D)
                 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
                 BN = block_n_splitkv_heuristic(D)
@@ -1434,7 +1445,7 @@ def dispatch(B, H, Q, K, D):
                     n_splits = 32
                     n_blocks = triton.cdiv(K, BN)
                     blocks_per_split = triton.cdiv(n_blocks, n_splits)
-                    print('block_n:', block_n)
+                    print('block_n:', BN)
                     print('n_splits:', n_splits)
                     print('blocks_per_split', blocks_per_split)
 
@@ -1487,6 +1498,14 @@ def dispatch(B, H, Q, K, D):
 
         kernel1, kernel1_args, kernel2, kernel2_args = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size)
 
+        if _debug:
+            p = torch.empty(
+                (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
+                dtype=torch.float32,
+                device=q_device
+            )
+            return_softmax = True
+
         prefilled_args = {
             "Q_ptr": q,
             "K_ptr": k,
@@ -1539,7 +1558,8 @@ def dispatch(B, H, Q, K, D):
             print(f'{kernel.name} shared memory:', kernel.metadata.shared)
             print(f'{kernel.name} num_warps:', kernel.metadata.num_warps)
             print(f'{kernel.name} num_stages:', kernel.metadata.num_stages)
-            print(kernel.asm['ttgir'])
+            # print(kernel.asm['ttgir'])
+            print('p:', p)
 
         # Combine
         if kernel2 is not None:
@@ -1580,7 +1600,8 @@ def flash_attention_forward(
     window_size_left=None,
     window_size_right=None,
     seqused_k=None,
-    alibi_slopes=None
+    alibi_slopes=None,
+    disable_splitkv=False
 ):
     logging.debug("GEMS FLASH_ATTENTION_FORWARD")
     assert cum_seq_q is None and cum_seq_k is None, "varlen is not supported yet."
@@ -1591,8 +1612,14 @@ def flash_attention_forward(
     assert HEAD_DIM_K in {16, 32, 64, 128, 256}
 
     softmax_scale = scale or 1.0 / (HEAD_DIM_K**0.5)
-    non_null_window_left = window_size_left or -1
-    non_null_window_right = window_size_right or -1
+    if window_size_left is not None:
+        non_null_window_left = window_size_left
+    else:
+        non_null_window_left = -1
+    if window_size_right is not None:
+        non_null_window_right = window_size_right
+    else:
+        non_null_window_right = -1
 
     out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
         query,
@@ -1606,6 +1633,7 @@ def flash_attention_forward(
         non_null_window_left,
         non_null_window_right,
         return_debug_mask,
+        disable_splitkv=disable_splitkv
     )
     
     return (out, lse, philox_seed, philox_offset, p)
diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py
index 5af90fe85..ca5e07350 100644
--- a/tests/test_attention_ops.py
+++ b/tests/test_attention_ops.py
@@ -8,13 +8,16 @@
 
 
 @pytest.mark.scaled_dot_product_attention
-@pytest.mark.parametrize("batch", [8, 16])
-@pytest.mark.parametrize("num_head", [1, 8])
-@pytest.mark.parametrize("q_seq_len", [17, 64, 128])
-@pytest.mark.parametrize("kv_seq_len", [128, 2048])
-@pytest.mark.parametrize("head_size", [64, 128])
-@pytest.mark.parametrize("add_bias", [True, False])
-@pytest.mark.parametrize("is_causal", [True, False])
+@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
+                         [
+                          (1, 1, 128, 2048),
+                          (4, 8, 1024, 1024),
+                          (4, 8, 1024, 128),
+                          (4, 8, 17, 1030)
+                          ])
+@pytest.mark.parametrize("head_size", [128])
+@pytest.mark.parametrize("add_bias", [False])
+@pytest.mark.parametrize("is_causal", [False, True])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 def test_scaled_dot_product_attention(
     batch, num_head, q_seq_len, kv_seq_len, head_size, add_bias, is_causal, dtype
@@ -48,15 +51,57 @@ def test_scaled_dot_product_attention(
 
     scale = float(1.0 / np.sqrt(head_size))
 
-    if is_causal:
-        torch_result = torch.nn.functional.scaled_dot_product_attention(
-            ref_query,
-            ref_key,
-            ref_value,
+    if is_causal and q_seq_len != kv_seq_len:
+        # Pytorch treats non-square causal as a special case where the sdp attention
+        # does not route to flash attn and instead uses mem-eff attn.
+        # In this case, we directly compare on the lower level _flash_attention_forward
+        q = ref_query.transpose(1, 2)
+        k = ref_key.transpose(1, 2)
+        v = ref_value.transpose(1, 2)
+        out, *_ = torch.ops.aten._flash_attention_forward(
+            q,
+            k,
+            v,
+            None,
+            None,
+            q.shape[-3],
+            k.shape[-3],
+            0,
+            is_causal,
+            False,
             scale=scale,
-            is_causal=is_causal,
+            window_size_left=None,
+            window_size_right=None,
+            seqused_k=None,
+            alibi_slopes=None
         )
+        torch_result = out.transpose(1, 2)
+        
+        with flag_gems.use_gems():
+            q = query.transpose(1, 2)
+            k = key.transpose(1, 2)
+            v = value.transpose(1, 2)
+            out, *_ = torch.ops.aten._flash_attention_forward(
+                q,
+                k,
+                v,
+                None,
+                None,
+                q.shape[-3],
+                k.shape[-3],
+                0,
+                is_causal,
+                False,
+                scale=scale,
+                window_size_left=None,
+                window_size_right=None,
+                seqused_k=None,
+                alibi_slopes=None
+            )
+            flaggem_result = out.transpose(1, 2)
     else:
+        attn_mask = None if is_causal else ref_attn_bias
+
         torch_result = torch.nn.functional.scaled_dot_product_attention(
             ref_query,
             ref_key,
@@ -66,13 +111,107 @@ def test_scaled_dot_product_attention(
             is_causal=is_causal,
         )
 
-    with flag_gems.use_gems():
-        if is_causal:
-            flaggem_result = torch.nn.functional.scaled_dot_product_attention(
-                query, key, value, scale=scale, is_causal=is_causal
-            )
-        else:
+        with flag_gems.use_gems():
+            attn_mask = None if is_causal else attn_bias
             flaggem_result = torch.nn.functional.scaled_dot_product_attention(
                 query, key, value, attn_mask=attn_bias, scale=scale, is_causal=is_causal
             )
+
+    gems_assert_close(flaggem_result, torch_result, dtype)
+
+
+@pytest.mark.flash_attention_forward
+@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
+                         [
+                          (1, 1, 128, 2048),
+                          (8, 32, 1024, 1024),
+                          (8, 32, 1024, 128),
+                          (8, 32, 17, 1030)
+                          ])
+@pytest.mark.parametrize("head_size", [128])
+@pytest.mark.parametrize(["window_size_left", "window_size_right"],
+                         [
+                            (None, None),
+                            (256, 0),
+                            (128, 128)
+                          ])
+@pytest.mark.parametrize("is_causal", [False])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_flash_attention_forward(
+    batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, window_size_left, window_size_right, dtype
+):
+    np.random.seed(0)
+    np_query = np.random.uniform(
+        -0.05, 0.05, (batch, num_head, q_seq_len, head_size)
+    ).astype(np.float32)
+    np_key = np.random.uniform(
+        -0.05, 0.05, (batch, num_head, kv_seq_len, head_size)
+    ).astype(np.float32)
+    np_value = np.random.uniform(
+        -0.05, 0.05, (batch, num_head, kv_seq_len, head_size)
+    ).astype(np.float32)
+    np_attn_bias = np.random.uniform(
+        -0.05, 0.05, (batch, num_head, q_seq_len, kv_seq_len)
+    ).astype(np.float32)
+
+    query = torch.tensor(np_query, device="cuda", dtype=dtype)
+    key = torch.tensor(np_key, device="cuda", dtype=dtype)
+    value = torch.tensor(np_value, device="cuda", dtype=dtype)
+
+    ref_query = to_reference(query, False)
+    ref_key = to_reference(key, False)
+    ref_value = to_reference(value, False)
+
+    scale = float(1.0 / np.sqrt(head_size))
+    dropout_p = 0
+    return_debug_mask = False
+
+    q = ref_query.transpose(1, 2)
+    k = ref_key.transpose(1, 2)
+    v = ref_value.transpose(1, 2)
+    out, lse, _, _, _  = torch.ops.aten._flash_attention_forward(
+        q,
+        k,
+        v,
+        None,
+        None,
+        q.shape[-3],
+        k.shape[-3],
+        dropout_p,
+        is_causal,
+        return_debug_mask,
+        scale=scale,
+        window_size_left=window_size_left,
+        window_size_right=window_size_right,
+        seqused_k=None,
+        alibi_slopes=None
+    )
+    torch_result = out.transpose(1, 2)
+    torch_lse = lse.transpose(1, 2)
+    
+    with flag_gems.use_gems():
+        q = query.transpose(1, 2)
+        k = key.transpose(1, 2)
+        v = value.transpose(1, 2)
+        out, lse, _, _, _  = torch.ops.aten._flash_attention_forward(
+            q,
+            k,
+            v,
+            None,
+            None,
+            q.shape[-3],
+            k.shape[-3],
+            dropout_p,
+            is_causal,
+            return_debug_mask,
+            scale=scale,
+            window_size_left=window_size_left,
+            window_size_right=window_size_right,
+            seqused_k=None,
+            alibi_slopes=None
+        )
+        flaggem_result = out.transpose(1, 2)
+        flaggem_lse = lse.transpose(1, 2)
+
     gems_assert_close(flaggem_result, torch_result, dtype)
+    gems_assert_close(flaggem_lse, torch_lse, torch.float)
\ No newline at end of file

From 2aad10d6414a14aa5edb910c73cb1f25363a5e7d Mon Sep 17 00:00:00 2001
From: Tongxin Bai <waffle.bai@gmail.com>
Date: Fri, 28 Mar 2025 08:24:44 +0000
Subject: [PATCH 25/25] Non-square, causal, swa, all pass.

---
 src/flag_gems/ops/attention.py |  52 +++++----
 tests/test_attention_ops.py    | 204 +++++++++++++++++++--------------
 2 files changed, 149 insertions(+), 107 deletions(-)

diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py
index 6bc3d1b99..c6e0f5b4a 100644
--- a/src/flag_gems/ops/attention.py
+++ b/src/flag_gems/ops/attention.py
@@ -538,8 +538,9 @@ def apply_mask(
 ):
     need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn)
     if need_mask:
+        # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
         col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
-        col_rb = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + ws_right)
+        col_rb = min(max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + ws_right)
 
         if has_alibi:
             S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
@@ -579,8 +580,8 @@ def softmax_rescale(
     O_acc *= p_scale[:, None]
 
     max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e)
-
     P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
+    cur_rowsum = tl.sum(P, 1)
     row_sum = row_sum + tl.sum(P, 1)
     return O_acc, P, row_max, row_sum
 
@@ -599,9 +600,12 @@ def block_m_splitkv_heuristic(headdim):
 def block_n_splitkv_heuristic(headdim):
     return 64 if headdim <= 64 else 32
 
-def is_even_mn(args):
-    even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
-    return even_mn
+def is_even_mn(M, N, BM, BN, WL, WR):
+    if M % BM == 0 and N % BN == 0:
+        if M % N == 0 or N % M == 0:
+            if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
+                return True
+    return False
 
 
 @triton.heuristics(
@@ -611,7 +615,7 @@ def is_even_mn(args):
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: False,
-        'IS_EVEN_MN': lambda args: is_even_mn(args),
+        'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
@@ -667,17 +671,28 @@ def flash_fwd_kernel(
     m_block = tl.program_id(0)
     bid = tl.program_id(1)
     hid = tl.program_id(2)
+    num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
+
+    # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
+    # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. 
 
+    col_min = 0
     if is_local:
-        col_min = m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left
-        col_min = max(col_min, 0)
-    else:
-        col_min = 0
+        col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - ws_left)
+        if not IS_EVEN_MN:
+            # round left
+            col_min = (col_min // BLOCK_N) * BLOCK_N
     
-    col_max = tl.cdiv(seqlen_k, BLOCK_N) * BLOCK_N
+    col_max = seqlen_k
     if is_causal or is_local:
-        rounded_kv_end = max(0, tl.cdiv(seqlen_k - seqlen_q + ws_right, BLOCK_N) * BLOCK_N + (m_block + 1) * BLOCK_M)
-        col_max = min(col_max, rounded_kv_end)
+        col_max += (m_block - num_m_blocks + 1) * BLOCK_M
+        if is_local:
+            col_max += ws_right
+        col_max = min(seqlen_k, col_max)
+
+    if not IS_EVEN_MN:
+        # round right
+        col_max =  tl.cdiv(col_max, BLOCK_N) * BLOCK_N
 
     if (not is_causal) and (not is_local):
         if IS_EVEN_MN:
@@ -690,7 +705,6 @@ def flash_fwd_kernel(
         # local  
         masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
 
-
     if has_alibi:
         alibi_offset = bid * slopes_batch_stride + hid
         alibi_slope = tl.load(pSlopes + alibi_offset)
@@ -734,7 +748,7 @@ def flash_fwd_kernel(
 
     if is_causal | is_local | (not IS_EVEN_MN):
         # Cut short masking cols if there's not enough cols out there
-        masking_cols = min(col_max, masking_cols)
+        masking_cols = min(col_max - col_min, masking_cols)
         for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
             col_start = col_max - col_shift - BLOCK_N
             col_start = tl.multiple_of(col_start, BLOCK_N)
@@ -907,9 +921,8 @@ def flash_fwd_kernel(
     # LSE
     # Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
     lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
-
-    # Rescale output
     inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
+
     if is_dropout:
         O_ *= inv_sum[:, None] * rpdrop
     else:
@@ -945,7 +958,7 @@ def flash_fwd_kernel(
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
-        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
+        'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
@@ -1009,7 +1022,7 @@ def flash_fwd_bh_parallel_kernel(
         'num_warps': lambda args: 4,
         'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
         'PRE_LOAD_V': lambda args: True,
-        'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
+        'IS_EVEN_MN': lambda args: is_even_mn(args["seqlen_q"], args["seqlen_k"], args["BLOCK_M"], args["BLOCK_N"], args["ws_left"], args["ws_right"]),
     }
 )
 @triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
@@ -1575,7 +1588,6 @@ def dispatch(B, H, Q, K, D):
             args_copy.update(kernel2_args)
             kernel2(**args_copy)
 
-
     if swap_seq_and_group:
         out = out.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))
         q = q.transpose(1, 2).reshape((batch_size, 1, num_heads_k * seqlen_q, head_size))
diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py
index ca5e07350..458fc8d31 100644
--- a/tests/test_attention_ops.py
+++ b/tests/test_attention_ops.py
@@ -7,21 +7,7 @@
 from .accuracy_utils import gems_assert_close, to_reference
 
 
-@pytest.mark.scaled_dot_product_attention
-@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
-                         [
-                          (1, 1, 128, 2048),
-                          (4, 8, 1024, 1024),
-                          (4, 8, 1024, 128),
-                          (4, 8, 17, 1030)
-                          ])
-@pytest.mark.parametrize("head_size", [128])
-@pytest.mark.parametrize("add_bias", [False])
-@pytest.mark.parametrize("is_causal", [False, True])
-@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-def test_scaled_dot_product_attention(
-    batch, num_head, q_seq_len, kv_seq_len, head_size, add_bias, is_causal, dtype
-):
+def make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype):
     np.random.seed(0)
     np_query = np.random.uniform(
         -0.05, 0.05, (batch, num_head, q_seq_len, head_size)
@@ -39,25 +25,126 @@ def test_scaled_dot_product_attention(
     query = torch.tensor(np_query, device="cuda", dtype=dtype)
     key = torch.tensor(np_key, device="cuda", dtype=dtype)
     value = torch.tensor(np_value, device="cuda", dtype=dtype)
-    if add_bias:
-        attn_bias = torch.tensor(np_attn_bias, device="cuda", dtype=dtype)
-    else:
-        attn_bias = None
+    
+    return query, key, value
+
+
+@pytest.mark.scaled_dot_product_attention
+@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
+                         [
+                          (4, 8, 1024, 1024),
+                          ])
+@pytest.mark.parametrize("head_size", [64, 128, 256])
+@pytest.mark.parametrize("is_causal", [False, True])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_sdpa_square_qk_even_mn(
+    batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype
+):
+    query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype)
+    ref_query = to_reference(query, False)
+    ref_key = to_reference(key, False)
+    ref_value = to_reference(value, False)
+
+    scale = float(1.0 / np.sqrt(head_size))
+    torch_result = torch.nn.functional.scaled_dot_product_attention(
+        ref_query,
+        ref_key,
+        ref_value,
+        attn_mask=None,
+        scale=scale,
+        is_causal=is_causal,
+    )
+
+    with flag_gems.use_gems():
+        flaggem_result = torch.nn.functional.scaled_dot_product_attention(
+            query, key, value, attn_mask=None, scale=scale, is_causal=is_causal
+        )
+    
+    gems_assert_close(flaggem_result, torch_result, dtype)
 
+
+@pytest.mark.scaled_dot_product_attention
+@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
+                         [
+                          (1, 1, 128, 2048),
+                          (4, 8, 1024, 128),
+                          (4, 8, 17, 1030)
+                          ])
+@pytest.mark.parametrize("head_size", [64, 128, 256])
+@pytest.mark.parametrize("is_causal", [False])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_sdpa_nonsquare_qk(
+    batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype
+):
+    query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype)
     ref_query = to_reference(query, False)
     ref_key = to_reference(key, False)
     ref_value = to_reference(value, False)
-    ref_attn_bias = to_reference(attn_bias, False) if add_bias else None
 
     scale = float(1.0 / np.sqrt(head_size))
+    torch_result = torch.nn.functional.scaled_dot_product_attention(
+        ref_query,
+        ref_key,
+        ref_value,
+        attn_mask=None,
+        scale=scale,
+        is_causal=is_causal,
+    )
+
+    with flag_gems.use_gems():
+        flaggem_result = torch.nn.functional.scaled_dot_product_attention(
+            query, key, value, attn_mask=None, scale=scale, is_causal=is_causal
+        )
+    
+    gems_assert_close(flaggem_result, torch_result, dtype)
 
-    if is_causal and q_seq_len != kv_seq_len:
-        # Pytorch treats non-square causal as a special case where the sdp attention
-        # does not route to flash attn and instead uses mem-eff attn.
-        # In this case, we directly compare on the lower level _flash_attention_forward
-        q = ref_query.transpose(1, 2)
-        k = ref_key.transpose(1, 2)
-        v = ref_value.transpose(1, 2)
+
+@pytest.mark.scaled_dot_product_attention
+@pytest.mark.parametrize(["batch", "num_head", "q_seq_len", "kv_seq_len"],
+                         [
+                          (1, 1, 128, 2048),
+                          (4, 8, 1024, 128),
+                          (4, 8, 17, 1030)
+                          ])
+@pytest.mark.parametrize("head_size", [64, 128, 256])
+@pytest.mark.parametrize("is_causal", [True])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_sdpa_nonsquare_qk_causal(
+    batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, dtype
+):
+    query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype)
+    ref_query = to_reference(query, False)
+    ref_key = to_reference(key, False)
+    ref_value = to_reference(value, False)
+
+    scale = float(1.0 / np.sqrt(head_size))
+
+    q = ref_query.transpose(1, 2)
+    k = ref_key.transpose(1, 2)
+    v = ref_value.transpose(1, 2)
+    out, *_ = torch.ops.aten._flash_attention_forward(
+        q,
+        k,
+        v,
+        None,
+        None,
+        q.shape[-3],
+        k.shape[-3],
+        0,
+        is_causal,
+        False,
+        scale=scale,
+        window_size_left=None,
+        window_size_right=None,
+        seqused_k=None,
+        alibi_slopes=None
+    )
+    torch_result = out.transpose(1, 2)
+    
+    with flag_gems.use_gems():
+        q = query.transpose(1, 2)
+        k = key.transpose(1, 2)
+        v = value.transpose(1, 2)
         out, *_ = torch.ops.aten._flash_attention_forward(
             q,
             k,
@@ -75,47 +162,7 @@ def test_scaled_dot_product_attention(
             seqused_k=None,
             alibi_slopes=None
         )
-        torch_result = out.transpose(1, 2)
-        
-        with flag_gems.use_gems():
-            q = query.transpose(1, 2)
-            k = key.transpose(1, 2)
-            v = value.transpose(1, 2)
-            out, *_ = torch.ops.aten._flash_attention_forward(
-                q,
-                k,
-                v,
-                None,
-                None,
-                q.shape[-3],
-                k.shape[-3],
-                0,
-                is_causal,
-                False,
-                scale=scale,
-                window_size_left=None,
-                window_size_right=None,
-                seqused_k=None,
-                alibi_slopes=None
-            )
-            flaggem_result = out.transpose(1, 2)
-    else:
-        attn_mask = None if is_causal else ref_attn_bias
-
-        torch_result = torch.nn.functional.scaled_dot_product_attention(
-            ref_query,
-            ref_key,
-            ref_value,
-            attn_mask=ref_attn_bias,
-            scale=scale,
-            is_causal=is_causal,
-        )
-
-        with flag_gems.use_gems():
-            attn_mask = None if is_causal else attn_bias
-            flaggem_result = torch.nn.functional.scaled_dot_product_attention(
-                query, key, value, attn_mask=attn_bias, scale=scale, is_causal=is_causal
-            )
+        flaggem_result = out.transpose(1, 2)
 
     gems_assert_close(flaggem_result, torch_result, dtype)
 
@@ -131,32 +178,15 @@ def test_scaled_dot_product_attention(
 @pytest.mark.parametrize("head_size", [128])
 @pytest.mark.parametrize(["window_size_left", "window_size_right"],
                          [
-                            (None, None),
                             (256, 0),
                             (128, 128)
                           ])
 @pytest.mark.parametrize("is_causal", [False])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-def test_flash_attention_forward(
+def test_flash_fwd_swa(
     batch, num_head, q_seq_len, kv_seq_len, head_size, is_causal, window_size_left, window_size_right, dtype
 ):
-    np.random.seed(0)
-    np_query = np.random.uniform(
-        -0.05, 0.05, (batch, num_head, q_seq_len, head_size)
-    ).astype(np.float32)
-    np_key = np.random.uniform(
-        -0.05, 0.05, (batch, num_head, kv_seq_len, head_size)
-    ).astype(np.float32)
-    np_value = np.random.uniform(
-        -0.05, 0.05, (batch, num_head, kv_seq_len, head_size)
-    ).astype(np.float32)
-    np_attn_bias = np.random.uniform(
-        -0.05, 0.05, (batch, num_head, q_seq_len, kv_seq_len)
-    ).astype(np.float32)
-
-    query = torch.tensor(np_query, device="cuda", dtype=dtype)
-    key = torch.tensor(np_key, device="cuda", dtype=dtype)
-    value = torch.tensor(np_value, device="cuda", dtype=dtype)
+    query, key, value = make_input(batch, num_head, q_seq_len, kv_seq_len, head_size, dtype)
 
     ref_query = to_reference(query, False)
     ref_key = to_reference(key, False)