Skip to content

Commit 3ae9205

Browse files
committed
[main] flashcomm_v1 before mlp weight prefetch in Qwen Dense Models
Signed-off-by: rjg-lyh <[email protected]>
1 parent 4c90fa7 commit 3ae9205

File tree

11 files changed

+286
-10
lines changed

11 files changed

+286
-10
lines changed

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@
2323
import os
2424
from unittest.mock import patch
2525

26+
import pytest
27+
2628
from modelscope import snapshot_download # type: ignore
2729
from vllm import SamplingParams
2830

2931
from tests.e2e.conftest import VllmRunner
3032

3133
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3234

35+
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
36+
3337

3438
def test_models_distributed_QwQ():
3539
example_prompts = [
@@ -150,3 +154,23 @@ def test_sp_for_qwen3_moe() -> None:
150154
enable_expert_parallel=True,
151155
enforce_eager=True) as vllm_model:
152156
vllm_model.generate(example_prompts, sampling_params)
157+
158+
159+
@pytest.mark.parametrize("enforce_eager", [True, False])
160+
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
161+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
162+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
163+
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
164+
example_prompts = [
165+
"Hello, my name is",
166+
]
167+
max_tokens = 5
168+
169+
with VllmRunner(
170+
snapshot_download(model),
171+
max_model_len=8192,
172+
enforce_eager=enforce_eager,
173+
dtype="auto",
174+
tensor_parallel_size=4,
175+
) as vllm_model:
176+
vllm_model.generate_greedy(example_prompts, max_tokens)

tests/ut/ops/test_layernorm.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1-
from unittest.mock import patch
1+
import os
2+
from unittest.mock import patch, MagicMock
23

34
import pytest
45
import torch
56
from vllm.model_executor.layers.layernorm import RMSNorm
7+
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
68

79

810
@pytest.fixture
911
def dummy_tensor():
1012
return torch.randn(4, 8, dtype=torch.float16)
1113

1214

15+
def mock_maybe_chunk_residual(x, residual):
16+
if x.size(0) != residual.size(0):
17+
return residual[:4]
18+
19+
return residual
20+
21+
1322
def mock_rms_norm(x, weight, eps):
1423
return x + 1, None
1524

@@ -23,11 +32,12 @@ def mock_add_rms_norm(x, residual, weight, eps):
2332
[None, torch.randn(4, 8, dtype=torch.float32)])
2433
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
2534
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
26-
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
27-
residual, dummy_tensor):
35+
@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual)
36+
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm, mock_rmsnorm,
37+
is_310p_return,residual, dummy_tensor):
2838

2939
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
30-
layer = RMSNorm(hidden_size=32, eps=1e-05)
40+
layer = RMSNorm(hidden_size=8, eps=1e-05)
3141
if residual is not None:
3242
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
3343

@@ -51,3 +61,23 @@ def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
5161

5262
mock_rmsnorm.assert_called_once()
5363
assert torch.allclose(out_x, expected_out_x)
64+
65+
66+
@patch("vllm_ascend.utils.is_310p", return_value=False)
67+
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
68+
@patch("torch.ops.vllm.maybe_chunk_residual", side_effect=mock_maybe_chunk_residual)
69+
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, mock_add_rms_norm, mock_is310p):
70+
x = torch.randn(4, 512, dtype=torch.bfloat16)
71+
residual = torch.randn(16, 512, dtype=torch.bfloat16)
72+
layer = RMSNorm(hidden_size=512, eps=1e-05)
73+
74+
out_x, out_residual = layer.forward_oot(x, residual)
75+
76+
expected_out_x = 2 * x
77+
expected_out_residual = 2 * residual[:4]
78+
79+
mock_maybe_chunk_residual.assert_called_once()
80+
mock_add_rms_norm.assert_called_once()
81+
assert out_residual.size(0) == 4
82+
assert torch.allclose(out_x, expected_out_x)
83+
assert torch.allclose(out_residual, expected_out_residual)

vllm_ascend/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,10 @@ def register():
2323

2424

2525
def register_model():
26+
import vllm.envs as envs
27+
28+
import vllm_ascend.envs as envs_ascend
29+
2630
from .models import register_model
31+
2732
register_model()

