Skip to content

Commit 617af62

Browse files
committed
[main] refactor and support in aclgraph
Signed-off-by: rjg-lyh <[email protected]>
1 parent 5ffd8db commit 617af62

File tree

7 files changed

+135
-87
lines changed

7 files changed

+135
-87
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def set_ascend_forward_context(
8585
):
8686
forward_context = get_forward_context()
8787

88-
forward_context.prefetch_stream = prefetch_stream
89-
forward_context.prefetch_model = prefetch_model
90-
forward_context.prefetch_mlp_up = False
91-
9288
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
9389
forward_context.with_prefill = with_prefill
9490
ep_size = (get_ep_group().world_size if
@@ -112,8 +108,18 @@ def set_ascend_forward_context(
112108
# due to multiple warmups before actual capturing
113109
forward_context.capturing = False
114110

115-
# set this for rope forward_oot using
116-
forward_context.is_first_layer = True
111+
# set this for layer index
112+
forward_context.layer_idx = 0
113+
114+
# set for mlp weight prefetch
115+
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
116+
num_tokens is not None and num_tokens < 500
117+
if prefetch_mlp_enabled:
118+
forward_context.prefetch_stream = prefetch_stream
119+
forward_context.prefetch_model = prefetch_model
120+
forward_context.prefetch_mlp_gate_up_proj = False
121+
forward_context.prefetch_mlp_down_proj = False
122+
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
117123

118124
# set for flashcomm_v1
119125
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \

vllm_ascend/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@
142142
# this feature in eager mode will get better performance.
143143
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
144144
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))),
145+
# Whether to enable MLP weight prefetch, only used in decode.
146+
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
147+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
148+
# buffer size for gate up prefetch
149+
"MLP_GATE_UP_PREFETCH_SIZE":
150+
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
151+
# buffer size for down proj prefetch
152+
"MLP_DOWN_PREFETCH_SIZE":
153+
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
145154
# Determine the number of physical devices in a non-full-use scenario
146155
# caused by the initialization of the Mooncake connector.
147156
"PHYSICAL_DEVICES":

vllm_ascend/ops/activation.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
20-
from vllm.forward_context import get_forward_context
2120

2221

2322
class AscendQuickGELU(QuickGELU):
@@ -30,26 +29,6 @@ def forward_oot(self, x: torch.tensor) -> torch.Tensor:
3029

3130

3231
class AscendSiluAndMul(SiluAndMul):
33-
def prefetch_down_proj(self,
34-
dependency: torch.Tensor):
35-
import torch_npu
36-
forward_context = get_forward_context()
37-
prefetch_model = forward_context.prefetch_model
38-
prefetch_stream = forward_context.prefetch_stream
39-
layer_idx = forward_context.layer_idx
40-
41-
prefetch_stream.wait_stream(torch.npu.current_stream())
42-
43-
with torch.npu.stream(prefetch_stream):
44-
MLP_DOWN_PREFETCH_SIZE = 6 * 1024 * 1024
45-
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
46-
dependency, MLP_DOWN_PREFETCH_SIZE)
47-
forward_context.layer_idx += 1
48-
49-
def wait_prefetch_done(self):
50-
forward_context = get_forward_context()
51-
prefetch_stream = forward_context.prefetch_stream
52-
torch.npu.current_stream().wait_stream(prefetch_stream)
5332

5433
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
5534
import torch_npu
@@ -59,10 +38,7 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
5938
if is_310p():
6039
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
6140
else:
62-
dependency = x
63-
self.prefetch_down_proj(dependency)
64-
41+
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
6542
out = torch_npu.npu_swiglu(x)
66-
67-
self.wait_prefetch_done()
43+
torch.ops.vllm.maybe_wait_prefetch_done(out)
6844
return out

vllm_ascend/ops/flashcomm_gate_ops.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22
import torch.nn.functional as F
3+
import torch_npu
34
from vllm.utils import direct_register_custom_op
45
from vllm.distributed import (tensor_model_parallel_all_gather,
56
tensor_model_parallel_reduce_scatter,
67
tensor_model_parallel_all_reduce,
78
get_tensor_model_parallel_rank,
89
get_tensor_model_parallel_world_size)
910
from vllm.forward_context import get_forward_context
11+
import vllm_ascend.envs as envs_ascend
1012

1113

1214
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@@ -16,6 +18,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch
1618
"Currently, this situation only occurs "
1719
"when flashcomm_v1 is enabled"
1820
)
21+
pad_size = get_forward_context().pad_size
22+
if pad_size > 0:
23+
residual = F.pad(residual, (0, 0, 0, pad_size))
1924
tp_size = get_tensor_model_parallel_world_size()
2025
tp_rank = get_tensor_model_parallel_rank()
2126
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
@@ -44,6 +49,73 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
4449
return tensor_model_parallel_all_reduce(x)
4550

