Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,15 +1844,25 @@ def apply_score_mod_bwd(
self,
grad_tensor,
score_tensor,
index_tensor,
thr_copy_t2r,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me double check this top commit

thr_mma_S,
batch_idx,
head_idx,
m_block,
n_block,
stage,
softmax_scale,
seqlen_info,
aux_tensors=None,
fastdiv_mods=(None, None),
):
"""Apply backward score modification (joint graph) for SM100."""
cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m))
cS_bwd = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS_bwd)
tScS_bwd = thr_mma_S.partition_C(cS_bwd)
tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd)
index_tensor = tScS_idx_bwd[None, stage, 0, 0]

apply_score_mod_bwd_inner(
grad_tensor,
score_tensor,
Expand All @@ -1871,6 +1881,10 @@ def apply_score_mod_bwd(
transpose_indices=True,
)

for i in cutlass.range(cute.size(grad_tensor), unroll_full=True):
kv_idx = index_tensor[i][0]
grad_tensor[i] = 0.0 if kv_idx >= seqlen_info.seqlen_k else grad_tensor[i]

@cute.jit
def compute_loop(
self,
Expand Down Expand Up @@ -2213,28 +2227,21 @@ def compute_loop(

if const_expr(self.score_mod_bwd is not None):
tSrS_pre_cur = tSrS_pre[None, stage, 0, 0]
cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m))
cS_bwd = cute.domain_offset(
(n_block * self.tile_n, m_block * self.tile_m), cS_bwd
)
tScS_bwd = thr_mma_S.partition_C(cS_bwd)
tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd)
tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0]
self.apply_score_mod_bwd(
tdPrdP_cur,
tSrS_pre_cur,
tScS_idx_cur,
thr_copy_t2r,
thr_mma_S,
batch_idx,
head_idx,
m_block,
n_block,
stage,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)
# Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd
for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True):
kv_idx = tScS_idx_cur[i][0]
tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i]

tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype)
utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt)
Expand Down
195 changes: 194 additions & 1 deletion flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cutlass.utils.hopper_helpers as sm90_utils_basic
from cutlass.cute.nvgpu import cpasync, warpgroup
from cutlass.cute.arch import ProxyKind, SharedSpace
from cutlass.cute import FastDivmodDivisor
from cutlass import Float32, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum

Expand All @@ -22,6 +23,7 @@
from flash_attn.cute import pipeline
from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase
from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd
from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
get_total_q_block_count_bwd,
Expand Down Expand Up @@ -68,6 +70,10 @@ def __init__(
AtomLayoutMdQ: int = 1,
num_threads: int = 384,
V_in_regs: bool = False,
score_mod: cutlass.Constexpr | None = None,
score_mod_bwd: cutlass.Constexpr | None = None,
mask_mod: cutlass.Constexpr | None = None,
has_aux_tensors: cutlass.Constexpr = False,
subtile_factor: cutlass.Constexpr[int] = 1,
):
self.dtype = dtype
Expand Down Expand Up @@ -113,7 +119,17 @@ def __init__(
# TODO: impl these for hdim 64
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64

self.score_mod = score_mod
self.score_mod_bwd = score_mod_bwd
self.mask_mod = mask_mod
self.has_aux_tensors = has_aux_tensors
self.subtile_factor = subtile_factor
if cutlass.const_expr(has_aux_tensors):
self.vec_size: cutlass.Constexpr = 1
else:
self.vec_size: cutlass.Constexpr = 4
self.qk_acc_dtype = Float32

@staticmethod
def can_implement(
Expand Down Expand Up @@ -432,7 +448,18 @@ def __call__(
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)

LOG2_E = math.log2(math.e)
softmax_scale_log2 = softmax_scale * LOG2_E
if const_expr(self.score_mod is None):
softmax_scale_log2 = softmax_scale * LOG2_E
else:
softmax_scale_log2 = LOG2_E

fastdiv_mods = None
if const_expr(aux_tensors is not None):
seqlen_q = cute.size(mQ.shape[0])
seqlen_k = cute.size(mK.shape[0])
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)

self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)

