diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index b70a6beca3..c4aad2cd58 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -15,6 +15,8 @@ # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.named_barrier import NamedBarrierBwd @cute.jit @@ -380,8 +382,8 @@ def consume_block_sparse_loads( mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, - kv_consumer_state=kv_consumer_state, seqlen=seqlen, + kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, mask_mod=mask_mod, @@ -396,6 +398,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) @@ -406,8 +409,8 @@ def consume_block_sparse_loads( if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, - kv_consumer_state=kv_consumer_state, seqlen=seqlen, + kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, is_first_block=True, @@ -416,6 +419,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) @@ -425,6 +429,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, + seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) @@ -1069,3 +1074,353 @@ def get_m_block_from_iter_bwd( sparse_m_block = curr_q_idx[sparse_iter_idx] return sparse_m_block * subtile_factor + subtile_offset, is_full_block + + +@cute.jit +def _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage: cutlass.Constexpr, + load_kv: bool, +): + """Load one Q/dO block, optionally loading K/V on first iteration.""" + if load_kv: + pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + else: + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + + producer_state_dO_cur = ( + producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q + ) + if load_kv: + pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + else: + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + + producer_state_Q.advance() + producer_state_dO.advance() + return producer_state_Q, producer_state_dO + + +@cute.jit +def produce_block_sparse_q_loads_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage: cutlass.Constexpr, + subtile_factor: cutlass.Constexpr, + m_block_max: int, +): + """SM90 backward block sparse loading with separate partial/full loops. + + K/V are loaded with the first valid block. Iterates partial blocks first, + then full blocks, matching consumer order. + + Returns updated (producer_state_Q, producer_state_dO). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + kv_loaded = False + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + producer_state_Q, producer_state_dO = _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage, + load_kv=not kv_loaded, + ) + kv_loaded = True + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + producer_state_Q, producer_state_dO = _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage, + load_kv=not kv_loaded, + ) + kv_loaded = True + + return producer_state_Q, producer_state_dO + + +@cute.jit +def consume_block_sparse_mma_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + consumer_state_Q, + consumer_state_dO, + mma_one_m_block_fn, + mask, + mask_mod, + is_causal: cutlass.Constexpr, + is_local: cutlass.Constexpr, + thr_mma_SdP, + softmax_scale, + seqlen, + subtile_factor: cutlass.Constexpr, + m_block_max: int, + aux_tensors=None, + fastdiv_mods=(None, None), +): + """SM90 backward block sparse MMA consumption with separate partial/full loops. + + Partial blocks are processed first (with mask_mod applied), then full blocks + (without mask_mod). This ensures mask_mod is only applied where needed. + + Returns updated (consumer_state_Q, consumer_state_dO). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + dKV_accumulate = False + + mask_fn_partial = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=is_causal, + mask_local=is_local, + mask_mod=mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + mask_fn_full = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=is_causal, + mask_local=is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn_partial, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn_full, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + + return consumer_state_Q, consumer_state_dO + + +@cute.jit +def _store_one_dQaccum_sm90( + m_block, + sdQaccum: cute.Tensor, + gdQaccum: cute.Tensor, + num_mma_warp_groups: cutlass.Constexpr, + num_threads_per_warp_group: cutlass.Constexpr, + tma_copy_bytes_dQ, +): + """Store dQaccum for a single m_block.""" + for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block].iterator, + tma_copy_bytes_dQ, + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + +@cute.jit +def dQaccum_store_block_sparse_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + sdQaccum: cute.Tensor, + gdQaccum: cute.Tensor, + subtile_factor: cutlass.Constexpr, + m_block_max: int, + num_mma_warp_groups: cutlass.Constexpr, + num_threads_per_warp_group: cutlass.Constexpr, + tma_copy_bytes_dQ, +): + """SM90 backward block sparse dQaccum store with separate partial/full loops. + + Iterates partial blocks first, then full blocks, matching producer/consumer order. + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + _store_one_dQaccum_sm90( + m_block, + sdQaccum, + gdQaccum, + num_mma_warp_groups, + num_threads_per_warp_group, + tma_copy_bytes_dQ, + ) + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + _store_one_dQaccum_sm90( + m_block, + sdQaccum, + gdQaccum, + num_mma_warp_groups, + num_threads_per_warp_group, + tma_copy_bytes_dQ, + ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index fd999150bf..d9b504cee2 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -9,6 +9,7 @@ import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warpgroup from cutlass.cute.arch import ProxyKind, SharedSpace +from cutlass.cute import FastDivmodDivisor from cutlass import Float32, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum @@ -22,6 +23,13 @@ from flash_attn.cute import pipeline from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_q_block_count_bwd, + produce_block_sparse_q_loads_bwd_sm90, + consume_block_sparse_mma_bwd_sm90, + dQaccum_store_block_sparse_bwd_sm90, +) def mma_partition_fragment_AB( @@ -62,6 +70,9 @@ def __init__( AtomLayoutMdQ: int = 1, num_threads: int = 384, V_in_regs: bool = False, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -107,6 +118,14 @@ def __init__( self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + self.mask_mod = mask_mod + self.has_aux_tensors = has_aux_tensors + self.subtile_factor = subtile_factor + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 4 + @staticmethod def can_implement( dtype, @@ -298,6 +317,8 @@ def __call__( mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( "determinism not supported yet for Sm90" @@ -424,6 +445,16 @@ def __call__( LOG2_E = math.log2(math.e) softmax_scale_log2 = softmax_scale * LOG2_E + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -456,6 +487,9 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], @@ -498,6 +532,9 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -579,6 +616,7 @@ def kernel( self.tile_n, window_size_left=None, window_size_right=None, + swap_AB=self.SdP_swapAB, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -607,6 +645,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) if warp_idx == 1: for warp_group_idx in cutlass.range(self.num_mma_warp_groups): @@ -614,7 +653,14 @@ def kernel( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) - self.dQaccum_store(mdQaccum, sdQaccum, block_info, TileSchedulerCls, SeqlenInfoCls) + self.dQaccum_store( + mdQaccum, + sdQaccum, + block_info, + TileSchedulerCls, + SeqlenInfoCls, + blocksparse_tensors, + ) else: cute.arch.warpgroup_reg_alloc(self.num_mma_regs) tidx, _, _ = cute.arch.thread_idx() @@ -648,6 +694,9 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, ) @cute.jit @@ -674,6 +723,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -723,48 +773,84 @@ def load( load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - m_block = m_block_min - pipeline_Q.producer_acquire( - producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) - load_Q(m_block, producer_state=producer_state_Q) - # cp.async.bulk is using ptx, so we need to elect one thread to do it - with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state_Q) - producer_state_dO_cur = ( - producer_state_dO - if const_expr(self.Q_stage != self.dO_stage) - else producer_state_Q - ) - pipeline_dO.producer_acquire( - producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) - load_dO(m_block, producer_state=producer_state_dO_cur) - with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO_cur) - producer_state_Q.advance() - producer_state_dO.advance() - # Subsequent iterations: load Q & LSE, then dO & dPsum - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): - pipeline_Q.producer_acquire(producer_state_Q) - load_Q(m_block, producer_state=producer_state_Q) - # cp.async.bulk is using ptx, so we need to elect one thread to do it - with cute.arch.elect_one(): - load_LSE(m_block, producer_state=producer_state_Q) - producer_state_dO_cur = ( - producer_state_dO - if const_expr(self.Q_stage != self.dO_stage) - else producer_state_Q + + if const_expr(not self.use_block_sparsity): + total_m_block_cnt = m_block_max - m_block_min + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + else: + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, ) - pipeline_dO.producer_acquire(producer_state_dO_cur) - load_dO(m_block, producer_state=producer_state_dO_cur) - with cute.arch.elect_one(): - load_dPsum(m_block, producer_state=producer_state_dO_cur) - producer_state_Q.advance() - producer_state_dO.advance() + process_tile = total_m_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + first_m_block = m_block_min + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + load_Q(first_m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(first_m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire( + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + load_dO(first_m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(first_m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + else: + producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + self.tma_copy_bytes["K"], + self.tma_copy_bytes["V"], + Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage), + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -801,6 +887,9 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -944,49 +1033,116 @@ def mma( n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen) - mask_fn = partial( - mask.apply_mask, - batch_idx=None, - head_idx=None, - n_block=n_block, - thr_mma=thr_mma_SdP, - mask_seqlen=True, - mask_causal=self.is_causal, - mask_local=self.is_local, - ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block_min = {}, m_block_max = {}", cute.arch.thread_idx()[0], m_block_min, m_block_max) - dKV_accumulate = False - for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - consumer_state_Q, consumer_state_dO = mma_one_m_block_all( - m_block, - consumer_state_Q, - consumer_state_dO, - mask_fn=mask_fn, - dKV_accumulate=dKV_accumulate, + + if const_expr(not self.use_block_sparsity): + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + else: + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, ) - dKV_accumulate = True + process_tile = total_m_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + mask_mod=self.mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = False + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + consumer_state_Q, consumer_state_dO = mma_one_m_block_all( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + else: + consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + consumer_state_Q, + consumer_state_dO, + mma_one_m_block_all, + mask, + self.mask_mod, + is_causal=self.is_causal, + is_local=self.is_local, + thr_mma_SdP=thr_mma_SdP, + softmax_scale=softmax_scale, + seqlen=seqlen, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) - # scale dK - acc_dK.store(acc_dK.load() * softmax_scale) - self.epilogue_dKV( - acc_dV, - mdV, - sV, - acc_dK, - mdK, - sK, - seqlen, - tma_atom_dK, - tma_atom_dV, - tiled_mma_dK, - tiled_mma_dV, - tidx, - n_block, - head_idx, - batch_idx, - ) + acc_dK.store(acc_dK.load() * softmax_scale) + self.epilogue_dKV( + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, + seqlen, + tma_atom_dK, + tma_atom_dV, + tiled_mma_dK, + tiled_mma_dV, + tidx, + n_block, + head_idx, + batch_idx, + ) + else: + # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros. + if const_expr(self.use_block_sparsity): + acc_dK.fill(0.0) + acc_dV.fill(0.0) + self.epilogue_dKV( + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, + seqlen, + tma_atom_dK, + tma_atom_dV, + tiled_mma_dK, + tiled_mma_dV, + tidx, + n_block, + head_idx, + batch_idx, + ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1014,9 +1170,15 @@ def mma_one_m_block( smem_thr_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: Float32, mask_fn: Optional[Callable] = None, - # acc_dV, - # acc_dK, dKV_accumulate: Boolean = True, + thr_mma_SdP: Optional[cute.core.ThrMma] = None, + batch_idx: Int32 = 0, + head_idx: Int32 = 0, + n_block: Int32 = 0, + softmax_scale: Float32 = 1.0, + seqlen: Optional[SeqlenInfoQK] = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), ): consumer_state_dO_cur = ( consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q @@ -1033,17 +1195,16 @@ def mma_one_m_block( consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) + # (3) [Pointwise 1] P = exp(S - LSE) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) - # if cute.arch.thread_idx()[0] == 256: cute.print_tensor(acc_S_mn) for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): acc_S_mn[r, c] = cute.math.exp2( acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True ) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_S_mn) tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) # Convert P from f32 -> f16 @@ -1061,11 +1222,10 @@ def mma_one_m_block( # (4) [Pointwise 2] dS = P*(dP-dPsum) warpgroup.wait_group(0) acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dP_mn) + # Convert dS from f32 -> f16 tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) @@ -1213,6 +1373,7 @@ def dQaccum_store( block_info: BlockInfo, TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() @@ -1226,26 +1387,61 @@ def dQaccum_store( gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) - for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) - with cute.arch.elect_one(): - copy_utils.cpasync_reduce_bulk_add_f32( - sdQaccum[None, warp_group_idx].iterator, - gdQaccum[None, warp_group_idx, m_block].iterator, - self.tma_copy_bytes["dQ"], - ) - cute.arch.cp_async_bulk_commit_group() - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group( - self.num_mma_warp_groups - 1 - warp_group_idx, read=True - ) - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + if const_expr(not self.use_block_sparsity): + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + loop_count = m_block_max - m_block_min + else: + total_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_safe = m_block + + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block_safe].iterator, + self.tma_copy_bytes["dQ"], + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.cp_async_bulk_wait_group( + self.num_mma_warp_groups - 1 - warp_group_idx, read=True + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + + cute.arch.WARP_SIZE, + ) + else: + dQaccum_store_block_sparse_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + sdQaccum, + gdQaccum, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + num_mma_warp_groups=self.num_mma_warp_groups, + num_threads_per_warp_group=self.num_threads_per_warp_group, + tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index fff327fc56..574413bbd0 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -30,11 +30,10 @@ import cutlass import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import to_cute_tensor -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80, FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 @@ -628,9 +627,6 @@ def _flash_attn_bwd( total_k = k.shape[0] seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k - seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size - num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] @@ -646,6 +642,20 @@ def _flash_attn_bwd( use_block_sparsity = block_sparse_tensors is not None + # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, + # the base block_m of 128 from forward, and block-sparse size for subtiling. + if compute_capability == 9 and use_block_sparsity: + m_block_size = 64 + # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) + dQ_swapAB = False + + # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 + subtile_factor = 2 + sparse_block_size_q = subtile_factor * m_block_size + + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) @@ -692,8 +702,8 @@ def _flash_attn_bwd( qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 - if compute_capability in [10, 11]: - pack_gqa = False # override for now + # pack_gqa backward not yet supported in bwd + pack_gqa = False if compute_capability not in [10, 11]: assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now" @@ -708,9 +718,6 @@ def _flash_attn_bwd( device = q.device out_torch_dtype = q.dtype - # nb: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 - subtile_factor = 2 - if dq is None: dq = torch.empty_like(q) else: @@ -863,6 +870,14 @@ def _flash_attn_bwd( ) # Backward kernel: compute dk, dv, dq_accum. + score_mod_hash = utils.hash_callable(score_mod) if score_mod else False + score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False + num_aux_tensors = len(aux_tensors) if aux_tensors else 0 + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + if compute_capability == 9: compile_key = ( compute_capability, @@ -889,18 +904,14 @@ def _flash_attn_bwd( cu_seqlens_k is None, seqused_q is None, seqused_k is None, + score_mod_hash, + score_mod_bwd_hash, + mask_mod_hash, + num_aux_tensors, + use_block_sparsity, ) cute_aux_tensors = None else: - # Hash callables for compile key - score_mod_hash = utils.hash_callable(score_mod) if score_mod else False - score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False - num_aux_tensors = len(aux_tensors) if aux_tensors else 0 - # Convert aux_tensors to cute tensors - cute_aux_tensors = None - if aux_tensors is not None: - cute_aux_tensors = [from_dlpack(buf).mark_layout_dynamic() for buf in aux_tensors] compile_key = ( compute_capability, dtype, @@ -988,6 +999,9 @@ def _flash_attn_bwd( AtomLayoutMdQ, num_threads, V_in_regs=V_in_regs, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( @@ -1004,14 +1018,14 @@ def _flash_attn_bwd( score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None and len(aux_tensors) > 0, + has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, ) # Block sparse tensors for backward use Q-direction indexing (transposed from forward). - # sparse_block_size_q = 2*tile_m matches forward's q_stage=2 pipelining. + # sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity. sparse_tensors_compile = None - if block_sparse_tensors is not None and compute_capability in [10, 11]: + if block_sparse_tensors is not None: expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( batch_size, num_head, seqlen_q, seqlen_k, m_block_size, n_block_size, subtile_factor, @@ -1051,8 +1065,9 @@ def _flash_attn_bwd( sparse_tensors_compile, options="--enable-tvm-ffi", ) + # Runtime normalization of block sparse tensors for both SM90 and SM100 normalized_block_sparse_tensors = None - if block_sparse_tensors is not None and compute_capability in [10, 11]: + if block_sparse_tensors is not None: expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( batch_size, num_head, seqlen_q, seqlen_k, m_block_size, n_block_size, subtile_factor, @@ -1090,6 +1105,7 @@ def _flash_attn_bwd( ) num_threads = 256 if compute_capability == 9 else 128 + arch = compute_capability * 10 # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 compile_key_post = ( compute_capability, @@ -1111,7 +1127,6 @@ def _flash_attn_bwd( to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_q, seqused_q) ] - arch = compute_capability * 10 fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB )