1
1
import torch
2
2
import torch .nn .functional as F
3
+ import torch_npu
3
4
from vllm .utils import direct_register_custom_op
4
5
from vllm .distributed import (tensor_model_parallel_all_gather ,
5
6
tensor_model_parallel_reduce_scatter ,
6
7
tensor_model_parallel_all_reduce ,
7
8
get_tensor_model_parallel_rank ,
8
9
get_tensor_model_parallel_world_size )
9
10
from vllm .forward_context import get_forward_context
11
+ import vllm_ascend .envs as envs_ascend
10
12
11
13
12
14
def _maybe_chunk_residual_impl (x : torch .Tensor , residual : torch .Tensor ) -> torch .Tensor :
@@ -44,6 +46,73 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
44
46
return tensor_model_parallel_all_reduce (x )
45
47
46
48
49
+ def _maybe_prefetch_mlp_gate_up_proj_impl (x_dependency : torch .Tensor , prefix : str ) -> None :
50
+ forward_context = get_forward_context ()
51
+ if not forward_context .prefetch_mlp_enabled :
52
+ return
53
+ prefetch_model = forward_context .prefetch_model
54
+ prefetch_stream = forward_context .prefetch_stream
55
+ layer_idx = int (prefix .split ('.' )[2 ])
56
+
57
+ # start point of gate_up_proj weight prefetch
58
+ if prefix .split ('.' )[- 2 ] == "self_attn" :
59
+ forward_context .prefetch_mlp_gate_up_proj = True
60
+ if forward_context .prefetch_mlp_gate_up_proj :
61
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
62
+
63
+ with torch .npu .stream (prefetch_stream ):
64
+ MLP_GATE_UP_PREFETCH_SIZE = envs_ascend .MLP_GATE_UP_PREFETCH_SIZE
65
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .gate_up_proj .weight , \
66
+ x_dependency , MLP_GATE_UP_PREFETCH_SIZE )
67
+ return
68
+
69
+
70
+ def _maybe_prefetch_mlp_gate_up_proj_impl_fake (x_dependency : torch .Tensor , prefix : str ) -> None :
71
+ return
72
+
73
+
74
+ def _maybe_prefetch_mlp_down_proj_impl (x_dependency : torch .Tensor ) -> None :
75
+ forward_context = get_forward_context ()
76
+ if not forward_context .prefetch_mlp_enabled :
77
+ return
78
+ forward_context .prefetch_mlp_down_proj = True
79
+ prefetch_model = forward_context .prefetch_model
80
+ prefetch_stream = forward_context .prefetch_stream
81
+ layer_idx = forward_context .layer_idx
82
+
83
+ # start point of down_proj weight prefetch
84
+ prefetch_stream .wait_stream (torch .npu .current_stream ())
85
+
86
+ with torch .npu .stream (prefetch_stream ):
87
+ MLP_DOWN_PREFETCH_SIZE = envs_ascend .MLP_DOWN_PREFETCH_SIZE
88
+ torch_npu .npu_prefetch (prefetch_model .model .layers [layer_idx ].mlp .down_proj .weight , \
89
+ x_dependency , MLP_DOWN_PREFETCH_SIZE )
90
+ forward_context .layer_idx += 1
91
+ return
92
+
93
+
94
+ def _maybe_prefetch_mlp_down_proj_impl_fake (x_dependency : torch .Tensor ) -> None :
95
+ return
96
+
97
+
98
+ def _maybe_wait_prefetch_done_impl (x : torch .Tensor ) -> None :
99
+ forward_context = get_forward_context ()
100
+ if not forward_context .prefetch_mlp_enabled :
101
+ return
102
+ if forward_context .prefetch_mlp_gate_up_proj or \
103
+ forward_context .prefetch_mlp_down_proj :
104
+ prefetch_stream = get_forward_context ().prefetch_stream
105
+ # wait until prefetch done
106
+ torch .npu .current_stream ().wait_stream (prefetch_stream )
107
+ forward_context .prefetch_mlp_gate_up_proj = False
108
+ forward_context .prefetch_mlp_down_proj = False
109
+ return
110
+
111
+
112
+ def _maybe_wait_prefetch_done_impl_fake (x : torch .Tensor ) -> None :
113
+ return
114
+
115
+
47
116
direct_register_custom_op (
48
117
op_name = "maybe_chunk_residual" ,
49
118
op_func = _maybe_chunk_residual_impl ,
@@ -69,3 +138,30 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
69
138
mutates_args = [],
70
139
dispatch_key = "PrivateUse1"
71
140
)
141
+
142
+
143
+ direct_register_custom_op (
144
+ op_name = "maybe_prefetch_mlp_gate_up_proj" ,
145
+ op_func = _maybe_prefetch_mlp_gate_up_proj_impl ,
146
+ fake_impl = _maybe_prefetch_mlp_gate_up_proj_impl_fake ,
147
+ mutates_args = [],
148
+ dispatch_key = "PrivateUse1"
149
+ )
150
+
151
+
152
+ direct_register_custom_op (
153
+ op_name = "maybe_prefetch_mlp_down_proj" ,
154
+ op_func = _maybe_prefetch_mlp_down_proj_impl ,
155
+ fake_impl = _maybe_prefetch_mlp_down_proj_impl_fake ,
156
+ mutates_args = [],
157
+ dispatch_key = "PrivateUse1"
158
+ )
159
+
160
+
161
+ direct_register_custom_op (
162
+ op_name = "maybe_wait_prefetch_done" ,
163
+ op_func = _maybe_wait_prefetch_done_impl ,
164
+ fake_impl = _maybe_wait_prefetch_done_impl_fake ,
165
+ mutates_args = [],
166
+ dispatch_key = "PrivateUse1"
167
+ )
0 commit comments