-
Notifications
You must be signed in to change notification settings - Fork 441
[Perf][V1] Fully overlap model execution #2783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+227
−29
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,8 +63,8 @@ | |
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher | ||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
KVCacheSpec) | ||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, | ||
LogprobsTensors, ModelRunnerOutput) | ||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, | ||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput) | ||
from vllm.v1.pool.metadata import PoolingMetadata | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata | ||
|
@@ -156,6 +156,53 @@ def graph_capture(device: torch.device): | |
yield graph_capture_context | ||
|
||
|
||
# Wrapper for ModelRunnerOutput to support overlapped execution. | ||
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput): | ||
|
||
def __init__( | ||
self, | ||
model_runner_output: ModelRunnerOutput, | ||
sampled_token_ids: torch.Tensor, | ||
invalid_req_indices: list[int], | ||
async_output_copy_stream: torch.npu.Stream, | ||
): | ||
self._model_runner_output = model_runner_output | ||
self._invalid_req_indices = invalid_req_indices | ||
|
||
# Event on the copy stream so we can synchronize the non-blocking copy. | ||
self._async_copy_ready_event = torch.npu.Event() | ||
|
||
# Keep a reference to the device tensor to avoid it being | ||
# deallocated until we finish copying it to the host. | ||
self._sampled_token_ids = sampled_token_ids | ||
|
||
# Initiate the copy on a separate stream, but do not synchronize it. | ||
default_stream = torch.npu.current_stream() | ||
with torch.npu.stream(async_output_copy_stream): | ||
async_output_copy_stream.wait_stream(default_stream) | ||
self._sampled_token_ids_cpu = self._sampled_token_ids.to( | ||
'cpu', non_blocking=True) | ||
self._async_copy_ready_event.record() | ||
|
||
def get_output(self) -> ModelRunnerOutput: | ||
"""Copy the device tensors to the host and return a ModelRunnerOutput. | ||
|
||
This function blocks until the copy is finished. | ||
""" | ||
self._async_copy_ready_event.synchronize() | ||
|
||
# Release the device tensor once the copy has completed | ||
del self._sampled_token_ids | ||
|
||
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() | ||
for i in self._invalid_req_indices: | ||
valid_sampled_token_ids[i].clear() | ||
|
||
output = self._model_runner_output | ||
output.sampled_token_ids = valid_sampled_token_ids | ||
return output | ||
|
||
|
||
class NPUModelRunner(LoRAModelRunnerMixin): | ||
|
||
def __init__(self, vllm_config: VllmConfig, device: torch.device): | ||
|
@@ -358,6 +405,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |
device=self.device, | ||
) | ||
|
||
self.use_async_scheduling = self.scheduler_config.async_scheduling | ||
self.async_output_copy_stream = torch.npu.Stream() if \ | ||
self.use_async_scheduling else None | ||
|
||
def _use_aclgraph(self) -> bool: | ||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager | ||
|
||
|
@@ -845,6 +896,76 @@ def _get_cumsum_and_arange( | |
|
||
return cu_num_tokens, arange | ||
|
||
def _prepare_input_ids(self, total_num_scheduled_tokens: int, | ||
cu_num_tokens: np.ndarray) -> None: | ||
"""Prepare the input IDs for the current batch. | ||
|
||
Carefully handles the `prev_sampled_token_ids` which can be cached | ||
from the previous engine iteration, in which case those tokens on the | ||
NPU need to be copied into the corresponding slots into input_ids.""" | ||
|
||
if self.input_batch.prev_sampled_token_ids is None: | ||
# Normal scheduling case | ||
self.input_ids[:total_num_scheduled_tokens].copy_( | ||
self.input_ids_cpu[:total_num_scheduled_tokens], | ||
non_blocking=True) | ||
return | ||
|
||
# Async scheduling case, where some decode requests from the previous | ||
# iteration won't have entries in input_ids_cpu and need to be copied | ||
# on the NPU from prev_sampled_token_ids. | ||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index | ||
assert prev_req_id_to_index is not None | ||
flattened_indices = [] | ||
prev_common_req_indices = [] | ||
indices_match = True | ||
max_flattened_index = -1 | ||
for req_id, cur_index in self.input_batch.req_id_to_index.items(): | ||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None: | ||
prev_common_req_indices.append(prev_index) | ||
# We need to compute the flattened input_ids index of the | ||
# last token in each common request. | ||
flattened_index = cu_num_tokens[cur_index].item() - 1 | ||
flattened_indices.append(flattened_index) | ||
indices_match &= (prev_index == flattened_index) | ||
max_flattened_index = max(max_flattened_index, flattened_index) | ||
num_commmon_tokens = len(flattened_indices) | ||
if num_commmon_tokens < total_num_scheduled_tokens: | ||
# If not all requests are decodes from the last iteration, | ||
# We need to copy the input_ids_cpu to the NPU first. | ||
self.input_ids[:total_num_scheduled_tokens].copy_( | ||
self.input_ids_cpu[:total_num_scheduled_tokens], | ||
non_blocking=True) | ||
if num_commmon_tokens == 0: | ||
# No requests in common with the previous iteration | ||
# So input_ids_cpu will have all the input ids. | ||
return | ||
if indices_match and max_flattened_index == (num_commmon_tokens - 1): | ||
# Common-case optimization: the batch is unchanged | ||
# and no reordering happened. | ||
# The indices are both the same permutation of 0..N-1 so | ||
# we can copy directly using a single slice. | ||
self.input_ids[:num_commmon_tokens].copy_( | ||
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, | ||
0], | ||
non_blocking=True) | ||
return | ||
# Upload the index tensors asynchronously | ||
# so the scatter can be non-blocking. | ||
input_ids_index_tensor = torch.tensor(flattened_indices, | ||
dtype=torch.int64, | ||
pin_memory=self.pin_memory).to( | ||
self.device, | ||
non_blocking=True) | ||
prev_common_req_indices_tensor = torch.tensor( | ||
prev_common_req_indices, | ||
dtype=torch.int64, | ||
pin_memory=self.pin_memory).to(self.device, non_blocking=True) | ||
self.input_ids.scatter_(dim=0, | ||
index=input_ids_index_tensor, | ||
src=self.input_batch.prev_sampled_token_ids[ | ||
prev_common_req_indices_tensor, 0]) | ||
|
||
def _prepare_inputs( | ||
self, | ||
scheduler_output: "SchedulerOutput", | ||
|
@@ -1033,6 +1154,16 @@ def _prepare_inputs( | |
if self.vllm_config.model_config.use_mla: | ||
attn_metadata.num_input_tokens = num_input_tokens | ||
|
||
# Prepare input_ids | ||
token_indices = (positions_np + | ||
req_indices * self.input_batch.token_ids_cpu.shape[1]) | ||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), | ||
0, | ||
torch.from_numpy(token_indices), | ||
out=self.input_ids_cpu[:total_num_scheduled_tokens]) | ||
# Copy the tensors to the NPU. | ||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) | ||
|
||
# _prepare_inputs may reorder the batch, so we must gather | ||
# multi-modal outputs after that to ensure the correct order | ||
if self.is_multimodal_model: | ||
|
@@ -1382,11 +1513,11 @@ def _select_moe_comm_method(self, num_tokens: int) -> str: | |
2. If expert parallel is enabled, we need to consider the soc version and the | ||
number of tokens. This is based on the observation that all-gather is more | ||
efficient than all-to-all when running on A2. | ||
|
||
a. For A2, we choose from MC2 and all-gather. | ||
|
||
b. For A3, we choose from MC2 and all-to-all. | ||
|
||
In both cases, we use MC2 when the number of tokens is smaller than | ||
a its capacity threshold. | ||
|
||
|
@@ -1424,7 +1555,7 @@ def execute_model( | |
self, | ||
scheduler_output: "SchedulerOutput", | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
) -> Union[ModelRunnerOutput, torch.Tensor]: | ||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: | ||
with ProfileExecuteDuration().capture_async("prepare input"): | ||
self._update_states(scheduler_output) | ||
if not scheduler_output.total_num_scheduled_tokens: | ||
|
@@ -1580,6 +1711,12 @@ def execute_model( | |
generator.set_offset(generator.get_offset() - 4) | ||
discard_sampled_tokens_req_indices.append(i) | ||
|
||
# Copy some objects so they don't get modified after returning. | ||
# This is important when using async scheduling. | ||
req_ids_output_copy = self.input_batch.req_ids.copy() | ||
req_id_to_index_output_copy = \ | ||
self.input_batch.req_id_to_index.copy() | ||
|
||
# NOTE: NPU -> CPU Sync happens here. | ||
# Move as many CPU operations as possible before this sync point. | ||
logprobs_tensors = sampler_output.logprobs_tensors | ||
|
@@ -1592,27 +1729,52 @@ def execute_model( | |
scheduler_output, | ||
) | ||
|
||
# Get the valid generated tokens. | ||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] | ||
sampled_token_ids = sampler_output.sampled_token_ids | ||
max_gen_len = sampled_token_ids.shape[-1] | ||
if max_gen_len == 1: | ||
# No spec decode tokens. | ||
valid_sampled_token_ids = sampled_token_ids.tolist() | ||
if not self.use_async_scheduling: | ||
# Get the valid generated tokens. | ||
max_gen_len = sampled_token_ids.shape[-1] | ||
if max_gen_len == 1: | ||
# No spec decode tokens. | ||
valid_sampled_token_ids = sampled_token_ids.tolist() | ||
else: | ||
# Includes spec decode tokens. | ||
valid_sampled_token_ids = self.rejection_sampler.parse_output( | ||
sampled_token_ids, | ||
self.input_batch.vocab_size, | ||
) | ||
# Mask out the sampled tokens that should not be sampled. | ||
for i in discard_sampled_tokens_req_indices: | ||
valid_sampled_token_ids[i].clear() | ||
else: | ||
# Includes spec decode tokens. | ||
valid_sampled_token_ids = self.rejection_sampler.parse_output( | ||
sampled_token_ids, | ||
self.input_batch.vocab_size, | ||
) | ||
|
||
for i in discard_sampled_tokens_req_indices: | ||
valid_sampled_token_ids[i].clear() | ||
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions | ||
valid_sampled_token_ids = [] | ||
invalid_req_indices = list(discard_sampled_tokens_req_indices) | ||
invalid_req_indices_set = set(invalid_req_indices) | ||
assert sampled_token_ids.shape[-1] == 1 | ||
|
||
# Cache the sampled tokens on the NPU and avoid CPU sync. | ||
# These will be copied into input_ids in the next step | ||
# when preparing inputs. | ||
self.input_batch.prev_sampled_token_ids = \ | ||
sampled_token_ids | ||
self.input_batch.prev_sampled_token_ids_invalid_indices = \ | ||
invalid_req_indices_set | ||
Comment on lines
+1760
to
+1761
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.input_batch.prev_req_id_to_index = { | ||
req_id: i | ||
for i, req_id in enumerate(self.input_batch.req_ids) | ||
if i not in invalid_req_indices_set | ||
} | ||
# Cache the sampled tokens in the model runner, so that the scheduler | ||
# doesn't need to send them back. | ||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends | ||
# the sampled tokens back, because there's no direct communication | ||
# between the first-stage worker and the last-stage worker. | ||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): | ||
for req_idx in range(num_sampled_tokens): | ||
if self.use_async_scheduling: | ||
sampled_ids = [-1] * 1 if \ | ||
req_idx not in invalid_req_indices_set else None | ||
else: | ||
sampled_ids = valid_sampled_token_ids[req_idx] | ||
if not sampled_ids: | ||
continue | ||
|
||
|
@@ -1650,8 +1812,8 @@ def execute_model( | |
extra_args = ({"kv_connector_output": kv_connector_output}) | ||
|
||
model_runner_output = ModelRunnerOutput( | ||
req_ids=self.input_batch.req_ids, | ||
req_id_to_index=self.input_batch.req_id_to_index, | ||
req_ids=req_ids_output_copy, | ||
req_id_to_index=req_id_to_index_output_copy, | ||
sampled_token_ids=valid_sampled_token_ids, | ||
logprobs=logprobs_lists, | ||
prompt_logprobs_dict=prompt_logprobs_dict, | ||
|
@@ -1669,7 +1831,15 @@ def execute_model( | |
logger.info("Profile execute duration [%s]:%s", captured_name, | ||
" ".join(dr_str)) | ||
|
||
return model_runner_output | ||
if not self.use_async_scheduling: | ||
return model_runner_output | ||
|
||
return AsyncNPUModelRunnerOutput( | ||
model_runner_output=model_runner_output, | ||
sampled_token_ids=sampled_token_ids, | ||
invalid_req_indices=invalid_req_indices, | ||
async_output_copy_stream=self.async_output_copy_stream, | ||
) | ||
|
||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: | ||
if self._draft_token_ids is None: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -263,6 +263,11 @@ def __init__( | |
|
||
self.pooling_params: dict[str, PoolingParams] = {} | ||
|
||
# Cached reference to the GPU tensor of previously sampled tokens | ||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None | ||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.prev_req_id_to_index: Optional[dict[str, int]] = None | ||
|
||
@property | ||
def req_ids(self) -> list[str]: | ||
# None elements should only be present transiently | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To correctly handle asynchronous scheduling, the worker's CPU-side state must be updated with the actual token IDs from the previous step. This should happen at the beginning of the current step, before preparing inputs. Without this, features like repetition penalty will use stale or incorrect token history.
Please add state update logic at the start of
execute_model
to synchronizeprev_sampled_token_ids
and update the CPU-side token history. Here is a code snippet to illustrate the required logic: