Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions benchmarks/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def run(*args, **kwargs):
else:
page_table = None

# for causal in [False, True]:
for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
for causal in [False, True]:
# for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
Expand Down Expand Up @@ -395,7 +395,10 @@ def run(*args, **kwargs):
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python')
if not varlen:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
else:
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')

if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
Expand Down
1 change: 1 addition & 0 deletions flash_attn/cute/cute_dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def cute_compile_patched(*args, **kwargs):
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
return output


def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
Expand Down
290 changes: 12 additions & 278 deletions flash_attn/cute/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,15 @@ def __call__(
TileScheduler = SingleTileVarlenScheduler
num_head = mdQ.shape[1]
num_batch = mCuSeqlensQ.shape[0] - 1
num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
else:
TileScheduler = SingleTileScheduler
num_head = mdQ.shape[2]
num_batch = mdQ.shape[0]
num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)

tile_sched_args = TileSchedulerArguments(
num_block=cute.ceil_div(mdQ.shape[1], self.tile_m),
num_block=num_block,
num_head=num_head,
num_batch=num_batch,
num_splits=1,
Expand Down Expand Up @@ -318,15 +320,15 @@ def kernel(
tile_scheduler = TileScheduler.create(tile_sched_params)
work_tile = tile_scheduler.initial_work_tile_info()

m_block, num_head, batch_size, _ = work_tile.tile_idx
m_block, head_idx, batch_idx, _ = work_tile.tile_idx

if work_tile.is_valid_tile:
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# ///////////////////////////////////////////////////////////////////////////////

seqlen = SeqlenInfoQK.create(
batch_size,
batch_idx,
mdQ.shape[1],
0,
mCuSeqlensQ=mCuSeqlensQ,
Expand All @@ -335,14 +337,16 @@ def kernel(
mSeqUsedK=None,
)
if const_expr(not seqlen.has_cu_seqlens_q):
mdQ_cur = mdQ[batch_size, None, num_head, None]
mdQaccum_cur = mdQaccum[batch_size, num_head, None]
mdQ_cur = mdQ[batch_idx, None, head_idx, None]
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
head_dim = mdQ.shape[3]
else:
padded_offset_q = seqlen.offset_q + batch_size * self.tile_m
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None])
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
if cutlass.const_expr(self.arch >= 90):
padded_offset_q = padded_offset_q // self.tile_m * self.tile_m
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
mdQaccum_cur = cute.domain_offset(
(padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None]
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
)
head_dim = mdQ.shape[2]

Expand Down Expand Up @@ -457,273 +461,3 @@ def kernel(
tdQgdQ[None, rest_m, None],
pred=tdQpdQ[None, rest_m, None],
)


class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess):
def __init__(
self,
dtype: Type[cutlass.Numeric],
head_dim: int,
tile_m: int = 128,
num_threads: int = 256,
AtomLayoutMdQ: int = 1,
dQ_swapAB: bool = False,
):
super().__init__(
dtype=dtype,
head_dim=head_dim,
arch=90, # tmp dummy placement for now
tile_m=tile_m,
num_threads=num_threads,
AtomLayoutMdQ=AtomLayoutMdQ,
dQ_swapAB=dQ_swapAB,
)

def _setup_attributes(self):
self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128

self.sdQaccum_layout = cute.make_layout(
shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32)
)
self.epi_tile_q = (self.tile_m, self.tile_hdim)
self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
self.dtype,
LayoutEnum.ROW_MAJOR,
self.epi_tile_q,
1,
)

@cute.jit
def __call__(
self,
mdQaccum: cute.Tensor,
mdQ: cute.Tensor,
scale: cutlass.Float32,
stream: cuda.CUstream,
):
# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mdQaccum, mdQ = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mdQaccum, mdQ)
]
# (b, h, s*d) -> (s*d, h, b)
mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0]))
# (b, s, h, d) -> (s, d, h, b)
mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0]))

self._setup_attributes()

