Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/e2e/singlecard/test_ascend_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import SamplingParams

from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
Expand Down Expand Up @@ -86,3 +87,25 @@ def test_chunked_prefill_with_ascend_scheduler(
name_0="vllm_output",
name_1="chunked_prefill_output",
)


def test_async_scheduling() -> None:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 10
sampling_params = SamplingParams(temperature=0.2,
max_tokens=10,
stop_token_ids=None)

with VllmRunner(
"Qwen/Qwen2.5-0.5B-Instruct",
max_model_len=4096,
max_num_seqs=50,
dtype="bfloat16",
gpu_memory_utilization=0.9,
async_scheduling=True,
) as vllm_model:
vllm_model.generate(prompts, sampling_params=sampling_params)
218 changes: 194 additions & 24 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
Expand Down Expand Up @@ -156,6 +156,53 @@ def graph_capture(device: torch.device):
yield graph_capture_context


# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):

def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
invalid_req_indices: list[int],
async_output_copy_stream: torch.npu.Stream,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices

# Event on the copy stream so we can synchronize the non-blocking copy.
self._async_copy_ready_event = torch.npu.Event()

# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids

# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.npu.current_stream()
with torch.npu.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
'cpu', non_blocking=True)
self._async_copy_ready_event.record()

def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.

This function blocks until the copy is finished.
"""
self._async_copy_ready_event.synchronize()

# Release the device tensor once the copy has completed
del self._sampled_token_ids

valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()

output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
return output


class NPUModelRunner(LoRAModelRunnerMixin):

def __init__(self, vllm_config: VllmConfig, device: torch.device):
Expand Down Expand Up @@ -358,6 +405,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device=self.device,
)

self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.npu.Stream() if \
self.use_async_scheduling else None

def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager

Expand Down Expand Up @@ -845,6 +896,76 @@ def _get_cumsum_and_arange(

return cu_num_tokens, arange

def _prepare_input_ids(self, total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray) -> None:
"""Prepare the input IDs for the current batch.

Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
NPU need to be copied into the corresponding slots into input_ids."""

if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
return

# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the NPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
flattened_indices = []
prev_common_req_indices = []
indices_match = True
max_flattened_index = -1
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index)
indices_match &= (prev_index == flattened_index)
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the NPU first.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
return
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of 0..N-1 so
# we can copy directly using a single slice.
self.input_ids[:num_commmon_tokens].copy_(
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
input_ids_index_tensor = torch.tensor(flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(
self.device,
non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
self.input_ids.scatter_(dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])

def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -1033,6 +1154,16 @@ def _prepare_inputs(
if self.vllm_config.model_config.use_mla:
attn_metadata.num_input_tokens = num_input_tokens

# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)

# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
Expand Down Expand Up @@ -1382,11 +1513,11 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.

a. For A2, we choose from MC2 and all-gather.

b. For A3, we choose from MC2 and all-to-all.

In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.

Expand Down Expand Up @@ -1424,7 +1555,7 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
Comment on lines 1555 to +1558
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To correctly handle asynchronous scheduling, the worker's CPU-side state must be updated with the actual token IDs from the previous step. This should happen at the beginning of the current step, before preparing inputs. Without this, features like repetition penalty will use stale or incorrect token history.

Please add state update logic at the start of execute_model to synchronize prev_sampled_token_ids and update the CPU-side token history. Here is a code snippet to illustrate the required logic:

if self.use_async_scheduling and self.input_batch.prev_sampled_token_ids is not None:
    # Sync and update state from previous async step
    prev_sampled_token_ids_cpu = self.input_batch.prev_sampled_token_ids.tolist()
    prev_req_id_to_index = self.input_batch.prev_req_id_to_index
    assert prev_req_id_to_index is not None

    for req_id, prev_req_idx in prev_req_id_to_index.items():
        if req_id not in self.requests:
            continue
        req_state = self.requests[req_id]
        req_idx = self.input_batch.req_id_to_index.get(req_id)
        if req_idx is None:
            continue
        
        sampled_ids = prev_sampled_token_ids_cpu[prev_req_idx]
        if not sampled_ids:
            continue

        req_state.output_token_ids.extend(sampled_ids)

        start_idx = self.input_batch.num_tokens_no_spec[req_idx]
        end_idx = start_idx + len(sampled_ids)
        assert end_idx <= self.model_config.max_model_len
        self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
        self.input_batch.num_tokens_no_spec[req_idx] = end_idx
        self.input_batch.num_tokens[req_idx] = end_idx
    
    # Clear the prev step's data
    self.input_batch.prev_sampled_token_ids = None
    self.input_batch.prev_req_id_to_index = None

with ProfileExecuteDuration().capture_async("prepare input"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
Expand Down Expand Up @@ -1580,6 +1711,12 @@ def execute_model(
generator.set_offset(generator.get_offset() - 4)
discard_sampled_tokens_req_indices.append(i)

# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = \
self.input_batch.req_id_to_index.copy()

# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
Expand All @@ -1592,27 +1729,52 @@ def execute_model(
scheduler_output,
)

# Get the valid generated tokens.
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)

for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
valid_sampled_token_ids = []
invalid_req_indices = list(discard_sampled_tokens_req_indices)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1

# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
Comment on lines +1760 to +1761
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attribute self.input_batch.prev_sampled_token_ids_invalid_indices is assigned here but is never read or used anywhere. This appears to be dead code and should be removed along with its definition in InputBatch to improve clarity and reduce maintenance overhead.

self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] * 1 if \
req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
continue

Expand Down Expand Up @@ -1650,8 +1812,8 @@ def execute_model(
extra_args = ({"kv_connector_output": kv_connector_output})

model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
Expand All @@ -1669,7 +1831,15 @@ def execute_model(
logger.info("Profile execute duration [%s]:%s", captured_name,
" ".join(dr_str))

return model_runner_output
if not self.use_async_scheduling:
return model_runner_output

return AsyncNPUModelRunnerOutput(
model_runner_output=model_runner_output,
sampled_token_ids=sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def __init__(

self.pooling_params: dict[str, PoolingParams] = {}

# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attribute prev_sampled_token_ids_invalid_indices is defined here but is never read or used anywhere in the codebase. This appears to be dead code and should be removed to avoid confusion and reduce maintenance overhead.

self.prev_req_id_to_index: Optional[dict[str, int]] = None

@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
Expand Down
10 changes: 5 additions & 5 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#

import copy
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand All @@ -38,8 +38,8 @@
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.worker.worker_base import WorkerBase

from vllm_ascend.ascend_config import init_ascend_config
Expand Down Expand Up @@ -191,7 +191,7 @@ def determine_available_memory(self) -> int:
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
Expand Down Expand Up @@ -220,7 +220,7 @@ def execute_model(
new_output.kv_connector_output = kv_connector_output
return new_output

assert isinstance(output, ModelRunnerOutput)
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
return output

def load_model(self) -> None:
Expand Down
Loading