-
Notifications
You must be signed in to change notification settings - Fork 424
[main] mlp weight prefetch in Qwen Dense Models #2762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: rjg-lyh <[email protected]>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces MLP weight prefetching for Qwen Dense Models to optimize performance, primarily in the decode phase. This is achieved by adding new flashcomm
and dense_optimize
features, controlled by environment variables. The changes include new custom operators for communication, and specialized linear layers that use these operators.
While the overall approach seems promising for performance, I've found several critical issues that must be addressed. There are broken imports in vllm_ascend/ops/linear.py
and vllm_ascend/worker/model_runner_v1.py
that will prevent the code from running. Additionally, there is brittle logic in AscendDenseQKVParallelLinear
for identifying the layer number which is not robust. I've also pointed out a magic number that should be refactored into a constant for better maintainability. Please review the comments for details.
vllm_ascend/ops/linear.py
Outdated
from vllm_ascend.utils import (all_gather_and_maybe_unpad, | ||
maybe_pad_and_reduce_scatter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Matrix multiply. | ||
assert self.quant_method is not None | ||
|
||
layer_num = self.prefix.split('.')[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine the layer number by splitting the prefix string (self.prefix.split('.')[2]
) is very brittle and assumes a fixed model architecture naming scheme. This can easily break if a model with a different naming convention is used (e.g., model.decoder.layers.0...
), leading to incorrect behavior or crashes. It would be more robust to pass the layer index or an is_first_layer
flag explicitly during the layer's initialization.
if get_forward_context().flashcomm_v1_enabled: | ||
from vllm_ascend.utils import all_gather_and_maybe_unpad | ||
hidden_states = all_gather_and_maybe_unpad( | ||
hidden_states, get_forward_context().pad_size, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function all_gather_and_maybe_unpad
is imported from vllm_ascend.utils
, but it is not defined there, which will cause an ImportError
. Furthermore, the way it's called all_gather_and_maybe_unpad(hidden_states, get_forward_context().pad_size, dim=0)
does not match the signature of the related custom op maybe_all_gather_and_maybe_unpad(x: Tensor, label: bool)
. You should probably be calling torch.ops.vllm.maybe_all_gather_and_maybe_unpad
with the correct arguments.
if get_forward_context().flashcomm_v1_enabled:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True)
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | ||
num_tokens is not None and num_tokens > 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value 1000
is a magic number. It's used as a threshold to enable the flashcomm_v1
optimization. This makes the code harder to understand and maintain. It should be defined as a named constant with a comment explaining its purpose and how this value was determined. This would improve readability and make it easier to tune this threshold in the future.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | |
num_tokens is not None and num_tokens > 1000 | |
# e.g. FLASHCOMM_V1_TOKEN_THRESHOLD = 1000 at the top of the file | |
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | |
num_tokens is not None and num_tokens > FLASHCOMM_V1_TOKEN_THRESHOLD |
cdedbf9
to
c19031d
Compare
Signed-off-by: rjg-lyh <[email protected]>
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to optimize the performance in Decode phase mainly.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
CI passed with new added/existing test.