Skip to content

Commit 5ffd8db

Browse files
Shuming19rjg-lyh
authored andcommitted
add mlp weight prefetch
1 parent 3ae9205 commit 5ffd8db

File tree

6 files changed

+92
-3
lines changed

6 files changed

+92
-3
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def set_ascend_forward_context(
6767
moe_comm_method: str = "",
6868
num_actual_tokens: Optional[int] = None,
6969
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
70-
batch_descriptor: Optional[BatchDescriptor] = None):
70+
batch_descriptor: Optional[BatchDescriptor] = None,
71+
prefetch_stream: torch.npu.Stream = None,
72+
prefetch_model: torch.nn.Module = None):
7173
"""A context manager that stores the current forward context,
7274
can be attention metadata, etc.
7375
We add some additional param into forward_context.
@@ -82,6 +84,11 @@ def set_ascend_forward_context(
8284
batch_descriptor=batch_descriptor,
8385
):
8486
forward_context = get_forward_context()
87+
88+
forward_context.prefetch_stream = prefetch_stream
89+
forward_context.prefetch_model = prefetch_model
90+
forward_context.prefetch_mlp_up = False
91+
8592
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
8693
forward_context.with_prefill = with_prefill
8794
ep_size = (get_ep_group().world_size if

vllm_ascend/ops/activation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
20+
from vllm.forward_context import get_forward_context
2021

2122

2223
class AscendQuickGELU(QuickGELU):
@@ -29,6 +30,26 @@ def forward_oot(self, x: torch.tensor) -> torch.Tensor:
2930

3031

3132
class AscendSiluAndMul(SiluAndMul):
33+
def prefetch_down_proj(self,
34+
dependency: torch.Tensor):
35+
import torch_npu
36+
forward_context = get_forward_context()
37+
prefetch_model = forward_context.prefetch_model
38+
prefetch_stream = forward_context.prefetch_stream
39+
layer_idx = forward_context.layer_idx
40+
41+
prefetch_stream.wait_stream(torch.npu.current_stream())
42+
43+
with torch.npu.stream(prefetch_stream):
44+
MLP_DOWN_PREFETCH_SIZE = 6 * 1024 * 1024
45+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_idx].mlp.down_proj.weight, \
46+
dependency, MLP_DOWN_PREFETCH_SIZE)
47+
forward_context.layer_idx += 1
48+
49+
def wait_prefetch_done(self):
50+
forward_context = get_forward_context()
51+
prefetch_stream = forward_context.prefetch_stream
52+
torch.npu.current_stream().wait_stream(prefetch_stream)
3253

3354
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3455
import torch_npu
@@ -38,5 +59,10 @@ def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
3859
if is_310p():
3960
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
4061
else:
62+
dependency = x
63+
self.prefetch_down_proj(dependency)
64+
4165
out = torch_npu.npu_swiglu(x)
66+
67+
self.wait_prefetch_done()
4268
return out

vllm_ascend/ops/layernorm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def __init__(
3636
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
3737
self.layer = layer
3838

39+
def wait_prefetch_done(self):
40+
forward_context = get_forward_context()
41+
prefetch_stream = forward_context.prefetch_stream
42+
# wait until
43+
torch.npu.current_stream().wait_stream(prefetch_stream)
44+
3945
def forward(
4046
self,
4147
x: torch.Tensor,
@@ -53,10 +59,18 @@ def forward(
5359
self.layer.aclnn_input_scale,
5460
self.layer.aclnn_input_offset,
5561
epsilon=self.variance_epsilon)
62+
63+
if forward_context.prefetch_mlp_up:
64+
self.wait_prefetch_done()
65+
5666
return x, residual
5767

5868
x, residual = torch_npu.npu_rms_norm(x, self.weight,
5969
self.variance_epsilon)
70+
71+
forward_context = get_forward_context()
72+
if forward_context.prefetch_mlp_up:
73+
self.wait_prefetch_done()
6074
return x
6175

6276

vllm_ascend/ops/linear.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,33 @@ class AscendDenseRowParallelLinear(RowParallelLinear):
381381
communication-computation fusion.
382382
"""
383383

