From fed1ab603d440f4c9e3a6e9786f5ba7ab19906ec Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Wed, 21 May 2025 17:57:22 -0700 Subject: [PATCH 1/6] Remove warning and use vllm flash attn Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/vision.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 901d83ec5b9e..1fae2f1a4a57 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -83,16 +83,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if current_platform.is_cuda(): device_available = current_platform.has_device_capability(80) if device_available and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - selected_backend = _Backend.FLASH_ATTN - else: - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - selected_backend = _Backend.XFORMERS + selected_backend = _Backend.FLASH_ATTN_VLLM_V1 else: # For Volta and Turing GPUs, use xformers instead. selected_backend = _Backend.XFORMERS From ccbbe757a87712743b5d6a398f77919673db3b63 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Wed, 21 May 2025 19:46:01 -0700 Subject: [PATCH 2/6] Fix backend check Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 68dd07820189..a368ddc8a350 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -249,7 +249,7 @@ def __init__( # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN_VLLM_V1, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." From 80342532ec94f1100820d65e22cd84efdf678980 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Wed, 21 May 2025 19:55:37 -0700 Subject: [PATCH 3/6] Fix forward flash attn call Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/qwen2_5_vl.py | 18 ++++++++++++++++++ vllm/model_executor/models/qwen2_vl.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a368ddc8a350..a52b24c6aece 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -307,6 +307,24 @@ def forward( q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, k, v, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ff0836b0897..aa66be8b69c1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -329,6 +329,24 @@ def forward( q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + output = flash_attn_varlen_func(q, k, v, From 396643de3a5f339d3e9fcd877a19ef26737ca3de Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 22 May 2025 12:20:24 -0700 Subject: [PATCH 4/6] Fix import for vllm_flash_attn Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a52b24c6aece..5d596cadc9f6 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -321,7 +321,7 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: - from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index aa66be8b69c1..0e04479fcee1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -343,7 +343,7 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1: - from vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func + from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) From 37e74464eeaea8d668b7485193e6b3a674a9cfeb Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 22 May 2025 12:57:51 -0700 Subject: [PATCH 5/6] Fix how max seq len is calculate Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5d596cadc9f6..5a67261d827a 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -656,7 +656,7 @@ def compute_attn_mask_seqlen( cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend in [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0e04479fcee1..4db924b09975 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -636,7 +636,7 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() From 9465afa2b29b328d2dc6be4338708d681d064cf3 Mon Sep 17 00:00:00 2001 From: Vinay Damodaran Date: Thu, 22 May 2025 17:28:15 -0700 Subject: [PATCH 6/6] Try with flash attn Signed-off-by: Vinay Damodaran --- vllm/model_executor/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 1fae2f1a4a57..edba12276627 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -83,7 +83,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if current_platform.is_cuda(): device_available = current_platform.has_device_capability(80) if device_available and support_fa: - selected_backend = _Backend.FLASH_ATTN_VLLM_V1 + selected_backend = _Backend.FLASH_ATTN else: # For Volta and Turing GPUs, use xformers instead. selected_backend = _Backend.XFORMERS