From 94a0c5c5244f3e81e447377061fd9e72381b22e4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 5 Jan 2026 00:38:22 +0000 Subject: [PATCH 1/2] score-mod backward SM90 Adds score_mod and mask_mod support to SM90 backward pass: - score_mod, score_mod_bwd, mask_mod, has_aux_tensors parameters - apply_score_mod() and apply_score_mod_bwd() methods - fastdiv_mods and aux_tensors plumbing through kernel/mma - mask_mod application in mask_fn for both block-sparse and dense paths - Score modification in mma_one_m_block before softmax stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2137, branch: drisspg/stack/8 --- flash_attn/cute/flash_bwd_sm90.py | 195 +++++++++++++++++++++++++++++- flash_attn/cute/interface.py | 5 +- tests/cute/test_score_mod.py | 17 ++- 3 files changed, 213 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 66f3e8233f..08075c2779 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -9,6 +9,7 @@ import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warpgroup from cutlass.cute.arch import ProxyKind, SharedSpace +from cutlass.cute import FastDivmodDivisor from cutlass import Float32, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum @@ -22,6 +23,7 @@ from flash_attn.cute import pipeline from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, @@ -68,6 +70,10 @@ def __init__( AtomLayoutMdQ: int = 1, num_threads: int = 384, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, ): self.dtype = dtype @@ -113,7 +119,17 @@ def __init__( # TODO: impl these for hdim 64 self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd + self.mask_mod = mask_mod + self.has_aux_tensors = has_aux_tensors self.subtile_factor = subtile_factor + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 4 + self.qk_acc_dtype = Float32 @staticmethod def can_implement( @@ -432,7 +448,18 @@ def __call__( grid_dim = TileScheduler.get_grid_shape(tile_sched_params) LOG2_E = math.log2(math.e) - softmax_scale_log2 = softmax_scale * LOG2_E + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + else: + softmax_scale_log2 = LOG2_E + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) @@ -468,6 +495,8 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + aux_tensors, + fastdiv_mods, blocksparse_tensors, ).launch( grid=grid_dim, @@ -511,6 +540,8 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -671,6 +702,8 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + aux_tensors, + fastdiv_mods, blocksparse_tensors, ) @@ -883,6 +916,100 @@ def load( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def apply_score_mod( + self, + acc_S: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + # Create index tensor matching SM100's approach - no view transformation + # Index tensor shape and offset are (n, m) when SdP_swapAB=True to match + # the transposed score matrix layout + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + # Pass acc_S directly without view transformation to keep alignment with index tensor + # (matching SM100's approach which doesn't use make_acc_tensor_mn_view) + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + + @cute.jit + def apply_score_mod_bwd( + self, + grad_tensor: cute.Tensor, + score_tensor: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + # Create index tensor matching SM100's approach - no view transformation + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + # Pass tensors directly without view transformation to keep alignment with index tensor + # (matching SM100's approach which doesn't use make_acc_tensor_mn_view) + apply_score_mod_bwd_inner( + grad_tensor, + score_tensor, + tScS, + self.score_mod_bwd, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + @cute.jit def mma( self, @@ -914,6 +1041,8 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) @@ -1090,6 +1219,7 @@ def mma( m_block_max, ) + # With subtile_factor>1, is_full_block is dynamic; always apply mask_mod. mask_fn = partial( mask.apply_mask, batch_idx=batch_idx, @@ -1099,6 +1229,9 @@ def mma( mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, + mask_mod=self.mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) dKV_accumulate = False @@ -1117,6 +1250,14 @@ def mma( consumer_state_dO, mask_fn=mask_fn, dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) dKV_accumulate = True else: @@ -1129,6 +1270,9 @@ def mma( mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, + mask_mod=self.mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) dKV_accumulate = False for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): @@ -1138,6 +1282,14 @@ def mma( consumer_state_dO, mask_fn=mask_fn, dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) dKV_accumulate = True @@ -1209,6 +1361,14 @@ def mma_one_m_block( softmax_scale_log2: Float32, mask_fn: Optional[Callable] = None, dKV_accumulate: Boolean = True, + thr_mma_SdP: Optional[cute.core.ThrMma] = None, + batch_idx: Int32 = 0, + head_idx: Int32 = 0, + n_block: Int32 = 0, + softmax_scale: Float32 = 1.0, + seqlen: Optional[SeqlenInfoQK] = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): consumer_state_dO_cur = ( consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q @@ -1226,6 +1386,24 @@ def mma_one_m_block( ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) + if const_expr(self.score_mod_bwd is not None): + acc_S_pre = cute.make_fragment_like(acc_S) + cute.autovec_copy(acc_S, acc_S_pre) + + if const_expr(self.score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + # (3) [Pointwise 1] P = exp(S - LSE) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) @@ -1256,6 +1434,21 @@ def mma_one_m_block( for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + if const_expr(self.score_mod_bwd is not None): + self.apply_score_mod_bwd( + acc_dP, + acc_S_pre, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + # Convert dS from f32 -> f16 tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 70043cc984..b3a2bb9f06 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -710,7 +710,6 @@ def _flash_attn_bwd( assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) - assert compute_capability == 10, "score_mod in bwd only supported on SM100 for now" device = q.device out_torch_dtype = q.dtype @@ -981,6 +980,10 @@ def _flash_attn_bwd( AtomLayoutMdQ, num_threads, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None and len(aux_tensors) > 0, subtile_factor=bwd_subtile_factor, ) else: diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 82c135a8ee..fad47fe372 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -6,6 +6,9 @@ import operator from torch.nn.attention.flex_attention import flex_attention from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + from score_mod_definitions import ( # TensorSSA-based score mods score_mod_identity as score_mod_1, @@ -289,6 +292,7 @@ def _generate_block_kvcache( ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache( seqlen_q, seqlen_kv, @@ -445,6 +449,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_q, seqlen_kv, @@ -738,6 +743,9 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): """Test backward pass with score_mod against flex_attention reference.""" + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple @@ -809,6 +817,9 @@ def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, def test_cute_vs_flex_attention_backward_with_aux( seqlen_q, seqlen_kv, dim, dtype, score_mod_triple ): + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_factory = score_mod_triple @@ -862,14 +873,16 @@ def test_cute_vs_flex_attention_backward_with_aux( @pytest.mark.parametrize("seqlen_q,seqlen_kv", [(128, 128), (128, 256)]) -@pytest.mark.parametrize("dim", [64]) +@pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_PACK_GQA) def test_cute_vs_flex_attention_backward_pack_gqa( seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple ): - pytest.skip("pack_gqa backward not yet implemented") + if COMPUTE_CAPABILITY == 9: + pytest.xfail("pack_gqa backward not yet implemented on SM90") + torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple From 62cc64e1c4923a6cff55d49ef1dfe0dbee701fa2 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 31 Dec 2025 19:30:23 +0000 Subject: [PATCH 2/2] score-mod backward SM100 cleanup stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2130, branch: drisspg/stack/4 --- flash_attn/cute/flash_bwd_sm100.py | 33 ++++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e7019382b7..8f11c6e8f1 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1844,15 +1844,25 @@ def apply_score_mod_bwd( self, grad_tensor, score_tensor, - index_tensor, + thr_copy_t2r, + thr_mma_S, batch_idx, head_idx, + m_block, + n_block, + stage, softmax_scale, seqlen_info, aux_tensors=None, fastdiv_mods=(None, None), ): """Apply backward score modification (joint graph) for SM100.""" + cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m)) + cS_bwd = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS_bwd) + tScS_bwd = thr_mma_S.partition_C(cS_bwd) + tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd) + index_tensor = tScS_idx_bwd[None, stage, 0, 0] + apply_score_mod_bwd_inner( grad_tensor, score_tensor, @@ -1871,6 +1881,10 @@ def apply_score_mod_bwd( transpose_indices=True, ) + for i in cutlass.range(cute.size(grad_tensor), unroll_full=True): + kv_idx = index_tensor[i][0] + grad_tensor[i] = 0.0 if kv_idx >= seqlen_info.seqlen_k else grad_tensor[i] + @cute.jit def compute_loop( self, @@ -2213,28 +2227,21 @@ def compute_loop( if const_expr(self.score_mod_bwd is not None): tSrS_pre_cur = tSrS_pre[None, stage, 0, 0] - cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m)) - cS_bwd = cute.domain_offset( - (n_block * self.tile_n, m_block * self.tile_m), cS_bwd - ) - tScS_bwd = thr_mma_S.partition_C(cS_bwd) - tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd) - tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0] self.apply_score_mod_bwd( tdPrdP_cur, tSrS_pre_cur, - tScS_idx_cur, + thr_copy_t2r, + thr_mma_S, batch_idx, head_idx, + m_block, + n_block, + stage, softmax_scale, seqlen, aux_tensors, fastdiv_mods, ) - # Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd - for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True): - kv_idx = tScS_idx_cur[i][0] - tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i] tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt)