diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 71f61c19dd95..3a0723c17a69 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -8,7 +8,7 @@ from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] -HEAD_SIZES = [128, 256] +HEAD_SIZES = [4, 60, 80, 128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] # one value large enough to test overflow in index calculation. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 22d07c0a4f68..2d7f5bb605e0 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -164,8 +164,8 @@ def _( class FlashAttentionBackend(AttentionBackend): @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + def get_max_supported_head_size() -> int: + return 256 @staticmethod def get_name() -> str: @@ -641,11 +641,11 @@ def __init__( raise ValueError( "Sliding window is not supported in FlashAttention.") - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: + max_head_size = FlashAttentionBackend.get_max_supported_head_size() + if head_size <= 0 or head_size > max_head_size: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") + f"Maximum supported head size is: {max_head_size}.") def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 30aa7cb311af..94946a68a3d0 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -248,8 +248,8 @@ def which_attn_to_use( from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - supported_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in supported_sizes: + maximum_size = FlashAttentionBackend.get_max_supported_head_size() + if head_size <= 0 or head_size > maximum_size: logger.info( "Cannot use FlashAttention-2 backend for head size %d.", head_size)