Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def consume_block_sparse_loads(
intra_wg_overlap: cutlass.Constexpr,
warp_scheduler_barrier_sync: Callable,
warp_scheduler_barrier_arrive: Callable,
seqlen=None,
):
"""Consume the mask and full block lists for a single tile on the consumer side.

Expand Down Expand Up @@ -379,6 +380,7 @@ def consume_block_sparse_loads(
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
kv_consumer_state = process_first_half_block(
n_block=mask_n_block,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(
mask_fn,
Expand All @@ -394,6 +396,7 @@ def consume_block_sparse_loads(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=mask_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
)
Expand All @@ -404,6 +407,7 @@ def consume_block_sparse_loads(
if curr_mask_block_cnt == 0:
kv_consumer_state = process_first_half_block(
n_block=full_n_block,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
score_mod_fn=score_mod_fn,
Expand All @@ -413,6 +417,7 @@ def consume_block_sparse_loads(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
)
Expand All @@ -422,6 +427,7 @@ def consume_block_sparse_loads(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=full_n_block,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
)
Expand Down
185 changes: 144 additions & 41 deletions flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from flash_attn.cute import pipeline
from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase
from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
get_total_q_block_count_bwd,
get_block_sparse_iteration_info_bwd,
get_m_block_from_iter_bwd,
)


def mma_partition_fragment_AB(
Expand Down Expand Up @@ -62,6 +68,7 @@ def __init__(
AtomLayoutMdQ: int = 1,
num_threads: int = 384,
V_in_regs: bool = False,
subtile_factor: cutlass.Constexpr[int] = 1,
):
self.dtype = dtype
# padding head_dim to a multiple of 16 as k_block_size
Expand Down Expand Up @@ -106,6 +113,7 @@ def __init__(
# TODO: impl these for hdim 64
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
self.subtile_factor = subtile_factor

@staticmethod
def can_implement(
Expand Down Expand Up @@ -298,6 +306,8 @@ def __call__(
mdQ_semaphore: Optional[cute.Tensor] = None,
mdK_semaphore: Optional[cute.Tensor] = None,
mdV_semaphore: Optional[cute.Tensor] = None,
aux_tensors: Optional[list] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
"determinism not supported yet for Sm90"
Expand Down Expand Up @@ -424,6 +434,8 @@ def __call__(
LOG2_E = math.log2(math.e)
softmax_scale_log2 = softmax_scale * LOG2_E

self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)

self.kernel(
tma_tensor_Q,
tma_tensor_K,
Expand Down Expand Up @@ -456,6 +468,7 @@ def __call__(
tile_sched_params,
TileScheduler,
SharedStorage,
blocksparse_tensors,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
Expand Down Expand Up @@ -498,6 +511,7 @@ def kernel(
tile_sched_params: ParamsBase,
TileScheduler: cutlass.Constexpr[Callable],
SharedStorage: cutlass.Constexpr[Callable],
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())

Expand Down Expand Up @@ -579,6 +593,7 @@ def kernel(
self.tile_n,
window_size_left=None,
window_size_right=None,
swap_AB=self.SdP_swapAB,
)
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)

Expand Down Expand Up @@ -607,6 +622,7 @@ def kernel(
block_info,
SeqlenInfoCls,
TileSchedulerCls,
blocksparse_tensors,
)
if warp_idx == 1:
for warp_group_idx in cutlass.range(self.num_mma_warp_groups):
Expand Down Expand Up @@ -648,6 +664,7 @@ def kernel(
SeqlenInfoCls,
AttentionMaskCls,
TileSchedulerCls,
blocksparse_tensors,
)

@cute.jit
Expand All @@ -674,6 +691,7 @@ def load(
block_info: BlockInfo,
SeqlenInfoCls: Callable,
TileSchedulerCls: Callable,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4

Expand Down Expand Up @@ -723,48 +741,132 @@ def load(
load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO)

m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
# First iteration: load K together w Q & LSE, then V together w dO & dPsum
m_block = m_block_min
pipeline_Q.producer_acquire(
producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
)
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
load_Q(m_block, producer_state=producer_state_Q)
# cp.async.bulk is using ptx, so we need to elect one thread to do it
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q
)
pipeline_dO.producer_acquire(
producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
)
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()
# Subsequent iterations: load Q & LSE, then dO & dPsum
for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
pipeline_Q.producer_acquire(producer_state_Q)
load_Q(m_block, producer_state=producer_state_Q)
# cp.async.bulk is using ptx, so we need to elect one thread to do it
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q

if const_expr(self.use_block_sparsity):
total_m_block_cnt = get_total_q_block_count_bwd(
blocksparse_tensors,
batch_idx,
head_idx,
n_block,
subtile_factor=self.subtile_factor,
m_block_max=m_block_max,
)
pipeline_dO.producer_acquire(producer_state_dO_cur)
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()
process_tile = total_m_block_cnt > Int32(0)
else:
total_m_block_cnt = m_block_max - m_block_min
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max

if process_tile:
if const_expr(self.use_block_sparsity):
(
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
loop_count,
) = get_block_sparse_iteration_info_bwd(
blocksparse_tensors,
batch_idx,
head_idx,
n_block,
self.subtile_factor,
m_block_max,
)

for iter_idx in cutlass.range(loop_count, unroll=1):
m_block, _ = get_m_block_from_iter_bwd(
iter_idx,
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
self.subtile_factor,
)

if iter_idx == 0:
pipeline_Q.producer_acquire(
producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
)
load_K(
tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)
)
load_Q(m_block, producer_state=producer_state_Q)
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q
)
pipeline_dO.producer_acquire(
producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
)
load_V(
tma_bar_ptr=pipeline_dO.producer_get_barrier(
producer_state_dO_cur
)
)
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()
else:
pipeline_Q.producer_acquire(producer_state_Q)
load_Q(m_block, producer_state=producer_state_Q)
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q
)
pipeline_dO.producer_acquire(producer_state_dO_cur)
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()
else:
first_m_block = m_block_min
pipeline_Q.producer_acquire(
producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
)
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
load_Q(first_m_block, producer_state=producer_state_Q)
with cute.arch.elect_one():
load_LSE(first_m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q
)
pipeline_dO.producer_acquire(
producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
)
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
load_dO(first_m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()

for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
pipeline_Q.producer_acquire(producer_state_Q)
load_Q(m_block, producer_state=producer_state_Q)
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
producer_state_dO_cur = (
producer_state_dO
if const_expr(self.Q_stage != self.dO_stage)
else producer_state_Q
)
pipeline_dO.producer_acquire(producer_state_dO_cur)
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
producer_state_Q.advance()
producer_state_dO.advance()

tile_scheduler.prefetch_next_work()
tile_scheduler.advance_to_next_work()
Expand Down Expand Up @@ -801,6 +903,7 @@ def mma(
SeqlenInfoCls: Callable,
AttentionMaskCls: Callable,
TileSchedulerCls: Callable,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
warp_group_thread_layout = cute.make_layout(
Expand Down
1 change: 1 addition & 0 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ def mma(
self.intra_wg_overlap,
self.warp_scheduler_barrier_sync,
self.warp_scheduler_barrier_arrive,
seqlen,
)

# Handle empty case (when no blocks to process)
Expand Down
12 changes: 8 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ def _flash_attn_bwd(
AtomLayoutNdKV,
AtomLayoutMdQ,
V_in_regs,
use_block_sparsity,
)
else:
# Hash callables for compile key
Expand Down Expand Up @@ -964,6 +965,7 @@ def _flash_attn_bwd(
AtomLayoutMdQ,
num_threads,
V_in_regs=V_in_regs,
subtile_factor=1,
)
else:
fa_bwd_obj = FlashAttentionBackwardSm100(
Expand All @@ -987,10 +989,11 @@ def _flash_attn_bwd(
# Block sparse tensors for backward use Q-direction indexing (transposed from forward).
# sparse_block_size_q = 2*tile_m matches forward's q_stage=2 pipelining.
sparse_tensors_compile = None
if block_sparse_tensors is not None and compute_capability == 10:
if block_sparse_tensors is not None:
bwd_subtile_factor = subtile_factor if compute_capability == 10 else 1
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, subtile_factor,
m_block_size, n_block_size, bwd_subtile_factor,
)
compile_time_normalized = normalize_block_sparse_tensors(
block_sparse_tensors,
Expand Down Expand Up @@ -1028,10 +1031,11 @@ def _flash_attn_bwd(
options="--enable-tvm-ffi",
)
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None and compute_capability == 10:
if block_sparse_tensors is not None:
bwd_subtile_factor = subtile_factor if compute_capability == 10 else 1
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, subtile_factor,
m_block_size, n_block_size, bwd_subtile_factor,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def apply_mask_sm100(
mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)

elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse w/ mask_mod
# Block sparse case w/ mask_mod
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
Expand Down
Loading