Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Block-sparsity utilities for FlexAttention
"""

from typing import NamedTuple, Optional, Tuple
from typing import Callable, NamedTuple, Tuple

import cutlass.cute as cute
import torch
Expand All @@ -17,8 +17,8 @@ def ceildiv(a: int, b: int) -> int:
class BlockSparseTensors(NamedTuple):
mask_block_cnt: cute.Tensor
mask_block_idx: cute.Tensor
full_block_cnt: Optional[cute.Tensor]
full_block_idx: Optional[cute.Tensor]
full_block_cnt: cute.Tensor | None
full_block_idx: cute.Tensor | None

def __new_from_mlir_values__(self, values):
if len(values) == 2:
Expand All @@ -29,34 +29,42 @@ def __new_from_mlir_values__(self, values):
class BlockSparseTensorsTorch(NamedTuple):
mask_block_cnt: torch.Tensor
mask_block_idx: torch.Tensor
full_block_cnt: Optional[torch.Tensor] = None
full_block_idx: Optional[torch.Tensor] = None
full_block_cnt: torch.Tensor | None = None
full_block_idx: torch.Tensor | None = None


def _expand_sparsity_tensor(
tensor: torch.Tensor,
expected_shape: Tuple[int, ...],
tensor_name: str,
context: str | None,
hint: str | Callable[[], str] | None,
) -> torch.Tensor:
"""Check if we need to expand the tensor to expected shape, and do so if possible."""
needs_expand = tensor.shape != expected_shape
if not needs_expand:
return tensor
can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
if not can_expand:
context_clause = f" ({context})" if context else ""
resolved_hint = hint() if callable(hint) else hint
hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
raise ValueError(
f"{tensor_name} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
f"{hint_clause}"
)
return tensor.expand(*expected_shape).contiguous()


def _check_and_expand_block(
name: str,
cnt: Optional[torch.Tensor],
idx: Optional[torch.Tensor],
cnt: torch.Tensor | None,
idx: torch.Tensor | None,
expected_count_shape: Tuple[int, int, int],
expected_index_shape: Tuple[int, int, int, int],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
context: str | None,
hint: str | Callable[[], str] | None,
) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
if (cnt is None) != (idx is None):
raise ValueError(
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
Expand All @@ -69,8 +77,12 @@ def _check_and_expand_block(
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
if not cnt.is_cuda or not idx.is_cuda:
raise ValueError(f"{name}_block tensors must live on CUDA")
expanded_cnt = _expand_sparsity_tensor(cnt, expected_count_shape, f"{name}_block_cnt")
expanded_idx = _expand_sparsity_tensor(idx, expected_index_shape, f"{name}_block_idx")
expanded_cnt = _expand_sparsity_tensor(
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
)
expanded_idx = _expand_sparsity_tensor(
idx, expected_index_shape, f"{name}_block_idx", context, hint
)
return expanded_cnt, expanded_idx


Expand Down Expand Up @@ -120,6 +132,8 @@ def normalize_block_sparse_tensors(
*,
expected_count_shape: Tuple[int, int, int],
expected_index_shape: Tuple[int, int, int, int],
context: str | None = None,
hint: str | Callable[[], str] | None = None,
) -> BlockSparseTensorsTorch:
if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
Expand All @@ -130,6 +144,8 @@ def normalize_block_sparse_tensors(
tensors.mask_block_idx,
expected_count_shape,
expected_index_shape,
context,
hint,
)
if mask_cnt is None or mask_idx is None:
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
Expand All @@ -140,6 +156,8 @@ def normalize_block_sparse_tensors(
tensors.full_block_idx,
expected_count_shape,
expected_index_shape,
context,
hint,
)
if full_cnt is not None and mask_cnt.device != full_cnt.device:
raise ValueError("All block sparse tensors must be on the same device")
Expand All @@ -158,7 +176,7 @@ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:

def to_cute_block_sparse_tensors(
tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
) -> Optional[BlockSparseTensors]:
) -> BlockSparseTensors | None:
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
if not is_block_sparsity_enabled(tensors):
return None
Expand Down
131 changes: 130 additions & 1 deletion flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flash_attn.cute import pipeline
from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase
from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd
from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute.block_sparse_utils import (
get_total_q_block_count_bwd,
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
AtomLayoutMdQ: int = 1,
num_threads: int = 384,
V_in_regs: bool = False,
score_mod: cutlass.Constexpr | None = None,
score_mod_bwd: cutlass.Constexpr | None = None,
mask_mod: cutlass.Constexpr | None = None,
has_aux_tensors: cutlass.Constexpr = False,
subtile_factor: cutlass.Constexpr[int] = 1,
Expand Down Expand Up @@ -118,13 +121,16 @@ def __init__(
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64

self.score_mod = score_mod
self.score_mod_bwd = score_mod_bwd
self.mask_mod = mask_mod
self.has_aux_tensors = has_aux_tensors
self.subtile_factor = subtile_factor
if cutlass.const_expr(has_aux_tensors):
self.vec_size: cutlass.Constexpr = 1
else:
self.vec_size: cutlass.Constexpr = 4
self.qk_acc_dtype = Float32

@staticmethod
def can_implement(
Expand Down Expand Up @@ -443,7 +449,10 @@ def __call__(
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)

LOG2_E = math.log2(math.e)
softmax_scale_log2 = softmax_scale * LOG2_E
if const_expr(self.score_mod is None):
softmax_scale_log2 = softmax_scale * LOG2_E
else:
softmax_scale_log2 = LOG2_E

fastdiv_mods = None
if const_expr(aux_tensors is not None):
Expand Down Expand Up @@ -856,6 +865,93 @@ def load(
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()

@cute.jit
def apply_score_mod(
self,
acc_S: cute.Tensor,
thr_mma_SdP: cute.core.ThrMma,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen_info: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
# [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing
cS = cute.make_identity_tensor(
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
)
cS = cute.domain_offset(
(n_block * self.tile_n, m_block * self.tile_m)
if self.SdP_swapAB
else (m_block * self.tile_m, n_block * self.tile_n),
cS,
)
tScS = thr_mma_SdP.partition_C(cS)

apply_score_mod_inner(
acc_S,
tScS,
self.score_mod,
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead,
transpose_indices=self.SdP_swapAB,
)

@cute.jit
def apply_score_mod_bwd(
self,
grad_tensor: cute.Tensor,
score_tensor: cute.Tensor,
thr_mma_SdP: cute.core.ThrMma,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen_info: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
cS = cute.make_identity_tensor(
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
)
cS = cute.domain_offset(
(n_block * self.tile_n, m_block * self.tile_m)
if self.SdP_swapAB
else (m_block * self.tile_m, n_block * self.tile_n),
cS,
)
tScS = thr_mma_SdP.partition_C(cS)

apply_score_mod_bwd_inner(
grad_tensor,
score_tensor,
tScS,
self.score_mod_bwd,
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead,
transpose_indices=self.SdP_swapAB,
)

@cute.jit
def mma(
self,
Expand Down Expand Up @@ -1196,6 +1292,24 @@ def mma_one_m_block(
)
acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)

if const_expr(self.score_mod_bwd is not None):
acc_S_pre = cute.make_fragment_like(acc_S)
cute.autovec_copy(acc_S, acc_S_pre)

if const_expr(self.score_mod is not None):
self.apply_score_mod(
acc_S,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)

# (3) [Pointwise 1] P = exp(S - LSE)
if cutlass.const_expr(mask_fn is not None):
mask_fn(acc_S, m_block=m_block)
Expand Down Expand Up @@ -1226,6 +1340,21 @@ def mma_one_m_block(
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])

if const_expr(self.score_mod_bwd is not None):
self.apply_score_mod_bwd(
acc_dP,
acc_S_pre,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)

# Convert dS from f32 -> f16
tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)

Expand Down
16 changes: 14 additions & 2 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,6 @@ def _flash_attn_bwd(
assert cu_seqlens_q is None and cu_seqlens_k is None, (
"varlen + score_mod not supported in bwd yet"
)
assert compute_capability in [10, 11], "score_mod in bwd only supported on SM100/SM110 for now"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep but make it [9,10,11]?


device = q.device
out_torch_dtype = q.dtype
Expand Down Expand Up @@ -910,7 +909,6 @@ def _flash_attn_bwd(
num_aux_tensors,
use_block_sparsity,
)
cute_aux_tensors = None
else:
compile_key = (
compute_capability,
Expand Down Expand Up @@ -999,6 +997,8 @@ def _flash_attn_bwd(
AtomLayoutMdQ,
num_threads,
V_in_regs=V_in_regs,
score_mod=score_mod,
score_mod_bwd=score_mod_bwd,
mask_mod=mask_mod,
has_aux_tensors=aux_tensors is not None,
subtile_factor=subtile_factor,
Expand Down Expand Up @@ -1034,6 +1034,12 @@ def _flash_attn_bwd(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
context="_flash_attn_bwd",
hint=lambda: (
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
f"(sparse_block_size_q={sparse_block_size_q})."
),
)
sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized)

Expand Down Expand Up @@ -1076,6 +1082,12 @@ def _flash_attn_bwd(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
context="_flash_attn_bwd",
hint=lambda: (
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
f"(sparse_block_size_q={sparse_block_size_q})."
),
)

_flash_attn_bwd.compile_cache[compile_key](
Expand Down
9 changes: 5 additions & 4 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def apply_mask(
): # FlexAttention mask mod
nrow = const_expr(cute.size(tScS_mn.shape[0]))
ncol = const_expr(cute.size(tScS_mn.shape[1]))
thr_col_offset = tScS_mn[0, 0][1]
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
Expand All @@ -150,7 +149,9 @@ def apply_mask(
)

for r in cutlass.range_constexpr(nrow):
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
# Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
local_row = tScS_mn[r, 0][ROW]
global_row_idx = local_row + m_block * self.tile_m
row_for_mod = global_row_idx
head_idx_for_mod = head_idx
if const_expr(self.qhead_per_kvhead_packgqa != 1):
Expand All @@ -162,7 +163,7 @@ def apply_mask(
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])

for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][1]
col_idx_local = t0ScS_mn[0, col][COL]
# Convert to absolute column index
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
col_for_mod = global_col_idx
Expand Down Expand Up @@ -354,7 +355,7 @@ def apply_mask_sm100(
mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)

elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse w/ mask_mod
# Block sparse case w/ mask_mod
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
Expand Down
Loading