grid_dim = [
cute.ceil_div(mdQ.shape[0], self.tile_m),
cute.size(mdQ.shape[2]),
cute.size(mdQ.shape[3]),
]

cta_group = tcgen05.CtaGroup.ONE
self.mma_tiler_dsk = (self.tile_m, self.tile_hdim)

dS_major_mode = tcgen05.OperandMajorMode.MN
kt_major_mode_dsq = tcgen05.OperandMajorMode.MN

tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma(
cutlass.BFloat16,
dS_major_mode,
kt_major_mode_dsq,
cutlass.Float32,
cta_group,
self.mma_tiler_dsk,
)

dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk)
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_store_op,
mdQ,
cute.select(self.sdQ_layout, mode=[0, 1]),
dQ_cta_v_layout,
)

buffer_align_bytes = 1024

@cute.struct
class SharedStorage:
sdQaccum: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)],
128,
]

sdQ: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)],
buffer_align_bytes,
]

self.shared_storage = SharedStorage

self.kernel(
mdQaccum,
tma_tensor_dQ,
tma_atom_dQ,
self.sdQaccum_layout,
self.sdQ_layout,
tiled_mma_dsk,
scale,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)

@cute.kernel
def kernel(
self,
mdQaccum: cute.Tensor,
mdQ: cute.Tensor,
tma_atom_dQ: cute.CopyAtom,
sdQaccum_layout: cute.Layout,
sdQ_layout: cute.ComposedLayout,
tiled_mma_dsk: cute.TiledMma,
scale: cutlass.Float32,
):
tidx = cute.arch.thread_idx()[0]
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
m_block, head_idx, batch_idx = cute.arch.block_idx()

# SMEM
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
swz128 = cute.make_swizzle(3, 4, 3)
sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128)

sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner)

mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
mdQ_cur = mdQ[None, None, head_idx, batch_idx]

thr_mma_dsk = tiled_mma_dsk.get_slice(tidx)
dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2])
tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)
tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)

tmem_ld_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32
)
tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ)
thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)

cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1]))
tdQcdQ = thr_mma_dsk.partition_C(cdQ)
tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)
tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)

gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))

num_reduce_warps = 4
num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps

atom_universal_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128
)
tiler_mn, layout_tv = cute.make_layout_tv(
thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1),
val_layout=cute.make_layout(shape=4, stride=1),
)
G2S_tiled_copy_dQaccum = cute.make_tiled_copy(
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
)

smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx)

# S->R
tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)
tiled_smem_store_s2r = cute.make_tiled_copy(
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
)

s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx)
tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape)

# R->S
smem_copy_atom = sm100_utils_basic.get_smem_store_op(
LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld
)
tiled_smem_store_r2s = cute.make_tiled_copy(
smem_copy_atom,
layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,
tiler_mn=tiled_tmem_ld.tiler_mn,
)
tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))
tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)

num_stages = cute.size(tdQrdQ_t2r, mode=[1])
for stage in cutlass.range_constexpr(num_stages):
# G->S
gdQaccum_stage = cute.local_tile(
gdQaccum,
(self.tile_m * 32,),
(stage,),
)

gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0))
gdQaccum_stage_g2s = cute.make_tensor(
cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s
)

tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s)
tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum)

cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0])

cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)

# S -> R
tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None]
tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0]
tdQrdQ_r2s_cpy = cute.make_tensor(
tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape)
)

cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy)

cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)

# R->S
tdQrdQ_r2s_cpy = cute.make_tensor(
cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape,
)
dQ_vec = tdQrdQ_r2s_cpy.load() * scale
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype))

cute.copy(
tiled_smem_store_r2s,
tdQrdQ_r2s[None, None, None, None, 0],
tdQsdQ_r2s[None, None, None, None, 0],
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)

# S-> G
gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
tdQsdQ, tdQgdQ = cpasync.tma_partition(
tma_atom_dQ,
0,
cute.make_layout(1),
cute.group_modes(sdQ, 0, 2),
cute.group_modes(gdQ, 0, 2),
)

cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block])
Loading