Skip to content

Conversation

vrdn-23
Copy link
Contributor

@vrdn-23 vrdn-23 commented May 22, 2025

FIX #18324

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it would be best for this PR to have some benchmark to see the performance improvement as well.

@@ -329,6 +329,24 @@ def forward(

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(q,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If flash_attn_varlen_func from vllm-flash-attn can work well, we can remove original FA implementation as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep original FA implementation for other platform like ROCm which do not support FA through vllm-flash-attn?

context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1:
elif self.attn_backend in (_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASH_ATTN):

vrdn-23 added 5 commits May 22, 2025 13:00
Signed-off-by: Vinay Damodaran <[email protected]>
Signed-off-by: Vinay Damodaran <[email protected]>
Signed-off-by: Vinay Damodaran <[email protected]>
Signed-off-by: Vinay Damodaran <[email protected]>
@vrdn-23 vrdn-23 force-pushed the vrdn/remove-vllm-flash-attn-warning branch from 0486de7 to 37e7446 Compare May 22, 2025 20:00
@vrdn-23
Copy link
Contributor Author

vrdn-23 commented May 22, 2025

INFO 05-22 19:59:28 [default_loader.py:279] Loading weights took 4.66 seconds
INFO 05-22 19:59:29 [gpu_model_runner.py:1532] Model loading took 15.6271 GiB and 5.303416 seconds
INFO 05-22 19:59:32 [gpu_model_runner.py:1828] Encoder cache will be initialized with a budget of 49152 tokens, and profiled with 1 video items of the maximum feature size.
Unused or unrecognized kwargs: return_tensors, fps.
ERROR 05-22 19:59:34 [core.py:493] EngineCore failed to start.
ERROR 05-22 19:59:34 [core.py:493] Traceback (most recent call last):
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/engine/core.py", line 484, in run_engine_core
ERROR 05-22 19:59:34 [core.py:493]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/engine/core.py", line 383, in __init__
ERROR 05-22 19:59:34 [core.py:493]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/engine/core.py", line 78, in __init__
ERROR 05-22 19:59:34 [core.py:493]     self._initialize_kv_caches(vllm_config)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/engine/core.py", line 137, in _initialize_kv_caches
ERROR 05-22 19:59:34 [core.py:493]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/executor/abstract.py", line 75, in determine_available_memory
ERROR 05-22 19:59:34 [core.py:493]     output = self.collective_rpc("determine_available_memory")
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 05-22 19:59:34 [core.py:493]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/utils.py", line 2598, in run_method
ERROR 05-22 19:59:34 [core.py:493]     return func(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 05-22 19:59:34 [core.py:493]     return func(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
ERROR 05-22 19:59:34 [core.py:493]     self.model_runner.profile_run()
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/v1/worker/gpu_model_runner.py", line 1848, in profile_run
ERROR 05-22 19:59:34 [core.py:493]     dummy_encoder_outputs = self.model.get_multimodal_embeddings(
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1063, in get_multimodal_embeddings
ERROR 05-22 19:59:34 [core.py:493]     video_embeddings = self._process_video_input(multimodal_input)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1015, in _process_video_input
ERROR 05-22 19:59:34 [core.py:493]     video_embeds = self.visual(pixel_values_videos,
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 05-22 19:59:34 [core.py:493]     return self._call_impl(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 05-22 19:59:34 [core.py:493]     return forward_call(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 746, in forward
ERROR 05-22 19:59:34 [core.py:493]     hidden_states = blk(
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 05-22 19:59:34 [core.py:493]     return self._call_impl(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 05-22 19:59:34 [core.py:493]     return forward_call(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 413, in forward
ERROR 05-22 19:59:34 [core.py:493]     x = x + self.attn(self.norm1(x),
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 05-22 19:59:34 [core.py:493]     return self._call_impl(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 05-22 19:59:34 [core.py:493]     return forward_call(*args, **kwargs)
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 328, in forward
ERROR 05-22 19:59:34 [core.py:493]     output = flash_attn_varlen_func(q,
ERROR 05-22 19:59:34 [core.py:493]   File "/home/vidamoda/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 227, in flash_attn_varlen_func
ERROR 05-22 19:59:34 [core.py:493]     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
ERROR 05-22 19:59:34 [core.py:493]   File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 05-22 19:59:34 [core.py:493]     return self._op(*args, **(kwargs or {}))
ERROR 05-22 19:59:34 [core.py:493] RuntimeError: This flash attention build does not support headdim not being a multiple of 32.

Looks like vllm-flash-attn doesn't have the changes for supporting all heads in the upstream. Is this something that gets fixed if we merge with the latest main from FlashAttention?
cc @Isotr0py @LucasWilkinson @fyabc

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented May 22, 2025

The check seems to be happening here

Signed-off-by: Vinay Damodaran <[email protected]>
@vrdn-23
Copy link
Contributor Author

vrdn-23 commented May 23, 2025

Unused or unrecognized kwargs: return_tensors, fps.
WARNING 05-23 00:35:41 [topk_topp_sampler.py:58] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 05-23 00:35:41 [gpu_model_runner.py:1514] Starting to load model reducto/RolmOCR...
INFO 05-23 00:35:42 [backends.py:37] Using InductorAdaptor
INFO 05-23 00:35:42 [weight_utils.py:291] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.34s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.16it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:03<00:01,  1.07s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:04<00:00,  1.19s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:04<00:00,  1.14s/it]

INFO 05-23 00:35:46 [default_loader.py:279] Loading weights took 4.67 seconds
INFO 05-23 00:35:47 [gpu_model_runner.py:1532] Model loading took 15.6271 GiB and 5.155356 seconds
INFO 05-23 00:35:50 [gpu_model_runner.py:1828] Encoder cache will be initialized with a budget of 49152 tokens, and profiled with 1 video items of the maximum feature size.
Unused or unrecognized kwargs: return_tensors, fps.
INFO 05-23 00:36:02 [backends.py:459] Using cache directory: /home/vidamoda/.cache/vllm/torch_compile_cache/4ce3341006/rank_0_0 for vLLM's torch.compile
INFO 05-23 00:36:02 [backends.py:469] Dynamo bytecode transform time: 7.22 s
INFO 05-23 00:36:05 [backends.py:160] Cache the graph of shape None for later use
INFO 05-23 00:36:35 [backends.py:172] Compiling a graph for general shape takes 32.12 s
INFO 05-23 00:36:44 [monitor.py:33] torch.compile takes 39.34 s in total
INFO 05-23 00:36:44 [kv_cache_utils.py:639] GPU KV cache size: 306,208 tokens
INFO 05-23 00:36:44 [kv_cache_utils.py:642] Maximum concurrency for 80,000 tokens per request: 3.83x
INFO 05-23 00:37:14 [gpu_model_runner.py:1895] Graph capturing finished in 30 secs, took 0.50 GiB
INFO 05-23 00:37:14 [core.py:167] init engine (profile, create kv cache, warmup model) took 87.59 seconds
INFO 05-23 00:37:18 [loggers.py:134] vllm cache_config_info with initialization after num_gpu_blocks is: 19138
WARNING 05-23 00:37:19 [config.py:1333] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
INFO 05-23 00:37:19 [serving_chat.py:117] Using default chat sampling params from model: {'repetition_penalty': 1.05, 'temperature': 1e-06}
INFO 05-23 00:37:19 [serving_completion.py:65] Using default completion sampling params from model: {'repetition_penalty': 1.05, 'temperature': 1e-06}
INFO 05-23 00:37:19 [api_server.py:1336] Starting vLLM API server on http://0.0.0.0:8000
INFO 05-23 00:37:19 [launcher.py:28] Available routes are:
INFO 05-23 00:37:19 [launcher.py:36] Route: /health, Methods: GET
INFO 05-23 00:37:19 [launcher.py:36] Route: /load, Methods: GET
INFO 05-23 00:37:19 [launcher.py:36] Route: /ping, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /ping, Methods: GET
INFO 05-23 00:37:19 [launcher.py:36] Route: /tokenize, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /detokenize, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/models, Methods: GET
INFO 05-23 00:37:19 [launcher.py:36] Route: /version, Methods: GET
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/chat/completions, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/completions, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/embeddings, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /pooling, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /classify, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /score, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/score, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/audio/transcriptions, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /rerank, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v1/rerank, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /v2/rerank, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /invocations, Methods: POST
INFO 05-23 00:37:19 [launcher.py:36] Route: /metrics, Methods: GET
INFO:     Started server process [146810]
INFO:     Waiting for application startup.
INFO:     Application startup complete.

It loads fine with the normal flash-attn so the commit I highlighted can't be the issue. I'll try to dig a little deeper into this

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented May 23, 2025

Okay. So I did a little digging and it looks like this PR (#8910) was opened to fix that issue which relied on this PR (vllm-project/flash-attention#21) to actually disable that flag and build it with the new heads.

It seems the conversation ran stale after @njhill and @WoosukKwon discussed this for a bit. Is there any reason why we can't go ahead with this and make sure we build for all head sizes? It seems to be a more pragmatic decision than to have users who rely on a single docker image to have to install both vllm-flash-attn and flash-attn to run all models seamlessly.

cc @DarkLight1337 @Isotr0py @LucasWilkinson

@vrdn-23
Copy link
Contributor Author

vrdn-23 commented May 27, 2025

Just wanted to bring this back up in case anyone can provide any directions on what the path forward can be. If we feel that enabling this in vllm-flash-attn is too much of an overhead then I can close the PR.
cc @DarkLight1337 @Isotr0py @LucasWilkinson @njhill @WoosukKwon

@LucasWilkinson
Copy link
Collaborator

Unfortunately we don't have much wheel size to spare with the recent Blackwell addition; I am going OOO for the next 2 weeks but I can investigate the wheel size impacts when I get back.

It looks like this is mostly an FA2 problem so we can potentially use vllm-flash-attn when on Hopper. For Ampere user I think installing FA is not a huge deal, we could potentially make it a dependency of vLLM to make it auto install or add it to the released docker images. (thoughts @mgoin @khluu ?)

@effortprogrammer
Copy link

So, if I understood this issue correctly, I can use flash_attn backend with multimodal model with original flash attention 2, correct?

@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Clarification regarding bug inside vllm-flash-attn vision module
5 participants