Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.utils import seed_everything

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
HEAD_SIZES = [4, 60, 80, 128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
# one value large enough to test overflow in index calculation.
Expand Down
10 changes: 5 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _(
class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
def get_max_supported_head_size() -> int:
return 256

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -641,11 +641,11 @@ def __init__(
raise ValueError(
"Sliding window is not supported in FlashAttention.")

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
max_head_size = FlashAttentionBackend.get_max_supported_head_size()
if head_size <= 0 or head_size > max_head_size:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
f"Maximum supported head size is: {max_head_size}.")

def forward(
self,
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def which_attn_to_use(
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
maximum_size = FlashAttentionBackend.get_max_supported_head_size()
if head_size <= 0 or head_size > maximum_size:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
Expand Down
Loading