384+
def prefetch_gate_up_proj(self,
385+
dependency: torch.Tensor):
386+
# get prefetch model
387+
forward_context = get_forward_context()
388+
layer_num = int(self.prefix.split('.')[2])
389+
prefetch_model = forward_context.prefetch_model
390+
prefetch_stream = forward_context.prefetch_stream
391+
392+
# start point of weight prefetch
393+
forward_context.prefetch_mlp_up = True if self.prefix.split('.')[-2] == 'self_attn' else False
394+
if forward_context.prefetch_mlp_up:
395+
prefetch_stream.wait_stream(torch.npu.current_stream())
396+
397+
with torch.npu.stream(prefetch_stream):
398+
# For Qwen3-32B
399+
MLP_GATE_UP_PREFETCH_SIZE = 50 * 1024 * 1024
400+
torch_npu.npu_prefetch(prefetch_model.model.layers[layer_num].mlp.gate_up_proj.weight, \
401+
dependency, MLP_GATE_UP_PREFETCH_SIZE)
402+
403+
404+
def wait_prefetch_done(self):
405+
forward_context = get_forward_context()
406+
if forward_context.prefetch_mlp_up:
407+
prefetch_stream = forward_context.prefetch_stream
408+
# wait until reduce-scatter is done
409+
torch.npu.current_stream().wait_stream(prefetch_stream)
410+
384411
def forward(
385412
self, input_: torch.Tensor
386413
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -404,6 +431,10 @@ def forward(
404431
output_parallel = self.quant_method.apply(self,
405432
input_parallel,
406433
bias=bias_)
434+
dependency = output_parallel
435+
436+
self.prefetch_gate_up_proj(dependency)
437+
407438
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
408439

409440
output_bias = self.bias if self.skip_bias_add else None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
181181
self.dp_size = vllm_config.parallel_config.data_parallel_size
182182
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
183183
self.device = device
184+
self.prefetch_stream = torch.npu.Stream(device=device)
184185
self.dtype = self.model_config.dtype
185186
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
186187
# TODO: drop the env config to use ascend sampler by default
@@ -1497,7 +1498,9 @@ def execute_model(
14971498
aclgraph_runtime_mode=aclgraph_runtime_mode,
14981499
batch_descriptor=batch_descriptor,
14991500
num_actual_tokens=scheduler_output.
1500-
total_num_scheduled_tokens):
1501+
total_num_scheduled_tokens,
1502+
prefetch_stream=self.prefetch_stream,
1503+
prefetch_model=self.model):
15011504
self.maybe_setup_kv_connector(scheduler_output)
15021505

15031506
hidden_states = self._generate_process_reqs_hidden_states(
@@ -1944,7 +1947,9 @@ def dummy_compute_logits(hidden_states):
19441947
moe_comm_method=moe_comm_method,
19451948
num_actual_tokens=0,
19461949
aclgraph_runtime_mode=aclgraph_runtime_mode,
1947-
batch_descriptor=batch_descriptor):
1950+
batch_descriptor=batch_descriptor,
1951+
prefetch_stream=self.prefetch_stream,
1952+
prefetch_model=self.model):
19481953
hidden_states = self._generate_dummy_run_hidden_states(
19491954
with_prefill, is_torchair_compile, input_ids, positions,
19501955
attn_metadata, num_tokens, intermediate_tensors,

vllm_ascend/worker/worker_v1.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
else:
5656
DraftTokenIds = None
5757

58+
torch._dynamo.trace_rules.clear_lru_cache()
59+
from torch._dynamo.variables import TorchInGraphFunctionVariable
60+
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(["torch.npu.current_stream"], TorchInGraphFunctionVariable,)
61+
torch_non_c_binding_in_graph_functions_npu["torch.npu.stream"] = TorchInGraphFunctionVariable
62+
torch._dynamo.trace_rules.torch_name_rule_map.append(torch_non_c_binding_in_graph_functions_npu)
63+
5864

5965
class NPUWorker(WorkerBase):
6066

0 commit comments

Comments
 (0)