diff --git a/.github/workflows/test-spyre.yml b/.github/workflows/test-spyre.yml index 065bd48..59fc319 100644 --- a/.github/workflows/test-spyre.yml +++ b/.github/workflows/test-spyre.yml @@ -13,16 +13,16 @@ jobs: run: | docker run -i --rm --entrypoint /bin/bash vllm-spyre -c ''' pip install pytest sentence-transformers && \ - python3.12 -c "from transformers import pipeline; pipeline(\"text-generation\", model=\"JackFram/llama-160m\")" && \ + python -c "from transformers import pipeline; pipeline(\"text-generation\", model=\"JackFram/llama-160m\")" && \ export VARIANT=$(ls /root/.cache/huggingface/hub/models--JackFram--llama-160m/snapshots/) && \ mkdir -p /models && \ ln -s /root/.cache/huggingface/hub/models--JackFram--llama-160m/snapshots/${VARIANT} /models/llama-194m && \ - python3.12 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer(\"sentence-transformers/all-roberta-large-v1\")" && \ + python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer(\"sentence-transformers/all-roberta-large-v1\")" && \ export VARIANT=$(ls /root/.cache/huggingface/hub/models--sentence-transformers--all-roberta-large-v1/snapshots/) && \ ln -s /root/.cache/huggingface/hub/models--sentence-transformers--all-roberta-large-v1/snapshots/${VARIANT} /models/all-roberta-large-v1 && \ export MASTER_PORT=12355 && \ export MASTER_ADDR=localhost && \ export DISTRIBUTED_STRATEGY_IGNORE_MODULES=WordEmbedding && \ cd vllm-spyre && \ - python3.12 -m pytest tests -v + python -m pytest tests -v ''' \ No newline at end of file diff --git a/Dockerfile.spyre b/Dockerfile.spyre index 85b73ea..7d44b00 100644 --- a/Dockerfile.spyre +++ b/Dockerfile.spyre @@ -12,17 +12,14 @@ WORKDIR /workspace RUN microdnf update -y && microdnf install -y \ python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel git vim gcc g++ kmod\ && microdnf clean all +RUN ln -sf $(which python${PYTHON_VERSION}) /usr/bin/python && \ + ln -sf $(which pip${PYTHON_VERSION}) /usr/bin/pip # Download and install vllm ########################################################### -ENV VLLM_TARGET_DEVICE=empty -RUN git clone -b sop-plugin-refactoring --single-branch https://github.com/IBM/vllm.git \ - && cd vllm \ - && python3.12 -m pip install --upgrade pip \ - && pip install -r requirements-build.txt \ - && pip install --no-build-isolation -v -e . \ - && mkdir /workspace/vllm-spyre +RUN pip install vllm==0.7.3 # Install vllm Spyre plugin ################################################################## +RUN mkdir /workspace/vllm-spyre COPY . /workspace/vllm-spyre RUN cd /workspace/vllm-spyre && pip install --no-build-isolation -v -e . ENV VLLM_PLUGINS=spyre diff --git a/README.md b/README.md index 23d8f4e..b1530b0 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,7 @@ docker run -it --rm vllm-spyre bash ``` # Install vllm -git clone https://github.com/vllm-project/vllm.git -cd vllm -pip install -r requirements-build.txt -export VLLM_TARGET_DEVICE=empty -pip install --no-build-isolation -v -e . +pip install vllm==0.7.3 # Install vllm-spyre cd .. diff --git a/vllm_spyre/core/scheduler.py b/vllm_spyre/core/scheduler.py index 24c2cb6..4890ea0 100644 --- a/vllm_spyre/core/scheduler.py +++ b/vllm_spyre/core/scheduler.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import random import time from collections import deque @@ -6,10 +8,12 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager +# SPYRE SPECIFIC CODE BLOCK START # yapf: disable from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ARTIFICIAL_PREEMPTION_PROB, - ENABLE_ARTIFICIAL_PREEMPT, PreemptionMode, + ENABLE_ARTIFICIAL_PREEMPT, + PartialPrefillMetadata, PreemptionMode, ScheduledSequenceGroup, SchedulerOutputs, SchedulerPrefillOutputs, SchedulerRunningOutputs, @@ -19,6 +23,7 @@ seq_group_metadata_builder) from vllm.logger import init_logger from vllm.platforms import current_platform +# SPYRE SPECIFIC CODE BLOCK END from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) @@ -66,7 +71,8 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=self.cache_config.enable_prefix_caching, + ) # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -127,6 +133,18 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -206,12 +224,15 @@ def _free_seq_group_cross_attn_blocks( self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: - return len(self.waiting) != 0 or len(self.running) != 0 or len( - self.swapped) != 0 + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) + def reset_prefix_cache(self) -> bool: + return self.block_manager.reset_prefix_cache() + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -226,6 +247,7 @@ def _schedule_running( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -240,12 +262,14 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + partial_prefill_metadata: information about the partial prefills + that are currently running + Returns: SchedulerRunningOutputs. """ - ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache[self.cache_id].get_object() + ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ + self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -280,10 +304,14 @@ def _schedule_running( # 2. If a sequence is running with non-chunked prefill, then # there it's a decoding sequence, and the cached tokens info is # irrelevant. - num_uncached_new_tokens, _ = ( + num_uncached_new_tokens, _ = \ self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, - budget)) + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: @@ -296,8 +324,8 @@ def _schedule_running( # to process the final tokens. The check below avoids this extra # decode run when the model max len is reached, in order to avoid # a memory overflow. - if self.use_async_output_proc and seq_group.seqs[0].get_len( - ) > self.scheduler_config.max_model_len: + if (self.use_async_output_proc and seq_group.seqs[0].get_len() + > self.scheduler_config.max_model_len): self._async_stopped.append(seq_group) continue @@ -356,8 +384,9 @@ def _schedule_running( self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() - scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache[self.cache_id].get_object() + scheduled_seq_group: ScheduledSequenceGroup = ( + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -434,7 +463,8 @@ def _schedule_swapped( logger.warning( "Failing the request %s because there's not enough kv " "cache blocks to run the entire sequence.", - seq_group.request_id) + seq_group.request_id, + ) for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED infeasible_seq_groups.append(seq_group) @@ -473,7 +503,6 @@ def _schedule_swapped( swapped_queue.popleft() self._swap_in(seq_group, blocks_to_swap_in) self._append_slots(seq_group, blocks_to_copy, enable_chunking) - is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( ScheduledSequenceGroup( @@ -504,16 +533,17 @@ def _schedule_swapped( ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled and \ - not self.scheduler_config.is_multi_step: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): prompt_limit = self.scheduler_config.max_model_len else: - prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) # Model is fine tuned with long context. Return the fine tuned max_len. - if (seq_group.lora_request - and seq_group.lora_request.long_lora_max_len): + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: assert prompt_limit <= seq_group.lora_request.long_lora_max_len return seq_group.lora_request.long_lora_max_len else: @@ -521,7 +551,7 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: def _get_priority(self, seq_group: SequenceGroup) -> Tuple[Optional[int], float]: - """ Get the priority of the sequence group. + """Get the priority of the sequence group. Highest preference to user-defined priority, followed by arrival time. Args: seq_group: The sequence group input. @@ -554,14 +584,14 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = ( + num_new_tokens_uncached, _ = \ self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget)) + seq_group, SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens_uncached > 0 and can_allocate == AllocStatus.OK @@ -571,7 +601,7 @@ def _schedule_priority_preemption( )): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( @@ -582,11 +612,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -600,6 +630,7 @@ def _schedule_prefills( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. @@ -620,15 +651,27 @@ def _schedule_prefills( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: SchedulerPrefillOutputs. """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + # SPYRE SPECIFIC CODE BLOCK START spyre_warmup_shapes = current_platform.get_warmup_shapes() applicable_spyre_warmup_shapes = list(spyre_warmup_shapes) + # SPYRE SPECIFIC CODE BLOCK END waiting_queue = self.waiting @@ -640,10 +683,19 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") + if (partial_prefill_metadata is not None + and not partial_prefill_metadata.can_schedule(seq_group)): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue num_new_tokens_uncached, num_new_tokens_cached = ( self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, enable_chunking, - budget)) + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: @@ -654,7 +706,10 @@ def _schedule_prefills( if num_new_tokens > prompt_limit: logger.warning( "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", num_new_tokens, prompt_limit) + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -675,7 +730,9 @@ def _schedule_prefills( logger.warning( "Input prompt (%d tokens) + lookahead slots (%d) is " "too long and exceeds the capacity of block_manager", - num_new_tokens, num_lookahead_slots) + num_new_tokens, + num_lookahead_slots, + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -696,8 +753,8 @@ def _schedule_prefills( waiting_queue.popleft() continue - if (budget.num_batched_tokens >= - self.scheduler_config.max_num_batched_tokens): + if (budget.num_batched_tokens + >= self.scheduler_config.max_num_batched_tokens): # We've reached the budget limit - since there might be # continuous prefills in the running queue, we should break # to avoid scheduling any new prefills. @@ -710,6 +767,7 @@ def _schedule_prefills( ): break + # SPYRE SPECIFIC CODE BLOCK START # check if current request can be scheduled based on the applicable # spyre warmup shapes max_tokens = 0 @@ -754,6 +812,7 @@ def _schedule_prefills( continue else: applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes + # SPYRE SPECIFIC CODE BLOCK END # Can schedule this request. if curr_loras is not None and lora_int_id > 0: @@ -761,6 +820,10 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) + if partial_prefill_metadata is not None: + partial_prefill_metadata.maybe_increment_partial_prefills( + seq_group) + if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -776,7 +839,8 @@ def _schedule_prefills( num_scheduler_steps=self.scheduler_config. num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, @@ -788,6 +852,7 @@ def _schedule_prefills( ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + # SPYRE SPECIFIC CODE BLOCK START # Check if number of scheduled requests has reached the maximum # batch size of the applicable warmup shapes if len(seq_groups) >= max([ @@ -795,6 +860,7 @@ def _schedule_prefills( for shape in applicable_spyre_warmup_shapes ]): break + # SPYRE SPECIFIC CODE BLOCK END # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) @@ -805,11 +871,12 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) + is_prefill=True, enable_chunking=enable_chunking), + ) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -825,20 +892,22 @@ def _schedule_default(self) -> SchedulerOutputs: for seq_group in self.running: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) - curr_loras = set( + curr_loras = (set( seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # SPYRE SPECIFIC CODE BLOCK START # Schedule new prefills only when no requests have been swapped - # and all previous decodes have completed. + # and all previous decodes have completed (Spyre constraint). if not self.swapped and not self.running: prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=False) + # SPYRE SPECIFIC CODE BLOCK END if len(prefills.seq_groups ) == 0 and self.scheduler_config.policy == "priority": @@ -854,12 +923,13 @@ def _schedule_default(self) -> SchedulerOutputs: # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -876,8 +946,8 @@ def _schedule_default(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) - preempted = (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -915,7 +985,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -936,10 +1006,20 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config, + ) + # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. @@ -947,12 +1027,15 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) - assert (budget.num_batched_tokens <= - self.scheduler_config.max_num_batched_tokens) + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -968,8 +1051,15 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: [s.seq_group for s in swapped_in.prefill_seq_groups]) self.running.extend( [s.seq_group for s in running_scheduled.decode_seq_groups]) + # Because multiple prefills may be running concurrently, we need to + # make sure that prefills which are scheduled to finish are listed + # before those that won't. This is so that on the next scheduling + # iteration when they have transitioned to the decode stage, they are + # properly prioritized over sequences that are still in the prefill + # stage. self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups]) + self._order_finishing_prefills_first( + running_scheduled.prefill_seq_groups)) self.running.extend([s.seq_group for s in prefills.seq_groups]) # Update swapped requests. @@ -986,7 +1076,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # If all prompts, then we set num_lookahead_slots to 0 # this allows us to go through the `no_spec` path in # `spec_decode_worker.py` - all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) + all_prefills = len(scheduled_seq_groups) == num_prefill_groups num_lookahead_slots = (0 if (all_prefills and not self.scheduler_config.is_multi_step) @@ -1008,6 +1098,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) + def _order_finishing_prefills_first( + self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] + ) -> List[SequenceGroup]: + """Returns a list of prefilling SequenceGroups where sequences that are + scheduled to finish prefilling are listed first""" + finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size + ] + not_finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size + ] + return finishing + not_finishing + def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: @@ -1036,9 +1141,8 @@ def _can_append_slots(self, seq_group: SequenceGroup, # chunked-prefill are enabled together. assert self.scheduler_config.is_multi_step and enable_chunking - # heuristic below doesn't make sense when using very large - # blocks - return True + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence @@ -1121,8 +1225,8 @@ def schedule( # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if (token_chunk_size + num_computed_tokens < - seqs[0].data.get_len()): + if (token_chunk_size + num_computed_tokens + < seqs[0].data.get_len()): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1147,10 +1251,12 @@ def schedule( # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, - multi_modal_placeholders=seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_data=(seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups + > 0 else None), + multi_modal_placeholders=( + seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None), mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1256,10 +1362,12 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots(self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False) -> None: + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1280,7 +1388,8 @@ def _append_slots(self, num_lookahead_slots, num_scheduler_steps=self.scheduler_config.num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING if self.scheduler_config.is_multi_step and enable_chunking: @@ -1323,8 +1432,11 @@ def _preempt(self, seq_group: SequenceGroup, "not enough KV cache space. This can affect the end-to-end " "performance. Increase gpu_memory_utilization or " "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", seq_group.request_id, - preemption_mode, self.num_cumulative_preemption + 1) + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) self.num_cumulative_preemption += 1 if preemption_mode == PreemptionMode.RECOMPUTE: @@ -1388,10 +1500,9 @@ def _passed_delay(self, now: float) -> bool: if self.scheduler_config.delay_factor > 0 and self.waiting: earliest_arrival_time = min( [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ( - (now - earliest_arrival_time) > - (self.scheduler_config.delay_factor * self.last_prompt_latency) - or not self.running) + passed_delay = ((now - earliest_arrival_time) + > (self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running) else: passed_delay = True return passed_delay @@ -1431,6 +1542,7 @@ def _get_num_new_uncached_and_cached_tokens( status: SequenceStatus, enable_chunking: bool, budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> Tuple[int, int]: """ Returns the number of new uncached and cached tokens to schedule for a @@ -1454,6 +1566,8 @@ def _get_num_new_uncached_and_cached_tokens( to schedule. enable_chunking: Whether to chunk the number of tokens to compute. budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: @@ -1531,6 +1645,8 @@ def _get_num_new_uncached_and_cached_tokens( budget, self._get_prompt_limit(seq_group), num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, ) return num_uncached_new_tokens, num_cached_new_tokens @@ -1542,6 +1658,8 @@ def _chunk_new_tokens_to_schedule( budget: SchedulingBudget, prompt_limit: int, num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> int: """ Chunks the number of new tokens to schedule based on the budget when @@ -1574,29 +1692,31 @@ def _chunk_new_tokens_to_schedule( # the sequence. return num_new_tokens - return (0 if num_new_tokens > remaining_token_budget else - num_new_tokens) + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens - if cache_config.enable_prefix_caching: - # Adjust the remaining token budget to be divisible by the block - # size when prefix caching is enabled. + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = ( + remaining_token_budget if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.schedulable_prefills]) - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. block_size = cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - # Round down to block size. - remaining_token_budget = (remaining_token_budget // block_size * - block_size) - - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) return num_new_tokens diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py index a071dd1..5e267ba 100644 --- a/vllm_spyre/worker/spyre_model_runner.py +++ b/vllm_spyre/worker/spyre_model_runner.py @@ -102,6 +102,9 @@ def __init__( # Lazy initialization: after load_model. self.model: nn.Module + def get_model(self) -> nn.Module: + return self.model + def load_model(self, prompt_lens: Iterable[int], num_decode_tokens: Iterable[int]) -> None: max_pad_length = max(prompt_lens)