Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.FLASH_ATTN_VLLM_V1,
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
Expand Down Expand Up @@ -307,6 +307,24 @@

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)

context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Check failure on line 327 in vllm/model_executor/models/qwen2_5_vl.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/qwen2_5_vl.py:327:81: E501 Line too long (88 > 80)
output = flash_attn_varlen_func(q,
k,
v,
Expand Down Expand Up @@ -638,7 +656,7 @@
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if self.attn_backend in [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
20 changes: 19 additions & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,24 @@

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(q,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If flash_attn_varlen_func from vllm-flash-attn can work well, we can remove original FA implementation as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep original FA implementation for other platform like ROCm which do not support FA through vllm-flash-attn?

k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)

context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1:
elif self.attn_backend in (_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASH_ATTN):

from vllm.vllm_flash_attn.flash_attn_interface import flash_attn_varlen_func

Check failure on line 346 in vllm/model_executor/models/qwen2_vl.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/qwen2_vl.py:346:81: E501 Line too long (88 > 80)

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(q,
k,
v,
Expand Down Expand Up @@ -618,7 +636,7 @@
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if self.attn_backend == [_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1]:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
11 changes: 1 addition & 10 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS
selected_backend = _Backend.FLASH_ATTN
else:
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS
Expand Down
Loading