4651

52+
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, prefix: str) -> None:
53+
forward_context = get_forward_context()
54+
if not forward_context.prefetch_mlp_enabled:
55+
return
56+
prefetch_model = forward_context.prefetch_model
57+
prefetch_stream = forward_context.prefetch_stream
58+
layer_idx = int(prefix.split('.')[2])
59+
60+
# start point of gate_up_proj weight prefetch
61+
if prefix.split('.')[-2] == "self_attn":
62+
forward_context.prefetch_mlp_gate_up_proj = True
63+
if forward_context.prefetch_mlp_gate_up_proj:
64+
prefetch_stream.wait_stream(torch.npu.current_stream())
65+
66+
with torch.npu.stream(prefetch_stream):
67+
MLP_GATE_UP_PREFETCH_SIZE = envs_ascend.MLP_GATE_UP_PREFETCH_SIZE
68+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.gate_up_proj.weight, \
69+
x_dependency, MLP_GATE_UP_PREFETCH_SIZE)
70+
return
71+
72+
73+
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, prefix: str) -> None:
74+
return
75+
76+
77+
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
78+
forward_context = get_forward_context()
79+
if not forward_context.prefetch_mlp_enabled:
80+
return
81+
forward_context.prefetch_mlp_down_proj = True
82+
prefetch_model = forward_context.prefetch_model
83+
prefetch_stream = forward_context.prefetch_stream
84+
layer_idx = forward_context.layer_idx
85+
86+
# start point of down_proj weight prefetch
87+
prefetch_stream.wait_stream(torch.npu.current_stream())
88+
89+
with torch.npu.stream(prefetch_stream):
90+
MLP_DOWN_PREFETCH_SIZE = envs_ascend.MLP_DOWN_PREFETCH_SIZE
91+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
92+
x_dependency, MLP_DOWN_PREFETCH_SIZE)
93+
forward_context.layer_idx += 1
94+
return
95+
96+
97+
def _maybe_prefetch_mlp_down_proj_impl_fake(x_dependency: torch.Tensor) -> None:
98+
return
99+
100+
101+
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
102+
forward_context = get_forward_context()
103+
if not forward_context.prefetch_mlp_enabled:
104+
return
105+
if forward_context.prefetch_mlp_gate_up_proj or \
106+
forward_context.prefetch_mlp_down_proj:
107+
prefetch_stream = get_forward_context().prefetch_stream
108+
# wait until prefetch done
109+
torch.npu.current_stream().wait_stream(prefetch_stream)
110+
forward_context.prefetch_mlp_gate_up_proj = False
111+
forward_context.prefetch_mlp_down_proj = False
112+
return
113+
114+
115+
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
116+
return
117+
118+
47119
direct_register_custom_op(
48120
op_name="maybe_chunk_residual",
49121
op_func=_maybe_chunk_residual_impl,
@@ -69,3 +141,30 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
69141
mutates_args=[],
70142
dispatch_key="PrivateUse1"
71143
)
144+
145+
146+
direct_register_custom_op(
147+
op_name="maybe_prefetch_mlp_gate_up_proj",
148+
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
149+
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
150+
mutates_args=[],
151+
dispatch_key="PrivateUse1"
152+
)
153+
154+
155+
direct_register_custom_op(
156+
op_name="maybe_prefetch_mlp_down_proj",
157+
op_func=_maybe_prefetch_mlp_down_proj_impl,
158+
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
159+
mutates_args=[],
160+
dispatch_key="PrivateUse1"
161+
)
162+
163+
164+
direct_register_custom_op(
165+
op_name="maybe_wait_prefetch_done",
166+
op_func=_maybe_wait_prefetch_done_impl,
167+
fake_impl=_maybe_wait_prefetch_done_impl_fake,
168+
mutates_args=[],
169+
dispatch_key="PrivateUse1"
170+
)

