diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 706e3d6ad2..ee6867dfc5 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -287,6 +287,7 @@ def consume_block_sparse_loads( intra_wg_overlap: cutlass.Constexpr, warp_scheduler_barrier_sync: Callable, warp_scheduler_barrier_arrive: Callable, + seqlen=None, ): """Consume the mask and full block lists for a single tile on the consumer side. @@ -379,6 +380,7 @@ def consume_block_sparse_loads( mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, + seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, @@ -394,6 +396,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) @@ -404,6 +407,7 @@ def consume_block_sparse_loads( if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, + seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, @@ -413,6 +417,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) @@ -422,6 +427,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index deb40f7939..a7eb1880f8 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -22,6 +22,12 @@ from flash_attn.cute import pipeline from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_q_block_count_bwd, + get_block_sparse_iteration_info_bwd, + get_m_block_from_iter_bwd, +) def mma_partition_fragment_AB( @@ -62,6 +68,7 @@ def __init__( AtomLayoutMdQ: int = 1, num_threads: int = 384, V_in_regs: bool = False, + subtile_factor: cutlass.Constexpr[int] = 1, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -106,6 +113,7 @@ def __init__( # TODO: impl these for hdim 64 self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + self.subtile_factor = subtile_factor @staticmethod def can_implement( @@ -298,6 +306,8 @@ def __call__( mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( "determinism not supported yet for Sm90" @@ -424,6 +434,8 @@ def __call__( LOG2_E = math.log2(math.e) softmax_scale_log2 = softmax_scale * LOG2_E + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -456,6 +468,7 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + blocksparse_tensors, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -498,6 +511,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -579,6 +593,7 @@ def kernel( self.tile_n, window_size_left=None, window_size_right=None, + swap_AB=self.SdP_swapAB, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -607,6 +622,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) if warp_idx == 1: for warp_group_idx in cutlass.range(self.num_mma_warp_groups): @@ -648,6 +664,7 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + blocksparse_tensors, ) @cute.jit @@ -674,6 +691,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -723,48 +741,132 @@ def load( load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - m_block = m_block_min - pipeline_Q.producer_acquire( - producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) - load_Q(m_block, producer_state=producer_state_Q) - # cp.async.bulk is using ptx, so we need to elect one thread to do it - with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state_Q) - producer_state_dO_cur = ( - producer_state_dO - if const_expr(self.Q_stage != self.dO_stage) - else producer_state_Q - ) - pipeline_dO.producer_acquire( - producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) - load_dO(m_block, producer_state=producer_state_dO_cur) - with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO_cur) - producer_state_Q.advance() - producer_state_dO.advance() - # Subsequent iterations: load Q & LSE, then dO & dPsum - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - pipeline_Q.producer_acquire(producer_state_Q) - load_Q(m_block, producer_state=producer_state_Q) - # cp.async.bulk is using ptx, so we need to elect one thread to do it - with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state_Q) - producer_state_dO_cur = ( - producer_state_dO - if const_expr(self.Q_stage != self.dO_stage) - else producer_state_Q + + if const_expr(self.use_block_sparsity): + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, ) - pipeline_dO.producer_acquire(producer_state_dO_cur) - load_dO(m_block, producer_state=producer_state_dO_cur) - with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO_cur) - producer_state_Q.advance() - producer_state_dO.advance() + process_tile = total_m_block_cnt > Int32(0) + else: + total_m_block_cnt = m_block_max - m_block_min + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + + if process_tile: + if const_expr(self.use_block_sparsity): + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + self.subtile_factor, + m_block_max, + ) + + for iter_idx in cutlass.range(loop_count, unroll=1): + m_block, _ = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + self.subtile_factor, + ) + + if iter_idx == 0: + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K( + tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q) + ) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire( + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V( + tma_bar_ptr=pipeline_dO.producer_get_barrier( + producer_state_dO_cur + ) + ) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + else: + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + else: + first_m_block = m_block_min + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + load_Q(first_m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(first_m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire( + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + load_dO(first_m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(first_m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -801,6 +903,7 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 23fee1e185..3913af7b03 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -2168,6 +2168,7 @@ def mma( self.intra_wg_overlap, self.warp_scheduler_barrier_sync, self.warp_scheduler_barrier_arrive, + seqlen, ) # Handle empty case (when no blocks to process) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f5c64f597a..27b88e686f 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -870,6 +870,7 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, + use_block_sparsity, ) else: # Hash callables for compile key @@ -964,6 +965,7 @@ def _flash_attn_bwd( AtomLayoutMdQ, num_threads, V_in_regs=V_in_regs, + subtile_factor=1, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( @@ -987,10 +989,11 @@ def _flash_attn_bwd( # Block sparse tensors for backward use Q-direction indexing (transposed from forward). # sparse_block_size_q = 2*tile_m matches forward's q_stage=2 pipelining. sparse_tensors_compile = None - if block_sparse_tensors is not None and compute_capability == 10: + if block_sparse_tensors is not None: + bwd_subtile_factor = subtile_factor if compute_capability == 10 else 1 expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, subtile_factor, + m_block_size, n_block_size, bwd_subtile_factor, ) compile_time_normalized = normalize_block_sparse_tensors( block_sparse_tensors, @@ -1028,10 +1031,11 @@ def _flash_attn_bwd( options="--enable-tvm-ffi", ) normalized_block_sparse_tensors = None - if block_sparse_tensors is not None and compute_capability == 10: + if block_sparse_tensors is not None: + bwd_subtile_factor = subtile_factor if compute_capability == 10 else 1 expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, subtile_factor, + m_block_size, n_block_size, bwd_subtile_factor, ) normalized_block_sparse_tensors = normalize_block_sparse_tensors( block_sparse_tensors, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 0a772fa425..78bb918ced 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -339,7 +339,7 @@ def apply_mask_sm100( mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): - # Block sparse w/ mask_mod + # Block sparse case w/ mask_mod has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index f40304e6c5..9b066fbfd7 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -448,9 +448,6 @@ def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): - Block-sparse with mask_mod: exercises is_full_block=True path - Backward pass: where the bug manifested """ - if COMPUTE_CAPABILITY != 10: - pytest.skip("SM100-only backward test") - _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -469,6 +466,7 @@ def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): ) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Test uses SM100 block mask conventions (2*tile_m)") def test_single_doc_bwd_minimal(): """Minimal test to isolate single-document backward pass bug. @@ -479,9 +477,6 @@ def test_single_doc_bwd_minimal(): Run with: pytest tests/cute/test_mask_mod.py::test_single_doc_bwd_minimal -v -s """ - if COMPUTE_CAPABILITY != 10: - pytest.skip("SM100-only test") - import random random.seed(42) torch.manual_seed(42)