Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make flash attn compatible with flash_attn v2 api. WIP. #473

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Causal results are stable now. Consistent with aten._flash_attention_…
…forward.
tongxin committed Mar 20, 2025
commit a33ff44275784588d446b3a81a7c715188c0ed2a
82 changes: 39 additions & 43 deletions src/flag_gems/ops/attention.py
Original file line number Diff line number Diff line change
@@ -546,18 +546,16 @@ def apply_mask(
need_mask: tl.constexpr = is_causal | has_alibi | is_local | (not is_even_mn)
if need_mask:
col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - ws_left)
col_rb = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + ws_right)
col_rb = min(max_seqlen_k, row_idx + max_seqlen_k - max_seqlen_q + ws_right)

if not has_alibi:
alibi_slope = .0

S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])
if has_alibi:
S -= alibi_slope * tl.abs(col_idx[None, :] - row_idx[:, None])

if is_causal:
S = tl.where(col_idx[None, :] >= col_rb[:, None], float('-inf'), S)
S = tl.where(col_idx[None, :] > col_rb[:, None], float('-inf'), S)

if is_local:
S = tl.where(col_idx[None, :] >= col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)
S = tl.where(col_idx[None, :] > col_rb[:, None] | col_idx[None, :] < col_lb[:, None], float('-inf'), S)

if (not is_local) & (not is_causal) & (not is_even_mn):
S = tl.where(col_idx[None, :] >= max_seqlen_k, float('-inf'), S)
@@ -578,34 +576,36 @@ def softmax_rescale(
prev_max = row_max
row_max = tl.maximum(row_max, tl.max(S, 1))

# if not is_init:
# if is_border:
# cur_max = tl.where(row_max == float('-inf'), 0, row_max)
# else:
# cur_max = row_max
# p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
# row_sum *= p_scale
# O_acc *= p_scale[:, None]

if is_border:
cur_max = tl.where(row_max == float('-inf'), 0, row_max)
else:
cur_max = row_max

p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
row_sum *= p_scale
O_acc *= p_scale[:, None]

max_scaled = tl.where(row_max == float('-inf'), 0, row_max * softmax_scale_log2e)

P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
row_sum = row_sum + tl.sum(P, 1)
return O_acc, P, row_max, row_sum


def block_m_heuristic(headdim, is_dropout):
return 128 if headdim <= 128 else 64
block_m = 128 if headdim <= 128 else 64
print('block_m:', block_m)
return block_m

def block_n_heuristic(headdim, is_dropout):
return 64 if headdim <= 64 else 32
block_n = 64 if headdim <= 64 else 32
print('block_n:', block_n)
return block_n

def is_even_mn(args):
even_mn = (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0)
print('is_even_mn:', even_mn)
return even_mn

def block_m_splitkv_heuristic(headdim):
return 128 if headdim <= 128 else 64
@@ -633,8 +633,8 @@ def block_n_splitkv_heuristic(headdim):
'BLOCK_N': lambda args: block_n_heuristic(args["HEAD_DIM"], args["is_dropout"]),
'num_warps': lambda args: 4,
'num_stages': lambda args: 3 if args["HEAD_DIM"] <= 128 else 2,
'PRE_LOAD_V': lambda args: True,
'IS_EVEN_MN': lambda args: (args["seqlen_q"] % args["BLOCK_M"] == 0) and (args["seqlen_k"] % args["BLOCK_N"] == 0),
'PRE_LOAD_V': lambda args: False,
'IS_EVEN_MN': lambda args: is_even_mn(args),
}
)
@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k", "philox_seed", "philox_offset"])
@@ -762,7 +762,8 @@ def flash_fwd_kernel(
if PRE_LOAD_V:
V = tl.load(p_bv0 + off, cache_modifier=".cg")
else:
kvmask = col < seqlen_k
col_idx = col_start + tl.arange(0, BLOCK_N)
kvmask = col_idx < seqlen_k
K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
if PRE_LOAD_V:
V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
@@ -845,7 +846,7 @@ def flash_fwd_kernel(
off = col_start * k_s_stride
K = tl.load(p_bk0 + off, cache_modifier=".cg")
if PRE_LOAD_V:
V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
V = tl.load(p_bv0 + off, cache_modifier=".cg")
S = tl.dot(Q, K)

col_idx = col_start + tl.arange(0, BLOCK_N)
@@ -919,12 +920,13 @@ def flash_fwd_kernel(

O_ = tl.dot(P, V, O_)


# LSE
# Note, rowsum = exp(-rowmax) * lse, therefore rowmax + log(rowsum) cancels the effect of rowmax and outputs lse only.
lse = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), float('inf'), rowmax_ * softmax_scale + tl.log(rowsum_))
inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)


