diff --git a/examples/offline_inference/spans/spans_benchmark.py b/examples/offline_inference/spans/spans_benchmark.py new file mode 100644 index 0000000000..8b8eb7a561 --- /dev/null +++ b/examples/offline_inference/spans/spans_benchmark.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import time +import random + +# necessary for spans to work +os.environ["VLLM_USE_V1"] = "1" +# to ensure deterministic behaviour +os.environ["TOKENIZERS_PARALLELISM"] = "False" + +# in case you need it +os.environ['VLLM_ATTENTION_BACKEND'] = "TRITON_ATTN_VLLM_V1" +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = '0' + +# standard imports +from vllm import LLM, SamplingParams +from vllm.inputs import TokensPrompt + + +# helper functions +def pad(toklist): + padtok = int(os.environ.get("VLLM_V1_SPANS_TOKEN_PAD", None)) + return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:] + + +def avg(list_of_numbers): + return sum(list_of_numbers) / max(len(list_of_numbers), 1) + + +def wrap(prompt): + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + +def initialize_vllm(model, + temp=0.6, + logprobs=None, + max_toks=131072, + max_generated_toks=1): + # boot up vLLM + samp_params_preload = SamplingParams(temperature=temp, max_tokens=1) + samp_params_generate = SamplingParams(temperature=temp, + max_tokens=max_generated_toks, + logprobs=logprobs) + llm = LLM( + model=model, + gpu_memory_utilization=0.9, + enforce_eager=True, # <- so it boots faster + block_size=16, + max_model_len=max_toks, + max_num_seqs=4, + ) + tok = llm.get_tokenizer() + tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x)) + return samp_params_preload, samp_params_generate, tok_fun, llm + + +def main(): + model_names = [ + "ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention + "ldsjmdy/Tulu3-RAG", # <- baseline + ] + model_name = model_names[0] + + # tokens that need to be set to perform block-attention + PAD_TOK = 27 # <- "<" + SPAN_TOK = 10 # <- "+" + SPAN_RECOMP_TOK = 31 # <- "@" + + # vLLM-specific env vars + + # enables block attention + # -> when this line is not commented, we expect a speedup + # in the execution of the last two .generate calls + os.environ['VLLM_V1_SPANS_ENABLED'] = 'True' + + # the token that tells vLLM "this is the beginning of a span" + os.environ['VLLM_V1_SPANS_TOKEN_PLUS'] = str(SPAN_TOK) + + # token that tells vLLM: + # "from here on, recompute KV vectors if any previous tokens differ" + os.environ['VLLM_V1_SPANS_TOKEN_CROSS'] = str(SPAN_RECOMP_TOK) + + # will print every step of the span process if set to true + # os.environ['VLLM_V1_SPANS_DEBUG'] = 'True' + + # will disable the adjustment of positional encodings when a KV cache + # block is loaded to a different position than it was stored + # -> when this line is not commented, + # spans overlap in their positional encodings + os.environ['VLLM_V1_SPANS_DISABLE_REPOSITION'] = 'True' + + # general env vars + + # our helper function uses this token to pad spans + os.environ['VLLM_V1_SPANS_TOKEN_PAD'] = str(PAD_TOK) + + # now we instantiate the model + samp_params_preload, samp_params_generate, tok, llm = initialize_vllm( + model_name, max_generated_toks=1) + # model_name, max_generated_toks=1, max_toks=2048) + + # components of the prompt template + prefix = pad( + [SPAN_RECOMP_TOK] + tok("<|system|>\nYou are an intelligent AI assistant. " \ + "Please answer questions based on the user's instructions. " \ + "Below are some reference documents that may help you in " \ + "answering the user's question." + )) + midfx = [SPAN_RECOMP_TOK] + tok( + "<|user|>\nPlease write a high-quality answer for the " \ + "given question using only the provided search documents " \ + "(some of which might be irrelevant).\nQuestion: " + ) + postfx = tok('''\n<|assistant|>\n''') + + print("---->", postfx) + + times = [] + for ndocs in [1, 2, 4, 8]: + for dlen in [512, 1024, 2048, 4096, 8192]: + print(f" DOCLENGTH {dlen} NUMDOCS {ndocs}") + + doc_toks = tok( + "Sequence Transduction Models and Template-Assisted Selective Epitaxy") + docs = [pad([SPAN_TOK] + + random.choices(doc_toks, k=dlen)) + for _ in range(ndocs)] + + # user query + query = midfx + tok( + "Tell me which one concerns deep learning. " \ + "Indicate your answer with a number in brackets." + ) + postfx + + for i in range(3): + print(f" ITERATION {i}") + + # preload documents + ts_pre = time.time() + llm.generate( + [wrap(d) for d in docs] + [wrap(prefix)], + sampling_params=samp_params_preload, use_tqdm=False) + te_pre = time.time() - ts_pre + + ts_gen = time.time() + + # this now will load prefix, doc_a, doc_b, + # from the KV cache regardless of the order + random.shuffle(docs) + llm.generate(wrap(prefix + \ + sum(docs, []) + \ + query), + sampling_params=samp_params_generate, use_tqdm=False) + + # this should also run faster: + random.shuffle(docs) + llm.generate(wrap(prefix + \ + sum(docs, []) + \ + query), + sampling_params=samp_params_generate, use_tqdm=False) + + te_gen = time.time() - ts_gen + + print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)") + times.append(dict( + preload_time=te_pre, + gen_time=te_gen, + it=i, + doc_len=dlen, + num_docs=ndocs, + )) + + +if __name__ == '__main__': + main() diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9069a364db..fc753d55c6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -73,9 +73,17 @@ def __init__( self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] + def _closest_cache_hit( + self, cached_blocks: dict[int, KVCacheBlock], + position: int, + ) -> dict[int, KVCacheBlock]: + return min(list(cached_blocks.values()), + key=lambda x: abs(x.position - position)) + def get_cached_block( self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: + kv_cache_group_ids: list[int], + position: Optional[int] = None) -> Optional[list[KVCacheBlock]]: """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -95,7 +103,11 @@ def get_cached_block( block_hash_with_group_id) if not cached_blocks_one_group: return None - first_block = next(iter(cached_blocks_one_group.values())) + if position is not None and len(cached_blocks_one_group) > 1: + first_block = self._closest_cache_hit(cached_blocks_one_group, + position) + else: + first_block = next(iter(cached_blocks_one_group.values())) cached_blocks.append(first_block) return cached_blocks @@ -193,17 +205,19 @@ def _set_block_positions(self, new_full_blocks: list[KVCacheBlock], debug logging that prints each block's tokens, to help debug span-related workflows. """ + dbg = envs.VLLM_V1_SPANS_DEBUG pos = 0 + nfb_ids = {b.block_id for b in new_full_blocks} for blk in blocks: - if blk in new_full_blocks: + if blk.block_id in nfb_ids: blk.position = pos - if envs.VLLM_V1_SPANS_DEBUG: + if dbg: # this prints the tokens assigned to a new block # in the KV cache blk_tks = request.all_token_ids[pos:pos + 16] assert blk.block_hash is not None - bhash = str(abs(blk.block_hash.block_hash.hash_value) - )[:4] if blk.block_hash.block_hash else None + bhash = str(blk.block_hash + )[:4] if blk.block_hash else None print('[SPANS -> block_pool] assigning to pos', pos, 'with hash', bhash, 'block: ', blk_tks) pos += 16 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a6628cfc55..5c902ae1b5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -18,9 +18,11 @@ @dataclass class BlockRepositionRequest: - block_id: int - kvc_pos: int prompt_pos: int + cached_pos: int + cached_blockid: int + prompt_blockpos: int + prompt_reqid: str @dataclass @@ -190,13 +192,48 @@ def get_computed_blocks(self, computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) + + repo_reqs = [] + if envs.VLLM_V1_SPANS_ENABLED: + # now we check how many of those computed blocks have incorrect or are + # after an incorrect position match + # our own positions are clear, now we need to compare that to cached + # positions + non_match_idx = -1 + non_match_found = False + for i, block in enumerate(computed_blocks[0]): + if block.is_null: # null blocks don't have meaningful position + continue + prompt_pos = self.block_size * i + cached_pos = block.position + # find first block id where pos didn't match + if prompt_pos != cached_pos and not non_match_found: + non_match_found = True + non_match_idx = i + # record from then on and after, repo requests + if non_match_found: + repo_reqs.append( + BlockRepositionRequest( + prompt_pos, + cached_pos, + block.block_id, + i, + request.request_id)) + # if any repo is needed, we need to exclude that from the + # computed blocks and num_new_computed_tokens, so that + # new blocks get allocated that we can copy kv values to + if non_match_found: + computed_blocks = (computed_blocks[0][:non_match_idx],) + num_new_computed_tokens = len(computed_blocks[0]) * self.block_size + + if envs.VLLM_V1_SPANS_DEBUG: print( "[SPANS -> kv_cache_manager] here's the blocks hashed in " \ "this request:", - [str(abs(b.hash_value))[:4] for b in request.block_hashes]) + [str(b)[-4:] for b in request.block_hashes]) kvcache_contents = [ - str(abs(b.block_hash.block_hash.hash_value))[:4] + str(b.block_hash)[-4:] if b.block_hash else None for b in self.block_pool.blocks if b._block_hash ] @@ -212,35 +249,17 @@ def get_computed_blocks(self, "[SPANS -> kv_cache_manager] here's the number of blocks " \ "that hit the cache:", [ - str(abs(b.block_hash.block_hash.hash_value))[:4] + str(b.block_hash)[-4:] if b.block_hash else None for b in computed_blocks[0] ]) - - blocks_to_reposition = [] - if envs.VLLM_V1_SPANS_ENABLED: - # Spans does yet not support hybrid models - assert len(computed_blocks) == 1 - for i, b in enumerate(computed_blocks[0]): - prompt_pos = i * 16 - kvc_pos = b.position - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] checking block " \ - f"{b.block_id} with prompot pos {prompt_pos} " \ - f"and kv pos {kvc_pos}" - ) - assert isinstance(kvc_pos, int) - if kvc_pos != prompt_pos: - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} "\ - f"to prompt pos: {prompt_pos} repositioning needed" - ) - - blocks_to_reposition.append( - BlockRepositionRequest(b.block_id, kvc_pos, - prompt_pos)) - b.position = int(prompt_pos) + # for block duplication + num_repo = len([r for r in repo_reqs + if r.prompt_pos != r.cached_pos]) + num_copy = len(repo_reqs) - num_repo + print( + "[SPANS -> kv_cache_manager] here's the number of blocks", + f"total: {len(repo_reqs)} to reposition: {num_repo},", + f"to copy: {num_copy}") if self.log_stats: assert self.prefix_cache_stats is not None @@ -248,7 +267,7 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks, blocks_to_reposition),\ + return KVCacheBlocks(computed_blocks, repo_reqs),\ num_new_computed_tokens def allocate_slots( diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 077b09f4ab..a9b7dabe5a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -372,6 +372,7 @@ def append_n(self, blocks: list[KVCacheBlock]) -> None: """ if len(blocks) == 0: return + blocks = list({b.block_id: b for b in blocks}.values()) self.num_free_blocks += len(blocks) last_block = self.fake_free_list_tail.prev_free_block diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dbcb7ed39f..591cc4e843 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -389,6 +389,11 @@ def schedule(self) -> SchedulerOutput: len(new_computed_blocks.blocks_to_reposition) > 0: blocks_to_reposition.extend( new_computed_blocks.blocks_to_reposition) + + # TODO (Nathan) find something smarter to do than this + token_budget += \ + len(new_computed_blocks.blocks_to_reposition) \ + * self.block_size # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -545,8 +550,10 @@ def schedule(self) -> SchedulerOutput: self.waiting.prepend_requests(skipped_waiting_requests) # Check if the scheduling constraints are satisfied. + # TODO make this smarter for spans total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) - assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + \ + len(blocks_to_reposition) * self.block_size assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8159349e46..6d14619e55 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -269,12 +269,14 @@ def find_longest_cache_hit( if dcp_world_size > 1: block_size *= dcp_world_size max_num_blocks = max_length // block_size - for block_hash in itertools.islice(block_hashes, max_num_blocks): + for pidx, block_hash in enumerate( + itertools.islice(block_hashes, max_num_blocks)): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids, + position=pidx*kv_cache_spec.block_size): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5e9b6930b2..703100057a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -85,6 +85,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.core.kv_cache_manager import BlockRepositionRequest +from vllm.v1.core.sched.output import NewRequestData from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -378,6 +380,10 @@ def __init__( device="cpu", pin_memory=self.pin_memory) + # self.reposition_request_cache: dict[str, BlockRepositionRequest] = {} + self.reposition_request_cache: \ + defaultdict[str, list[BlockRepositionRequest]] = defaultdict(list) + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, @@ -1587,9 +1593,100 @@ def _pool( pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) + def _copy_blocks(self, + blocks_to_copy: list[BlockRepositionRequest], + newreqs_by_id: dict[str, NewRequestData]) -> None: + from_ids = [] + to_ids = [] + for req in blocks_to_copy: + from_ids.append(req.cached_blockid) + # find out the block to copy to + # 1. find the relevant new request + # 2. then find the block at position prompt_blockpos + to_ids.append(newreqs_by_id[req.prompt_reqid]\ + .block_ids[0][req.prompt_blockpos]) + # perform copies + args = dict(dtype=torch.long, device=self.kv_caches[0].device) + for i in range(len(self.kv_caches)): + self.kv_caches[i][:, torch.tensor(to_ids, **args), ...] = \ + self.kv_caches[i][:, torch.tensor(from_ids, **args), ...] + + def _custom_cache_manipulations(self, + scheduler_output: "SchedulerOutput") \ + -> None: + # only allow as many reposition requests + # as a request has tokens scheduled + for req in scheduler_output.blocks_to_reposition: + self.reposition_request_cache[req.prompt_reqid].append(req) + # 1. find out how many repo requests + # we are scheduled to make + scheduled_reposition_reqs = [] + for rid, ntoks in scheduler_output.num_scheduled_tokens.items(): + cached_rreqs = self.reposition_request_cache[rid] + n_cached_rreqs = len(cached_rreqs) + if n_cached_rreqs > 0: + # take as many as can be scheduled + nsched_rreqs = min( + ntoks // self.cache_config.block_size, + n_cached_rreqs) + scheduled_reposition_reqs.extend( + cached_rreqs[:nsched_rreqs]) + if nsched_rreqs < n_cached_rreqs: + self.reposition_request_cache[rid] =\ + cached_rreqs[nsched_rreqs:] + else: + self.reposition_request_cache[rid] = [] + # and then we adjust the rest of this function so it only uses + # the scheduled repo requests + + # 0. sort requests + blocks_to_copy = [] + blocks_to_repo = [] + [(blocks_to_copy if req.cached_pos == req.prompt_pos + else blocks_to_repo).append(req) + for req in scheduled_reposition_reqs] + if envs.VLLM_V1_SPANS_DISABLE_REPOSITION: + blocks_to_copy.extend(blocks_to_repo) + blocks_to_repo = [] + newreqs_by_id = {r.req_id: r for r + in + scheduler_output.scheduled_new_reqs + \ + [self.requests[rid] for rid in \ + scheduler_output.scheduled_cached_reqs.req_ids]} + # 1. perform copies + self._copy_blocks(blocks_to_copy, newreqs_by_id) + # 2. do repositioning + self._perform_repositioning(blocks_to_repo, newreqs_by_id) + # 3. adjust relevant counters + # 3.1 num_scheduled_tokens + req_ntokens_to_skip = defaultdict(lambda: 0) + for rreq in scheduled_reposition_reqs: + req_ntokens_to_skip[rreq.prompt_reqid] += \ + self.cache_config.block_size + # 16 + for reqid, ntoks in req_ntokens_to_skip.items(): + scheduler_output.num_scheduled_tokens[reqid] -= ntoks + # 3.2 total_num_scheduled_tokens + scheduler_output.total_num_scheduled_tokens -= \ + len(scheduled_reposition_reqs) \ + * self.cache_config.block_size + # * 16 + # 3.3 scheduled_new_reqs (num_computed_tokens) + for i in range(len(scheduler_output.scheduled_new_reqs)): + sr = scheduler_output.scheduled_new_reqs[i] + sr.num_computed_tokens += req_ntokens_to_skip[sr.req_id] + scc = scheduler_output.scheduled_cached_reqs + for i in range(len(scc.req_ids)): + rid = scc.req_ids[i] + scc.num_computed_tokens[i] += req_ntokens_to_skip[rid] + # NOTE (nathan) maybe PP is broken here because + # we don't manipulate new_token_ids + # in the cached request data + def _perform_repositioning(self, - scheduler_output: "SchedulerOutput") -> None: + blocks_to_reposition: list[BlockRepositionRequest], + newreqs_by_id: dict[str, NewRequestData]) -> None: """ Repositions KV cache blocks based on the scheduler's instructions. @@ -1602,99 +1699,121 @@ def _perform_repositioning(self, scheduler_output: The output from the scheduler containing blocks to reposition. """ - blocks_to_reposition = scheduler_output.blocks_to_reposition if envs.VLLM_V1_SPANS_DEBUG: ts_repo = time.time() repo_count = len(blocks_to_reposition) - if len(blocks_to_reposition) > 0: + # figure out destination block IDs + dest_ids = [] + valid_blocks_to_reposition = [] + for req in blocks_to_reposition: + try: + dest_ids.append(newreqs_by_id[req.prompt_reqid]\ + .block_ids[0][req.prompt_blockpos]) + valid_blocks_to_reposition.append(req) + except IndexError as e: + # breakpoint() + print('INDEX_ERROR could not run reposition request:', req, e) + + if len(valid_blocks_to_reposition) > 0: bs = 512 - for i in range(0, len(blocks_to_reposition), bs): - repo_batch = blocks_to_reposition[i:i+bs] - self._repositionings_handler(repo_batch) + for i in range(0, len(valid_blocks_to_reposition), bs): + j = i+bs + repo_batch = valid_blocks_to_reposition[i:j] + dest_batch = dest_ids[i:j] + self._repositionings_handler(repo_batch, + dest_batch) if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: torch.cuda.synchronize() t_repo = time.time() - ts_repo print(f'[SPANS -> gpu_model_runner] repositioning' \ - f' speed: {repo_count/t_repo:.2f} (blocks/s)'\ - f' (total {repo_count})') + f' speed: {repo_count/t_repo:.2f} (blocks/s)') @torch.inference_mode() - def _repositionings_handler(self, blocks_to_reposition): + def _repositionings_handler(self, blocks_to_reposition, + destination_block_ids): num_repos = len(blocks_to_reposition) - if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: + if envs.VLLM_V1_SPANS_DEBUG: print( f'[SPANS -> gpu_model_runner] ' \ f'reposition block count: {num_repos}' ) - if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION: - kvc_positions = torch.tensor( - [d.kvc_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device).unsqueeze(-1) - prt_positions = torch.tensor( - [d.prompt_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device).unsqueeze(-1) - block_ids = torch.tensor( - [d.block_id for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device) - - # (self.kv_caches shape): - # [nlay, kv, maxblocks, blocksize, headcount, headsize] - concerned_vectors = [ - x[0, block_ids, :, :, :] for x in self.kv_caches - ] # -> [nlay, blockids, blocksize, headcount, headsize] - bids, bsize, hcount, hsize = concerned_vectors[0].shape - - template_tensor = torch.arange( - bsize, dtype=torch.long, - device=self.kv_caches[0].device).unsqueeze(0) - pos_depos = kvc_positions + template_tensor - pos_repos = prt_positions + template_tensor - - # precision highly affects the outputs - PRECISION = torch.float32 - DEF_PRECISION = self.kv_caches[0].dtype - - # do the rotation - # note: PPMissingLayer is for pipeline parallel support - if not hasattr(self, 'rotate'): - if not isinstance(self.model.model.layers[0], PPMissingLayer): - self.rotate = self.model.model.layers[ - 0].self_attn.rotary_emb - else: - for lay in self.model.model.layers: - if not isinstance(lay, PPMissingLayer): - self.rotate = lay.self_attn.rotary_emb - break - assert pos_depos.shape[0] == concerned_vectors[0].shape[0] - - if num_repos > 100: - for i, k_vectors in enumerate(concerned_vectors): - k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos, - k_vectors.to(PRECISION), - invert_rotation_angle=True) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos, k_vectors_tmp) - self.kv_caches[i][0, block_ids, ...] = \ - k_vectors_tmp.to(DEF_PRECISION) + kvc_positions = torch.tensor( + [d.cached_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + prt_positions = torch.tensor( + [d.prompt_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + block_ids = torch.tensor( + [d.cached_blockid for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device) + dest_block_ids = torch.tensor( + destination_block_ids, + dtype=torch.long, + device=self.kv_caches[0].device) + + # (self.kv_caches shape): + # [nlay, kv, maxblocks, blocksize, headcount, headsize] + concerned_vectors = [ + x[0, block_ids, :, :, :] for x in self.kv_caches + ] # -> [nlay, blockids, blocksize, headcount, headsize] + bids, bsize, hcount, hsize = concerned_vectors[0].shape + + template_tensor = torch.arange( + bsize, dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(0) + pos_depos = kvc_positions + template_tensor + pos_repos = prt_positions + template_tensor + + # precision highly affects the outputs + PRECISION = torch.float32 + DEF_PRECISION = self.kv_caches[0].dtype + + # do the rotation + # note: PPMissingLayer is for pipeline parallel support + if not hasattr(self, 'rotate'): + if not isinstance(self.model.model.layers[0], PPMissingLayer): + self.rotate = self.model.model.layers[ + 0].self_attn.rotary_emb else: - nlays = len(concerned_vectors) - kvecs = torch.cat(concerned_vectors, dim=0).to(PRECISION) + for lay in self.model.model.layers: + if not isinstance(lay, PPMissingLayer): + self.rotate = lay.self_attn.rotary_emb + break + assert pos_depos.shape[0] == concerned_vectors[0].shape[0] + + + if num_repos > 100: + for i, k_vectors in enumerate(concerned_vectors): k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos.repeat(nlays, 1), - kvecs, + pos_depos, + k_vectors.to(PRECISION), invert_rotation_angle=True) k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos.repeat(nlays, 1), - k_vectors_tmp) - k_vectors_tmp = k_vectors_tmp.reshape(nlays, - *concerned_vectors[0].shape) - for i in range(len(self.kv_caches)): - self.kv_caches[i][0, block_ids, ...] = \ - k_vectors_tmp[i].to(DEF_PRECISION) + pos_repos, k_vectors_tmp) + self.kv_caches[i][0, dest_block_ids, ...] = \ + k_vectors_tmp.to(DEF_PRECISION) + self.kv_caches[i][1, dest_block_ids, ...] = \ + self.kv_caches[i][1, block_ids] + else: + nlays = len(concerned_vectors) + kvecs = torch.cat(concerned_vectors, dim=0).to(PRECISION) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_depos.repeat(nlays, 1), + kvecs, + invert_rotation_angle=True) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_repos.repeat(nlays, 1), + k_vectors_tmp) + k_vectors_tmp = k_vectors_tmp.reshape(nlays, + *concerned_vectors[0].shape) + for i in range(len(self.kv_caches)): + self.kv_caches[i][0, dest_block_ids, ...] = \ + k_vectors_tmp[i].to(DEF_PRECISION) + self.kv_caches[i][1, dest_block_ids, ...] = \ + self.kv_caches[i][1, block_ids] def _preprocess( self, @@ -1961,7 +2080,7 @@ def execute_model( with record_function_or_nullcontext("Preprocess"): # handle repositioning requests - self._perform_repositioning(scheduler_output) + self._custom_cache_manipulations(scheduler_output) self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: