Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 150 additions & 55 deletions benchmark/test_attention_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,74 +179,172 @@ class FlashAttnVarlenBenchmark(Benchmark):

def set_shapes(self, shape_file_path: Optional[List[Any]] = None):
# Collecting from qwen/Qwen3-1.7B --random-input 512 --random-output 2048 --num-prompts 200 --request-rate inf
# Format: (seq_lens, num_heads, head_size, block_size, num_blocks, alibi, soft_cap)
# ([(1, 1), (1, 1), (1, 1)], (16, 8), 128, 32, 18208, False, None),
# The performance is very poor, which may be related to prefill
flash_attn_configs = [
([(1, 1), (1, 1), (1, 1)], (16, 8), 128, 32, 18208, False, None),
([(1, 1), (1, 1), (23, 23)], (16, 8), 128, 32, 18208, False, None),
([(1, 1), (1, 1), (7, 7)], (16, 8), 128, 32, 18208, False, None),
([(1, 1), (1, 1), (39, 39)], (16, 8), 128, 32, 18208, False, None),
([(1, 1), (1, 1), (55, 55)], (16, 8), 128, 32, 18208, False, None),
([(1, 1), (1, 1), (70, 70)], (16, 8), 128, 32, 18208, False, None),
# Format: (cu_seq_lens_q, seqused_k, num_heads, head_size, block_size, num_blocks, alibi, soft_cap)

all_cu_seq_lens_q = [
(
0,
512,
),
(
0,
1,
2,
72,
),
tuple(range(0, 45))
+ (
105,
121,
137,
153,
169,
185,
201,
217,
233,
249,
265,
),
tuple(range(0, 196))
+ (
211,
226,
240,
253,
265,
),
]
all_seqused_k = [
(512,),
(
1,
1,
70,
),
(515,) + (514,) * 20 + (513,) * 20 + (512,) * 14,
(2333,)
+ (2331,) * 20
+ (2330,) * 20
+ (2329,) * 14
+ (2328,) * 18
+ (2327,) * 15
+ (2326,) * 17
+ (2325,) * 18
+ (2324,) * 21
+ (2323,) * 22
+ (2322,) * 24
+ (2321,) * 5
+ (
2320,
2319,
2318,
2317,
2316,
),
]

num_heads = 16
num_heads_k = 8
head_dim = 128
block_size = 16
num_blocks = 2000
alibi = False
soft_cap = None

all_configs = [
(
cu_seq_lens_q,
seqused_k,
num_heads,
num_heads_k,
head_dim,
block_size,
num_blocks,
alibi,
soft_cap,
)
for cu_seq_lens_q, seqused_k in zip(all_cu_seq_lens_q, all_seqused_k)
]
self.shapes = flash_attn_configs

self.shapes = all_configs

def get_input_iter(self, cur_dtype):
for config in self.shapes:
yield from self.flash_attn_varlen_input_fn(config, cur_dtype, self.device)
yield self.flash_attn_varlen_input_fn(config, cur_dtype, self.device)

def flash_attn_varlen_input_fn(self, config, dtype, device):
"""Input function for flash attention varlen benchmark"""
seq_lens, num_heads, head_size, block_size, num_blocks, alibi, soft_cap = config
(
cu_query_lens,
seqused_k,
num_query_heads,
num_kv_heads,
head_size,
block_size,
num_blocks,
alibi,
soft_cap,
) = config

if alibi is True and soft_cap is not None:
return

num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
num_seqs = len(cu_query_lens) - 1
max_query_len = max(
map(lambda x, y: x - y, cu_query_lens[1:], cu_query_lens[:-1])
)
max_kv_len = max(seqused_k)
window_size = (-1, -1)
scale = head_size**-0.5

query = torch.randn(
sum(query_lens), num_query_heads, head_size, dtype=dtype, device=device
)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor(
[0] + query_lens, dtype=torch.int32, device=device
).cumsum(dim=0, dtype=torch.int32)
seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=device)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32,
device=device,
)
assert num_seqs == len(seqused_k)

causal = True
with torch.device(device):
query = torch.randn(
cu_query_lens[-1],
num_query_heads,
head_size,
dtype=dtype,
device=device,
)
out = torch.empty_like(query)
key_cache = torch.randn(
num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype,
device=device,
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor(
cu_query_lens, dtype=torch.int32, device=device
)
seqused_k = torch.tensor(seqused_k, dtype=torch.int32, device=device)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32,
device=device,
)

causal = True