# Rescale output
inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
if is_dropout:
O_ *= inv_sum[:, None] * rpdrop
else:
@@ -1101,7 +1103,6 @@ def flash_fwd_splitkv_kernel(
is_local=is_local,
has_alibi=has_alibi
)
# col_idx -= BLOCK_N

O_, P, rowmax_, rowsum_ = softmax_rescale(
O_,
@@ -1119,21 +1120,16 @@ def flash_fwd_splitkv_kernel(
else:
V = tl.load(V_ptr + V_offset, mask=kvmask[:, None], cache_modifier=".cg")
O_ = tl.dot(P, V, O_, allow_tf32=False)
# if n_masking_blocks > 1 and n_block <= n_block_min:
# break


for n_block in tl.range(n_block_max - n_masking_blocks - 1, n_block_min - 1, step=-1, num_stages=num_stages):
col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
K_offset = col_idx[None, :] * k_s_stride + tl.arange(0, HEAD_DIM)[:, None]
K = tl.load(K_ptr + K_offset, cache_modifier=".cg")
# if PRE_LOAD_V:
# V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
# V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
S = tl.dot(Q, K)

if PRE_LOAD_V:
V_offset = col_idx[:, None] * k_s_stride + tl.arange(0, HEAD_DIM)[None, :]
V = tl.load(V_ptr + V_offset, cache_modifier=".cg")
S = tl.dot(Q, K)

S = apply_mask(
S,
@@ -1149,7 +1145,6 @@ def flash_fwd_splitkv_kernel(
is_local=is_local,
has_alibi=has_alibi
)
# col_idx -= BLOCK_N

O_, P, rowmax_, rowsum_ = softmax_rescale(
O_,
@@ -1246,8 +1241,7 @@ def flash_fwd_splitkv_combine_kernel(
out_splits = tl.load(out_splits_ptr + out_split_offset, mask=out_split_mask, other=0)
out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
out = out.to(out_ptr.type.element_ty)

# tl.device_print('O', out)

# Write back output
out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size)
tl.store(out_ptr + out_offset, out, mask=out_mask[:, None])
@@ -1364,16 +1358,18 @@ def splits_heuristics(num_tasks, num_sms, n_blocks):
rpdrop = 1. / p_dropout

# Check splitkv
if not is_dropout:
n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m_splitkv_heuristic(head_size))
def try_split_kv():
block_m = block_m_splitkv_heuristic(head_size)
n_tasks = batch_size * num_heads * triton.cdiv(seqlen_q, block_m)
num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count
n_blocks = triton.cdiv(seqlen_k, block_n_splitkv_heuristic(head_size))
block_n = block_n_splitkv_heuristic(head_size)
n_blocks = triton.cdiv(seqlen_k, block_n)
n_splits = splits_heuristics(n_tasks, num_sms, n_blocks)
print('n_blocks:', n_blocks)
print('n_splits:', n_splits)
else:
n_splits = 1
return n_splits

n_splits = try_split_kv() if is_dropout else 1
print('n_splits:', n_splits)

if n_splits > 1:
lse_splits = torch.empty(
(n_splits, batch_size, num_heads, seqlen_q),