Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/e2e/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
import os
from unittest.mock import patch

import pytest

from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams

from tests.e2e.conftest import VllmRunner

os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"

QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]


def test_models_distributed_QwQ():
example_prompts = [
Expand Down Expand Up @@ -150,3 +154,23 @@ def test_sp_for_qwen3_moe() -> None:
enable_expert_parallel=True,
enforce_eager=True) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [
"Hello, my name is",
]
max_tokens = 5

with VllmRunner(
snapshot_download(model),
max_model_len=8192,
enforce_eager=enforce_eager,
dtype="auto",
tensor_parallel_size=4,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
38 changes: 34 additions & 4 deletions tests/ut/ops/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from unittest.mock import patch
import os
from unittest.mock import patch, MagicMock

import pytest
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_ascend.ascend_forward_context import set_ascend_forward_context


@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)


def mock_maybe_chunk_residual(x, residual):
if x.size(0) != residual.size(0):
return residual[:4]

return residual


def mock_rms_norm(x, weight, eps):
return x + 1, None

Expand All @@ -23,11 +32,12 @@ def mock_add_rms_norm(x, residual, weight, eps):
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
residual, dummy_tensor):
@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, mock_rmsnorm,
is_310p_return,residual, dummy_tensor):

with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = RMSNorm(hidden_size=32, eps=1e-05)
layer = RMSNorm(hidden_size=8, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)

Expand All @@ -51,3 +61,23 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,

mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)


@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
layer = RMSNorm(hidden_size=512, eps=1e-05)

out_x, out_residual = layer.forward_oot(x, residual)

expected_out_x = 2 * x
expected_out_residual = 2 * residual[:4]

mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once()
assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
5 changes: 5 additions & 0 deletions vllm_ascend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,10 @@


def register_model():
import vllm.envs as envs

Check failure on line 26 in vllm_ascend/__init__.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F401)

vllm_ascend/__init__.py:26:25: F401 `vllm.envs` imported but unused

import vllm_ascend.envs as envs_ascend

Check failure on line 28 in vllm_ascend/__init__.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F401)

vllm_ascend/__init__.py:28:32: F401 `vllm_ascend.envs` imported but unused

from .models import register_model

register_model()
30 changes: 29 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def set_ascend_forward_context(
moe_comm_method: str = "",
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
prefetch_model: torch.nn.Module = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
Expand All @@ -82,6 +84,7 @@ def set_ascend_forward_context(
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()

forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
Expand All @@ -105,6 +108,31 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing
forward_context.capturing = False

# set this for layer index
forward_context.layer_idx = 0

# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.prefetch_model = prefetch_model
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled

# set for flashcomm_v1
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
num_tokens is not None and num_tokens > 1000
Comment on lines +125 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 1000 is a magic number. It's used as a threshold to enable the flashcomm_v1 optimization. This makes the code harder to understand and maintain. It should be defined as a named constant with a comment explaining its purpose and how this value was determined. This would improve readability and make it easier to tune this threshold in the future.

Suggested change
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
num_tokens is not None and num_tokens > 1000
# e.g. FLASHCOMM_V1_TOKEN_THRESHOLD = 1000 at the top of the file
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
num_tokens is not None and num_tokens > FLASHCOMM_V1_TOKEN_THRESHOLD


if flashcomm_v1_enabled:
tp_world_size = get_tensor_model_parallel_world_size()
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size

forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled

if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens

Expand Down
16 changes: 16 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,26 @@
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# this feature will get better performance in prefill phase.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
# Whether to enable dense model and general optimizations for better performance.
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
# Whether to enable mlp optimize when tensor parallel is enabled.
# this feature in eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))),
# Whether to enable MLP weight prefetch, only used in decode.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch
"MLP_GATE_UP_PREFETCH_SIZE":
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
# buffer size for down proj prefetch
"MLP_DOWN_PREFETCH_SIZE":
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
# Determine the number of physical devices in a non-full-use scenario
# caused by the initialization of the Mooncake connector.
"PHYSICAL_DEVICES":
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
import vllm_ascend.ops.flashcomm_gate_ops

Check failure on line 27 in vllm_ascend/ops/__init__.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F401)

vllm_ascend/ops/__init__.py:27:8: F401 `vllm_ascend.ops.flashcomm_gate_ops` imported but unused; consider removing, adding to `__all__`, or using a redundant alias


class dummyFusionOp:
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out
170 changes: 170 additions & 0 deletions vllm_ascend/ops/flashcomm_gate_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm.utils import direct_register_custom_op
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
tensor_model_parallel_all_reduce,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend


def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
if x.size(0) != residual.size(0):
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
assert flashcomm_v1_enabled is True, (
"Currently, this situation only occurs "
"when flashcomm_v1 is enabled"
)
pad_size = get_forward_context().pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]

return residual


def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = x[:-pad_size, :]
return x


def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
if flashcomm_v1_enabled:
pad_size = get_forward_context().pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)


def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefix: str) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
prefetch_model = forward_context.prefetch_model
prefetch_stream = forward_context.prefetch_stream
layer_idx = int(prefix.split('.')[2])

# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())

with torch.npu.stream(prefetch_stream):
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
return


def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, prefix: str) -> None:
return


def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
forward_context.prefetch_mlp_down_proj = True
prefetch_model = forward_context.prefetch_model
prefetch_stream = forward_context.prefetch_stream
layer_idx = forward_context.layer_idx

# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())

with torch.npu.stream(prefetch_stream):
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, MLP_DOWN_PREFETCH_SIZE)
forward_context.layer_idx += 1
return


def _maybe_prefetch_mlp_down_proj_impl_fake(x_dependency: torch.Tensor) -> None:
return


def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
forward_context = get_forward_context()
if not forward_context.prefetch_mlp_enabled:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = get_forward_context().prefetch_stream
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return


def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1"
)


direct_register_custom_op(
op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
mutates_args=[],
dispatch_key="PrivateUse1"
)


direct_register_custom_op(
op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1"
)


direct_register_custom_op(
op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1"
)


direct_register_custom_op(
op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1"
)


direct_register_custom_op(
op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1"
)
Loading
Loading