@@ -181,6 +181,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
181
181
self .dp_size = vllm_config .parallel_config .data_parallel_size
182
182
self .dp_rank = vllm_config .parallel_config .data_parallel_rank
183
183
self .device = device
184
+ self .prefetch_stream = torch .npu .Stream (device = device )
184
185
self .dtype = self .model_config .dtype
185
186
if envs_ascend .VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION :
186
187
# TODO: drop the env config to use ascend sampler by default
@@ -1497,7 +1498,9 @@ def execute_model(
1497
1498
aclgraph_runtime_mode = aclgraph_runtime_mode ,
1498
1499
batch_descriptor = batch_descriptor ,
1499
1500
num_actual_tokens = scheduler_output .
1500
- total_num_scheduled_tokens ):
1501
+ total_num_scheduled_tokens ,
1502
+ prefetch_stream = self .prefetch_stream ,
1503
+ prefetch_model = self .model ):
1501
1504
self .maybe_setup_kv_connector (scheduler_output )
1502
1505
1503
1506
hidden_states = self ._generate_process_reqs_hidden_states (
@@ -1944,7 +1947,9 @@ def dummy_compute_logits(hidden_states):
1944
1947
moe_comm_method = moe_comm_method ,
1945
1948
num_actual_tokens = 0 ,
1946
1949
aclgraph_runtime_mode = aclgraph_runtime_mode ,
1947
- batch_descriptor = batch_descriptor ):
1950
+ batch_descriptor = batch_descriptor ,
1951
+ prefetch_stream = self .prefetch_stream ,
1952
+ prefetch_model = self .model ):
1948
1953
hidden_states = self ._generate_dummy_run_hidden_states (
1949
1954
with_prefill , is_torchair_compile , input_ids , positions ,
1950
1955
attn_metadata , num_tokens , intermediate_tensors ,
0 commit comments