if alibi:
alibi_slopes = (
torch.ones(
num_seqs, num_query_heads, device=device, dtype=torch.float32
if alibi:
alibi_slopes = (
torch.ones(
num_seqs, num_query_heads, device=device, dtype=torch.float32
)
* 0.3
)
* 0.3
)
else:
alibi_slopes = None
else:
alibi_slopes = None

yield (
return (
query,
key_cache,
value_cache,
Expand All @@ -266,7 +364,7 @@ def flash_attn_varlen_input_fn(self, config, dtype, device):
False,
block_tables,
False,
None,
out,
None,
None,
None,
Expand Down Expand Up @@ -297,10 +395,7 @@ def test_perf_flash_attn_varlen_func():
bench = FlashAttnVarlenBenchmark(
op_name="flash_attn_varlen_func",
torch_op=flash_attn_varlen_func,
dtypes=[
torch.float16,
torch.bfloat16,
],
dtypes=[torch.float16, torch.bfloat16],
)
bench.set_gems(flag_gems.ops.flash_attn_varlen_func)
bench.run()
Expand Down
22 changes: 15 additions & 7 deletions src/flag_gems/ops/flash_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,22 @@ def mha_varlan_fwd(
args = tuple(getattr(params, k) for k in params.__slots__)

# We assess which phase the requests are likely to be in and set the config accordingly.
# prefill_config: BLOCK_M=128, BLOCK_N=32, num_warps=4, num_stages=3
# decode_config: BLOCK_M=32, BLOCK_N=32, num_warps=4, num_stages=3
avg_seqlen_q = total_q / batch_size
if avg_seqlen_q >= 256:
varlen_fwd_config_str = "mha_varlen_prefill"
total_rows = total_q * num_heads
num_sms = torch_device_fn.get_device_properties("cuda").multi_processor_count
avg_rows_per_sm = total_rows / num_sms
avg_rows_per_batch = total_q / batch_size
avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
# Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
# This is a rough heuristic and may not be accurate for all scenarios.
if avg_rows_per_cta > 64:
varlen_fwd_config_str = "mha_block_128"
elif avg_rows_per_cta > 32:
varlen_fwd_config_str = "mha_block_64"
elif avg_rows_per_cta > 16:
varlen_fwd_config_str = "mha_block_32"
else:
varlen_fwd_config_str = "mha_varlen_decode"
varlen_fwd_config_str = "mha_block_16"

cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
cfg_params = {
"BLOCK_M": cfg["BLOCK_M"](args),
Expand All @@ -537,7 +546,6 @@ def mha_varlan_fwd(
"num_stages": cfg["num_stages"](args),
}

logger.debug("Average query sequence length: %d", avg_seqlen_q)
logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
kernel(*args, **cfg_params)

Expand Down
46 changes: 35 additions & 11 deletions src/flag_gems/ops/flash_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,36 +1151,57 @@ def flash_fwd_splitkv_combine_kernel(


@triton.jit
def virtual_to_cache(virtual_index, page_table_ptr, block_size):
def virtual_to_cache(
virtual_index,
max_virtual_index,
page_table_ptr,
block_size,
boundary_check: tl.constexpr = False,
):
# virtual_index is the kv sequence index in the current batch element
# page_table_ptr is already pointed at current batch element's block table entry
# block_size is the size of each block in the page table
virtual_page_index = virtual_index // block_size
page_offset = virtual_index % block_size
page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
if boundary_check:
page_block_index = tl.load(
page_table_ptr + virtual_page_index,
mask=virtual_index < max_virtual_index,
other=0,
).to(tl.int32)
else:
page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
return page_block_index * block_size + page_offset


@triton.jit
def load_from_kvcache(
i,
virtual_index,
max_virtual_index,
page_table_ptr,
k_ptr_base,
v_ptr_base,
block_size,
d,
d: tl.constexpr,
k_row_stride,
BLOCK_K: tl.constexpr,
boundary_check: tl.constexpr = False,
):
kvcache_idx = virtual_to_cache(i, page_table_ptr, block_size)
kvcache_idx = virtual_to_cache(
virtual_index, max_virtual_index, page_table_ptr, block_size, boundary_check
)
k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride
v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride
bK = tl.load(
k_ptr_base + k_offset, mask=tl.arange(0, BLOCK_K)[:, None] < d, other=0.0
)
bV = tl.load(
v_ptr_base + v_offset, mask=tl.arange(0, BLOCK_K)[None, :] < d, other=0.0
)
if d == BLOCK_K:
bK = tl.load(k_ptr_base + k_offset)
bV = tl.load(v_ptr_base + v_offset)
else:
bK = tl.load(
k_ptr_base + k_offset, mask=tl.arange(0, BLOCK_K)[:, None] < d, other=0.0
)
bV = tl.load(
v_ptr_base + v_offset, mask=tl.arange(0, BLOCK_K)[None, :] < d, other=0.0
)
return bK, bV


Expand Down Expand Up @@ -1377,13 +1398,15 @@ def flash_varlen_fwd_kernel(
col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
bK, bV = load_from_kvcache(
col_idx,
k_len,
page_table_ptr,
k_ptr_base,
v_ptr_base,
block_size,
d,
k_row_stride,
BLOCK_K=BLOCK_K,
boundary_check=True,
)
S = tl.dot(bQ, bK, out_dtype=tl.float32)
S = apply_softcap(S, softcap, is_softcap)
Expand Down Expand Up @@ -1447,6 +1470,7 @@ def flash_varlen_fwd_kernel(
col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
bK, bV = load_from_kvcache(
col_idx,
k_len,
page_table_ptr,
k_ptr_base,
v_ptr_base,
Expand Down
Loading
Loading