From 3ae9205e2919f4b1fac575bc2df1d3196fdfb8a1 Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Thu, 4 Sep 2025 23:26:52 +0800 Subject: [PATCH 1/3] [main] flashcomm_v1 before mlp weight prefetch in Qwen Dense Models Signed-off-by: rjg-lyh <1318825571@qq.com> --- .../test_offline_inference_distributed.py | 24 ++++ tests/ut/ops/test_layernorm.py | 38 +++++- vllm_ascend/__init__.py | 5 + vllm_ascend/ascend_forward_context.py | 15 +++ vllm_ascend/envs.py | 7 ++ vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/flashcomm_gate_ops.py | 71 +++++++++++ vllm_ascend/ops/layernorm.py | 4 + vllm_ascend/ops/linear.py | 114 +++++++++++++++++- vllm_ascend/utils.py | 13 +- vllm_ascend/worker/model_runner_v1.py | 4 + 11 files changed, 286 insertions(+), 10 deletions(-) create mode 100644 vllm_ascend/ops/flashcomm_gate_ops.py diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index a90c8643de..5f6588ea92 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -23,6 +23,8 @@ import os from unittest.mock import patch +import pytest + from modelscope import snapshot_download # type: ignore from vllm import SamplingParams @@ -30,6 +32,8 @@ os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"] + def test_models_distributed_QwQ(): example_prompts = [ @@ -150,3 +154,23 @@ def test_sp_for_qwen3_moe() -> None: enable_expert_parallel=True, enforce_eager=True) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download(model), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=4, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index c7bc657f10..c13dea5310 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,8 +1,10 @@ -from unittest.mock import patch +import os +from unittest.mock import patch, MagicMock import pytest import torch from vllm.model_executor.layers.layernorm import RMSNorm +from vllm_ascend.ascend_forward_context import set_ascend_forward_context @pytest.fixture @@ -10,6 +12,13 @@ def dummy_tensor(): return torch.randn(4, 8, dtype=torch.float16) +def mock_maybe_chunk_residual(x, residual): + if x.size(0) != residual.size(0): + return residual[:4] + + return residual + + def mock_rms_norm(x, weight, eps): return x + 1, None @@ -23,11 +32,12 @@ def mock_add_rms_norm(x, residual, weight, eps): [None, torch.randn(4, 8, dtype=torch.float32)]) @patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) @patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) -def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, - residual, dummy_tensor): +@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual) +def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, mock_rmsnorm, + is_310p_return,residual, dummy_tensor): with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): - layer = RMSNorm(hidden_size=32, eps=1e-05) + layer = RMSNorm(hidden_size=8, eps=1e-05) if residual is not None: out_x, out_residual = layer.forward_oot(dummy_tensor, residual) @@ -51,3 +61,23 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return, mock_rmsnorm.assert_called_once() assert torch.allclose(out_x, expected_out_x) + + +@patch("vllm_ascend.utils.is_310p", return_value=False) +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual) +def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, mock_add_rms_norm, mock_is310p): + x = torch.randn(4, 512, dtype=torch.bfloat16) + residual = torch.randn(16, 512, dtype=torch.bfloat16) + layer = RMSNorm(hidden_size=512, eps=1e-05) + + out_x, out_residual = layer.forward_oot(x, residual) + + expected_out_x = 2 * x + expected_out_residual = 2 * residual[:4] + + mock_maybe_chunk_residual.assert_called_once() + mock_add_rms_norm.assert_called_once() + assert out_residual.size(0) == 4 + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 7588e70ed9..789bfc07ae 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,5 +23,10 @@ def register(): def register_model(): + import vllm.envs as envs + + import vllm_ascend.envs as envs_ascend + from .models import register_model + register_model() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 601f33a500..4439ad1795 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -105,6 +105,21 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False + # set this for rope forward_oot using + forward_context.is_first_layer = True + + # set for flashcomm_v1 + flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ + num_tokens is not None and num_tokens > 1000 + + if flashcomm_v1_enabled: + tp_world_size = get_tensor_model_parallel_world_size() + pad_size = (tp_world_size - + (num_tokens % tp_world_size)) % tp_world_size + forward_context.pad_size = pad_size + + forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 78f8c50f8e..c5c5fcdfca 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -131,6 +131,13 @@ # this feature is supported in A2, and eager mode will get better performance. "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), + # Whether to enable FlashComm optimization when tensor parallel is enabled. + # this feature will get better performance in prefill phase. + "VLLM_ASCEND_ENABLE_FLASHCOMM": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), + # Whether to enable dense model and general optimizations for better performance. + "VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))), # Whether to enable mlp optimize when tensor parallel is enabled. # this feature in eager mode will get better performance. "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index a1e7417b07..ffb9b68bac 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -24,6 +24,7 @@ from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) +import vllm_ascend.ops.flashcomm_gate_ops class dummyFusionOp: diff --git a/vllm_ascend/ops/flashcomm_gate_ops.py b/vllm_ascend/ops/flashcomm_gate_ops.py new file mode 100644 index 0000000000..e2a77239fc --- /dev/null +++ b/vllm_ascend/ops/flashcomm_gate_ops.py @@ -0,0 +1,71 @@ +import torch +import torch.nn.functional as F +from vllm.utils import direct_register_custom_op +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter, + tensor_model_parallel_all_reduce, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context + + +def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + if x.size(0) != residual.size(0): + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + assert flashcomm_v1_enabled is True, ( + "Currently, this situation only occurs " + "when flashcomm_v1 is enabled" + ) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + +def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor: + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + if flashcomm_v1_enabled and label: + x = tensor_model_parallel_all_gather(x, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + x = x[:-pad_size, :] + return x + + +def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: + flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled + if flashcomm_v1_enabled: + pad_size = get_forward_context().pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + return tensor_model_parallel_reduce_scatter(x, 0) + else: + return tensor_model_parallel_all_reduce(x) + + +direct_register_custom_op( + op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: residual, + mutates_args=[], + dispatch_key="PrivateUse1" +) + + +direct_register_custom_op( + op_name="maybe_all_gather_and_maybe_unpad", + op_func=_maybe_all_gather_and_maybe_unpad_impl, + fake_impl=lambda x, label: x, + mutates_args=[], + dispatch_key="PrivateUse1" +) + + +direct_register_custom_op( + op_name="maybe_pad_and_reduce", + op_func=_maybe_pad_and_reduce_impl, + fake_impl=lambda x: x, + mutates_args=[], + dispatch_key="PrivateUse1" +) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 4f0b550e98..01c9dcbbe8 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -44,6 +44,8 @@ def forward( import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + assert x.size(0) == residual.size(0) x, _, residual = torch_npu.npu_add_rms_norm_quant( x, residual, @@ -69,6 +71,8 @@ def forward_oot( from vllm_ascend.utils import is_310p if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + assert x.size(0) == residual.size(0) if is_310p(): orig_dtype = residual.dtype x = x + residual.to(x.dtype) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index e2f427e8b5..266392e7ab 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -24,11 +24,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs @@ -36,6 +32,14 @@ from vllm_ascend.distributed.parallel_state import ( get_mlp_tensor_model_parallel_rank, get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group) +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.utils import (all_gather_and_maybe_unpad, + maybe_pad_and_reduce_scatter) + +from vllm.model_executor.layers.linear import ( # isort: skip + WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, + UnquantizedLinearMethod) class AscendMlpColumnParallelLinear(ColumnParallelLinear): @@ -307,3 +311,103 @@ def forward( if not self.return_bias: return output return output, output_bias + + +class AscendDenseMergedColumnParallelLinear(MergedColumnParallelLinear): + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +class AscendDenseQKVParallelLinear(QKVParallelLinear): + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + layer_num = self.prefix.split('.')[2] + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, layer_num != '0') + output_parallel = self.quant_method.apply(self, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +class AscendDenseRowParallelLinear(RowParallelLinear): + """Linear layer with row parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + + if self.tp_size == 1 or not self.reduce_results: + output = self.quant_method.apply(self, input_parallel, bias=bias_) + else: + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f3c1aef09e..86ad507858 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -491,7 +491,10 @@ def register_ascend_customop(): from vllm.model_executor.custom_op import CustomOp from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul - from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear, + from vllm_ascend.ops.linear import (AscendDenseMergedColumnParallelLinear, + AscendDenseQKVParallelLinear, + AscendDenseRowParallelLinear, + AscendMlpColumnParallelLinear, AscendMlpMergedColumnParallelLinear, AscendMlpRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( @@ -521,6 +524,14 @@ def register_ascend_customop(): CustomOp.register_oot( _decorated_op_cls=AscendMlpMergedColumnParallelLinear, name="MergedColumnParallelLinear") + if envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE: + CustomOp.register_oot( + _decorated_op_cls=AscendDenseMergedColumnParallelLinear, + name="MergedColumnParallelLinear") + CustomOp.register_oot(_decorated_op_cls=AscendDenseQKVParallelLinear, + name="QKVParallelLinear") + CustomOp.register_oot(_decorated_op_cls=AscendDenseRowParallelLinear, + name="RowParallelLinear") from vllm_ascend.ops.layernorm import AscendRMSNorm CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3068d36d0d..3701a09819 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1143,6 +1143,10 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if get_forward_context().flashcomm_v1_enabled: + from vllm_ascend.utils import all_gather_and_maybe_unpad + hidden_states = all_gather_and_maybe_unpad( + hidden_states, get_forward_context().pad_size, dim=0) return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens, From 5ffd8db8355ca7347e1e32a08a88f382d8dcaf0b Mon Sep 17 00:00:00 2001 From: Shuming19 <313093131@qq.com> Date: Tue, 2 Sep 2025 16:43:40 +0800 Subject: [PATCH 2/3] add mlp weight prefetch --- vllm_ascend/ascend_forward_context.py | 9 +++++++- vllm_ascend/ops/activation.py | 26 ++++++++++++++++++++++ vllm_ascend/ops/layernorm.py | 14 ++++++++++++ vllm_ascend/ops/linear.py | 31 +++++++++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 9 ++++++-- vllm_ascend/worker/worker_v1.py | 6 ++++++ 6 files changed, 92 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 4439ad1795..c25fe5553d 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -67,7 +67,9 @@ def set_ascend_forward_context( moe_comm_method: str = "", num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + prefetch_stream: torch.npu.Stream = None, + prefetch_model: torch.nn.Module = None): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -82,6 +84,11 @@ def set_ascend_forward_context( batch_descriptor=batch_descriptor, ): forward_context = get_forward_context() + + forward_context.prefetch_stream = prefetch_stream + forward_context.prefetch_model = prefetch_model + forward_context.prefetch_mlp_up = False + forward_context.moe_comm_method_name = moe_comm_method + "commimpl" forward_context.with_prefill = with_prefill ep_size = (get_ep_group().world_size if diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 26082fea28..0ad2b69c95 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -17,6 +17,7 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.forward_context import get_forward_context class AscendQuickGELU(QuickGELU): @@ -29,6 +30,26 @@ def forward_oot(self, x: torch.tensor) -> torch.Tensor: class AscendSiluAndMul(SiluAndMul): + def prefetch_down_proj(self, + dependency: torch.Tensor): + import torch_npu + forward_context = get_forward_context() + prefetch_model = forward_context.prefetch_model + prefetch_stream = forward_context.prefetch_stream + layer_idx = forward_context.layer_idx + + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + MLP_DOWN_PREFETCH_SIZE = 6 * 1024 * 1024 + torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \ + dependency, MLP_DOWN_PREFETCH_SIZE) + forward_context.layer_idx += 1 + + def wait_prefetch_done(self): + forward_context = get_forward_context() + prefetch_stream = forward_context.prefetch_stream + torch.npu.current_stream().wait_stream(prefetch_stream) def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu @@ -38,5 +59,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor: if is_310p(): out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: + dependency = x + self.prefetch_down_proj(dependency) + out = torch_npu.npu_swiglu(x) + + self.wait_prefetch_done() return out diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 01c9dcbbe8..17117922f0 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -36,6 +36,12 @@ def __init__( super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) self.layer = layer + def wait_prefetch_done(self): + forward_context = get_forward_context() + prefetch_stream = forward_context.prefetch_stream + # wait until + torch.npu.current_stream().wait_stream(prefetch_stream) + def forward( self, x: torch.Tensor, @@ -53,10 +59,18 @@ def forward( self.layer.aclnn_input_scale, self.layer.aclnn_input_offset, epsilon=self.variance_epsilon) + + if forward_context.prefetch_mlp_up: + self.wait_prefetch_done() + return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + + forward_context = get_forward_context() + if forward_context.prefetch_mlp_up: + self.wait_prefetch_done() return x diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 266392e7ab..5997ff5cec 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -381,6 +381,33 @@ class AscendDenseRowParallelLinear(RowParallelLinear): communication-computation fusion. """ + def prefetch_gate_up_proj(self, + dependency: torch.Tensor): + # get prefetch model + forward_context = get_forward_context() + layer_num = int(self.prefix.split('.')[2]) + prefetch_model = forward_context.prefetch_model + prefetch_stream = forward_context.prefetch_stream + + # start point of weight prefetch + forward_context.prefetch_mlp_up = True if self.prefix.split('.')[-2] == 'self_attn' else False + if forward_context.prefetch_mlp_up: + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + # For Qwen3-32B + MLP_GATE_UP_PREFETCH_SIZE = 50 * 1024 * 1024 + torch_npu.npu_prefetch(prefetch_model.model.layers[layer_num].mlp.gate_up_proj.weight, \ + dependency, MLP_GATE_UP_PREFETCH_SIZE) + + + def wait_prefetch_done(self): + forward_context = get_forward_context() + if forward_context.prefetch_mlp_up: + prefetch_stream = forward_context.prefetch_stream + # wait until reduce-scatter is done + torch.npu.current_stream().wait_stream(prefetch_stream) + def forward( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -404,6 +431,10 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + dependency = output_parallel + + self.prefetch_gate_up_proj(dependency) + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) output_bias = self.bias if self.skip_bias_add else None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3701a09819..deab471ae0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -181,6 +181,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device + self.prefetch_stream = torch.npu.Stream(device=device) self.dtype = self.model_config.dtype if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: # TODO: drop the env config to use ascend sampler by default @@ -1497,7 +1498,9 @@ def execute_model( aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. - total_num_scheduled_tokens): + total_num_scheduled_tokens, + prefetch_stream=self.prefetch_stream, + prefetch_model=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -1944,7 +1947,9 @@ def dummy_compute_logits(hidden_states): moe_comm_method=moe_comm_method, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + prefetch_stream=self.prefetch_stream, + prefetch_model=self.model): hidden_states = self._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 1062d47965..6803d73b71 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -55,6 +55,12 @@ else: DraftTokenIds = None +torch._dynamo.trace_rules.clear_lru_cache() +from torch._dynamo.variables import TorchInGraphFunctionVariable +torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(["torch.npu.current_stream"], TorchInGraphFunctionVariable,) +torch_non_c_binding_in_graph_functions_npu["torch.npu.stream"] = TorchInGraphFunctionVariable +torch._dynamo.trace_rules.torch_name_rule_map.append(torch_non_c_binding_in_graph_functions_npu) + class NPUWorker(WorkerBase): From 617af6213705a735ed083e8cf6e67d4d152dc647 Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Thu, 4 Sep 2025 17:34:47 +0800 Subject: [PATCH 3/3] [main] refactor and support in aclgraph Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/ascend_forward_context.py | 18 +++-- vllm_ascend/envs.py | 9 +++ vllm_ascend/ops/activation.py | 28 +------- vllm_ascend/ops/flashcomm_gate_ops.py | 99 +++++++++++++++++++++++++++ vllm_ascend/ops/layernorm.py | 16 +---- vllm_ascend/ops/linear.py | 39 +---------- vllm_ascend/worker/model_runner_v1.py | 13 ++-- 7 files changed, 135 insertions(+), 87 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index c25fe5553d..87b64bc1c3 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -85,10 +85,6 @@ def set_ascend_forward_context( ): forward_context = get_forward_context() - forward_context.prefetch_stream = prefetch_stream - forward_context.prefetch_model = prefetch_model - forward_context.prefetch_mlp_up = False - forward_context.moe_comm_method_name = moe_comm_method + "commimpl" forward_context.with_prefill = with_prefill ep_size = (get_ep_group().world_size if @@ -112,8 +108,18 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False - # set this for rope forward_oot using - forward_context.is_first_layer = True + # set this for layer index + forward_context.layer_idx = 0 + + # set for mlp weight prefetch + prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ + num_tokens is not None and num_tokens < 500 + if prefetch_mlp_enabled: + forward_context.prefetch_stream = prefetch_stream + forward_context.prefetch_model = prefetch_model + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled # set for flashcomm_v1 flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index c5c5fcdfca..cef722eb77 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -142,6 +142,15 @@ # this feature in eager mode will get better performance. "VLLM_ASCEND_ENABLE_MLP_OPTIMIZE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))), + # Whether to enable MLP weight prefetch, only used in decode. + "VLLM_ASCEND_ENABLE_PREFETCH_MLP": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), + # buffer size for gate up prefetch + "MLP_GATE_UP_PREFETCH_SIZE": + lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), + # buffer size for down proj prefetch + "MLP_DOWN_PREFETCH_SIZE": + lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), # Determine the number of physical devices in a non-full-use scenario # caused by the initialization of the Mooncake connector. "PHYSICAL_DEVICES": diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py index 0ad2b69c95..021b43cf5a 100644 --- a/vllm_ascend/ops/activation.py +++ b/vllm_ascend/ops/activation.py @@ -17,7 +17,6 @@ import torch from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul -from vllm.forward_context import get_forward_context class AscendQuickGELU(QuickGELU): @@ -30,26 +29,6 @@ def forward_oot(self, x: torch.tensor) -> torch.Tensor: class AscendSiluAndMul(SiluAndMul): - def prefetch_down_proj(self, - dependency: torch.Tensor): - import torch_npu - forward_context = get_forward_context() - prefetch_model = forward_context.prefetch_model - prefetch_stream = forward_context.prefetch_stream - layer_idx = forward_context.layer_idx - - prefetch_stream.wait_stream(torch.npu.current_stream()) - - with torch.npu.stream(prefetch_stream): - MLP_DOWN_PREFETCH_SIZE = 6 * 1024 * 1024 - torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \ - dependency, MLP_DOWN_PREFETCH_SIZE) - forward_context.layer_idx += 1 - - def wait_prefetch_done(self): - forward_context = get_forward_context() - prefetch_stream = forward_context.prefetch_stream - torch.npu.current_stream().wait_stream(prefetch_stream) def forward_oot(self, x: torch.Tensor) -> torch.Tensor: import torch_npu @@ -59,10 +38,7 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor: if is_310p(): out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) else: - dependency = x - self.prefetch_down_proj(dependency) - + torch.ops.vllm.maybe_prefetch_mlp_down_proj(x) out = torch_npu.npu_swiglu(x) - - self.wait_prefetch_done() + torch.ops.vllm.maybe_wait_prefetch_done(out) return out diff --git a/vllm_ascend/ops/flashcomm_gate_ops.py b/vllm_ascend/ops/flashcomm_gate_ops.py index e2a77239fc..75ddd28ad8 100644 --- a/vllm_ascend/ops/flashcomm_gate_ops.py +++ b/vllm_ascend/ops/flashcomm_gate_ops.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +import torch_npu from vllm.utils import direct_register_custom_op from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_reduce_scatter, @@ -7,6 +8,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context +import vllm_ascend.envs as envs_ascend def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: @@ -16,6 +18,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch "Currently, this situation only occurs " "when flashcomm_v1 is enabled" ) + pad_size = get_forward_context().pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] @@ -44,6 +49,73 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: return tensor_model_parallel_all_reduce(x) +def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefix: str) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + prefetch_model = forward_context.prefetch_model + prefetch_stream = forward_context.prefetch_stream + layer_idx = int(prefix.split('.')[2]) + + # start point of gate_up_proj weight prefetch + if prefix.split('.')[-2] == "self_attn": + forward_context.prefetch_mlp_gate_up_proj = True + if forward_context.prefetch_mlp_gate_up_proj: + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE + torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \ + x_dependency, MLP_GATE_UP_PREFETCH_SIZE) + return + + +def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, prefix: str) -> None: + return + + +def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + forward_context.prefetch_mlp_down_proj = True + prefetch_model = forward_context.prefetch_model + prefetch_stream = forward_context.prefetch_stream + layer_idx = forward_context.layer_idx + + # start point of down_proj weight prefetch + prefetch_stream.wait_stream(torch.npu.current_stream()) + + with torch.npu.stream(prefetch_stream): + MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE + torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \ + x_dependency, MLP_DOWN_PREFETCH_SIZE) + forward_context.layer_idx += 1 + return + + +def _maybe_prefetch_mlp_down_proj_impl_fake(x_dependency: torch.Tensor) -> None: + return + + +def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: + forward_context = get_forward_context() + if not forward_context.prefetch_mlp_enabled: + return + if forward_context.prefetch_mlp_gate_up_proj or \ + forward_context.prefetch_mlp_down_proj: + prefetch_stream = get_forward_context().prefetch_stream + # wait until prefetch done + torch.npu.current_stream().wait_stream(prefetch_stream) + forward_context.prefetch_mlp_gate_up_proj = False + forward_context.prefetch_mlp_down_proj = False + return + + +def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None: + return + + direct_register_custom_op( op_name="maybe_chunk_residual", op_func=_maybe_chunk_residual_impl, @@ -69,3 +141,30 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: mutates_args=[], dispatch_key="PrivateUse1" ) + + +direct_register_custom_op( + op_name="maybe_prefetch_mlp_gate_up_proj", + op_func=_maybe_prefetch_mlp_gate_up_proj_impl, + fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1" +) + + +direct_register_custom_op( + op_name="maybe_prefetch_mlp_down_proj", + op_func=_maybe_prefetch_mlp_down_proj_impl, + fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1" +) + + +direct_register_custom_op( + op_name="maybe_wait_prefetch_done", + op_func=_maybe_wait_prefetch_done_impl, + fake_impl=_maybe_wait_prefetch_done_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1" +) \ No newline at end of file diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 17117922f0..4612251815 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -36,12 +36,6 @@ def __init__( super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) self.layer = layer - def wait_prefetch_done(self): - forward_context = get_forward_context() - prefetch_stream = forward_context.prefetch_stream - # wait until - torch.npu.current_stream().wait_stream(prefetch_stream) - def forward( self, x: torch.Tensor, @@ -59,18 +53,11 @@ def forward( self.layer.aclnn_input_scale, self.layer.aclnn_input_offset, epsilon=self.variance_epsilon) - - if forward_context.prefetch_mlp_up: - self.wait_prefetch_done() - + torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - - forward_context = get_forward_context() - if forward_context.prefetch_mlp_up: - self.wait_prefetch_done() return x @@ -96,6 +83,7 @@ def forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 5997ff5cec..8a3c49ae98 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -24,7 +24,6 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs @@ -32,14 +31,10 @@ from vllm_ascend.distributed.parallel_state import ( get_mlp_tensor_model_parallel_rank, get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group) -from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.utils import (all_gather_and_maybe_unpad, - maybe_pad_and_reduce_scatter) from vllm.model_executor.layers.linear import ( # isort: skip WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, - MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, - UnquantizedLinearMethod) + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) class AscendMlpColumnParallelLinear(ColumnParallelLinear): @@ -381,33 +376,6 @@ class AscendDenseRowParallelLinear(RowParallelLinear): communication-computation fusion. """ - def prefetch_gate_up_proj(self, - dependency: torch.Tensor): - # get prefetch model - forward_context = get_forward_context() - layer_num = int(self.prefix.split('.')[2]) - prefetch_model = forward_context.prefetch_model - prefetch_stream = forward_context.prefetch_stream - - # start point of weight prefetch - forward_context.prefetch_mlp_up = True if self.prefix.split('.')[-2] == 'self_attn' else False - if forward_context.prefetch_mlp_up: - prefetch_stream.wait_stream(torch.npu.current_stream()) - - with torch.npu.stream(prefetch_stream): - # For Qwen3-32B - MLP_GATE_UP_PREFETCH_SIZE = 50 * 1024 * 1024 - torch_npu.npu_prefetch(prefetch_model.model.layers[layer_num].mlp.gate_up_proj.weight, \ - dependency, MLP_GATE_UP_PREFETCH_SIZE) - - - def wait_prefetch_done(self): - forward_context = get_forward_context() - if forward_context.prefetch_mlp_up: - prefetch_stream = forward_context.prefetch_stream - # wait until reduce-scatter is done - torch.npu.current_stream().wait_stream(prefetch_stream) - def forward( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -431,11 +399,8 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - dependency = output_parallel - - self.prefetch_gate_up_proj(dependency) - output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) output_bias = self.bias if self.skip_bias_add else None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index deab471ae0..440e0d416d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -37,6 +37,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -181,7 +182,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device - self.prefetch_stream = torch.npu.Stream(device=device) + if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: + self.prefetch_stream = torch.npu.Stream(device=device) + else: + self.prefetch_stream = None self.dtype = self.model_config.dtype if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: # TODO: drop the env config to use ascend sampler by default @@ -1145,9 +1149,10 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, inputs_embeds=inputs_embeds, ) if get_forward_context().flashcomm_v1_enabled: - from vllm_ascend.utils import all_gather_and_maybe_unpad - hidden_states = all_gather_and_maybe_unpad( - hidden_states, get_forward_context().pad_size, dim=0) + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens,