-
Notifications
You must be signed in to change notification settings - Fork 91
Removed the assertion imposed on cu_seqlens_k and seqused_k #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,12 @@ | |
is_fa_version_supported, | ||
) | ||
|
||
from test_util import ( | ||
construct_local_mask, | ||
generate_qkv, | ||
generate_random_padding_mask, | ||
) | ||
|
||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] | ||
HEAD_SIZES = [128, 256] | ||
BLOCK_SIZES = [16, 32] | ||
|
@@ -29,6 +35,117 @@ | |
([3] if is_fa_version_supported(3) else []) | ||
|
||
|
||
# This function is copied from hopper/test_utils.py | ||
def attention_ref( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. I tried it last Friday and it didn't work. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, if the overall change looks good, I am wondering if it would be fine to merge the PR while keeping two versions at the moment. We have some dependency on this change. I will make sure to have a follow-up PR to unify them into one. I am sorry I am running out of my bandwidth for investigating the failure across each other at the moment. Thanks! |
||
q, | ||
k, | ||
v, | ||
query_padding_mask=None, | ||
key_padding_mask=None, | ||
key_leftpad=None, | ||
attn_bias=None, | ||
dropout_p=0.0, | ||
dropout_mask=None, | ||
causal=False, | ||
qv=None, | ||
q_descale=None, k_descale=None, v_descale=None, | ||
window_size=(-1, -1), # -1 means infinite window size | ||
sink_token_length=0, | ||
softcap=0.0, | ||
upcast=True, | ||
reorder_ops=False, | ||
intermediate_dtype=None, | ||
): | ||
""" | ||
Arguments: | ||
q: (batch_size, seqlen_q, nheads, head_dim) | ||
k: (batch_size, seqlen_k, nheads, head_dim) | ||
v: (batch_size, seqlen_k, nheads, head_dim_v) | ||
qv: (batch_size, seqlen_q, nheads, head_dim_v) | ||
query_padding_mask: (batch_size, seqlen_q) | ||
key_padding_mask: (batch_size, seqlen_k) | ||
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) | ||
dropout_p: float | ||
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) | ||
causal: whether to apply causal masking | ||
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast | ||
output back to fp16/bf16. | ||
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) | ||
without changing the math. This is to estimate the numerical error from operation | ||
reordering. | ||
Output: | ||
output: (batch_size, seqlen_q, nheads, head_dim_v) | ||
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout | ||
""" | ||
if causal: | ||
window_size = (window_size[0], 0) | ||
dtype_og = q.dtype | ||
if upcast: | ||
q, k, v = q.float(), k.float(), v.float() | ||
qv = qv.float() if qv is not None else None | ||
if q_descale is not None: | ||
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) | ||
q = (q.float() * q_descale).to(q.dtype) | ||
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None | ||
if k_descale is not None: | ||
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) | ||
if v_descale is not None: | ||
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) | ||
seqlen_q, seqlen_k = q.shape[1], k.shape[1] | ||
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) | ||
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) | ||
d = q.shape[-1] | ||
dv = v.shape[-1] | ||
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) | ||
if not reorder_ops: | ||
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) | ||
else: | ||
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) | ||
if qv is not None: | ||
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) | ||
if softcap > 0: | ||
scores = torch.tanh(scores / softcap) * softcap | ||
if key_padding_mask is not None: | ||
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) | ||
if window_size[0] >= 0 or window_size[1] >= 0: | ||
local_mask = construct_local_mask( | ||
seqlen_q, | ||
seqlen_k, | ||
window_size, | ||
query_padding_mask, | ||
key_padding_mask, | ||
device=q.device, | ||
key_leftpad=key_leftpad, | ||
) | ||
scores.masked_fill_(local_mask, float("-inf")) | ||
if attn_bias is not None: | ||
scores = scores + attn_bias | ||
attention = torch.softmax(scores, dim=-1).to(v.dtype) | ||
# We want to mask here so that the attention matrix doesn't have any NaNs | ||
# Otherwise we'll get NaN in dV | ||
if query_padding_mask is not None: | ||
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) | ||
# Without this we might get NaN in dv | ||
if key_padding_mask is not None: | ||
attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) | ||
# Some rows might be completely masked out so we fill them with zero instead of NaN | ||
if window_size[0] >= 0 or window_size[1] >= 0: | ||
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) | ||
dropout_scaling = 1.0 / (1 - dropout_p) | ||
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling | ||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v) | ||
if dropout_mask is not None: | ||
attention_drop = attention.masked_fill(~dropout_mask, 0.0) | ||
else: | ||
attention_drop = attention | ||
if intermediate_dtype is not None: | ||
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) | ||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) | ||
if query_padding_mask is not None: | ||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) | ||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) | ||
|
||
|
||
def ref_attn( | ||
q, | ||
k, | ||
|
@@ -522,3 +639,129 @@ def test_sparse_attention_varlen( | |
f"{torch.max(torch.abs(out - ref_out))}" | ||
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ | ||
f"{torch.max(torch.abs(lse - ref_lse))}" | ||
|
||
# simplified version of hopper/test_flash_attn.py | ||
# for testing seqused_k and cu_seqlens_k | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("num_heads", [32, 64]) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("batch_size", [1, 2]) | ||
@pytest.mark.parametrize("causal", [False, True]) | ||
@pytest.mark.parametrize( | ||
"seqlen_q,seqlen_k", | ||
[ | ||
(1, 1), | ||
(64, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("fa_version", VERSIONS) | ||
def test_flash_attn_varlen_output( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make the test name more specific? |
||
dtype: torch.dtype, | ||
num_heads: int, | ||
head_size: int, | ||
batch_size: int, | ||
causal: bool, | ||
seqlen_q: int, | ||
seqlen_k: int, | ||
fa_version: int, | ||
): | ||
device = "cuda" | ||
torch.random.manual_seed(123) | ||
|
||
q = torch.randn(batch_size, seqlen_q, num_heads, head_size, | ||
device=device, dtype=dtype) | ||
k = torch.randn(batch_size, seqlen_k, num_heads, head_size, | ||
device=device, dtype=dtype) | ||
v = torch.randn(batch_size, seqlen_k, num_heads, head_size, | ||
device=device, dtype=dtype) | ||
query_padding_mask = generate_random_padding_mask( | ||
seqlen_q, batch_size, device, mode="random", zero_lengths=False | ||
) | ||
key_padding_mask = generate_random_padding_mask( | ||
seqlen_k, batch_size, device, mode="random", zero_lengths=True | ||
) | ||
|
||
def _gen_unused_masks(padding_mask, max_seq_len, bs, device): | ||
another_mask = generate_random_padding_mask(max_seq_len, bs, device) | ||
attn_mask = torch.logical_and(padding_mask, another_mask) | ||
unused_mask = torch.logical_xor( | ||
torch.logical_or(padding_mask, another_mask), attn_mask | ||
) | ||
return attn_mask, unused_mask | ||
|
||
query_padding_mask, query_unused_mask = _gen_unused_masks( | ||
query_padding_mask, seqlen_q, batch_size, q.device | ||
) | ||
key_padding_mask, key_unused_mask = _gen_unused_masks( | ||
key_padding_mask, seqlen_k, batch_size, k.device | ||
) | ||
|
||
( | ||
q_unpad, | ||
k_unpad, | ||
v_unpad, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
_, # seqused_q | ||
seqused_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
q, | ||
k, | ||
v, | ||
output_pad_fn, | ||
_, # dq_pad_fn | ||
_, # dk_pad_fn | ||
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, | ||
query_unused_mask=query_unused_mask, | ||
key_unused_mask=key_unused_mask) | ||
out_ref, attn_ref = attention_ref( | ||
q, | ||
k, | ||
v, | ||
query_padding_mask, | ||
key_padding_mask, | ||
causal=causal, | ||
) | ||
out_pt, attn_pt = attention_ref( | ||
q, | ||
k, | ||
v, | ||
query_padding_mask, | ||
key_padding_mask, | ||
causal=causal, | ||
upcast=False, | ||
reorder_ops=True, | ||
) | ||
|
||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") | ||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") | ||
|
||
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") | ||
|
||
# Numerical error if we just do any arithmetic on out_ref | ||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() | ||
rtol = 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is rtol too large? should it be 1e-2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from flash attention's value: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L185 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe worth citing in the comment? |
||
|
||
out_unpad = flash_attn_varlen_func( | ||
q_unpad, | ||
k_unpad, | ||
v_unpad, | ||
max_seqlen_q, | ||
cu_seqlens_q, | ||
max_seqlen_k, | ||
cu_seqlens_k, | ||
seqused_k=seqused_k, | ||
causal=causal, | ||
fa_version=fa_version, | ||
) | ||
out = output_pad_fn(out_unpad) | ||
out.masked_fill_(q_zero_masking, 0.0) | ||
print(f"Output max diff: {(out - out_ref).abs().max().item()}") | ||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | ||
|
||
# Check that FlashAttention's numerical error is at most 3x | ||
# the numerical error of a Pytorch implementation. | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we have two checks, one check rtol, one check atol? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's from FA's implementation: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L477-L479 I think it might be fine to keep their setup? |
||
(out - out_ref).abs().max().item() <= | ||
rtol * (out_pt - out_ref).abs().max().item() + fwd_atol) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -199,8 +199,6 @@ def flash_attn_varlen_func( | |
""" | ||
assert cu_seqlens_k is not None or seqused_k is not None, \ | ||
"cu_seqlens_k or seqused_k must be provided" | ||
assert cu_seqlens_k is None or seqused_k is None, \ | ||
"cu_seqlens_k and seqused_k cannot be provided at the same time" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please explain to me the use case for using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a special use case for this but I can't expose the detail at the moment. Sorry about that! |
||
assert block_table is None or seqused_k is not None, \ | ||
"seqused_k must be provided if block_table is provided" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did we compile this before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't use the problematic version yet?
https://github.com/vllm-project/vllm/blob/main/cmake/external_projects/vllm_flash_attn.cmake#L41
I will update the tag once this PR gets approved and merged.