vllm_ascend/ops/layernorm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ def __init__(
3636
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
3737
self.layer = layer
3838

39-
def wait_prefetch_done(self):
40-
forward_context = get_forward_context()
41-
prefetch_stream = forward_context.prefetch_stream
42-
# wait until
43-
torch.npu.current_stream().wait_stream(prefetch_stream)
44-
4539
def forward(
4640
self,
4741
x: torch.Tensor,
@@ -59,18 +53,11 @@ def forward(
5953
self.layer.aclnn_input_scale,
6054
self.layer.aclnn_input_offset,
6155
epsilon=self.variance_epsilon)
62-
63-
if forward_context.prefetch_mlp_up:
64-
self.wait_prefetch_done()
65-
56+
torch.ops.vllm.maybe_wait_prefetch_done(x)
6657
return x, residual
6758

6859
x, residual = torch_npu.npu_rms_norm(x, self.weight,
6960
self.variance_epsilon)
70-
71-
forward_context = get_forward_context()
72-
if forward_context.prefetch_mlp_up:
73-
self.wait_prefetch_done()
7461
return x
7562

7663

@@ -96,6 +83,7 @@ def forward_oot(
9683
else:
9784
x, _, residual = torch_npu.npu_add_rms_norm(
9885
x, residual, self.weight, self.variance_epsilon)
86+
torch.ops.vllm.maybe_wait_prefetch_done(x)
9987
return x, residual
10088

10189
x, residual = torch_npu.npu_rms_norm(x, self.weight,

vllm_ascend/ops/linear.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,17 @@
2424
split_tensor_along_last_dim,
2525
tensor_model_parallel_all_gather,
2626
tensor_model_parallel_all_reduce)
27-
from vllm.forward_context import get_forward_context
2827
from vllm.model_executor.layers.quantization.base_config import \
2928
QuantizationConfig
3029
from vllm.model_executor.utils import set_weight_attrs
3130

3231
from vllm_ascend.distributed.parallel_state import (
3332
get_mlp_tensor_model_parallel_rank,
3433
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
35-
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
36-
from vllm_ascend.utils import (all_gather_and_maybe_unpad,
37-
maybe_pad_and_reduce_scatter)
3834

3935
from vllm.model_executor.layers.linear import ( # isort: skip
4036
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
41-
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear,
42-
UnquantizedLinearMethod)
37+
MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
4338

4439

4540
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
@@ -381,33 +376,6 @@ class AscendDenseRowParallelLinear(RowParallelLinear):
381376
communication-computation fusion.
382377
"""
383378

384-
def prefetch_gate_up_proj(self,
385-
dependency: torch.Tensor):
386-
# get prefetch model
387-
forward_context = get_forward_context()
388-
layer_num = int(self.prefix.split('.')[2])
389-
prefetch_model = forward_context.prefetch_model
390-
prefetch_stream = forward_context.prefetch_stream
391-
392-
# start point of weight prefetch
393-
forward_context.prefetch_mlp_up = True if self.prefix.split('.')[-2] == 'self_attn' else False
394-
if forward_context.prefetch_mlp_up:
395-
prefetch_stream.wait_stream(torch.npu.current_stream())
396-
397-
with torch.npu.stream(prefetch_stream):
398-
# For Qwen3-32B
399-
MLP_GATE_UP_PREFETCH_SIZE = 50 * 1024 * 1024
400-
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_num].mlp.gate_up_proj.weight, \
401-
dependency, MLP_GATE_UP_PREFETCH_SIZE)
402-
403-
404-
def wait_prefetch_done(self):
405-
forward_context = get_forward_context()
406-
if forward_context.prefetch_mlp_up:
407-
prefetch_stream = forward_context.prefetch_stream
408-
# wait until reduce-scatter is done
409-
torch.npu.current_stream().wait_stream(prefetch_stream)
410-
411379
def forward(
412380
self, input_: torch.Tensor
413381
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -431,11 +399,8 @@ def forward(
431399
output_parallel = self.quant_method.apply(self,
432400
input_parallel,
433401
bias=bias_)
434-
dependency = output_parallel
435-
436-
self.prefetch_gate_up_proj(dependency)
437-
438402
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
403+
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
439404

440405
output_bias = self.bias if self.skip_bias_add else None
441406

vllm_ascend/worker/model_runner_v1.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm.compilation.counter import compilation_counter
3838
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
3939
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
40+
from vllm.distributed import tensor_model_parallel_all_gather
4041
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
4142
has_kv_transfer_group)
4243
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -181,7 +182,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
181182
self.dp_size = vllm_config.parallel_config.data_parallel_size
182183
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
183184
self.device = device
184-
self.prefetch_stream = torch.npu.Stream(device=device)
185+
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
186+
self.prefetch_stream = torch.npu.Stream(device=device)
187+
else:
188+
self.prefetch_stream = None
185189
self.dtype = self.model_config.dtype
186190
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
187191
# TODO: drop the env config to use ascend sampler by default
@@ -1145,9 +1149,10 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
11451149
inputs_embeds=inputs_embeds,
11461150
)
11471151
if get_forward_context().flashcomm_v1_enabled:
1148-
from vllm_ascend.utils import all_gather_and_maybe_unpad
1149-
hidden_states = all_gather_and_maybe_unpad(
1150-
hidden_states, get_forward_context().pad_size, dim=0)
1152+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
1153+
pad_size = get_forward_context().pad_size
1154+
if pad_size > 0:
1155+
hidden_states = hidden_states[:-pad_size, :]
11511156
return hidden_states
11521157

11531158
def _build_attn_state(self, num_reqs, num_scheduled_tokens,

0 commit comments

Comments
 (0)