vllm_ascend/ascend_forward_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ def set_ascend_forward_context(
105105
# due to multiple warmups before actual capturing
106106
forward_context.capturing = False
107107

108+
# set this for rope forward_oot using
109+
forward_context.is_first_layer = True
110+
111+
# set for flashcomm_v1
112+
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
113+
num_tokens is not None and num_tokens > 1000
114+
115+
if flashcomm_v1_enabled:
116+
tp_world_size = get_tensor_model_parallel_world_size()
117+
pad_size = (tp_world_size -
118+
(num_tokens % tp_world_size)) % tp_world_size
119+
forward_context.pad_size = pad_size
120+
121+
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
122+
108123
if num_tokens is None and attn_metadata is not None:
109124
num_tokens = attn_metadata.num_actual_tokens
110125

vllm_ascend/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@
131131
# this feature is supported in A2, and eager mode will get better performance.
132132
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
133133
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
134+
# Whether to enable FlashComm optimization when tensor parallel is enabled.
135+
# this feature will get better performance in prefill phase.
136+
"VLLM_ASCEND_ENABLE_FLASHCOMM":
137+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
138+
# Whether to enable dense model and general optimizations for better performance.
139+
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
140+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
134141
# Whether to enable mlp optimize when tensor parallel is enabled.
135142
# this feature in eager mode will get better performance.
136143
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":

vllm_ascend/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
2525
from vllm_ascend.ops.rotary_embedding import (
2626
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
27+
import vllm_ascend.ops.flashcomm_gate_ops
2728

2829

2930
class dummyFusionOp:
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from vllm.utils import direct_register_custom_op
4+
from vllm.distributed import (tensor_model_parallel_all_gather,
5+
tensor_model_parallel_reduce_scatter,
6+
tensor_model_parallel_all_reduce,
7+
get_tensor_model_parallel_rank,
8+
get_tensor_model_parallel_world_size)
9+
from vllm.forward_context import get_forward_context
10+
11+
12+
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
13+
if x.size(0) != residual.size(0):
14+
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
15+
assert flashcomm_v1_enabled is True, (
16+
"Currently, this situation only occurs "
17+
"when flashcomm_v1 is enabled"
18+
)
19+
tp_size = get_tensor_model_parallel_world_size()
20+
tp_rank = get_tensor_model_parallel_rank()
21+
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
22+
23+
return residual
24+
25+
26+
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor:
27+
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
28+
if flashcomm_v1_enabled and label:
29+
x = tensor_model_parallel_all_gather(x, 0)
30+
pad_size = get_forward_context().pad_size
31+
if pad_size > 0:
32+
x = x[:-pad_size, :]
33+
return x
34+
35+
36+
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
37+
flashcomm_v1_enabled = get_forward_context().flashcomm_v1_enabled
38+
if flashcomm_v1_enabled:
39+
pad_size = get_forward_context().pad_size
40+
if pad_size > 0:
41+
x = F.pad(x, (0, 0, 0, pad_size))
42+
return tensor_model_parallel_reduce_scatter(x, 0)
43+
else:
44+
return tensor_model_parallel_all_reduce(x)
45+
46+
47+
direct_register_custom_op(
48+
op_name="maybe_chunk_residual",
49+
op_func=_maybe_chunk_residual_impl,
50+
fake_impl=lambda x, residual: residual,
51+
mutates_args=[],
52+
dispatch_key="PrivateUse1"
53+
)
54+
55+
56+
direct_register_custom_op(
57+
op_name="maybe_all_gather_and_maybe_unpad",
58+
op_func=_maybe_all_gather_and_maybe_unpad_impl,
59+
fake_impl=lambda x, label: x,
60+
mutates_args=[],
61+
dispatch_key="PrivateUse1"
62+
)
63+
64+
65+
direct_register_custom_op(
66+
op_name="maybe_pad_and_reduce",
67+
op_func=_maybe_pad_and_reduce_impl,
68+
fake_impl=lambda x: x,
69+
mutates_args=[],
70+
dispatch_key="PrivateUse1"
71+
)

vllm_ascend/ops/layernorm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def forward(
4444
import torch_npu
4545

4646
if residual is not None:
47+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
48+
assert x.size(0) == residual.size(0)
4749
x, _, residual = torch_npu.npu_add_rms_norm_quant(
4850
x,
4951
residual,
@@ -69,6 +71,8 @@ def forward_oot(
6971

7072
from vllm_ascend.utils import is_310p
7173
if residual is not None:
74+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
75+
assert x.size(0) == residual.size(0)
7276
if is_310p():
7377
orig_dtype = residual.dtype
7478
x = x + residual.to(x.dtype)

vllm_ascend/ops/linear.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,22 @@
2424
split_tensor_along_last_dim,
2525
tensor_model_parallel_all_gather,
2626
tensor_model_parallel_all_reduce)
27-
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
28-
ColumnParallelLinear,
29-
LinearBase,
30-
MergedColumnParallelLinear,
31-
RowParallelLinear)
27+
from vllm.forward_context import get_forward_context
3228
from vllm.model_executor.layers.quantization.base_config import \
3329
QuantizationConfig
3430
from vllm.model_executor.utils import set_weight_attrs
3531

3632
from vllm_ascend.distributed.parallel_state import (
3733
get_mlp_tensor_model_parallel_rank,
3834
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
35+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
36+
from vllm_ascend.utils import (all_gather_and_maybe_unpad,
37+
maybe_pad_and_reduce_scatter)
38+
39+
from vllm.model_executor.layers.linear import ( # isort: skip
40+
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
41+
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear,
42+
UnquantizedLinearMethod)
3943

4044

4145
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
@@ -307,3 +311,103 @@ def forward(
307311
if not self.return_bias:
308312
return output
309313
return output, output_bias
314+
315+
316+
class AscendDenseMergedColumnParallelLinear(MergedColumnParallelLinear):
317+
"""Linear layer with column parallelism.
318+
319+
Implemented multiple optimization projects for dense models, such as FlashComm and
320+
communication-computation fusion.
321+
"""
322+
323+
def forward(
324+
self, input_: torch.Tensor
325+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
326+
bias = self.bias if not self.skip_bias_add else None
327+
328+
# Matrix multiply.
329+
assert self.quant_method is not None
330+
331+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
332+
output_parallel = self.quant_method.apply(self, input_, bias)
333+
334+
if self.gather_output:
335+
# All-gather across the partitions.
336+
output = tensor_model_parallel_all_gather(output_parallel)
337+
else:
338+
output = output_parallel
339+
output_bias = self.bias if self.skip_bias_add else None
340+
if not self.return_bias:
341+
return output
342+
return output, output_bias
343+
344+
345+
class AscendDenseQKVParallelLinear(QKVParallelLinear):
346+
"""Linear layer with column parallelism.
347+
348+
Implemented multiple optimization projects for dense models, such as FlashComm and
349+
communication-computation fusion.
350+
"""
351+
352+
def forward(
353+
self, input_: torch.Tensor
354+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
355+
bias = self.bias if not self.skip_bias_add else None
356+
357+
# Matrix multiply.
358+
assert self.quant_method is not None
359+
360+
layer_num = self.prefix.split('.')[2]
361+
362+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
363+
input_, layer_num != '0')
364+
output_parallel = self.quant_method.apply(self, input_, bias)
365+
366+
if self.gather_output:
367+
# All-gather across the partitions.
368+
output = tensor_model_parallel_all_gather(output_parallel)
369+
else:
370+
output = output_parallel
371+
output_bias = self.bias if self.skip_bias_add else None
372+
if not self.return_bias:
373+
return output
374+
return output, output_bias
375+
376+
377+
class AscendDenseRowParallelLinear(RowParallelLinear):
378+
"""Linear layer with row parallelism.
379+
380+
Implemented multiple optimization projects for dense models, such as FlashComm and
381+
communication-computation fusion.
382+
"""
383+
384+
def forward(
385+
self, input_: torch.Tensor
386+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
387+
if self.input_is_parallel:
388+
input_parallel = input_
389+
else:
390+
tp_rank = get_tensor_model_parallel_rank()
391+
splitted_input = split_tensor_along_last_dim(
392+
input_, num_partitions=self.tp_size)
393+
input_parallel = splitted_input[tp_rank].contiguous()
394+
395+
# Matrix multiply.
396+
assert self.quant_method is not None
397+
# Only fuse bias add into GEMM for rank 0 (this ensures that
398+
# bias will not get added more than once in TP>1 case)
399+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
400+
401+
if self.tp_size == 1 or not self.reduce_results:
402+
output = self.quant_method.apply(self, input_parallel, bias=bias_)
403+
else:
404+
output_parallel = self.quant_method.apply(self,
405+
input_parallel,
406+
bias=bias_)
407+
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
408+
409+
output_bias = self.bias if self.skip_bias_add else None
410+
411+
if not self.return_bias:
412+
return output
413+
return output, output_bias

0 commit comments

Comments
 (0)