Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hopper/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class VarlenDynamicPersistentTileScheduler {
// If Split, for the purpose of scheduling, we pretend that instead there are
// (args.num_splits * args.num_head) number of heads.
assert(args.tile_count_semaphore != nullptr);
assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx
assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did we compile this before?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't use the problematic version yet?

https://github.com/vllm-project/vllm/blob/main/cmake/external_projects/vllm_flash_attn.cmake#L41

I will update the tag once this PR gets approved and merged.

assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits
return {args.num_head, args.num_batch,
args.qhead_per_khead, args.seqlen,
Expand Down
243 changes: 243 additions & 0 deletions tests/test_vllm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
is_fa_version_supported,
)

from test_util import (
construct_local_mask,
generate_qkv,
generate_random_padding_mask,
)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
Expand All @@ -29,6 +35,117 @@
([3] if is_fa_version_supported(3) else [])


# This function is copied from hopper/test_utils.py
def attention_ref(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this replace ref_attn below? or vice versa? would be nice to have less duplication here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I tried it last Friday and it didn't work. The ref_attn didn't work the new tests, and attention_ref failed with some of the existing tests. I didn't spend effort digging into the issue though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, if the overall change looks good, I am wondering if it would be fine to merge the PR while keeping two versions at the moment. We have some dependency on this change.

I will make sure to have a follow-up PR to unify them into one. I am sorry I am running out of my bandwidth for investigating the failure across each other at the moment. Thanks!

q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1), # -1 means infinite window size
sink_token_length=0,
softcap=0.0,
upcast=True,
reorder_ops=False,
intermediate_dtype=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim_v)
qv: (batch_size, seqlen_q, nheads, head_dim_v)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim_v)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
qv = qv.float() if qv is not None else None
if q_descale is not None:
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
q = (q.float() * q_descale).to(q.dtype)
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
if k_descale is not None:
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
if v_descale is not None:
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
dv = v.shape[-1]
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if qv is not None:
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
if softcap > 0:
scores = torch.tanh(scores / softcap) * softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
device=q.device,
key_leftpad=key_leftpad,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
# Without this we might get NaN in dv
if key_padding_mask is not None:
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
if intermediate_dtype is not None:
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)


def ref_attn(
q,
k,
Expand Down Expand Up @@ -522,3 +639,129 @@ def test_sparse_attention_varlen(
f"{torch.max(torch.abs(out - ref_out))}"
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(lse - ref_lse))}"

# simplified version of hopper/test_flash_attn.py
# for testing seqused_k and cu_seqlens_k
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("num_heads", [32, 64])
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 1),
(64, 128),
],
)
@pytest.mark.parametrize("fa_version", VERSIONS)
def test_flash_attn_varlen_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the test name more specific?

dtype: torch.dtype,
num_heads: int,
head_size: int,
batch_size: int,
causal: bool,
seqlen_q: int,
seqlen_k: int,
fa_version: int,
):
device = "cuda"
torch.random.manual_seed(123)

q = torch.randn(batch_size, seqlen_q, num_heads, head_size,
device=device, dtype=dtype)
k = torch.randn(batch_size, seqlen_k, num_heads, head_size,
device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_k, num_heads, head_size,
device=device, dtype=dtype)
query_padding_mask = generate_random_padding_mask(
seqlen_q, batch_size, device, mode="random", zero_lengths=False
)
key_padding_mask = generate_random_padding_mask(
seqlen_k, batch_size, device, mode="random", zero_lengths=True
)

def _gen_unused_masks(padding_mask, max_seq_len, bs, device):
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
attn_mask = torch.logical_and(padding_mask, another_mask)
unused_mask = torch.logical_xor(
torch.logical_or(padding_mask, another_mask), attn_mask
)
return attn_mask, unused_mask

query_padding_mask, query_unused_mask = _gen_unused_masks(
query_padding_mask, seqlen_q, batch_size, q.device
)
key_padding_mask, key_unused_mask = _gen_unused_masks(
key_padding_mask, seqlen_k, batch_size, k.device
)

(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
_, # seqused_q
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
_, # dq_pad_fn
_, # dk_pad_fn
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask,
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask)
out_ref, attn_ref = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
upcast=False,
reorder_ops=True,
)

print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")

# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is rtol too large? should it be 1e-2?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth citing in the comment?


out_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
max_seqlen_q,
cu_seqlens_q,
max_seqlen_k,
cu_seqlens_k,
seqused_k=seqused_k,
causal=causal,
fa_version=fa_version,
)
out = output_pad_fn(out_unpad)
out.masked_fill_(q_zero_masking, 0.0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")

# Check that FlashAttention's numerical error is at most 3x
# the numerical error of a Pytorch implementation.
assert (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have two checks, one check rtol, one check atol?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's from FA's implementation:

https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L477-L479

I think it might be fine to keep their setup?

(out - out_ref).abs().max().item() <=
rtol * (out_pt - out_ref).abs().max().item() + fwd_atol)
2 changes: 0 additions & 2 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ def flash_attn_varlen_func(
"""
assert cu_seqlens_k is not None or seqused_k is not None, \
"cu_seqlens_k or seqused_k must be provided"
assert cu_seqlens_k is None or seqused_k is None, \
"cu_seqlens_k and seqused_k cannot be provided at the same time"
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain to me the use case for using cu_seqlens_k and seqused_k simultaneously? Right now we use cu_seqlens_k for initial prefills (no context) and seqused_k for chunked prefill and decode (i.e. anytime there is a page table).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a special use case for this but I can't expose the detail at the moment. Sorry about that!

assert block_table is None or seqused_k is not None, \
"seqused_k must be provided if block_table is provided"

Expand Down