Expand Down Expand Up @@ -468,6 +495,8 @@ def __call__(
tile_sched_params,
TileScheduler,
SharedStorage,
aux_tensors,
fastdiv_mods,
blocksparse_tensors,
).launch(
grid=grid_dim,
Expand Down Expand Up @@ -511,6 +540,8 @@ def kernel(
tile_sched_params: ParamsBase,
TileScheduler: cutlass.Constexpr[Callable],
SharedStorage: cutlass.Constexpr[Callable],
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
Expand Down Expand Up @@ -671,6 +702,8 @@ def kernel(
SeqlenInfoCls,
AttentionMaskCls,
TileSchedulerCls,
aux_tensors,
fastdiv_mods,
blocksparse_tensors,
)

Expand Down Expand Up @@ -883,6 +916,100 @@ def load(
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()

@cute.jit
def apply_score_mod(
self,
acc_S: cute.Tensor,
thr_mma_SdP: cute.core.ThrMma,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen_info: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
# Create index tensor matching SM100's approach - no view transformation
# Index tensor shape and offset are (n, m) when SdP_swapAB=True to match
# the transposed score matrix layout
cS = cute.make_identity_tensor(
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
)
cS = cute.domain_offset(
(n_block * self.tile_n, m_block * self.tile_m)
if self.SdP_swapAB
else (m_block * self.tile_m, n_block * self.tile_n),
cS,
)
tScS = thr_mma_SdP.partition_C(cS)

# Pass acc_S directly without view transformation to keep alignment with index tensor
# (matching SM100's approach which doesn't use make_acc_tensor_mn_view)
apply_score_mod_inner(
acc_S,
tScS,
self.score_mod,
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead,
transpose_indices=self.SdP_swapAB,
)

@cute.jit
def apply_score_mod_bwd(
self,
grad_tensor: cute.Tensor,
score_tensor: cute.Tensor,
thr_mma_SdP: cute.core.ThrMma,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen_info: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
# Create index tensor matching SM100's approach - no view transformation
cS = cute.make_identity_tensor(
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
)
cS = cute.domain_offset(
(n_block * self.tile_n, m_block * self.tile_m)
if self.SdP_swapAB
else (m_block * self.tile_m, n_block * self.tile_n),
cS,
)
tScS = thr_mma_SdP.partition_C(cS)

# Pass tensors directly without view transformation to keep alignment with index tensor
# (matching SM100's approach which doesn't use make_acc_tensor_mn_view)
apply_score_mod_bwd_inner(
grad_tensor,
score_tensor,
tScS,
self.score_mod_bwd,
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead,
transpose_indices=self.SdP_swapAB,
)

@cute.jit
def mma(
self,
Expand Down Expand Up @@ -914,6 +1041,8 @@ def mma(
SeqlenInfoCls: Callable,
AttentionMaskCls: Callable,
TileSchedulerCls: Callable,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
Expand Down Expand Up @@ -1090,6 +1219,7 @@ def mma(
m_block_max,
)

# With subtile_factor>1, is_full_block is dynamic; always apply mask_mod.
mask_fn = partial(
mask.apply_mask,
batch_idx=batch_idx,
Expand All @@ -1099,6 +1229,9 @@ def mma(
mask_seqlen=True,
mask_causal=self.is_causal,
mask_local=self.is_local,
mask_mod=self.mask_mod,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)

dKV_accumulate = False
Expand All @@ -1117,6 +1250,14 @@ def mma(
consumer_state_dO,
mask_fn=mask_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True
else:
Expand All @@ -1129,6 +1270,9 @@ def mma(
mask_seqlen=True,
mask_causal=self.is_causal,
mask_local=self.is_local,
mask_mod=self.mask_mod,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = False
for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):
Expand All @@ -1138,6 +1282,14 @@ def mma(
consumer_state_dO,
mask_fn=mask_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True

Expand Down Expand Up @@ -1209,6 +1361,14 @@ def mma_one_m_block(
softmax_scale_log2: Float32,
mask_fn: Optional[Callable] = None,
dKV_accumulate: Boolean = True,
thr_mma_SdP: Optional[cute.core.ThrMma] = None,
batch_idx: Int32 = 0,
head_idx: Int32 = 0,
n_block: Int32 = 0,
softmax_scale: Float32 = 1.0,
seqlen: Optional[SeqlenInfoQK] = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
):
consumer_state_dO_cur = (
consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
Expand All @@ -1226,6 +1386,24 @@ def mma_one_m_block(
)
acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)

if const_expr(self.score_mod_bwd is not None):
acc_S_pre = cute.make_fragment_like(acc_S)
cute.autovec_copy(acc_S, acc_S_pre)

if const_expr(self.score_mod is not None):
self.apply_score_mod(
acc_S,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)

# (3) [Pointwise 1] P = exp(S - LSE)
if cutlass.const_expr(mask_fn is not None):
mask_fn(acc_S, m_block=m_block)
Expand Down Expand Up @@ -1256,6 +1434,21 @@ def mma_one_m_block(
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])

if const_expr(self.score_mod_bwd is not None):
self.apply_score_mod_bwd(
acc_dP,
acc_S_pre,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)

# Convert dS from f32 -> f16
tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)

Expand Down
Loading